Skip to content

Commit 79752b7

Browse files
authored
feat: enable new auto-batching passes (#1690)
* feat: enable new auto-batching passes * fix: move passes * fix: bad rebase * test: run acos/acosh with fp32 * test: mark acos and acosh tests broken on tpu
1 parent ee012b7 commit 79752b7

File tree

4 files changed

+41
-16
lines changed

4 files changed

+41
-16
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Reactant"
22
uuid = "3c362404-f566-11ee-1572-e11a4b42c853"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>", "Mosè Giordano <[email protected]>"]
4-
version = "0.2.166"
4+
version = "0.2.167"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/CompileOptions.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,8 @@ Fine-grained control over the compilation options for the Reactant compiler.
157157
`false` by default.
158158
- `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM)
159159
optimization passes. This is `false` by default.
160+
- `disable_auto_batching_passes`: Disables the auto-batching optimization passes. This
161+
is `false` by default.
160162
"""
161163
struct CompileOptions
162164
optimization_passes::Union{Symbol,String}
@@ -185,6 +187,7 @@ struct CompileOptions
185187
disable_scatter_gather_optimization_passes::Bool
186188
disable_pad_optimization_passes::Bool
187189
disable_licm_optimization_passes::Bool
190+
disable_auto_batching_passes::Bool
188191
end
189192

190193
function CompileOptions(;
@@ -208,6 +211,7 @@ function CompileOptions(;
208211
disable_scatter_gather_optimization_passes::Bool=false,
209212
disable_pad_optimization_passes::Bool=false,
210213
disable_licm_optimization_passes::Bool=false,
214+
disable_auto_batching_passes::Bool=false,
211215
)
212216
optimization_passes isa Bool &&
213217
(optimization_passes = ifelse(optimization_passes, :all, :none))
@@ -256,6 +260,7 @@ function CompileOptions(;
256260
disable_scatter_gather_optimization_passes,
257261
disable_pad_optimization_passes,
258262
disable_licm_optimization_passes,
263+
disable_auto_batching_passes,
259264
)
260265
end
261266

@@ -297,6 +302,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt
297302
compile_options.disable_scatter_gather_optimization_passes,
298303
compile_options.disable_pad_optimization_passes,
299304
compile_options.disable_licm_optimization_passes,
305+
compile_options.disable_auto_batching_passes,
300306
)
301307
end
302308

src/Compiler.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -903,14 +903,38 @@ function optimization_passes(
903903
"self_add_to_convolution_like($(Int(backend == "tpu")))",
904904
"self_mul_to_convolution_like($(Int(backend == "tpu")))",
905905
"subtract_multiply_const_to_add_mul_const",
906-
"concat_insert_dim_dot_general",
907-
"concat_insert_dim_gather",
908-
"concat_insert_dim_iota",
909-
"concat_insert_dim_reduce",
910-
"concat_insert_dim_sort",
911-
"concat_insert_dim_reduce_window",
906+
"trivial_reduce_window_to_reduce_op",
907+
"dot_general_add_distributive_simplify",
908+
"dot_general_subtract_distributive_simplify",
912909
]
913910

911+
if !compile_options.disable_auto_batching_passes
912+
append!(
913+
transform_passes_list,
914+
[
915+
"add_reduce_slice_fusion",
916+
"mul_reduce_slice_fusion",
917+
"min_reduce_slice_fusion",
918+
"max_reduce_slice_fusion",
919+
"concat_insert_dim_dot_general",
920+
"concat_insert_dim_gather",
921+
"concat_insert_dim_iota",
922+
"concat_insert_dim_reduce",
923+
"concat_insert_dim_sort",
924+
"concat_insert_dim_reduce_window",
925+
"dot_general_slice_to_batch",
926+
"gather_slice_to_batch",
927+
"iota_slice_to_batch",
928+
"reduce_slice_to_batch",
929+
"sort_slice_to_batch",
930+
"transpose_slice_to_batch",
931+
"broadcastindim_slice_to_batch",
932+
"reducewindow_slice_to_batch",
933+
"elementwise_slice_to_batch",
934+
],
935+
)
936+
end
937+
914938
if !compile_options.disable_licm_optimization_passes
915939
append!(
916940
transform_passes_list,
@@ -1047,11 +1071,6 @@ function optimization_passes(
10471071
"const_prop_through_barrier<16>",
10481072
"concat_const_prop<1>($max_constant_threshold)",
10491073
"dynamic_update_slice_const_prop($max_constant_threshold)",
1050-
"add_reduce_slice_fusion",
1051-
"mul_reduce_slice_fusion",
1052-
"min_reduce_slice_fusion",
1053-
"max_reduce_slice_fusion",
1054-
"trivial_reduce_window_to_reduce_op",
10551074
],
10561075
)
10571076

test/ops.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -790,13 +790,13 @@ end
790790
end
791791

792792
@testset "acos" begin
793-
x = Reactant.to_rarray([-1.0, 0.0, 1.0])
794-
@test acos.(Array(x)) @jit Ops.acos(x)
793+
x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0])
794+
@test acos.(Array(x)) @jit(Ops.acos(x)) broken = RunningOnTPU
795795
end
796796

797797
@testset "acosh" begin
798-
x = Reactant.to_rarray([1.0, 10.0])
799-
@test acosh.(Array(x)) @jit Ops.acosh(x)
798+
x = Reactant.to_rarray(Float32[1.0, 10.0])
799+
@test acosh.(Array(x)) @jit(Ops.acosh(x)) broken = RunningOnTPU
800800
end
801801

802802
@testset "asin" begin

0 commit comments

Comments
 (0)