Skip to content

Commit 05bd81f

Browse files
authored
Disable verifier in main pass manager pipeline (#269)
* Disable verifier in main pass manager pipeline * verify after pm * disable verifier between enzyme and arith-raise * cleanup
1 parent 9d6cc5d commit 05bd81f

File tree

3 files changed

+49
-38
lines changed

3 files changed

+49
-38
lines changed

src/Compiler.jl

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -244,14 +244,27 @@ const opt_passes::String = join(
244244
',',
245245
)
246246

247-
function run_pass_pipeline!(mod, pass_pipeline)
247+
function run_pass_pipeline!(mod, pass_pipeline; enable_verifier=true)
248248
pm = MLIR.IR.PassManager()
249+
MLIR.IR.enable_verifier!(pm, enable_verifier)
249250
opm = MLIR.IR.OpPassManager(pm)
250251
MLIR.IR.add_pipeline!(opm, pass_pipeline)
251252
MLIR.IR.run!(pm, mod)
252253
return mod
253254
end
254255

256+
# helper for debug purposes: String -> Text
257+
function run_pass_pipeline_on_source(source, pass_pipeline; enable_verifier=true)
258+
ctx = MLIR.IR.Context(Reactant.registry[], false)
259+
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
260+
MLIR.IR.context!(ctx) do
261+
mod = parse(MLIR.IR.Module, source)
262+
run_pass_pipeline!(mod, pass_pipeline; enable_verifier)
263+
MLIR.IR.verifyall(MLIR.IR.Operation(mod); debug=true)
264+
Text(repr(mod))
265+
end
266+
end
267+
255268
function compile_mlir(f, args; kwargs...)
256269
ctx = MLIR.IR.Context(Reactant.registry[], false)
257270
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid
@@ -280,15 +293,12 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
280293
optimize isa Bool && (optimize = ifelse(optimize, :all, :none))
281294

282295
if optimize === :all
296+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
297+
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
283298
run_pass_pipeline!(
284299
mod,
285300
join(
286301
[
287-
opt_passes,
288-
"enzyme-batch",
289-
opt_passes,
290-
"enzyme",
291-
"arith-raise{stablehlo=true}",
292302
"canonicalize",
293303
"remove-unnecessary-enzyme-ops",
294304
"enzyme-simplify-math",
@@ -298,28 +308,22 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
298308
),
299309
)
300310
elseif optimize === :only_enzyme
311+
run_pass_pipeline!(mod, "enzyme-batch")
312+
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
301313
run_pass_pipeline!(
302314
mod,
303315
join(
304-
[
305-
"enzyme-batch",
306-
"enzyme",
307-
"arith-raise{stablehlo=true}",
308-
"canonicalize",
309-
"remove-unnecessary-enzyme-ops",
310-
"enzyme-simplify-math",
311-
],
316+
["canonicalize", "remove-unnecessary-enzyme-ops", "enzyme-simplify-math"],
312317
',',
313318
),
314319
)
315320
elseif optimize === :after_enzyme
321+
run_pass_pipeline!(mod, "enzyme-batch")
322+
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
316323
run_pass_pipeline!(
317324
mod,
318325
join(
319326
[
320-
"enzyme-batch",
321-
"enzyme",
322-
"arith-raise{stablehlo=true}",
323327
"canonicalize",
324328
"remove-unnecessary-enzyme-ops",
325329
"enzyme-simplify-math",
@@ -329,21 +333,10 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
329333
),
330334
)
331335
elseif optimize === :before_enzyme
336+
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes]))
337+
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
332338
run_pass_pipeline!(
333-
mod,
334-
join(
335-
[
336-
opt_passes,
337-
"enzyme-batch",
338-
opt_passes,
339-
"enzyme",
340-
"arith-raise{stablehlo=true}",
341-
"canonicalize",
342-
"remove-unnecessary-enzyme-ops",
343-
"enzyme-simplify-math",
344-
],
345-
',',
346-
),
339+
mod, "canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math"
347340
)
348341
elseif optimize !== :none
349342
error("Invalid optimize option: $(Meta.quot(optimize))")

src/mlir/IR/IR.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,15 @@ Base.String(str::API.MlirIdentifier) = String(API.mlirIdentifierStr(str))
100100
### Utils
101101

102102
function visit(f, op)
103+
all_ok = true
103104
for region in RegionIterator(op)
104105
for block in BlockIterator(region)
105106
for op in OperationIterator(block)
106-
f(op)
107+
all_ok &= f(op)
107108
end
108109
end
109110
end
111+
return all_ok
110112
end
111113

112114
"""
@@ -115,13 +117,20 @@ end
115117
Prints the operations which could not be verified.
116118
"""
117119
function verifyall(operation::Operation; debug=false)
118-
io = IOContext(stdout, :debug => debug)
120+
io = IOBuffer()
119121
visit(operation) do op
120-
if !verify(op)
121-
show(io, op)
122+
ok = verifyall(op; debug)
123+
if !ok || !verify(op)
124+
if ok
125+
show(IOContext(io, :debug => debug), op)
126+
error(String(take!(io)))
127+
end
128+
false
129+
else
130+
true
122131
end
123132
end
124133
end
125-
verifyall(module_::IR.Module) = verifyall(Operation(module_))
134+
verifyall(module_::IR.Module; debug=false) = verifyall(Operation(module_); debug)
126135

127136
end # module IR

src/mlir/IR/Pass.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,17 @@ Base.convert(::Core.Type{API.MlirPassManager}, pass::PassManager) = pass.pass
4040
4141
Enable mlir-print-ir-after-all.
4242
"""
43-
function enable_ir_printing!(pm)
44-
API.mlirPassManagerEnableIRPrinting(pm)
43+
function enable_ir_printing!(
44+
pm;
45+
before_all=false,
46+
after_all=false,
47+
module_scope=false,
48+
after_only_on_change=false,
49+
after_only_on_failure=false,
50+
)
51+
API.mlirPassManagerEnableIRPrinting(
52+
pm, before_all, after_all, module_scope, after_only_on_change, after_only_on_failure
53+
)
4554
return pm
4655
end
4756

0 commit comments

Comments
 (0)