Improved ForwardDiff.jl Stacktraces With Package Tags


December 19 2021 in Uncategorized | Tags: | Author: Christopher Rackauckas

You may have seen some hilariously long stacktraces when using ForwardDiff. In the latest releases of OrdinaryDiffEq.jl we have fixed this, and the fix is rather safe. I want to take a second to describe some of the technical details so that others can copy this technique.

The reason for this is the tag parameter. The Dual number type is given by Dual{T,V,N} where V is an element type (usually Float64), N is a chunksize (some integer), and T is the tag. What the tag does is prevent perturbation confusion by erroring if two incompatible dual numbers try to interact. The key requirement for it to prevent perturbation confusion is for the type to be unique in the context of the user. For example, if the user is differentiating f, and then differentiating x->derivative(f,x), you want the tag to be different in those two scenarios. For this reason ForwardDiff.jl’s default tag when doing the operation derivative(f,x) is ForwardDiff.Tag{typeof(f),eltype(x)}. That would then make derivative(x->derivative(f,x),x) generate the the tag ForwardDiff.Tag{typeof(x->derivative(f,x)),eltype(x)} and then if Duals ever collide incorrectly it will error appropriately. Beautiful.

But you might start to see where the stacktrace is coming from. Every time you have a differentiation context you slap another typeof(f) tag over that, and closures or callable structs with lots of parameters can then start to show their full eltypes as the tag.

However, if one defines a tag in the package, like OrdinaryDiffEq.PackageTag() that is then used for the operations in the package, i.e. ForwardDiff.Tag{OrdinaryDiffEq.PackageTag(),eltype(x)}, then the only way it could clash with code from outside the package is if the user specifically chooses OrdinaryDiffEq.PackageTag() as the tag (and note that module is important, it will not clash with a struct PackageTag end defined in the REPL. Okay fine, someone can adversarially clash and force their computations to be incorrect, but there were a million ways to do that before using a package tag), and it would not clash with nested automatic differentiation since the second derivative would have eltype(x) = Dual{OrdinaryDiffEq.PackageTag(),V,N} (which is why the eltype is important). Thus the only way for this to be unsafe with perturbation confusion is if you are in your own package mixing first order AD calls, say threading two first order calls and using the same buffer between them un a non-thread safe way.

Basically, per-package tags are safe. ForwardDiff.jl cannot do this because it cannot (safely) define a struct in every module that uses ForwardDiff.jl, so in its differentiation context it uses something that is safe but many times verbose (tyepeof(f)). Thus you probably want to use a per-package tag once a library is stable. How to do it is shown in this PR The tag needs to be defined using ForwardDiff.Tag the function because it has a tag incrementor that must be moved. This looks like:

T = typeof(ForwardDiff.Tag(OrdinaryDiffEqTag(),eltype(x)))

Note that T needs to be typeof the Tag, i.e. the DataType, not the instance. Then you can directly define Dual numbers like:

xdual = Dual{T,eltype(df),1}(x,ForwardDiff.Partials((1.0,)))

Notice that the second part of the Dual must be a ForwardDiff.Partials object defined on a tuple of size matching the chunksize N (here N=1). If you do any of this incorrectly, you will get tag ordering errors. If you do this correctly, your stacktraces are now shorter.

Write a Reply or Comment

Your email address will not be published. Required fields are marked *


*

This site uses Akismet to reduce spam. Learn how your comment data is processed.