-
Notifications
You must be signed in to change notification settings - Fork 4
Description
@ChrisRackauckas assured me that ForwardDiff2
was pure source code transformation based approach.
Rather than an overloading based approach.
My thoughts were that is was primarily an overloading based approach that uses limited source code transformation to make it possible to extend the overloads with frule
s.
Now its a blurry line between the two.
But the place I would cut it is that a pure source code transform approach never calls the original function passing in a special overloaded type. Not even as a fallback.
It always calls a function that it created.
This means it never has method errors,
since it created the function its going to call.
julia> using ForwardDiff2: D
[ Info: Precompiling ForwardDiff2 [994df76e-a4c1-5e1f-bd5c-23b9b5303d4f]
julia> sq(x) = x^2
sq (generic function with 1 method)
julia> D(sq)(10)*1.0
20.0
julia> sq2(x::Float64) = x^2
sq2 (generic function with 1 method)
julia> D(sq2)(10)*1.0
ERROR: MethodError: no method matching sq2(::ForwardDiff2.Dual{ForwardDiff2.Tag{Nothing},Int64,Float64})
Closest candidates are:
sq(::Float64) at REPL[2]:1
Stacktrace:
[1] call at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:447 [inlined]
[2] fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:445 [inlined]
[3] _overdub_fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/overdub.jl:486 [inlined]
[4] _frule_overdub2 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:145 [inlined]
[5] alternative at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:186 [inlined]
[6] #47 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:61 [inlined]
[7] overdub(::Cassette.Context{nametype(DualContext),Nothing,Nothing,getfield(ForwardDiff2, Symbol("##PassType#371")),Nothing,Cassette.D
isableHooks}, ::getfield(ForwardDiff2, Symbol("##47#49")){D{Int64,typeof(sq)},Float64}) at /Users/oxinabox/.julia/packages/Cassette/kbN4l
/src/overdub.jl:0
[8] dualrun(::Function) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:192
[9] *(::D{Int64,typeof(sq)}, ::Float64) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:59
[10] top-level scope at REPL[3]:1
To describe the rewritten function:
I am going to write it without Dual
types but you can do it with Dual
types, it just means allocating an extra slot in the transformed code and putting stuff into that slot.
We started with:
sq2(x::Float64) = x^2
The new code is:
function do_forward_mode(sq2, x, dsq2, dx)
sub1_args = (Base.literal_pow, x, 2, Zero(), dx, Zero())
sub1_res = frule(sub1_args...)
if sub1_res === nothing
# call the generated function to make itself
return do_forward_mode(sub1_args...)
else
return sub1_res
end
end
Doing do a more complicated one:
f(x) = g(x)*h(x)
function do_forward_mode(f, x, df, dx)
subg_args = (g, x, Zero(), dx)
subg_res = frule(subg_args...)
if subg_res === nothing
# call the generated function to make itself
subg_res = do_forward_mode(subg_args...)
end
subh_args = (h, x, Zero(), dx)
subh_res = frule(subh_args...)
if subh_res === nothing
# call the generated function to make itself
subh_res = do_forward_mode(subh_args...)
end
# Result of `g` is being used, need to split the results
subg_values, subg_partials = subg_res
# Result of `h` is being used, need to split the results
subh_values, subh_partials = subh_res
submul_args = (*, subg_values..., subh_values..., Zero(), subg_partials..., subh_partials...)
submul_res = frule(sub1_args...)
if submul_res === nothing
# call the generated function to make itself
return do_forward_mode(sub1_args...)
end
return submul_res
end
So one can see that this kind of transform is complete.
We never call the original function.
We never perform any overloaded operations.
We only call frule
and out transform generating function: do_forward_mode
.
I believe ForwardDiff2 is meant to be using this approach,
and thus bugs have slipped in and caused it to not.
I suggest to make such bugs harder to slip through tests,
that Dual
stop subtyping Number
and stops defining operations like +
and *
.
That way we know that it is correctly doing the rewrites.
Alternatively, this may be intentional.
The hybrid overloading + source code transform ForwardDiff2 is using right now
works pretty well if everything is ether a Number
or an Array
.
but we might need another forward mode AD package that is pure source to source to easily handle concretely typed structs, and code that is otherwise hostile to the overloading approach.