Skip to content

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Sep 20, 2025

@avik-pal avik-pal requested a review from wsmoses September 20, 2025 14:28
@avik-pal avik-pal force-pushed the ap/autobatching_passes branch from 055028b to afb256b Compare September 20, 2025 14:35
@avik-pal
Copy link
Collaborator Author

needs some upstream fixes before merging

@avik-pal
Copy link
Collaborator Author

avik-pal commented Sep 20, 2025

module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func private @"Const{typeof(sum_squares)}(Main.sum_squares)_autodiff"(%arg0: tensor<3xf32>) -> (tensor<f32>, tensor<3xf32>) attributes {enzymexla.memory_effects = []} {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %1 = stablehlo.convert %cst : tensor<f32>
    %2 = stablehlo.abs %0 : tensor<3xf32>
    %3 = stablehlo.multiply %2, %2 : tensor<3xf32>
    %4 = stablehlo.reduce(%3 init: %1) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
    %5 = stablehlo.transpose %0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    return %4, %5 : tensor<f32>, tensor<3xf32>
  }
  func.func @main(%arg0: tensor<3xf32> {tf.aliasing_output = 3 : i32}) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<3xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
    %cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %c = stablehlo.constant dense<0> : tensor<i64>
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %1 = stablehlo.convert %c : (tensor<i64>) -> tensor<f32>
    %2 = stablehlo.convert %cst : tensor<f32>
    %3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
    %4 = stablehlo.pad %3, %1, low = [0], high = [2], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
    %5 = stablehlo.transpose %4, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %6 = stablehlo.reshape %5 : (tensor<3xf32>) -> tensor<3xf32>
    %7 = stablehlo.transpose %6, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %8 = stablehlo.pad %3, %1, low = [1], high = [1], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
    %9 = stablehlo.transpose %8, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %10 = stablehlo.reshape %9 : (tensor<3xf32>) -> tensor<3xf32>
    %11 = stablehlo.transpose %10, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %12 = stablehlo.pad %3, %1, low = [2], high = [0], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
    %13 = stablehlo.transpose %12, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %14 = stablehlo.reshape %13 : (tensor<3xf32>) -> tensor<3xf32>
    %15 = stablehlo.transpose %14, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %16 = stablehlo.transpose %7, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %17 = stablehlo.reshape %16 : (tensor<3xf32>) -> tensor<3x1xf32>
    %18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
    %19 = stablehlo.transpose %11, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %20 = stablehlo.reshape %19 : (tensor<3xf32>) -> tensor<3x1xf32>
    %21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
    %22 = stablehlo.transpose %15, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %23 = stablehlo.reshape %22 : (tensor<3xf32>) -> tensor<3x1xf32>
    %24 = stablehlo.transpose %23, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
    %25 = stablehlo.concatenate %18, %21, %24, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
    %26 = stablehlo.transpose %0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %27 = stablehlo.transpose %25, dims = [0, 1] : (tensor<3x3xf32>) -> tensor<3x3xf32>
    %28:2 = enzyme.fwddiff @"Const{typeof(sum_squares)}(Main.sum_squares)_autodiff"(%26, %27) {activity = [#enzyme<activity enzyme_dup>], ret_activity = [#enzyme<activity enzyme_dupnoneed>, #enzyme<activity enzyme_const>], width = 3 : i64} : (tensor<3xf32>, tensor<3x3xf32>) -> (tensor<3xf32>, tensor<3xf32>)
    %29 = stablehlo.transpose %28#0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %30 = stablehlo.slice %29 [0:1] : (tensor<3xf32>) -> tensor<1xf32>
    %31 = stablehlo.transpose %30, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %32 = stablehlo.reshape %31 : (tensor<1xf32>) -> tensor<f32>
    %33 = stablehlo.transpose %32, dims = [] : (tensor<f32>) -> tensor<f32>
    %34 = stablehlo.slice %29 [1:2] : (tensor<3xf32>) -> tensor<1xf32>
    %35 = stablehlo.transpose %34, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %36 = stablehlo.reshape %35 : (tensor<1xf32>) -> tensor<f32>
    %37 = stablehlo.transpose %36, dims = [] : (tensor<f32>) -> tensor<f32>
    %38 = stablehlo.slice %29 [2:3] : (tensor<3xf32>) -> tensor<1xf32>
    %39 = stablehlo.transpose %38, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
    %40 = stablehlo.reshape %39 : (tensor<1xf32>) -> tensor<f32>
    %41 = stablehlo.transpose %40, dims = [] : (tensor<f32>) -> tensor<f32>
    %42 = stablehlo.transpose %28#1, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    %43 = stablehlo.transpose %42, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
    return %33, %37, %41, %43 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<3xf32>
  }
}
julia> @code_hlo optimize="inline,enzyme-batch,inline,$(Reactant.Compiler.enzyme_pass),inline" Enzyme.gradient(Forward, sum_squares, x)
loc("abs/abs"("/mnt/software/lux/Reactant.jl/src/TracedRNumber.jl":505:0)): error: 'stablehlo.compare' op all non-scalar operands/results must have the same shape and base type

