Skip to content

Hybrid overloading + SCT; not pure SCT? #24

@oxinabox

Description

@oxinabox

@ChrisRackauckas assured me that ForwardDiff2
was pure source code transformation based approach.
Rather than an overloading based approach.

My thoughts were that is was primarily an overloading based approach that uses limited source code transformation to make it possible to extend the overloads with frules.

Now its a blurry line between the two.
But the place I would cut it is that a pure source code transform approach never calls the original function passing in a special overloaded type. Not even as a fallback.
It always calls a function that it created.
This means it never has method errors,
since it created the function its going to call.

julia> using ForwardDiff2: D
[ Info: Precompiling ForwardDiff2 [994df76e-a4c1-5e1f-bd5c-23b9b5303d4f]

julia> sq(x) = x^2
sq (generic function with 1 method)

julia> D(sq)(10)*1.0
20.0

julia> sq2(x::Float64) = x^2
sq2 (generic function with 1 method)

julia> D(sq2)(10)*1.0
ERROR: MethodError: no method matching sq2(::ForwardDiff2.Dual{ForwardDiff2.Tag{Nothing},Int64,Float64})
Closest candidates are:
  sq(::Float64) at REPL[2]:1
Stacktrace:
 [1] call at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:447 [inlined]
 [2] fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/context.jl:445 [inlined]
 [3] _overdub_fallback at /Users/oxinabox/.julia/packages/Cassette/kbN4l/src/overdub.jl:486 [inlined]
 [4] _frule_overdub2 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:145 [inlined]
 [5] alternative at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:186 [inlined]
 [6] #47 at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:61 [inlined]
 [7] overdub(::Cassette.Context{nametype(DualContext),Nothing,Nothing,getfield(ForwardDiff2, Symbol("##PassType#371")),Nothing,Cassette.D
isableHooks}, ::getfield(ForwardDiff2, Symbol("##47#49")){D{Int64,typeof(sq)},Float64}) at /Users/oxinabox/.julia/packages/Cassette/kbN4l
/src/overdub.jl:0
 [8] dualrun(::Function) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/dual_context.jl:192
 [9] *(::D{Int64,typeof(sq)}, ::Float64) at /Users/oxinabox/JuliaEnvs/ForwardDiff2.jl/src/api.jl:59
 [10] top-level scope at REPL[3]:1

To describe the rewritten function:
I am going to write it without Dual types but you can do it with Dual types, it just means allocating an extra slot in the transformed code and putting stuff into that slot.

We started with:
sq2(x::Float64) = x^2

The new code is:

function do_forward_mode(sq2, x, dsq2, dx)
    sub1_args = (Base.literal_pow, x, 2, Zero(), dx, Zero())
    sub1_res = frule(sub1_args...)
    if sub1_res === nothing
        # call the generated function to make itself
        return do_forward_mode(sub1_args...)
    else
        return sub1_res
    end
end

Doing do a more complicated one:

f(x) = g(x)*h(x)
function do_forward_mode(f, x, df, dx)
    subg_args = (g, x, Zero(), dx)
    subg_res = frule(subg_args...)
    if subg_res === nothing
        # call the generated function to make itself
        subg_res = do_forward_mode(subg_args...)
    end

    subh_args = (h, x, Zero(), dx)
    subh_res = frule(subh_args...)
    if subh_res === nothing
        # call the generated function to make itself
        subh_res = do_forward_mode(subh_args...)
    end

    # Result of `g` is being used, need to split the results
    subg_values, subg_partials = subg_res

    # Result of `h`  is being used, need to split the results
    subh_values, subh_partials = subh_res

    submul_args = (*, subg_values..., subh_values..., Zero(), subg_partials..., subh_partials...)
    submul_res = frule(sub1_args...)
    if submul_res === nothing
        # call the generated function to make itself
        return do_forward_mode(sub1_args...)
    end
    return submul_res
end

So one can see that this kind of transform is complete.
We never call the original function.
We never perform any overloaded operations.
We only call frule and out transform generating function: do_forward_mode.


I believe ForwardDiff2 is meant to be using this approach,
and thus bugs have slipped in and caused it to not.

I suggest to make such bugs harder to slip through tests,
that Dual stop subtyping Number and stops defining operations like + and *.
That way we know that it is correctly doing the rewrites.

Alternatively, this may be intentional.
The hybrid overloading + source code transform ForwardDiff2 is using right now
works pretty well if everything is ether a Number or an Array.

but we might need another forward mode AD package that is pure source to source to easily handle concretely typed structs, and code that is otherwise hostile to the overloading approach.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions