@@ -112,11 +112,9 @@ function create_result(
112
112
return Meta. quot (tocopy)
113
113
end
114
114
115
- const opt_passes:: String = join (
115
+ # Optimization passes via transform dialect
116
+ const transform_passes:: String = join (
116
117
[
117
- " inline{default-pipeline=canonicalize max-iterations=4}" ,
118
- " canonicalize,cse" ,
119
- " canonicalize" ,
120
118
" enzyme-hlo-generate-td{" *
121
119
join (
122
120
[
@@ -273,9 +271,22 @@ const opt_passes::String = join(
273
271
" transform-interpreter" ,
274
272
" enzyme-hlo-remove-transform" ,
275
273
],
276
- ' ,' ,
274
+ " ," ,
275
+ )
276
+
277
+ # Optimization passes which apply to an individual function
278
+ const func_passes:: String = join (
279
+ [" canonicalize,cse" , " canonicalize" , transform_passes], " ,"
277
280
)
278
281
282
+ const opt_passes:: String = join (
283
+ [" inline{default-pipeline=canonicalize max-iterations=4}" , func_passes], ' ,'
284
+ )
285
+
286
+ # TODO we want to be able to run the more advanced passes via transform dialect as an enzyme intermediate
287
+ # However, this errs as we cannot attach the transform with to the funcop itself [as we run a functionpass].
288
+ const enzyme_pass:: String = " enzyme{postpasses=\" arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\" }"
289
+
279
290
function run_pass_pipeline! (mod, pass_pipeline; enable_verifier= true )
280
291
pm = MLIR. IR. PassManager ()
281
292
MLIR. IR. enable_verifier! (pm, enable_verifier)
@@ -335,7 +346,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
335
346
kern = " lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[]) }"
336
347
if optimize === :all
337
348
run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes], " ," ))
338
- run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
349
+ run_pass_pipeline! (
350
+ mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
351
+ )
339
352
run_pass_pipeline! (
340
353
mod,
341
354
join (
@@ -351,7 +364,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
351
364
)
352
365
elseif optimize === :before_kernel
353
366
run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes], " ," ))
354
- run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
367
+ run_pass_pipeline! (
368
+ mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
369
+ )
355
370
run_pass_pipeline! (
356
371
mod,
357
372
join (
@@ -381,7 +396,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
381
396
)
382
397
elseif optimize === :only_enzyme
383
398
run_pass_pipeline! (mod, " enzyme-batch" )
384
- run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
399
+ run_pass_pipeline! (
400
+ mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
401
+ )
385
402
run_pass_pipeline! (
386
403
mod,
387
404
join (
@@ -391,7 +408,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
391
408
)
392
409
elseif optimize === :after_enzyme
393
410
run_pass_pipeline! (mod, " enzyme-batch" )
394
- run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
411
+ run_pass_pipeline! (
412
+ mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
413
+ )
395
414
run_pass_pipeline! (
396
415
mod,
397
416
join (
@@ -407,7 +426,9 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
407
426
)
408
427
elseif optimize === :before_enzyme
409
428
run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes], " ," ))
410
- run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
429
+ run_pass_pipeline! (
430
+ mod, " $enzyme_pass ,arith-raise{stablehlo=true}" ; enable_verifier= false
431
+ )
411
432
run_pass_pipeline! (
412
433
mod, " canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math," * kern
413
434
)
0 commit comments