@avik-pal
Copy link
Collaborator Author

module @reactant_f_gener... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<6x2xf32> {tf.aliasing_output = 1 : i32}, %arg1: tensor<2x4xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<4xf32>, tensor<6x2xf32>, tensor<2x4xf32>) attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x2xf32>) -> tensor<2x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x4xf32>) -> tensor<4x2xf32>
    %2 = stablehlo.slice %0 [0:2, 0:1] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %4 = stablehlo.reshape %3 : (tensor<1x2xf32>) -> tensor<2xf32>
    %5 = stablehlo.transpose %4, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %6 = stablehlo.transpose %5, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %7 = stablehlo.reshape %6 : (tensor<2xf32>) -> tensor<1x2xf32>
    %8 = stablehlo.transpose %7, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %9 = stablehlo.convert %8 : tensor<2x1xf32>
    %10 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %11 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %12 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %13 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %14 = stablehlo.dot_general %11, %13, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %15 = stablehlo.transpose %14, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %16 = stablehlo.reshape %15 : (tensor<1x4xf32>) -> tensor<4xf32>
    %17 = stablehlo.transpose %16, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %18 = stablehlo.slice %0 [0:2, 1:2] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %19 = stablehlo.transpose %18, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %20 = stablehlo.reshape %19 : (tensor<1x2xf32>) -> tensor<2xf32>
    %21 = stablehlo.transpose %20, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %22 = stablehlo.transpose %21, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %23 = stablehlo.reshape %22 : (tensor<2xf32>) -> tensor<1x2xf32>
    %24 = stablehlo.transpose %23, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %25 = stablehlo.convert %24 : tensor<2x1xf32>
    %26 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %27 = stablehlo.broadcast_in_dim %26, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %29 = stablehlo.broadcast_in_dim %28, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %30 = stablehlo.dot_general %27, %29, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %32 = stablehlo.reshape %31 : (tensor<1x4xf32>) -> tensor<4xf32>
    %33 = stablehlo.transpose %32, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %34 = stablehlo.slice %0 [0:2, 2:3] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %36 = stablehlo.reshape %35 : (tensor<1x2xf32>) -> tensor<2xf32>
    %37 = stablehlo.transpose %36, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %38 = stablehlo.transpose %37, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %39 = stablehlo.reshape %38 : (tensor<2xf32>) -> tensor<1x2xf32>
    %40 = stablehlo.transpose %39, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %41 = stablehlo.convert %40 : tensor<2x1xf32>
    %42 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %43 = stablehlo.broadcast_in_dim %42, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %44 = stablehlo.broadcast_in_dim %41, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %45 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %46 = stablehlo.dot_general %43, %45, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %47 = stablehlo.transpose %46, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %48 = stablehlo.reshape %47 : (tensor<1x4xf32>) -> tensor<4xf32>
    %49 = stablehlo.transpose %48, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %50 = stablehlo.slice %0 [0:2, 3:4] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %51 = stablehlo.transpose %50, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %52 = stablehlo.reshape %51 : (tensor<1x2xf32>) -> tensor<2xf32>
    %53 = stablehlo.transpose %52, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %54 = stablehlo.transpose %53, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %55 = stablehlo.reshape %54 : (tensor<2xf32>) -> tensor<1x2xf32>
    %56 = stablehlo.transpose %55, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %57 = stablehlo.convert %56 : tensor<2x1xf32>
    %58 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %59 = stablehlo.broadcast_in_dim %58, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %60 = stablehlo.broadcast_in_dim %57, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %61 = stablehlo.broadcast_in_dim %60, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %62 = stablehlo.dot_general %59, %61, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %63 = stablehlo.transpose %62, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %64 = stablehlo.reshape %63 : (tensor<1x4xf32>) -> tensor<4xf32>
    %65 = stablehlo.transpose %64, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %66 = stablehlo.slice %0 [0:2, 4:5] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %67 = stablehlo.transpose %66, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %68 = stablehlo.reshape %67 : (tensor<1x2xf32>) -> tensor<2xf32>
    %69 = stablehlo.transpose %68, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %70 = stablehlo.transpose %69, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %71 = stablehlo.reshape %70 : (tensor<2xf32>) -> tensor<1x2xf32>
    %72 = stablehlo.transpose %71, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %73 = stablehlo.convert %72 : tensor<2x1xf32>
    %74 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %75 = stablehlo.broadcast_in_dim %74, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %76 = stablehlo.broadcast_in_dim %73, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %77 = stablehlo.broadcast_in_dim %76, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %78 = stablehlo.dot_general %75, %77, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %79 = stablehlo.transpose %78, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %80 = stablehlo.reshape %79 : (tensor<1x4xf32>) -> tensor<4xf32>
    %81 = stablehlo.transpose %80, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %82 = stablehlo.slice %0 [0:2, 5:6] : (tensor<2x6xf32>) -> tensor<2x1xf32>
    %83 = stablehlo.transpose %82, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
    %84 = stablehlo.reshape %83 : (tensor<1x2xf32>) -> tensor<2xf32>
    %85 = stablehlo.transpose %84, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %86 = stablehlo.transpose %85, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %87 = stablehlo.reshape %86 : (tensor<2xf32>) -> tensor<1x2xf32>
    %88 = stablehlo.transpose %87, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
    %89 = stablehlo.convert %88 : tensor<2x1xf32>
    %90 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %91 = stablehlo.broadcast_in_dim %90, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
    %92 = stablehlo.broadcast_in_dim %89, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %93 = stablehlo.broadcast_in_dim %92, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
    %94 = stablehlo.dot_general %91, %93, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
    %95 = stablehlo.transpose %94, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
    %96 = stablehlo.reshape %95 : (tensor<1x4xf32>) -> tensor<4xf32>
    %97 = stablehlo.transpose %96, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %98 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %99 = stablehlo.broadcast_in_dim %98, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %100 = stablehlo.broadcast_in_dim %33, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %101 = stablehlo.broadcast_in_dim %100, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %102 = stablehlo.add %99, %101 : tensor<4xf32>
    %103 = stablehlo.broadcast_in_dim %102, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %104 = stablehlo.broadcast_in_dim %103, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %105 = stablehlo.broadcast_in_dim %49, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %106 = stablehlo.broadcast_in_dim %105, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %107 = stablehlo.add %104, %106 : tensor<4xf32>
    %108 = stablehlo.broadcast_in_dim %107, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %109 = stablehlo.broadcast_in_dim %108, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %110 = stablehlo.broadcast_in_dim %65, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %111 = stablehlo.broadcast_in_dim %110, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %112 = stablehlo.add %109, %111 : tensor<4xf32>
    %113 = stablehlo.broadcast_in_dim %112, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %114 = stablehlo.broadcast_in_dim %113, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %115 = stablehlo.broadcast_in_dim %81, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %116 = stablehlo.broadcast_in_dim %115, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %117 = stablehlo.add %114, %116 : tensor<4xf32>
    %118 = stablehlo.broadcast_in_dim %117, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %119 = stablehlo.broadcast_in_dim %118, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %120 = stablehlo.broadcast_in_dim %97, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %121 = stablehlo.broadcast_in_dim %120, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %122 = stablehlo.add %119, %121 : tensor<4xf32>
    %123 = stablehlo.transpose %122, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
    %124 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x6xf32>) -> tensor<6x2xf32>
    %125 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x2xf32>) -> tensor<2x4xf32>
    return %123, %124, %125 : tensor<4xf32>, tensor<6x2xf32>, tensor<2x4xf32>
  }
}

