Skip to content

Commit 319a3e5

Browse files
avik-palgithub-actions[bot]wsmoses
authored
feat: support vector mode AD (#519)
* feat: support vector mode AD * Update test/autodiff.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update Interpreter.jl * test: new JLL * fix: codegen for vector mode * fix: bump version --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: William Moses <[email protected]>
1 parent 1722fcb commit 319a3e5

File tree

4 files changed

+33
-13
lines changed

4 files changed

+33
-13
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ PythonCall = "0.9"
8787
Random = "1.10"
8888
Random123 = "1.7"
8989
ReactantCore = "0.1.9"
90-
Reactant_jll = "0.0.151"
90+
Reactant_jll = "0.0.152"
9191
ScopedValues = "1.3.0"
9292
Scratch = "1.2"
9393
Sockets = "1.10"

src/Interpreter.jl

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,11 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
192192
predims = size(x.val)
193193
cval = MLIR.IR.result(
194194
MLIR.Dialects.stablehlo.concatenate(
195-
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
195+
[
196+
TracedUtils.get_mlir_data(Ops.reshape(v, Int64[1, predims...])) for
197+
v in x.dval
198+
];
199+
dimension=Int64(0),
196200
),
197201
)
198202
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
@@ -244,12 +248,6 @@ function overload_autodiff(
244248
width = Enzyme.same_or_one(1, args...)
245249
if width == 0
246250
throw(ErrorException("Cannot differentiate with a batch size of 0"))
247-
elseif width != 1
248-
throw(
249-
ErrorException(
250-
"EnzymeMLIR does not presently support width=$width, please rewrite your code to not use BatchDuplicated and/or call gradient(; chunk=1)",
251-
),
252-
)
253251
end
254252

255253
primf = f.val
@@ -389,9 +387,10 @@ function overload_autodiff(
389387
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
390388
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
391389
res = (reverse ? MLIR.Dialects.enzyme.autodiff : MLIR.Dialects.enzyme.fwddiff)(
392-
[TracedUtils.transpose_val(v) for v in ad_inputs];
390+
[TracedUtils.transpose_val(v; keep_first_intact=width > 1) for v in ad_inputs];
393391
outputs=outtys,
394392
fn=fname,
393+
width,
395394
activity=MLIR.IR.Attribute([act_attr(a) for a in activity]),
396395
ret_activity=MLIR.IR.Attribute([act_attr(a) for a in ret_activity]),
397396
)
@@ -434,8 +433,11 @@ function overload_autodiff(
434433
push!(starts, 0)
435434
push!(limits, v)
436435
end
437-
sval = Ops.slice(sval, starts, limits)
438-
TracedUtils.set!(dresult[i], path[2:end], sval)
436+
sval = Ops.slice(TracedRArray(tval), starts, limits)
437+
sval = Ops.reshape(sval, collect(Int64, sz))
438+
TracedUtils.set!(
439+
dresult[i], path[2:end], TracedUtils.get_mlir_data(sval)
440+
)
439441
end
440442
end
441443
residx += 1

src/TracedUtils.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,15 @@ Base.@nospecializeinfer function transpose_ty(
186186
end
187187

188188
Base.@nospecializeinfer function transpose_val(
189-
@nospecialize(val::MLIR.IR.Value)
189+
@nospecialize(val::MLIR.IR.Value); keep_first_intact::Bool=false
190190
)::MLIR.IR.Value
191191
val_size = size(MLIR.IR.type(val))
192192
val_size == () && return val
193-
attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(val_size) - 1))...])
193+
if keep_first_intact
194+
attr = MLIR.IR.DenseArrayAttribute(Int64[0, reverse(1:(length(val_size) - 1))...])
195+
else
196+
attr = MLIR.IR.DenseArrayAttribute(Int64[reverse(0:(length(val_size) - 1))...])
197+
end
194198
return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1)
195199
end
196200

test/autodiff.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,17 @@ end
156156
hlo = @code_hlo optimize = false Enzyme.onehot(x)
157157
@test !contains("stablehlo.constant", repr(hlo))
158158
end
159+
160+
fn(x) = sum(abs2, x)
161+
vector_forward_ad(x) = Enzyme.autodiff(Forward, fn, BatchDuplicated(x, Enzyme.onehot(x)))
162+
163+
@testset "Vector Mode AD" begin
164+
x = Reactant.to_rarray(reshape(collect(Float32, 1:4), 2, 2))
165+
res = @jit vector_forward_ad(x)
166+
res_enz = vector_forward_ad(Array(x))
167+
168+
@test res[1][1] res_enz[1][1]
169+
@test res[1][2] res_enz[1][2]
170+
@test res[1][3] res_enz[1][3]
171+
@test res[1][4] res_enz[1][4]
172+
end

0 commit comments

Comments
 (0)