@@ -192,7 +192,11 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
192
192
predims = size (x. val)
193
193
cval = MLIR. IR. result (
194
194
MLIR. Dialects. stablehlo. concatenate (
195
- [Ops. reshape (v, Int64[1 , predims... ]) for v in x. dval]; dimension= Int64 (0 )
195
+ [
196
+ TracedUtils. get_mlir_data (Ops. reshape (v, Int64[1 , predims... ])) for
197
+ v in x. dval
198
+ ];
199
+ dimension= Int64 (0 ),
196
200
),
197
201
)
198
202
tval = TracedRArray {ET,length(predims) + 1} ((), cval, (length (x. dval), predims... ))
@@ -244,12 +248,6 @@ function overload_autodiff(
244
248
width = Enzyme. same_or_one (1 , args... )
245
249
if width == 0
246
250
throw (ErrorException (" Cannot differentiate with a batch size of 0" ))
247
- elseif width != 1
248
- throw (
249
- ErrorException (
250
- " EnzymeMLIR does not presently support width=$width , please rewrite your code to not use BatchDuplicated and/or call gradient(; chunk=1)" ,
251
- ),
252
- )
253
251
end
254
252
255
253
primf = f. val
@@ -389,9 +387,10 @@ function overload_autodiff(
389
387
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
390
388
fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
391
389
res = (reverse ? MLIR. Dialects. enzyme. autodiff : MLIR. Dialects. enzyme. fwddiff)(
392
- [TracedUtils. transpose_val (v) for v in ad_inputs];
390
+ [TracedUtils. transpose_val (v; keep_first_intact = width > 1 ) for v in ad_inputs];
393
391
outputs= outtys,
394
392
fn= fname,
393
+ width,
395
394
activity= MLIR. IR. Attribute ([act_attr (a) for a in activity]),
396
395
ret_activity= MLIR. IR. Attribute ([act_attr (a) for a in ret_activity]),
397
396
)
@@ -434,8 +433,11 @@ function overload_autodiff(
434
433
push! (starts, 0 )
435
434
push! (limits, v)
436
435
end
437
- sval = Ops. slice (sval, starts, limits)
438
- TracedUtils. set! (dresult[i], path[2 : end ], sval)
436
+ sval = Ops. slice (TracedRArray (tval), starts, limits)
437
+ sval = Ops. reshape (sval, collect (Int64, sz))
438
+ TracedUtils. set! (
439
+ dresult[i], path[2 : end ], TracedUtils. get_mlir_data (sval)
440
+ )
439
441
end
440
442
end
441
443
residx += 1
0 commit comments