@avik-pal avik-pal force-pushed the ap/autobatching_passes branch 4 times, most recently from 8943304 to fb3d3cf Compare September 22, 2025 03:03
@avik-pal
Copy link
Collaborator Author

map_of_slices is the only one left and it seems to give incorrect results

@avik-pal
Copy link
Collaborator Author

julia> @code_hlo mapped_sub(x_ra, y_ra)
module @reactant_mapped_sub attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> tensor<5x3x10xf32> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
    %1 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %2 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %3 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %4 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %5 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %6 = stablehlo.reshape %5 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
    %7 = stablehlo.reshape %4 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
    %8 = stablehlo.reshape %3 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
    %9 = stablehlo.reshape %2 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
    %10 = stablehlo.reshape %1 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
    %11 = stablehlo.transpose %6, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
    %12 = stablehlo.transpose %7, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
    %13 = stablehlo.transpose %8, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
    %14 = stablehlo.transpose %9, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
    %15 = stablehlo.transpose %10, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
    %16 = stablehlo.concatenate %11, %12, %13, %14, %15, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
    %17 = stablehlo.transpose %arg1, dims = [1, 0, 2] : (tensor<3x5x10xf32>) -> tensor<5x3x10xf32>
    %18 = stablehlo.subtract %16, %17 : tensor<5x3x10xf32>
    %19 = stablehlo.slice %18 [0:1, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
    %20 = stablehlo.slice %18 [4:5, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
    %21 = stablehlo.slice %18 [3:4, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
    %22 = stablehlo.slice %18 [2:3, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
    %23 = stablehlo.slice %18 [1:2, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
    %24 = stablehlo.concatenate %20, %21, %22, %23, %19, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
    return %24 : tensor<5x3x10xf32>
  }
}
julia> @code_hlo compile_options=CompileOptions(; disable_auto_batching_passes=true) mapped_sub(x_ra, y_ra)
module @reactant_mapped_sub attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
  func.func @main(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> tensor<5x3x10xf32> attributes {enzymexla.memory_effects = []} {
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
    %1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
    %2 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %3 = stablehlo.slice %1 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %4 = stablehlo.subtract %2, %3 : tensor<10x1x3xf32>
    %5 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %6 = stablehlo.slice %1 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %7 = stablehlo.subtract %5, %6 : tensor<10x1x3xf32>
    %8 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %9 = stablehlo.slice %1 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %10 = stablehlo.subtract %8, %9 : tensor<10x1x3xf32>
    %11 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %12 = stablehlo.slice %1 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %13 = stablehlo.subtract %11, %12 : tensor<10x1x3xf32>
    %14 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %15 = stablehlo.slice %1 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
    %16 = stablehlo.subtract %14, %15 : tensor<10x1x3xf32>
    %17 = stablehlo.transpose %4, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
    %18 = stablehlo.transpose %7, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
    %19 = stablehlo.transpose %10, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
    %20 = stablehlo.transpose %13, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
    %21 = stablehlo.transpose %16, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
    %22 = stablehlo.concatenate %17, %18, %19, %20, %21, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
    return %22 : tensor<5x3x10xf32>
  }
}

The concatenate ordering gets reversed somehow

@avik-pal avik-pal force-pushed the ap/autobatching_passes branch from 99f084b to dde5a3e Compare September 24, 2025 12:05
@avik-pal
Copy link
Collaborator Author

avik-pal commented Sep 24, 2025

Now this is interesting

acos: Error During Test at /__w/Reactant.jl/Reactant.jl/test/ops.jl:794
  Test threw exception
  Expression: acos.(Array(x)) ≈ #= /__w/Reactant.jl/Reactant.jl/test/ops.jl:794 =# @jit(Ops.acos(x))
  UNKNOWN: <unknown>:0: error: 'mhlo.acos' op can't be translated to XLA HLO
  <unknown>:0: note: see current operation: %0 = "mhlo.acos"(%arg0) : (tensor<3xf64>) -> tensor<3xf64>

probably an fp64 issue on TPUs. though strange that it never showed up before

EDIT: seems to fail also on main

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 7588c64 Previous: 2131b27 Ratio
DeepONet ([64, 1024], [1, 128])/forward/CPU/Default 0.0018566370000000002 s 0.0024067840000000004 s 0.77
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableScatterGatherPad 0.0014481490000000001 s 0.002254101 s 0.64
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadAfterEnzyme 0.0040894830000000005 s 0.005826802000000001 s 0.70
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultAfterEnzyme 0.004064674 s 0.005924193 s 0.69
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadBeforeEnzyme 0.004104079 s 0.005839057000000001 s 0.70
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadAll 0.003928868 s 0.005868277000000001 s 0.67
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadBeforeEnzyme 0.004189838 s 0.00568523 s 0.74
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadAll 0.003848447 s 0.005890702 s 0.65
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableScatterGather 0.001549824 s 0.002402238 s 0.65
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultAll 0.0040704750000000005 s 0.006035788 s 0.67
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeAfterEnzyme 0.004056754 s 0.005928088000000001 s 0.68
DeepONet ([64, 1024], [1, 128])/forward/CPU/XLA 0.001910898 s 0.0027047720000000003 s 0.71
DeepONet ([64, 1024], [1, 128])/backward/CPU/XLA 0.003902 s 0.005446732 s 0.72
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherAfterEnzyme 0.004108513 s 0.005861321 s 0.70
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultBeforeEnzyme 0.003987011 s 0.0060816590000000005 s 0.66
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeBeforeEnzyme 0.003943455 s 0.005698693 s 0.69
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadAfterEnzyme 0.004052483 s 0.006128887 s 0.66
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisablePad 0.00156591 s 0.00244653 s 0.64
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeAll 0.004135262000000001 s 0.0060840880000000005 s 0.68
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherAll 0.003938176 s 0.006018556 s 0.65
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableTransposeReshape 0.001965469 s 0.0026089150000000003 s 0.75
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherBeforeEnzyme 0.0040453550000000005 s 0.005988157 s 0.68
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisablePad 0.002070267 s 0.00207267 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherAll 0.0006746590000000001 s 0.000659719 s 1.02
DeepONet ([64, 1024], [1, 128])/backward/CUDA/XLA 0.0007607450000000001 s 0.000781538 s 0.97
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeAll 0.007142959000000001 s 0.007126589 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultBeforeEnzyme 0.0006859860000000001 s 0.0006768330000000001 s 1.01
FNO [64, 64, 1, 4]/backward/CUDA/DefaultAll 0.00294448 s 0.002952203 s 1.00
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableScatterGatherPad 0.002055082 s 0.002088262 s 0.98
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeBeforeEnzyme 0.007237473 s 0.007271700000000001 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/DefaultBeforeEnzyme 0.002992541 s 0.003008483 s 0.99
DeepONet ([64, 1024], [1, 128])/forward/CUDA/XLA 0.00040807200000000005 s 0.00031795 s 1.28
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadBeforeEnzyme 0.002995333 s 0.0030012750000000003 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadAll 0.002970879 s 0.0029658870000000004 s 1.00
FNO [64, 64, 1, 4]/forward/CUDA/DisablePad 0.001102949 s 0.001096432 s 1.01
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadBeforeEnzyme 0.0072452820000000005 s 0.007244951 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadBeforeEnzyme 0.007201009 s 0.007269816 s 0.99
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherAfterEnzyme 0.007177705 s 0.007139463 s 1.01
FNO [64, 64, 1, 4]/forward/CUDA/DisableScatterGatherPad 0.001089207 s 0.00109638 s 0.99
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadAll 0.0006744450000000001 s 0.0006655410000000001 s 1.01
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableTransposeReshape 0.000383983 s 0.00031572100000000004 s 1.22
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableScatterGatherPad 0.003333565 s 0.003225472 s 1.03
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableScatterGather 0.003335367 s 0.0031815460000000004 s 1.05
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultAfterEnzyme 0.000668823 s 0.0006713520000000001 s 1.00
ViT tiny [256, 256, 3, 4]/backward/CUDA/XLA 0.012436622000000001 s 0.012320307 s 1.01
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherAfterEnzyme 0.002930331 s 0.0029470190000000004 s 0.99
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultAfterEnzyme 0.007131237 s 0.0071298460000000004 s 1.00
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableScatterGather 0.000323033 s 0.000312852 s 1.03
FNO [64, 64, 1, 4]/forward/CUDA/DisableTransposeReshape 0.00113328 s 0.0011316800000000001 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadAll 0.002970673 s 0.0029653120000000003 s 1.00
ViT tiny [256, 256, 3, 4]/forward/CUDA/XLA 0.0033959880000000004 s 0.003401124 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherBeforeEnzyme 0.007233291 s 0.007224094 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeAll 0.003053605 s 0.0030645620000000003 s 1.00
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableTransposeReshape 0.0033168470000000004 s 0.003162512 s 1.05
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadAll 0.007110697 s 0.007115695000000001 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadAfterEnzyme 0.007124266000000001 s 0.0071315860000000005 s 1.00
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisablePad 0.00032266500000000003 s 0.00030670800000000005 s 1.05
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeBeforeEnzyme 0.003101855 s 0.003136779 s 0.99
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherAfterEnzyme 0.0006821370000000001 s 0.0006719590000000001 s 1.02
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeAfterEnzyme 0.0006630350000000001 s 0.0006703060000000001 s 0.99
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadAll 0.007131274000000001 s 0.007114193 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeBeforeEnzyme 0.000689004 s 0.000683454 s 1.01
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultAll 0.007136877000000001 s 0.007128227000000001 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/XLA 0.003098953 s 0.0031011740000000004 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/XLA 0.0072632420000000005 s 0.0072945530000000005 s 1.00
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableTransposeReshape 0.002096872 s 0.002086091 s 1.01
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeAfterEnzyme 0.0030684840000000002 s 0.0030906840000000002 s 0.99
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultAll 0.000678717 s 0.0006604100000000001 s 1.03
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadAfterEnzyme 0.002937916 s 0.002950761 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherAll 0.007114579 s 0.007119805000000001 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadAfterEnzyme 0.0006692240000000001 s 0.000666676 s 1.00
FNO [64, 64, 1, 4]/forward/CUDA/XLA 0.001168523 s 0.001170003 s 1.00
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisablePad 0.002515443 s 0.0025442660000000003 s 0.99
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/XLA 0.002147673 s 0.002106246 s 1.02
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadAfterEnzyme 0.007148421 s 0.007130592000000001 s 1.00
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherBeforeEnzyme 0.0029982370000000004 s 0.003008989 s 1.00
ViT tiny [256, 256, 3, 4]/forward/CUDA/Default 0.0025684220000000003 s 0.0026219620000000003 s 0.98
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherBeforeEnzyme 0.0006927880000000001 s 0.000684178 s 1.01
FNO [64, 64, 1, 4]/forward/CUDA/DisableScatterGather 0.0010950060000000001 s 0.0010816130000000001 s 1.01
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeAfterEnzyme 0.00714066 s 0.007133554 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadAfterEnzyme 0.0006785150000000001 s 0.000675371 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadAll 0.00068933 s 0.0006535110000000001 s 1.05
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherAll 0.002947022 s 0.002966312 s 0.99
FNO [64, 64, 1, 4]/backward/CUDA/DefaultAfterEnzyme 0.002924631 s 0.0029365600000000004 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeAll 0.000680089 s 0.0006594950000000001 s 1.03
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/Default 0.002071323 s 0.0020791380000000003 s 1.00
FNO [64, 64, 1, 4]/forward/CUDA/Default 0.001097006 s 0.0010892480000000001 s 1.01
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableScatterGatherPad 0.000309783 s 0.00031698800000000004 s 0.98
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableScatterGather 0.002055575 s 0.002082716 s 0.99
ViT tiny [256, 256, 3, 4]/backward/CUDA/DefaultAll 0.010291049 s 0.010472342 s 0.98
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadBeforeEnzyme 0.000715529 s 0.000728067 s 0.98
DeepONet ([64, 1024], [1, 128])/forward/CUDA/Default 0.000384106 s 0.00031147300000000004 s 1.23
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadAfterEnzyme 0.0029304390000000004 s 0.002953442 s 0.99
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadBeforeEnzyme 0.002994436 s 0.003013833 s 0.99
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultBeforeEnzyme 0.007188503000000001 s 0.007236543000000001 s 0.99
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadBeforeEnzyme 0.000683378 s 0.000674917 s 1.01
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadAfterEnzyme 0.00478776 s 0.00461879 s 1.04
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisablePad 0.00132105 s 0.00130101 s 1.02
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadAll 0.0030123600000000004 s 0.0030014100000000004 s 1.00
ViT tiny [256, 256, 3, 4]/backward/TPU/XLA 0.00277333 s 0.0027932 s 0.99
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadBeforeEnzyme 0.000348 s 0.00034408 s 1.01
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeAfterEnzyme 0.00034279000000000004 s 0.00032938 s 1.04
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherAfterEnzyme 0.0028914500000000003 s 0.00286118 s 1.01
FNO [64, 64, 1, 4]/backward/TPU/DefaultAll 0.003005329 s 0.00301053 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadAll 0.00476733 s 0.0046343600000000006 s 1.03
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadBeforeEnzyme 0.00033307000000000003 s 0.00034229000000000003 s 0.97
DeepONet ([64, 1024], [1, 128])/forward/TPU/Default 0.00017654000000000001 s 0.00015838 s 1.11
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadBeforeEnzyme 0.00477038 s 0.004628500000000001 s 1.03
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableScatterGatherPad 0.00130867 s 0.00129856 s 1.01
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeAfterEnzyme 0.00299545 s 0.00298065 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherBeforeEnzyme 0.00033365 s 0.0003303 s 1.01
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherAfterEnzyme 0.0003347 s 0.00033314 s 1.00
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableScatterGatherPad 0.00061604 s 0.0006215400000000001 s 0.99
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeAll 0.00300531 s 0.00296983 s 1.01
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeBeforeEnzyme 0.0030023800000000002 s 0.0029898100000000003 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultAfterEnzyme 0.00476489 s 0.00463141 s 1.03
FNO [64, 64, 1, 4]/forward/TPU/DisableScatterGather 0.00110545 s 0.0011204300000000002 s 0.99
FNO [64, 64, 1, 4]/backward/TPU/DisablePadBeforeEnzyme 0.00301042 s 0.0029917700000000004 s 1.01
ViT tiny [256, 256, 3, 4]/forward/TPU/XLA 0.00101647 s 0.00101515 s 1.00
ViT tiny [256, 256, 3, 4]/backward/TPU/DefaultAll 0.0024631600000000003 s 0.00246533 s 1.00
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/Default 0.0013048900000000002 s 0.00128643 s 1.01
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableScatterGatherPad 0.00018294 s 0.0002035 s 0.90
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultAll 0.00034422 s 0.00036082000000000003 s 0.95
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultAll 0.00475847 s 0.00462008 s 1.03
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableScatterGather 0.00018494 s 0.00019098900000000002 s 0.97
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableTransposeReshape 0.00130736 s 0.00128951 s 1.01
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadBeforeEnzyme 0.00301446 s 0.00299278 s 1.01
ViT tiny [256, 256, 3, 4]/forward/TPU/DisablePad 0.0006243500000000001 s 0.00062405 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultBeforeEnzyme 0.0003372 s 0.00036066 s 0.93
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisablePad 0.00018842000000000002 s 0.00019264 s 0.98
DeepONet ([64, 1024], [1, 128])/backward/TPU/XLA 0.00034975 s 0.00037036 s 0.94
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeAll 0.004752050000000001 s 0.0046303 s 1.03
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherAll 0.00300896 s 0.0030121310000000004 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadAll 0.00033156900000000004 s 0.00034628 s 0.96
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeBeforeEnzyme 0.004766109 s 0.00461928 s 1.03
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeAfterEnzyme 0.00477023 s 0.0046304300000000005 s 1.03
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherAll 0.00474051 s 0.004631290000000001 s 1.02
FNO [64, 64, 1, 4]/backward/TPU/DisablePadAfterEnzyme 0.0028922400000000003 s 0.0028658200000000003 s 1.01
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherBeforeEnzyme 0.00476063 s 0.00462901 s 1.03
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherAfterEnzyme 0.004748399 s 0.004632819000000001 s 1.02
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/XLA 0.0012423500000000001 s 0.00122663 s 1.01
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableTransposeReshape 0.00063924 s 0.00062295 s 1.03
ViT tiny [256, 256, 3, 4]/forward/TPU/Default 0.00063086 s 0.00060841 s 1.04
FNO [64, 64, 1, 4]/backward/TPU/DefaultBeforeEnzyme 0.0030050000000000003 s 0.00301759 s 1.00
FNO [64, 64, 1, 4]/forward/TPU/Default 0.00111855 s 0.00111681 s 1.00
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/XLA 0.00465067 s 0.00453087 s 1.03
FNO [64, 64, 1, 4]/forward/TPU/XLA 0.00138557 s 0.00139206 s 1.00
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableTransposeReshape 0.00017470000000000002 s 0.00018154 s 0.96
FNO [64, 64, 1, 4]/backward/TPU/DisablePadAll 0.003001 s 0.003002799 s 1.00
DeepONet ([64, 1024], [1, 128])/forward/TPU/XLA 0.00026143 s 0.00028878000000000004 s 0.91
FNO [64, 64, 1, 4]/forward/TPU/DisableScatterGatherPad 0.00110544 s 0.0011168900000000002 s 0.99
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadAfterEnzyme 0.00476395 s 0.00463045 s 1.03
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultAfterEnzyme 0.00033756000000000004 s 0.00036080000000000004 s 0.94
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadAll 0.00479154 s 0.00462816 s 1.04
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultBeforeEnzyme 0.00475882 s 0.00464774 s 1.02
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableScatterGather 0.00130991 s 0.00129289 s 1.01
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableScatterGather 0.00062874 s 0.00061041 s 1.03
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherBeforeEnzyme 0.00301922 s 0.0030071300000000002 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeBeforeEnzyme 0.00033915000000000003 s 0.00033028 s 1.03
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadAfterEnzyme 0.00033798000000000004 s 0.00033678 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadAll 0.00034544 s 0.00032839 s 1.05
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherAll 0.00033853 s 0.00033048 s 1.02
FNO [64, 64, 1, 4]/forward/TPU/DisableTransposeReshape 0.0011446100000000001 s 0.00114841 s 1.00
FNO [64, 64, 1, 4]/backward/TPU/DefaultAfterEnzyme 0.00287951 s 0.0028636900000000003 s 1.01
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadAfterEnzyme 0.0029170000000000003 s 0.0028570400000000004 s 1.02
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadBeforeEnzyme 0.0047691 s 0.0046275100000000005 s 1.03
FNO [64, 64, 1, 4]/forward/TPU/DisablePad 0.0011034 s 0.00111677 s 0.99
FNO [64, 64, 1, 4]/backward/TPU/XLA 0.00318736 s 0.0031743400000000003 s 1.00
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeAll 0.00034306 s 0.00033025000000000003 s 1.04
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadAfterEnzyme 0.00033275000000000004 s 0.0003396 s 0.98

This comment was automatically generated by workflow using github-action-benchmark.

@avik-pal avik-pal force-pushed the ap/autobatching_passes branch from 1a32cf5 to 70aa857 Compare September 24, 2025 19:05
@avik-pal avik-pal force-pushed the ap/autobatching_passes branch from 3120869 to 7588c64 Compare September 24, 2025 20:51
@avik-pal
Copy link
Collaborator Author

@wsmoses this should be good to go on my end. I marked the acos tests as broken for now.

opened an issue upstream openxla/xla#31796

@avik-pal avik-pal merged commit 79752b7 into main Sep 25, 2025
70 of 72 checks passed
@avik-pal avik-pal deleted the ap/autobatching_passes branch September 25, 2025 03:10
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants