@@ -244,14 +244,27 @@ const opt_passes::String = join(
244
244
' ,' ,
245
245
)
246
246
247
- function run_pass_pipeline! (mod, pass_pipeline)
247
+ function run_pass_pipeline! (mod, pass_pipeline; enable_verifier = true )
248
248
pm = MLIR. IR. PassManager ()
249
+ MLIR. IR. enable_verifier! (pm, enable_verifier)
249
250
opm = MLIR. IR. OpPassManager (pm)
250
251
MLIR. IR. add_pipeline! (opm, pass_pipeline)
251
252
MLIR. IR. run! (pm, mod)
252
253
return mod
253
254
end
254
255
256
+ # helper for debug purposes: String -> Text
257
+ function run_pass_pipeline_on_source (source, pass_pipeline; enable_verifier= true )
258
+ ctx = MLIR. IR. Context (Reactant. registry[], false )
259
+ @ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
260
+ MLIR. IR. context! (ctx) do
261
+ mod = parse (MLIR. IR. Module, source)
262
+ run_pass_pipeline! (mod, pass_pipeline; enable_verifier)
263
+ MLIR. IR. verifyall (MLIR. IR. Operation (mod); debug= true )
264
+ Text (repr (mod))
265
+ end
266
+ end
267
+
255
268
function compile_mlir (f, args; kwargs... )
256
269
ctx = MLIR. IR. Context (Reactant. registry[], false )
257
270
@ccall MLIR. API. mlir_c. RegisterDialects (ctx:: MLIR.API.MlirContext ):: Cvoid
@@ -280,15 +293,12 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
280
293
optimize isa Bool && (optimize = ifelse (optimize, :all , :none ))
281
294
282
295
if optimize === :all
296
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes], " ," ))
297
+ run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
283
298
run_pass_pipeline! (
284
299
mod,
285
300
join (
286
301
[
287
- opt_passes,
288
- " enzyme-batch" ,
289
- opt_passes,
290
- " enzyme" ,
291
- " arith-raise{stablehlo=true}" ,
292
302
" canonicalize" ,
293
303
" remove-unnecessary-enzyme-ops" ,
294
304
" enzyme-simplify-math" ,
@@ -298,28 +308,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
298
308
),
299
309
)
300
310
elseif optimize === :only_enzyme
311
+ run_pass_pipeline! (mod, " enzyme-batch" )
312
+ run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
301
313
run_pass_pipeline! (
302
314
mod,
303
315
join (
304
- [
305
- " enzyme-batch" ,
306
- " enzyme" ,
307
- " arith-raise{stablehlo=true}" ,
308
- " canonicalize" ,
309
- " remove-unnecessary-enzyme-ops" ,
310
- " enzyme-simplify-math" ,
311
- ],
316
+ [" canonicalize" , " remove-unnecessary-enzyme-ops" , " enzyme-simplify-math" ],
312
317
' ,' ,
313
318
),
314
319
)
315
320
elseif optimize === :after_enzyme
321
+ run_pass_pipeline! (mod, " enzyme-batch" )
322
+ run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
316
323
run_pass_pipeline! (
317
324
mod,
318
325
join (
319
326
[
320
- " enzyme-batch" ,
321
- " enzyme" ,
322
- " arith-raise{stablehlo=true}" ,
323
327
" canonicalize" ,
324
328
" remove-unnecessary-enzyme-ops" ,
325
329
" enzyme-simplify-math" ,
@@ -329,21 +333,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
329
333
),
330
334
)
331
335
elseif optimize === :before_enzyme
336
+ run_pass_pipeline! (mod, join ([opt_passes, " enzyme-batch" , opt_passes]))
337
+ run_pass_pipeline! (mod, " enzyme,arith-raise{stablehlo=true}" ; enable_verifier= false )
332
338
run_pass_pipeline! (
333
- mod,
334
- join (
335
- [
336
- opt_passes,
337
- " enzyme-batch" ,
338
- opt_passes,
339
- " enzyme" ,
340
- " arith-raise{stablehlo=true}" ,
341
- " canonicalize" ,
342
- " remove-unnecessary-enzyme-ops" ,
343
- " enzyme-simplify-math" ,
344
- ],
345
- ' ,' ,
346
- ),
339
+ mod, " canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math"
347
340
)
348
341
elseif optimize != = :none
349
342
error (" Invalid optimize option: $(Meta. quot (optimize)) " )
0 commit comments