diff --git a/Project.toml b/Project.toml index 240038e..6860184 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ADTypes" uuid = "47edcb42-4c32-4615-8424-f2b9edc5f35b" authors = ["Vaibhav Dixit , Guillaume Dalle and contributors"] -version = "1.17.0" +version = "1.18.0" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/ext/ADTypesConstructionBaseExt.jl b/ext/ADTypesConstructionBaseExt.jl index 883dd7f..9700213 100644 --- a/ext/ADTypesConstructionBaseExt.jl +++ b/ext/ADTypesConstructionBaseExt.jl @@ -3,12 +3,14 @@ module ADTypesConstructionBaseExt using ADTypes: AutoEnzyme, AutoForwardDiff, AutoPolyesterForwardDiff using ConstructionBase: ConstructionBase -struct InternalAutoEnzymeReconstructor{A} end +struct InternalAutoEnzymeReconstructor{A, C} end -InternalAutoEnzymeReconstructor{A}(mode::M) where {M, A} = AutoEnzyme{M, A}(mode) +function InternalAutoEnzymeReconstructor{A, C}(mode::M) where {M, A, C} + AutoEnzyme{M, A, C}(mode) +end -function ConstructionBase.constructorof(::Type{<:AutoEnzyme{M, A}}) where {M, A} - return InternalAutoEnzymeReconstructor{A} +function ConstructionBase.constructorof(::Type{<:AutoEnzyme{M, A, C}}) where {M, A, C} + return InternalAutoEnzymeReconstructor{A, C} end function ConstructionBase.constructorof(::Type{<:AutoForwardDiff{chunksize}}) where {chunksize} diff --git a/src/dense.jl b/src/dense.jl index 38072e6..19b232e 100644 --- a/src/dense.jl +++ b/src/dense.jl @@ -39,7 +39,7 @@ struct AutoDiffractor <: AbstractADType end mode(::AutoDiffractor) = ForwardOrReverseMode() """ - AutoEnzyme{M,A} + AutoEnzyme{M,A,C} Struct used to select the [Enzyme.jl](https://github.com/EnzymeAD/Enzyme.jl) backend for automatic differentiation. @@ -47,38 +47,57 @@ Defined by [ADTypes.jl](https://github.com/SciML/ADTypes.jl). # Constructors - AutoEnzyme(; mode::M=nothing, function_annotation::Type{A}=Nothing) + AutoEnzyme(; + mode::Union{EnzymeCore.Mode,Nothing}=nothing, + function_annotation::Type{<:Union{EnzymeCore.Annotation,Nothing}}=Nothing, + chunksize::Union{Int,Float64,Nothing}=nothing, + ) -# Type parameters + - `mode::M` determines the autodiff mode (forward or reverse). It can be: - - `A` determines how the function `f` to differentiate is passed to Enzyme. It can be: + + a mode object from EnzymeCore.jl, like `EnzymeCore.Forward` or `EnzymeCore.Reverse` (possibly modified with additional settings like runtime activity) + + `nothing` to choose the best mode automatically - + a subtype of `EnzymeCore.Annotation` (like `EnzymeCore.Const` or `EnzymeCore.Duplicated`) to enforce a given annotation - + `Nothing` to simply pass `f` and let Enzyme choose the most appropriate annotation + - `A=function_annotation` determines how the function `f` to differentiate is passed to Enzyme. It can be: -# Fields + + a subtype of `EnzymeCore.Annotation` (like `EnzymeCore.Const` or `EnzymeCore.Duplicated`) to enforce a given annotation - - `mode::M` determines the autodiff mode (forward or reverse). It can be: + + `Nothing` (the type, not the object) to simply pass `f` and let Enzyme choose the most appropriate annotation + - `C=chunksize` determines the number of derivatives evaluated simultaneously when computing operators like a Jacobian or a forward-mode gradient. It can be: - + an object subtyping `EnzymeCore.Mode` (like `EnzymeCore.Forward` or `EnzymeCore.Reverse`) if a specific mode is required - + `nothing` to choose the best mode automatically + + a positive `Int` to fix a constant chunk size + + `Inf` to pick the maximum chunk size, corresponding to the array length + + `nothing` to choose a good chunk size automatically """ -struct AutoEnzyme{M, A} <: AbstractADType +struct AutoEnzyme{M, A, C} <: AbstractADType mode::M + + function AutoEnzyme{M, A, C}(mode::M) where {M, A, C} + @assert C isa Union{Nothing, Int, Float64} + if C isa Int + @assert C > 0 + elseif C isa Float64 + @assert C == Inf + end + return new{M, A, C}(mode) + end end function AutoEnzyme(; - mode::M = nothing, function_annotation::Type{A} = Nothing) where {M, A} - return AutoEnzyme{M, A}(mode) + mode::M = nothing, + function_annotation::Type{A} = Nothing, + chunksize::Union{Nothing, Int, Float64} = nothing +) where {M, A} + return AutoEnzyme{M, A, chunksize}(mode) end mode(::AutoEnzyme) = ForwardOrReverseMode() # specialized in the extension -function Base.show(io::IO, backend::AutoEnzyme{M, A}) where {M, A} +function Base.show(io::IO, backend::AutoEnzyme{M, A, C}) where {M, A, C} print(io, AutoEnzyme, "(") - !isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io)) - !isnothing(backend.mode) && !(A <: Nothing) && print(io, ", ") - !(A <: Nothing) && print(io, "function_annotation=", repr(A; context = io)) + !isnothing(backend.mode) && print(io, "mode=", repr(backend.mode; context = io), ", ") + !(A <: Nothing) && print(io, "function_annotation=", repr(A; context = io), ", ") + !(C === nothing) && print(io, "chunksize=", repr(C; context = io)) print(io, ")") end diff --git a/test/dense.jl b/test/dense.jl index 307403d..022c081 100644 --- a/test/dense.jl +++ b/test/dense.jl @@ -25,12 +25,15 @@ end @test mode(ad) isa ForwardOrReverseMode end +get_chunksize(::AutoEnzyme{M, A, C}) where {M, A, C} = C + @testset "AutoEnzyme" begin ad = AutoEnzyme() @test ad isa AbstractADType @test ad isa AutoEnzyme{Nothing, Nothing} @test mode(ad) isa ForwardOrReverseMode @test ad.mode === nothing + @test get_chunksize(ad) === nothing ad = AutoEnzyme(; mode = EnzymeCore.Forward) @test ad isa AbstractADType @@ -50,6 +53,22 @@ end @test ad isa AutoEnzyme{typeof(EnzymeCore.Reverse), EnzymeCore.Duplicated} @test mode(ad) isa ReverseMode @test ad.mode == EnzymeCore.Reverse + + ad = AutoEnzyme(; chunksize = nothing) + @test get_chunksize(ad) === nothing + + ad = AutoEnzyme(; chunksize = 3) + @test get_chunksize(ad) == 3 + + ad = AutoEnzyme(; chunksize = Inf) + @test get_chunksize(ad) == Inf + + ad = AutoEnzyme(; chunksize = 3) + @test get_chunksize(ad) == 3 + + @test_throws TypeError AutoEnzyme(; chunksize = :big) + @test_throws AssertionError AutoEnzyme(; chunksize = 0) + @test_throws AssertionError AutoEnzyme(; chunksize = 1.3) end @testset "AutoFastDifferentiation" begin diff --git a/test/misc.jl b/test/misc.jl index 95c73bc..30ac9ad 100644 --- a/test/misc.jl +++ b/test/misc.jl @@ -34,6 +34,7 @@ for backend in [ ADTypes.AutoEnzyme(mode = :forward), ADTypes.AutoEnzyme(function_annotation = Val{:forward}), ADTypes.AutoEnzyme(mode = :reverse, function_annotation = Val{:duplicated}), + ADTypes.AutoEnzyme(chunksize = 2), ADTypes.AutoFastDifferentiation(), ADTypes.AutoFiniteDiff(), ADTypes.AutoFiniteDiff(fdtype = :fd, fdjtype = :fdj, fdhtype = :fdh),