diff --git a/Project.toml b/Project.toml index 71d0b6e670..b7830203dc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Reactant" uuid = "3c362404-f566-11ee-1572-e11a4b42c853" authors = ["William Moses ", "Valentin Churavy ", "Sergio Sánchez Ramírez ", "Paul Berg ", "Avik Pal ", "Mosè Giordano "] -version = "0.2.166" +version = "0.2.167" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/CompileOptions.jl b/src/CompileOptions.jl index ad3aae51a2..e8cac78be6 100644 --- a/src/CompileOptions.jl +++ b/src/CompileOptions.jl @@ -157,6 +157,8 @@ Fine-grained control over the compilation options for the Reactant compiler. `false` by default. - `disable_licm_optimization_passes`: Disables the Loop Invariant Code Motion (LICM) optimization passes. This is `false` by default. + - `disable_auto_batching_passes`: Disables the auto-batching optimization passes. This + is `false` by default. """ struct CompileOptions optimization_passes::Union{Symbol,String} @@ -185,6 +187,7 @@ struct CompileOptions disable_scatter_gather_optimization_passes::Bool disable_pad_optimization_passes::Bool disable_licm_optimization_passes::Bool + disable_auto_batching_passes::Bool end function CompileOptions(; @@ -208,6 +211,7 @@ function CompileOptions(; disable_scatter_gather_optimization_passes::Bool=false, disable_pad_optimization_passes::Bool=false, disable_licm_optimization_passes::Bool=false, + disable_auto_batching_passes::Bool=false, ) optimization_passes isa Bool && (optimization_passes = ifelse(optimization_passes, :all, :none)) @@ -256,6 +260,7 @@ function CompileOptions(; disable_scatter_gather_optimization_passes, disable_pad_optimization_passes, disable_licm_optimization_passes, + disable_auto_batching_passes, ) end @@ -297,6 +302,7 @@ function __compile_options_with_reversed_propagation(compile_options::CompileOpt compile_options.disable_scatter_gather_optimization_passes, compile_options.disable_pad_optimization_passes, compile_options.disable_licm_optimization_passes, + compile_options.disable_auto_batching_passes, ) end diff --git a/src/Compiler.jl b/src/Compiler.jl index 942b2bdfb8..1226b7757d 100644 --- a/src/Compiler.jl +++ b/src/Compiler.jl @@ -903,14 +903,38 @@ function optimization_passes( "self_add_to_convolution_like($(Int(backend == "tpu")))", "self_mul_to_convolution_like($(Int(backend == "tpu")))", "subtract_multiply_const_to_add_mul_const", - "concat_insert_dim_dot_general", - "concat_insert_dim_gather", - "concat_insert_dim_iota", - "concat_insert_dim_reduce", - "concat_insert_dim_sort", - "concat_insert_dim_reduce_window", + "trivial_reduce_window_to_reduce_op", + "dot_general_add_distributive_simplify", + "dot_general_subtract_distributive_simplify", ] + if !compile_options.disable_auto_batching_passes + append!( + transform_passes_list, + [ + "add_reduce_slice_fusion", + "mul_reduce_slice_fusion", + "min_reduce_slice_fusion", + "max_reduce_slice_fusion", + "concat_insert_dim_dot_general", + "concat_insert_dim_gather", + "concat_insert_dim_iota", + "concat_insert_dim_reduce", + "concat_insert_dim_sort", + "concat_insert_dim_reduce_window", + "dot_general_slice_to_batch", + "gather_slice_to_batch", + "iota_slice_to_batch", + "reduce_slice_to_batch", + "sort_slice_to_batch", + "transpose_slice_to_batch", + "broadcastindim_slice_to_batch", + "reducewindow_slice_to_batch", + "elementwise_slice_to_batch", + ], + ) + end + if !compile_options.disable_licm_optimization_passes append!( transform_passes_list, @@ -1047,11 +1071,6 @@ function optimization_passes( "const_prop_through_barrier<16>", "concat_const_prop<1>($max_constant_threshold)", "dynamic_update_slice_const_prop($max_constant_threshold)", - "add_reduce_slice_fusion", - "mul_reduce_slice_fusion", - "min_reduce_slice_fusion", - "max_reduce_slice_fusion", - "trivial_reduce_window_to_reduce_op", ], ) diff --git a/test/ops.jl b/test/ops.jl index 6be96585d8..104316bd21 100644 --- a/test/ops.jl +++ b/test/ops.jl @@ -790,13 +790,13 @@ end end @testset "acos" begin - x = Reactant.to_rarray([-1.0, 0.0, 1.0]) - @test acos.(Array(x)) ≈ @jit Ops.acos(x) + x = Reactant.to_rarray(Float32[-1.0, 0.0, 1.0]) + @test acos.(Array(x)) ≈ @jit(Ops.acos(x)) broken = RunningOnTPU end @testset "acosh" begin - x = Reactant.to_rarray([1.0, 10.0]) - @test acosh.(Array(x)) ≈ @jit Ops.acosh(x) + x = Reactant.to_rarray(Float32[1.0, 10.0]) + @test acosh.(Array(x)) ≈ @jit(Ops.acosh(x)) broken = RunningOnTPU end @testset "asin" begin