From 82c416802430a5a10de6b3ec8f4960ba6098553d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Sep 2025 10:28:07 -0400 Subject: [PATCH 1/5] feat: enable new auto-batching passes --- Project.toml | 4 ++++ src/CompileOptions.jl | 6 ++++++ src/Compiler.jl | 34 ++++++++++++++++++++++++++++------ 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/Project.toml b/Project.toml index 71d0b6e670..3707893f48 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,11 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] +<<<<<<< HEAD version = "0.2.166" +======= +version = "0.2.165" +>>>>>>> 6bc46b1dc (feat: enable new auto-batching passes) [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index ad3aae51a2..e8cac78be6 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -157,6 +157,8 @@ Fine-grained control over the compilation options for the Reactant compiler. `false` by default. - `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM) optimization passes. This is `false` by default. + - `disable_auto_batching_passes`: Disables the auto-batching optimization passes. This + is `false` by default. """ struct CompileOptions optimization_passes::Union{Symbol,String} @@ -185,6 +187,7 @@ struct CompileOptions disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool disable_licm_optimization_passes::Bool + disable_auto_batching_passes::Bool end function CompileOptions(; @@ -208,6 +211,7 @@ function CompileOptions(; disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, disable_licm_optimization_passes::Bool=false, + disable_auto_batching_passes::Bool=false, ) optimization_passes isa Bool && (optimization_passes = ifelse(optimization_passes, :all, :none)) @@ -256,6 +260,7 @@ function CompileOptions(; disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, disable_licm_optimization_passes, + disable_auto_batching_passes, ) end @@ -297,6 +302,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, compile_options.disable_licm_optimization_passes, + compile_options.disable_auto_batching_passes, ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 942b2bdfb8..970e2d4bf3 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -903,14 +903,36 @@ function optimization_passes( "self_add_to_convolution_like($(Int(backend == "tpu")))", "self_mul_to_convolution_like($(Int(backend == "tpu")))", "subtract_multiply_const_to_add_mul_const", - "concat_insert_dim_dot_general", - "concat_insert_dim_gather", - "concat_insert_dim_iota", - "concat_insert_dim_reduce", - "concat_insert_dim_sort", - "concat_insert_dim_reduce_window", ] + if !compile_options.disable_auto_batching_passes + append!( + transform_passes_list, + [ + "trivial_reduce_window_to_reduce_op", + "add_reduce_slice_fusion", + "mul_reduce_slice_fusion", + "min_reduce_slice_fusion", + "max_reduce_slice_fusion", + "concat_insert_dim_dot_general", + "concat_insert_dim_gather", + "concat_insert_dim_iota", + "concat_insert_dim_reduce", + "concat_insert_dim_sort", + "concat_insert_dim_reduce_window", + "dot_general_slice_to_batch", + "gather_slice_to_batch", + "iota_slice_to_batch", + "reduce_slice_to_batch", + "sort_slice_to_batch", + "transpose_slice_to_batch", + "broadcastindim_slice_to_batch", + "reducewindow_slice_to_batch", + "elementwise_slice_to_batch", + ], + ) + end + if !compile_options.disable_licm_optimization_passes append!( transform_passes_list, From dde5a3ed15e36313e7bccd9b4217e0d43616941e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 20 Sep 2025 11:35:57 -0400 Subject: [PATCH 2/5] fix: move passes --- src/Compiler.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/Compiler.jl b/src/Compiler.jl index 970e2d4bf3..1226b7757d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -903,13 +903,15 @@ function optimization_passes( "self_add_to_convolution_like($(Int(backend == "tpu")))", "self_mul_to_convolution_like($(Int(backend == "tpu")))", "subtract_multiply_const_to_add_mul_const", + "trivial_reduce_window_to_reduce_op", + "dot_general_add_distributive_simplify", + "dot_general_subtract_distributive_simplify", ] if !compile_options.disable_auto_batching_passes append!( transform_passes_list, [ - "trivial_reduce_window_to_reduce_op", "add_reduce_slice_fusion", "mul_reduce_slice_fusion", "min_reduce_slice_fusion", @@ -1069,11 +1071,6 @@ function optimization_passes( "const_prop_through_barrier<16>", "concat_const_prop<1>($max_constant_threshold)", "dynamic_update_slice_const_prop($max_constant_threshold)", - "add_reduce_slice_fusion", - "mul_reduce_slice_fusion", - "min_reduce_slice_fusion", - "max_reduce_slice_fusion", - "trivial_reduce_window_to_reduce_op", ], ) From c3b11000eecaea72203ded079afd09957459c105 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 08:12:49 -0400 Subject: [PATCH 3/5] fix: bad rebase --- Project.toml | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 3707893f48..b7830203dc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -<<<<<<< HEAD -version = "0.2.166" -======= -version = "0.2.165" ->>>>>>> 6bc46b1dc (feat: enable new auto-batching passes) +version = "0.2.167" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" From 70aa85723c6360cd8ba75b3ff80972c87d1b74bb Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 10:15:39 -0500 Subject: [PATCH 4/5] test: run acos/acosh with fp32 --- test/ops.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ops.jl b/test/ops.jl index 6be96585d8..8a4e086aa4 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -790,12 +790,12 @@ end end @testset "acos" begin - x = Reactant.to_rarray([-1.0, 0.0, 1.0]) + x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0]) @test acos.(Array(x)) ≈ @jit Ops.acos(x) end @testset "acosh" begin - x = Reactant.to_rarray([1.0, 10.0]) + x = Reactant.to_rarray(Float32[1.0, 10.0]) @test acosh.(Array(x)) ≈ @jit Ops.acosh(x) end From 7588c647cc12c79f3708503fef36e05242cbd25e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 24 Sep 2025 14:06:33 -0500 Subject: [PATCH 5/5] test: mark acos and acosh tests broken on tpu --- test/ops.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/ops.jl b/test/ops.jl index 8a4e086aa4..104316bd21 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -791,12 +791,12 @@ end @testset "acos" begin x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0]) - @test acos.(Array(x)) ≈ @jit Ops.acos(x) + @test acos.(Array(x)) ≈ @jit(Ops.acos(x)) broken = RunningOnTPU end @testset "acosh" begin x = Reactant.to_rarray(Float32[1.0, 10.0]) - @test acosh.(Array(x)) ≈ @jit Ops.acosh(x) + @test acosh.(Array(x)) ≈ @jit(Ops.acosh(x)) broken = RunningOnTPU end @testset "asin" begin