Skip to content

Commit 5b58b64

Browse files
authored
Add option to raise before ad/batching (#1204)
1 parent ac07d5f commit 5b58b64

File tree

1 file changed

+64
-2
lines changed

1 file changed

+64
-2
lines changed

src/Compiler.jl

Lines changed: 64 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1115,6 +1115,7 @@ function compile_mlir!(
11151115
assert_nonallocating::Bool=false,
11161116
backend="gpu",
11171117
raise::Union{Bool,String}=false,
1118+
raise_first::Bool=false,
11181119
# TODO: allow more fine-grained options to control the donation of specific arguments
11191120
donated_args::Symbol=:auto, # :auto | :none
11201121
optimize_then_pad::Bool=true,
@@ -1267,6 +1268,21 @@ function compile_mlir!(
12671268
run_pass_pipeline!(
12681269
mod,
12691270
join(
1271+
raise_first ?
1272+
[
1273+
opt_passes,
1274+
kern,
1275+
raise_passes,
1276+
"enzyme-batch",
1277+
opt_passes2,
1278+
enzyme_pass,
1279+
opt_passes2,
1280+
"canonicalize",
1281+
"remove-unnecessary-enzyme-ops",
1282+
"enzyme-simplify-math",
1283+
opt_passes2,
1284+
jit,
1285+
] :
12701286
[
12711287
opt_passes,
12721288
"enzyme-batch",
@@ -1289,6 +1305,10 @@ function compile_mlir!(
12891305
run_pass_pipeline!(
12901306
mod,
12911307
join(
1308+
raise_first ?
1309+
[
1310+
opt_passes,
1311+
] :
12921312
[
12931313
opt_passes,
12941314
"enzyme-batch",
@@ -1308,6 +1328,20 @@ function compile_mlir!(
13081328
run_pass_pipeline!(
13091329
mod,
13101330
join(
1331+
raise_first ?
1332+
[
1333+
opt_passes,
1334+
kern,
1335+
raise_passes,
1336+
"enzyme-batch",
1337+
opt_passes2,
1338+
enzyme_pass,
1339+
opt_passes2,
1340+
"canonicalize",
1341+
"remove-unnecessary-enzyme-ops",
1342+
"enzyme-simplify-math",
1343+
opt_passes2,
1344+
] :
13111345
[
13121346
opt_passes,
13131347
"enzyme-batch",
@@ -1329,7 +1363,9 @@ function compile_mlir!(
13291363
run_pass_pipeline!(
13301364
mod,
13311365
join(
1332-
[
1366+
raise_first ? [
1367+
opt_passes
1368+
] : [
13331369
opt_passes,
13341370
"enzyme-batch",
13351371
opt_passes2,
@@ -1383,7 +1419,18 @@ function compile_mlir!(
13831419
run_pass_pipeline!(
13841420
mod,
13851421
join(
1422+
raise_first ?
13861423
[
1424+
kern,
1425+
raise_passes,
1426+
"enzyme-batch",
1427+
enzyme_pass,
1428+
"canonicalize",
1429+
"remove-unnecessary-enzyme-ops",
1430+
"enzyme-simplify-math",
1431+
opt_passes2,
1432+
jit,
1433+
] : [
13871434
"enzyme-batch",
13881435
enzyme_pass,
13891436
"canonicalize",
@@ -1401,7 +1448,17 @@ function compile_mlir!(
14011448
elseif optimize === :before_enzyme
14021449
run_pass_pipeline!(
14031450
mod,
1404-
join(
1451+
join( raise_first ?
1452+
[
1453+
opt_passes,
1454+
kern,
1455+
raise_passes,
1456+
"enzyme-batch",
1457+
opt_passes2,
1458+
enzyme_pass,
1459+
"canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math",
1460+
jit,
1461+
] :
14051462
[
14061463
opt_passes,
14071464
"enzyme-batch",
@@ -1823,6 +1880,7 @@ macro code_hlo(args...)
18231880
:no_nan => false,
18241881
:client => nothing,
18251882
:raise => false,
1883+
:raise_first => false,
18261884
:shardy_passes => :(:to_mhlo_shardings),
18271885
:assert_nonallocating => false,
18281886
:donated_args => :(:auto),
@@ -1857,6 +1915,7 @@ macro code_mhlo(args...)
18571915
:no_nan => false,
18581916
:client => nothing,
18591917
:raise => false,
1918+
:raise_first => false,
18601919
:shardy_passes => :(:to_mhlo_shardings),
18611920
:assert_nonallocating => false,
18621921
:donated_args => :(:auto),
@@ -1891,6 +1950,7 @@ macro code_xla(args...)
18911950
:no_nan => false,
18921951
:client => nothing,
18931952
:raise => false,
1953+
:raise_first => false,
18941954
:shardy_passes => :(:to_mhlo_shardings),
18951955
:assert_nonallocating => false,
18961956
:donated_args => :(:auto),
@@ -1924,6 +1984,7 @@ macro compile(args...)
19241984
:no_nan => false,
19251985
:client => nothing,
19261986
:raise => false,
1987+
:raise_first => false,
19271988
:shardy_passes => :(:to_mhlo_shardings),
19281989
:assert_nonallocating => false,
19291990
:serializable => false,
@@ -1948,6 +2009,7 @@ macro jit(args...)
19482009
:no_nan => false,
19492010
:client => nothing,
19502011
:raise => false,
2012+
:raise_first => false,
19512013
:shardy_passes => :(:to_mhlo_shardings),
19522014
:assert_nonallocating => false,
19532015
:donated_args => :(:auto),

0 commit comments

Comments
 (0)