Skip to content

Commit 055028b

Browse files
committed
feat: enable new auto-batching passes
1 parent b6ec9a1 commit 055028b

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.164"
4+
version = "0.2.165"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
@@ -103,7 +103,7 @@ PythonCall = "0.9.25"
103103
Random = "1.10"
104104
Random123 = "1.7"
105105
ReactantCore = "0.1.16"
106-
Reactant_jll = "0.0.240"
106+
Reactant_jll = "0.0.241"
107107
ScopedValues = "1.3.0"
108108
Scratch = "1.2"
109109
Sockets = "1.10"

src/CompileOptions.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Fine-grained control over the compilation options for the Reactant compiler.
157157
`false` by default.
158158
- `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM)
159159
optimization passes. This is `false` by default.
160+
- `disable_auto_batching_passes`: Disables the auto-batching optimization passes. This
161+
is `false` by default.
160162
"""
161163
struct CompileOptions
162164
optimization_passes::Union{Symbol,String}
@@ -185,6 +187,7 @@ struct CompileOptions
185187
disable_scatter_gather_optimization_passes::Bool
186188
disable_pad_optimization_passes::Bool
187189
disable_licm_optimization_passes::Bool
190+
disable_auto_batching_passes::Bool
188191
end
189192

190193
function CompileOptions(;
@@ -208,6 +211,7 @@ function CompileOptions(;
208211
disable_scatter_gather_optimization_passes::Bool=false,
209212
disable_pad_optimization_passes::Bool=false,
210213
disable_licm_optimization_passes::Bool=false,
214+
disable_auto_batching_passes::Bool=false,
211215
)
212216
optimization_passes isa Bool &&
213217
(optimization_passes = ifelse(optimization_passes, :all, :none))
@@ -256,6 +260,7 @@ function CompileOptions(;
256260
disable_scatter_gather_optimization_passes,
257261
disable_pad_optimization_passes,
258262
disable_licm_optimization_passes,
263+
disable_auto_batching_passes,
259264
)
260265
end
261266

src/Compiler.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -903,14 +903,36 @@ function optimization_passes(
903903
"self_add_to_convolution_like($(Int(backend == "tpu")))",
904904
"self_mul_to_convolution_like($(Int(backend == "tpu")))",
905905
"subtract_multiply_const_to_add_mul_const",
906-
"concat_insert_dim_dot_general",
907-
"concat_insert_dim_gather",
908-
"concat_insert_dim_iota",
909-
"concat_insert_dim_reduce",
910-
"concat_insert_dim_sort",
911-
"concat_insert_dim_reduce_window",
912906
]
913907

908+
if !compile_options.disable_auto_batching_passes
909+
append!(
910+
transform_passes_list,
911+
[
912+
"trivial_reduce_window_to_reduce_op",
913+
"add_reduce_slice_fusion",
914+
"mul_reduce_slice_fusion",
915+
"min_reduce_slice_fusion",
916+
"max_reduce_slice_fusion",
917+
"concat_insert_dim_dot_general",
918+
"concat_insert_dim_gather",
919+
"concat_insert_dim_iota",
920+
"concat_insert_dim_reduce",
921+
"concat_insert_dim_sort",
922+
"concat_insert_dim_reduce_window",
923+
"dot_general_slice_to_batch",
924+
"gather_slice_to_batch",
925+
"iota_slice_to_batch",
926+
"reduce_slice_to_batch",
927+
"sort_slice_to_batch",
928+
"transpose_slice_to_batch",
929+
"broadcastindim_slice_to_batch",
930+
"reducewindow_slice_to_batch",
931+
"elementwise_slice_to_batch",
932+
],
933+
)
934+
end
935+
914936
if !compile_options.disable_licm_optimization_passes
915937
append!(
916938
transform_passes_list,

0 commit comments

Comments
 (0)