Skip to content

Commit f7e361e

Browse files
Pipeline for nested enzyme differentiation (#452)
* Pipeline for enzyme * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Nested AD * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Compiler.jl * Update src/Compiler.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Project.toml * Update autodiff.jl * Update test/autodiff.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update autodiff.jl * fixbug --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 327d252 commit f7e361e

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ PythonCall = "0.9"
6262
Random = "1.10"
6363
Random123 = "1.7"
6464
ReactantCore = "0.1.3"
65-
Reactant_jll = "0.0.34"
65+
Reactant_jll = "0.0.36"
6666
Scratch = "1.2"
6767
SpecialFunctions = "2"
6868
Statistics = "1.10"

src/Compiler.jl

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,9 @@ function create_result(
112112
return Meta.quot(tocopy)
113113
end
114114

115-
const opt_passes::String = join(
115+
# Optimization passes via transform dialect
116+
const transform_passes::String = join(
116117
[
117-
"inline{default-pipeline=canonicalize max-iterations=4}",
118-
"canonicalize,cse",
119-
"canonicalize",
120118
"enzyme-hlo-generate-td{" *
121119
join(
122120
[
@@ -273,9 +271,22 @@ const opt_passes::String = join(
273271
"transform-interpreter",
274272
"enzyme-hlo-remove-transform",
275273
],
276-
',',
274+
",",
275+
)
276+
277+
# Optimization passes which apply to an individual function
278+
const func_passes::String = join(
279+
["canonicalize,cse", "canonicalize", transform_passes], ","
277280
)
278281

282+
const opt_passes::String = join(
283+
["inline{default-pipeline=canonicalize max-iterations=4}", func_passes], ','
284+
)
285+
286+
# TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
287+
# However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
288+
const enzyme_pass::String = "enzyme{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"}"
289+
279290
function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
280291
pm = MLIR.IR.PassManager()
281292
MLIR.IR.enable_verifier!(pm, enable_verifier)
@@ -335,7 +346,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
335346
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
336347
if optimize === :all
337348
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
338-
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
349+
run_pass_pipeline!(
350+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
351+
)
339352
run_pass_pipeline!(
340353
mod,
341354
join(
@@ -351,7 +364,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
351364
)
352365
elseif optimize === :before_kernel
353366
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
354-
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
367+
run_pass_pipeline!(
368+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
369+
)
355370
run_pass_pipeline!(
356371
mod,
357372
join(
@@ -381,7 +396,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
381396
)
382397
elseif optimize === :only_enzyme
383398
run_pass_pipeline!(mod, "enzyme-batch")
384-
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
399+
run_pass_pipeline!(
400+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
401+
)
385402
run_pass_pipeline!(
386403
mod,
387404
join(
@@ -391,7 +408,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
391408
)
392409
elseif optimize === :after_enzyme
393410
run_pass_pipeline!(mod, "enzyme-batch")
394-
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
411+
run_pass_pipeline!(
412+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
413+
)
395414
run_pass_pipeline!(
396415
mod,
397416
join(
@@ -407,7 +426,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
407426
)
408427
elseif optimize === :before_enzyme
409428
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
410-
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
429+
run_pass_pipeline!(
430+
mod, "$enzyme_pass,arith-raise{stablehlo=true}"; enable_verifier=false
431+
)
411432
run_pass_pipeline!(
412433
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
413434
)

src/Reactant.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ include("Overlay.jl")
213213

214214
function Enzyme.make_zero(
215215
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
216-
)::RT where {copy_if_inactive,RT<:RArray}
216+
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
217217
if haskey(seen, prev)
218218
return seen[prev]
219219
end

test/autodiff.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,3 +120,14 @@ end
120120
@test stret.st2 x .+ 1
121121
@test stret.st1 === stret.st2
122122
end
123+
124+
@testset "Nested AD" begin
125+
x = ConcreteRNumber(3.1)
126+
f(x) = x * x * x * x
127+
df(x) = Enzyme.gradient(Reverse, f, x)[1]
128+
res1 = @jit df(x)
129+
@test res1 4 * 3.1^3
130+
ddf(x) = Enzyme.gradient(Reverse, df, x)[1]
131+
res2 = @jit ddf(x)
132+
@test res2 4 * 3 * 3.1^2
133+
end

0 commit comments

Comments
 (0)