-
Notifications
You must be signed in to change notification settings - Fork 28
feat: enable new auto-batching passes #1690
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
055028b
to
afb256b
Compare
needs some upstream fixes before merging |
module @reactant_gradient attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"Const{typeof(sum_squares)}(Main.sum_squares)_autodiff"(%arg0: tensor<3xf32>) -> (tensor<f32>, tensor<3xf32>) attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%1 = stablehlo.convert %cst : tensor<f32>
%2 = stablehlo.abs %0 : tensor<3xf32>
%3 = stablehlo.multiply %2, %2 : tensor<3xf32>
%4 = stablehlo.reduce(%3 init: %1) applies stablehlo.add across dimensions = [0] : (tensor<3xf32>, tensor<f32>) -> tensor<f32>
%5 = stablehlo.transpose %0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
return %4, %5 : tensor<f32>, tensor<3xf32>
}
func.func @main(%arg0: tensor<3xf32> {tf.aliasing_output = 3 : i32}) -> (tensor<f32>, tensor<f32>, tensor<f32>, tensor<3xf32>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%c = stablehlo.constant dense<0> : tensor<i64>
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%1 = stablehlo.convert %c : (tensor<i64>) -> tensor<f32>
%2 = stablehlo.convert %cst : tensor<f32>
%3 = stablehlo.broadcast_in_dim %2, dims = [] : (tensor<f32>) -> tensor<1xf32>
%4 = stablehlo.pad %3, %1, low = [0], high = [2], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
%5 = stablehlo.transpose %4, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%6 = stablehlo.reshape %5 : (tensor<3xf32>) -> tensor<3xf32>
%7 = stablehlo.transpose %6, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%8 = stablehlo.pad %3, %1, low = [1], high = [1], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
%9 = stablehlo.transpose %8, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%10 = stablehlo.reshape %9 : (tensor<3xf32>) -> tensor<3xf32>
%11 = stablehlo.transpose %10, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%12 = stablehlo.pad %3, %1, low = [2], high = [0], interior = [0] : (tensor<1xf32>, tensor<f32>) -> tensor<3xf32>
%13 = stablehlo.transpose %12, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%14 = stablehlo.reshape %13 : (tensor<3xf32>) -> tensor<3xf32>
%15 = stablehlo.transpose %14, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%16 = stablehlo.transpose %7, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%17 = stablehlo.reshape %16 : (tensor<3xf32>) -> tensor<3x1xf32>
%18 = stablehlo.transpose %17, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
%19 = stablehlo.transpose %11, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%20 = stablehlo.reshape %19 : (tensor<3xf32>) -> tensor<3x1xf32>
%21 = stablehlo.transpose %20, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
%22 = stablehlo.transpose %15, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%23 = stablehlo.reshape %22 : (tensor<3xf32>) -> tensor<3x1xf32>
%24 = stablehlo.transpose %23, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32>
%25 = stablehlo.concatenate %18, %21, %24, dim = 0 : (tensor<1x3xf32>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<3x3xf32>
%26 = stablehlo.transpose %0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%27 = stablehlo.transpose %25, dims = [0, 1] : (tensor<3x3xf32>) -> tensor<3x3xf32>
%28:2 = enzyme.fwddiff @"Const{typeof(sum_squares)}(Main.sum_squares)_autodiff"(%26, %27) {activity = [#enzyme<activity enzyme_dup>], ret_activity = [#enzyme<activity enzyme_dupnoneed>, #enzyme<activity enzyme_const>], width = 3 : i64} : (tensor<3xf32>, tensor<3x3xf32>) -> (tensor<3xf32>, tensor<3xf32>)
%29 = stablehlo.transpose %28#0, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%30 = stablehlo.slice %29 [0:1] : (tensor<3xf32>) -> tensor<1xf32>
%31 = stablehlo.transpose %30, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%32 = stablehlo.reshape %31 : (tensor<1xf32>) -> tensor<f32>
%33 = stablehlo.transpose %32, dims = [] : (tensor<f32>) -> tensor<f32>
%34 = stablehlo.slice %29 [1:2] : (tensor<3xf32>) -> tensor<1xf32>
%35 = stablehlo.transpose %34, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%36 = stablehlo.reshape %35 : (tensor<1xf32>) -> tensor<f32>
%37 = stablehlo.transpose %36, dims = [] : (tensor<f32>) -> tensor<f32>
%38 = stablehlo.slice %29 [2:3] : (tensor<3xf32>) -> tensor<1xf32>
%39 = stablehlo.transpose %38, dims = [0] : (tensor<1xf32>) -> tensor<1xf32>
%40 = stablehlo.reshape %39 : (tensor<1xf32>) -> tensor<f32>
%41 = stablehlo.transpose %40, dims = [] : (tensor<f32>) -> tensor<f32>
%42 = stablehlo.transpose %28#1, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
%43 = stablehlo.transpose %42, dims = [0] : (tensor<3xf32>) -> tensor<3xf32>
return %33, %37, %41, %43 : tensor<f32>, tensor<f32>, tensor<f32>, tensor<3xf32>
}
}
|
module @reactant_f_gener... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<6x2xf32> {tf.aliasing_output = 1 : i32}, %arg1: tensor<2x4xf32> {tf.aliasing_output = 2 : i32}) -> (tensor<4xf32>, tensor<6x2xf32>, tensor<2x4xf32>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<6x2xf32>) -> tensor<2x6xf32>
%1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<2x4xf32>) -> tensor<4x2xf32>
%2 = stablehlo.slice %0 [0:2, 0:1] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%3 = stablehlo.transpose %2, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%4 = stablehlo.reshape %3 : (tensor<1x2xf32>) -> tensor<2xf32>
%5 = stablehlo.transpose %4, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%6 = stablehlo.transpose %5, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%7 = stablehlo.reshape %6 : (tensor<2xf32>) -> tensor<1x2xf32>
%8 = stablehlo.transpose %7, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%9 = stablehlo.convert %8 : tensor<2x1xf32>
%10 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%11 = stablehlo.broadcast_in_dim %10, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%12 = stablehlo.broadcast_in_dim %9, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%13 = stablehlo.broadcast_in_dim %12, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%14 = stablehlo.dot_general %11, %13, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%15 = stablehlo.transpose %14, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%16 = stablehlo.reshape %15 : (tensor<1x4xf32>) -> tensor<4xf32>
%17 = stablehlo.transpose %16, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%18 = stablehlo.slice %0 [0:2, 1:2] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%19 = stablehlo.transpose %18, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%20 = stablehlo.reshape %19 : (tensor<1x2xf32>) -> tensor<2xf32>
%21 = stablehlo.transpose %20, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%22 = stablehlo.transpose %21, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%23 = stablehlo.reshape %22 : (tensor<2xf32>) -> tensor<1x2xf32>
%24 = stablehlo.transpose %23, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%25 = stablehlo.convert %24 : tensor<2x1xf32>
%26 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%27 = stablehlo.broadcast_in_dim %26, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%28 = stablehlo.broadcast_in_dim %25, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%29 = stablehlo.broadcast_in_dim %28, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%30 = stablehlo.dot_general %27, %29, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%31 = stablehlo.transpose %30, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%32 = stablehlo.reshape %31 : (tensor<1x4xf32>) -> tensor<4xf32>
%33 = stablehlo.transpose %32, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%34 = stablehlo.slice %0 [0:2, 2:3] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%35 = stablehlo.transpose %34, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%36 = stablehlo.reshape %35 : (tensor<1x2xf32>) -> tensor<2xf32>
%37 = stablehlo.transpose %36, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%38 = stablehlo.transpose %37, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%39 = stablehlo.reshape %38 : (tensor<2xf32>) -> tensor<1x2xf32>
%40 = stablehlo.transpose %39, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%41 = stablehlo.convert %40 : tensor<2x1xf32>
%42 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%43 = stablehlo.broadcast_in_dim %42, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%44 = stablehlo.broadcast_in_dim %41, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%45 = stablehlo.broadcast_in_dim %44, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%46 = stablehlo.dot_general %43, %45, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%47 = stablehlo.transpose %46, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%48 = stablehlo.reshape %47 : (tensor<1x4xf32>) -> tensor<4xf32>
%49 = stablehlo.transpose %48, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%50 = stablehlo.slice %0 [0:2, 3:4] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%51 = stablehlo.transpose %50, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%52 = stablehlo.reshape %51 : (tensor<1x2xf32>) -> tensor<2xf32>
%53 = stablehlo.transpose %52, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%54 = stablehlo.transpose %53, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%55 = stablehlo.reshape %54 : (tensor<2xf32>) -> tensor<1x2xf32>
%56 = stablehlo.transpose %55, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%57 = stablehlo.convert %56 : tensor<2x1xf32>
%58 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%59 = stablehlo.broadcast_in_dim %58, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%60 = stablehlo.broadcast_in_dim %57, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%61 = stablehlo.broadcast_in_dim %60, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%62 = stablehlo.dot_general %59, %61, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%63 = stablehlo.transpose %62, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%64 = stablehlo.reshape %63 : (tensor<1x4xf32>) -> tensor<4xf32>
%65 = stablehlo.transpose %64, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%66 = stablehlo.slice %0 [0:2, 4:5] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%67 = stablehlo.transpose %66, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%68 = stablehlo.reshape %67 : (tensor<1x2xf32>) -> tensor<2xf32>
%69 = stablehlo.transpose %68, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%70 = stablehlo.transpose %69, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%71 = stablehlo.reshape %70 : (tensor<2xf32>) -> tensor<1x2xf32>
%72 = stablehlo.transpose %71, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%73 = stablehlo.convert %72 : tensor<2x1xf32>
%74 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%75 = stablehlo.broadcast_in_dim %74, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%76 = stablehlo.broadcast_in_dim %73, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%77 = stablehlo.broadcast_in_dim %76, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%78 = stablehlo.dot_general %75, %77, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%79 = stablehlo.transpose %78, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%80 = stablehlo.reshape %79 : (tensor<1x4xf32>) -> tensor<4xf32>
%81 = stablehlo.transpose %80, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%82 = stablehlo.slice %0 [0:2, 5:6] : (tensor<2x6xf32>) -> tensor<2x1xf32>
%83 = stablehlo.transpose %82, dims = [1, 0] : (tensor<2x1xf32>) -> tensor<1x2xf32>
%84 = stablehlo.reshape %83 : (tensor<1x2xf32>) -> tensor<2xf32>
%85 = stablehlo.transpose %84, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%86 = stablehlo.transpose %85, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
%87 = stablehlo.reshape %86 : (tensor<2xf32>) -> tensor<1x2xf32>
%88 = stablehlo.transpose %87, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32>
%89 = stablehlo.convert %88 : tensor<2x1xf32>
%90 = stablehlo.broadcast_in_dim %1, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%91 = stablehlo.broadcast_in_dim %90, dims = [0, 1] : (tensor<4x2xf32>) -> tensor<4x2xf32>
%92 = stablehlo.broadcast_in_dim %89, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%93 = stablehlo.broadcast_in_dim %92, dims = [0, 1] : (tensor<2x1xf32>) -> tensor<2x1xf32>
%94 = stablehlo.dot_general %91, %93, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<4x2xf32>, tensor<2x1xf32>) -> tensor<4x1xf32>
%95 = stablehlo.transpose %94, dims = [1, 0] : (tensor<4x1xf32>) -> tensor<1x4xf32>
%96 = stablehlo.reshape %95 : (tensor<1x4xf32>) -> tensor<4xf32>
%97 = stablehlo.transpose %96, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%98 = stablehlo.broadcast_in_dim %17, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%99 = stablehlo.broadcast_in_dim %98, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%100 = stablehlo.broadcast_in_dim %33, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%101 = stablehlo.broadcast_in_dim %100, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%102 = stablehlo.add %99, %101 : tensor<4xf32>
%103 = stablehlo.broadcast_in_dim %102, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%104 = stablehlo.broadcast_in_dim %103, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%105 = stablehlo.broadcast_in_dim %49, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%106 = stablehlo.broadcast_in_dim %105, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%107 = stablehlo.add %104, %106 : tensor<4xf32>
%108 = stablehlo.broadcast_in_dim %107, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%109 = stablehlo.broadcast_in_dim %108, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%110 = stablehlo.broadcast_in_dim %65, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%111 = stablehlo.broadcast_in_dim %110, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%112 = stablehlo.add %109, %111 : tensor<4xf32>
%113 = stablehlo.broadcast_in_dim %112, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%114 = stablehlo.broadcast_in_dim %113, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%115 = stablehlo.broadcast_in_dim %81, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%116 = stablehlo.broadcast_in_dim %115, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%117 = stablehlo.add %114, %116 : tensor<4xf32>
%118 = stablehlo.broadcast_in_dim %117, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%119 = stablehlo.broadcast_in_dim %118, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%120 = stablehlo.broadcast_in_dim %97, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%121 = stablehlo.broadcast_in_dim %120, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%122 = stablehlo.add %119, %121 : tensor<4xf32>
%123 = stablehlo.transpose %122, dims = [0] : (tensor<4xf32>) -> tensor<4xf32>
%124 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x6xf32>) -> tensor<6x2xf32>
%125 = stablehlo.transpose %1, dims = [1, 0] : (tensor<4x2xf32>) -> tensor<2x4xf32>
return %123, %124, %125 : tensor<4xf32>, tensor<6x2xf32>, tensor<2x4xf32>
}
} |
8943304
to
fb3d3cf
Compare
|
julia> @code_hlo mapped_sub(x_ra, y_ra)
module @reactant_mapped_sub attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> tensor<5x3x10xf32> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
%1 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%2 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%3 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%4 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%5 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%6 = stablehlo.reshape %5 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
%7 = stablehlo.reshape %4 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
%8 = stablehlo.reshape %3 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
%9 = stablehlo.reshape %2 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
%10 = stablehlo.reshape %1 : (tensor<10x1x3xf32>) -> tensor<1x10x3xf32>
%11 = stablehlo.transpose %6, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
%12 = stablehlo.transpose %7, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
%13 = stablehlo.transpose %8, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
%14 = stablehlo.transpose %9, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
%15 = stablehlo.transpose %10, dims = [0, 2, 1] : (tensor<1x10x3xf32>) -> tensor<1x3x10xf32>
%16 = stablehlo.concatenate %11, %12, %13, %14, %15, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
%17 = stablehlo.transpose %arg1, dims = [1, 0, 2] : (tensor<3x5x10xf32>) -> tensor<5x3x10xf32>
%18 = stablehlo.subtract %16, %17 : tensor<5x3x10xf32>
%19 = stablehlo.slice %18 [0:1, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
%20 = stablehlo.slice %18 [4:5, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
%21 = stablehlo.slice %18 [3:4, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
%22 = stablehlo.slice %18 [2:3, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
%23 = stablehlo.slice %18 [1:2, 0:3, 0:10] : (tensor<5x3x10xf32>) -> tensor<1x3x10xf32>
%24 = stablehlo.concatenate %20, %21, %22, %23, %19, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
return %24 : tensor<5x3x10xf32>
}
} julia> @code_hlo compile_options=CompileOptions(; disable_auto_batching_passes=true) mapped_sub(x_ra, y_ra)
module @reactant_mapped_sub attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<3x5x10xf32>, %arg1: tensor<3x5x10xf32>) -> tensor<5x3x10xf32> attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
%1 = stablehlo.transpose %arg1, dims = [2, 1, 0] : (tensor<3x5x10xf32>) -> tensor<10x5x3xf32>
%2 = stablehlo.slice %0 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%3 = stablehlo.slice %1 [0:10, 0:1, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%4 = stablehlo.subtract %2, %3 : tensor<10x1x3xf32>
%5 = stablehlo.slice %0 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%6 = stablehlo.slice %1 [0:10, 1:2, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%7 = stablehlo.subtract %5, %6 : tensor<10x1x3xf32>
%8 = stablehlo.slice %0 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%9 = stablehlo.slice %1 [0:10, 2:3, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%10 = stablehlo.subtract %8, %9 : tensor<10x1x3xf32>
%11 = stablehlo.slice %0 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%12 = stablehlo.slice %1 [0:10, 3:4, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%13 = stablehlo.subtract %11, %12 : tensor<10x1x3xf32>
%14 = stablehlo.slice %0 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%15 = stablehlo.slice %1 [0:10, 4:5, 0:3] : (tensor<10x5x3xf32>) -> tensor<10x1x3xf32>
%16 = stablehlo.subtract %14, %15 : tensor<10x1x3xf32>
%17 = stablehlo.transpose %4, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
%18 = stablehlo.transpose %7, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
%19 = stablehlo.transpose %10, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
%20 = stablehlo.transpose %13, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
%21 = stablehlo.transpose %16, dims = [1, 2, 0] : (tensor<10x1x3xf32>) -> tensor<1x3x10xf32>
%22 = stablehlo.concatenate %17, %18, %19, %20, %21, dim = 0 : (tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>, tensor<1x3x10xf32>) -> tensor<5x3x10xf32>
return %22 : tensor<5x3x10xf32>
}
} The concatenate ordering gets reversed somehow |
99f084b
to
dde5a3e
Compare
Now this is interesting
probably an fp64 issue on TPUs. though strange that it never showed up before EDIT: seems to fail also on main |
9d20003
to
1a32cf5
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reactant.jl Benchmarks
Benchmark suite | Current: 7588c64 | Previous: 2131b27 | Ratio |
---|---|---|---|
DeepONet ([64, 1024], [1, 128])/forward/CPU/Default |
0.0018566370000000002 s |
0.0024067840000000004 s |
0.77 |
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableScatterGatherPad |
0.0014481490000000001 s |
0.002254101 s |
0.64 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadAfterEnzyme |
0.0040894830000000005 s |
0.005826802000000001 s |
0.70 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultAfterEnzyme |
0.004064674 s |
0.005924193 s |
0.69 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadBeforeEnzyme |
0.004104079 s |
0.005839057000000001 s |
0.70 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadAll |
0.003928868 s |
0.005868277000000001 s |
0.67 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadBeforeEnzyme |
0.004189838 s |
0.00568523 s |
0.74 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisablePadAll |
0.003848447 s |
0.005890702 s |
0.65 |
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableScatterGather |
0.001549824 s |
0.002402238 s |
0.65 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultAll |
0.0040704750000000005 s |
0.006035788 s |
0.67 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeAfterEnzyme |
0.004056754 s |
0.005928088000000001 s |
0.68 |
DeepONet ([64, 1024], [1, 128])/forward/CPU/XLA |
0.001910898 s |
0.0027047720000000003 s |
0.71 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/XLA |
0.003902 s |
0.005446732 s |
0.72 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherAfterEnzyme |
0.004108513 s |
0.005861321 s |
0.70 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DefaultBeforeEnzyme |
0.003987011 s |
0.0060816590000000005 s |
0.66 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeBeforeEnzyme |
0.003943455 s |
0.005698693 s |
0.69 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherPadAfterEnzyme |
0.004052483 s |
0.006128887 s |
0.66 |
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisablePad |
0.00156591 s |
0.00244653 s |
0.64 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableTransposeReshapeAll |
0.004135262000000001 s |
0.0060840880000000005 s |
0.68 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherAll |
0.003938176 s |
0.006018556 s |
0.65 |
DeepONet ([64, 1024], [1, 128])/forward/CPU/DisableTransposeReshape |
0.001965469 s |
0.0026089150000000003 s |
0.75 |
DeepONet ([64, 1024], [1, 128])/backward/CPU/DisableScatterGatherBeforeEnzyme |
0.0040453550000000005 s |
0.005988157 s |
0.68 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisablePad |
0.002070267 s |
0.00207267 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherAll |
0.0006746590000000001 s |
0.000659719 s |
1.02 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/XLA |
0.0007607450000000001 s |
0.000781538 s |
0.97 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeAll |
0.007142959000000001 s |
0.007126589 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultBeforeEnzyme |
0.0006859860000000001 s |
0.0006768330000000001 s |
1.01 |
FNO [64, 64, 1, 4]/backward/CUDA/DefaultAll |
0.00294448 s |
0.002952203 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableScatterGatherPad |
0.002055082 s |
0.002088262 s |
0.98 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeBeforeEnzyme |
0.007237473 s |
0.007271700000000001 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/DefaultBeforeEnzyme |
0.002992541 s |
0.003008483 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/XLA |
0.00040807200000000005 s |
0.00031795 s |
1.28 |
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadBeforeEnzyme |
0.002995333 s |
0.0030012750000000003 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadAll |
0.002970879 s |
0.0029658870000000004 s |
1.00 |
FNO [64, 64, 1, 4]/forward/CUDA/DisablePad |
0.001102949 s |
0.001096432 s |
1.01 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadBeforeEnzyme |
0.0072452820000000005 s |
0.007244951 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadBeforeEnzyme |
0.007201009 s |
0.007269816 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherAfterEnzyme |
0.007177705 s |
0.007139463 s |
1.01 |
FNO [64, 64, 1, 4]/forward/CUDA/DisableScatterGatherPad |
0.001089207 s |
0.00109638 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadAll |
0.0006744450000000001 s |
0.0006655410000000001 s |
1.01 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableTransposeReshape |
0.000383983 s |
0.00031572100000000004 s |
1.22 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableScatterGatherPad |
0.003333565 s |
0.003225472 s |
1.03 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableScatterGather |
0.003335367 s |
0.0031815460000000004 s |
1.05 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultAfterEnzyme |
0.000668823 s |
0.0006713520000000001 s |
1.00 |
ViT tiny [256, 256, 3, 4]/backward/CUDA/XLA |
0.012436622000000001 s |
0.012320307 s |
1.01 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherAfterEnzyme |
0.002930331 s |
0.0029470190000000004 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultAfterEnzyme |
0.007131237 s |
0.0071298460000000004 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableScatterGather |
0.000323033 s |
0.000312852 s |
1.03 |
FNO [64, 64, 1, 4]/forward/CUDA/DisableTransposeReshape |
0.00113328 s |
0.0011316800000000001 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadAll |
0.002970673 s |
0.0029653120000000003 s |
1.00 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/XLA |
0.0033959880000000004 s |
0.003401124 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherBeforeEnzyme |
0.007233291 s |
0.007224094 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeAll |
0.003053605 s |
0.0030645620000000003 s |
1.00 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisableTransposeReshape |
0.0033168470000000004 s |
0.003162512 s |
1.05 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadAll |
0.007110697 s |
0.007115695000000001 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadAfterEnzyme |
0.007124266000000001 s |
0.0071315860000000005 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisablePad |
0.00032266500000000003 s |
0.00030670800000000005 s |
1.05 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeBeforeEnzyme |
0.003101855 s |
0.003136779 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherAfterEnzyme |
0.0006821370000000001 s |
0.0006719590000000001 s |
1.02 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeAfterEnzyme |
0.0006630350000000001 s |
0.0006703060000000001 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherPadAll |
0.007131274000000001 s |
0.007114193 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeBeforeEnzyme |
0.000689004 s |
0.000683454 s |
1.01 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultAll |
0.007136877000000001 s |
0.007128227000000001 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/XLA |
0.003098953 s |
0.0031011740000000004 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/XLA |
0.0072632420000000005 s |
0.0072945530000000005 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableTransposeReshape |
0.002096872 s |
0.002086091 s |
1.01 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableTransposeReshapeAfterEnzyme |
0.0030684840000000002 s |
0.0030906840000000002 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DefaultAll |
0.000678717 s |
0.0006604100000000001 s |
1.03 |
FNO [64, 64, 1, 4]/backward/CUDA/DisablePadAfterEnzyme |
0.002937916 s |
0.002950761 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableScatterGatherAll |
0.007114579 s |
0.007119805000000001 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadAfterEnzyme |
0.0006692240000000001 s |
0.000666676 s |
1.00 |
FNO [64, 64, 1, 4]/forward/CUDA/XLA |
0.001168523 s |
0.001170003 s |
1.00 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/DisablePad |
0.002515443 s |
0.0025442660000000003 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/XLA |
0.002147673 s |
0.002106246 s |
1.02 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisablePadAfterEnzyme |
0.007148421 s |
0.007130592000000001 s |
1.00 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherBeforeEnzyme |
0.0029982370000000004 s |
0.003008989 s |
1.00 |
ViT tiny [256, 256, 3, 4]/forward/CUDA/Default |
0.0025684220000000003 s |
0.0026219620000000003 s |
0.98 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherBeforeEnzyme |
0.0006927880000000001 s |
0.000684178 s |
1.01 |
FNO [64, 64, 1, 4]/forward/CUDA/DisableScatterGather |
0.0010950060000000001 s |
0.0010816130000000001 s |
1.01 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DisableTransposeReshapeAfterEnzyme |
0.00714066 s |
0.007133554 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadAfterEnzyme |
0.0006785150000000001 s |
0.000675371 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadAll |
0.00068933 s |
0.0006535110000000001 s |
1.05 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherAll |
0.002947022 s |
0.002966312 s |
0.99 |
FNO [64, 64, 1, 4]/backward/CUDA/DefaultAfterEnzyme |
0.002924631 s |
0.0029365600000000004 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableTransposeReshapeAll |
0.000680089 s |
0.0006594950000000001 s |
1.03 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/Default |
0.002071323 s |
0.0020791380000000003 s |
1.00 |
FNO [64, 64, 1, 4]/forward/CUDA/Default |
0.001097006 s |
0.0010892480000000001 s |
1.01 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/DisableScatterGatherPad |
0.000309783 s |
0.00031698800000000004 s |
0.98 |
VGG11 bn=true [224, 224, 3, 4]/forward/CUDA/DisableScatterGather |
0.002055575 s |
0.002082716 s |
0.99 |
ViT tiny [256, 256, 3, 4]/backward/CUDA/DefaultAll |
0.010291049 s |
0.010472342 s |
0.98 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisablePadBeforeEnzyme |
0.000715529 s |
0.000728067 s |
0.98 |
DeepONet ([64, 1024], [1, 128])/forward/CUDA/Default |
0.000384106 s |
0.00031147300000000004 s |
1.23 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadAfterEnzyme |
0.0029304390000000004 s |
0.002953442 s |
0.99 |
FNO [64, 64, 1, 4]/backward/CUDA/DisableScatterGatherPadBeforeEnzyme |
0.002994436 s |
0.003013833 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/backward/CUDA/DefaultBeforeEnzyme |
0.007188503000000001 s |
0.007236543000000001 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/backward/CUDA/DisableScatterGatherPadBeforeEnzyme |
0.000683378 s |
0.000674917 s |
1.01 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadAfterEnzyme |
0.00478776 s |
0.00461879 s |
1.04 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisablePad |
0.00132105 s |
0.00130101 s |
1.02 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadAll |
0.0030123600000000004 s |
0.0030014100000000004 s |
1.00 |
ViT tiny [256, 256, 3, 4]/backward/TPU/XLA |
0.00277333 s |
0.0027932 s |
0.99 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadBeforeEnzyme |
0.000348 s |
0.00034408 s |
1.01 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeAfterEnzyme |
0.00034279000000000004 s |
0.00032938 s |
1.04 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherAfterEnzyme |
0.0028914500000000003 s |
0.00286118 s |
1.01 |
FNO [64, 64, 1, 4]/backward/TPU/DefaultAll |
0.003005329 s |
0.00301053 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadAll |
0.00476733 s |
0.0046343600000000006 s |
1.03 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadBeforeEnzyme |
0.00033307000000000003 s |
0.00034229000000000003 s |
0.97 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/Default |
0.00017654000000000001 s |
0.00015838 s |
1.11 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadBeforeEnzyme |
0.00477038 s |
0.004628500000000001 s |
1.03 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableScatterGatherPad |
0.00130867 s |
0.00129856 s |
1.01 |
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeAfterEnzyme |
0.00299545 s |
0.00298065 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherBeforeEnzyme |
0.00033365 s |
0.0003303 s |
1.01 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherAfterEnzyme |
0.0003347 s |
0.00033314 s |
1.00 |
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableScatterGatherPad |
0.00061604 s |
0.0006215400000000001 s |
0.99 |
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeAll |
0.00300531 s |
0.00296983 s |
1.01 |
FNO [64, 64, 1, 4]/backward/TPU/DisableTransposeReshapeBeforeEnzyme |
0.0030023800000000002 s |
0.0029898100000000003 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultAfterEnzyme |
0.00476489 s |
0.00463141 s |
1.03 |
FNO [64, 64, 1, 4]/forward/TPU/DisableScatterGather |
0.00110545 s |
0.0011204300000000002 s |
0.99 |
FNO [64, 64, 1, 4]/backward/TPU/DisablePadBeforeEnzyme |
0.00301042 s |
0.0029917700000000004 s |
1.01 |
ViT tiny [256, 256, 3, 4]/forward/TPU/XLA |
0.00101647 s |
0.00101515 s |
1.00 |
ViT tiny [256, 256, 3, 4]/backward/TPU/DefaultAll |
0.0024631600000000003 s |
0.00246533 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/Default |
0.0013048900000000002 s |
0.00128643 s |
1.01 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableScatterGatherPad |
0.00018294 s |
0.0002035 s |
0.90 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultAll |
0.00034422 s |
0.00036082000000000003 s |
0.95 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultAll |
0.00475847 s |
0.00462008 s |
1.03 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableScatterGather |
0.00018494 s |
0.00019098900000000002 s |
0.97 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableTransposeReshape |
0.00130736 s |
0.00128951 s |
1.01 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadBeforeEnzyme |
0.00301446 s |
0.00299278 s |
1.01 |
ViT tiny [256, 256, 3, 4]/forward/TPU/DisablePad |
0.0006243500000000001 s |
0.00062405 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultBeforeEnzyme |
0.0003372 s |
0.00036066 s |
0.93 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisablePad |
0.00018842000000000002 s |
0.00019264 s |
0.98 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/XLA |
0.00034975 s |
0.00037036 s |
0.94 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeAll |
0.004752050000000001 s |
0.0046303 s |
1.03 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherAll |
0.00300896 s |
0.0030121310000000004 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadAll |
0.00033156900000000004 s |
0.00034628 s |
0.96 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeBeforeEnzyme |
0.004766109 s |
0.00461928 s |
1.03 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableTransposeReshapeAfterEnzyme |
0.00477023 s |
0.0046304300000000005 s |
1.03 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherAll |
0.00474051 s |
0.004631290000000001 s |
1.02 |
FNO [64, 64, 1, 4]/backward/TPU/DisablePadAfterEnzyme |
0.0028922400000000003 s |
0.0028658200000000003 s |
1.01 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherBeforeEnzyme |
0.00476063 s |
0.00462901 s |
1.03 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherAfterEnzyme |
0.004748399 s |
0.004632819000000001 s |
1.02 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/XLA |
0.0012423500000000001 s |
0.00122663 s |
1.01 |
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableTransposeReshape |
0.00063924 s |
0.00062295 s |
1.03 |
ViT tiny [256, 256, 3, 4]/forward/TPU/Default |
0.00063086 s |
0.00060841 s |
1.04 |
FNO [64, 64, 1, 4]/backward/TPU/DefaultBeforeEnzyme |
0.0030050000000000003 s |
0.00301759 s |
1.00 |
FNO [64, 64, 1, 4]/forward/TPU/Default |
0.00111855 s |
0.00111681 s |
1.00 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/XLA |
0.00465067 s |
0.00453087 s |
1.03 |
FNO [64, 64, 1, 4]/forward/TPU/XLA |
0.00138557 s |
0.00139206 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/DisableTransposeReshape |
0.00017470000000000002 s |
0.00018154 s |
0.96 |
FNO [64, 64, 1, 4]/backward/TPU/DisablePadAll |
0.003001 s |
0.003002799 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/forward/TPU/XLA |
0.00026143 s |
0.00028878000000000004 s |
0.91 |
FNO [64, 64, 1, 4]/forward/TPU/DisableScatterGatherPad |
0.00110544 s |
0.0011168900000000002 s |
0.99 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadAfterEnzyme |
0.00476395 s |
0.00463045 s |
1.03 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DefaultAfterEnzyme |
0.00033756000000000004 s |
0.00036080000000000004 s |
0.94 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisablePadAll |
0.00479154 s |
0.00462816 s |
1.04 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DefaultBeforeEnzyme |
0.00475882 s |
0.00464774 s |
1.02 |
VGG11 bn=true [224, 224, 3, 4]/forward/TPU/DisableScatterGather |
0.00130991 s |
0.00129289 s |
1.01 |
ViT tiny [256, 256, 3, 4]/forward/TPU/DisableScatterGather |
0.00062874 s |
0.00061041 s |
1.03 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherBeforeEnzyme |
0.00301922 s |
0.0030071300000000002 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeBeforeEnzyme |
0.00033915000000000003 s |
0.00033028 s |
1.03 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherPadAfterEnzyme |
0.00033798000000000004 s |
0.00033678 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadAll |
0.00034544 s |
0.00032839 s |
1.05 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableScatterGatherAll |
0.00033853 s |
0.00033048 s |
1.02 |
FNO [64, 64, 1, 4]/forward/TPU/DisableTransposeReshape |
0.0011446100000000001 s |
0.00114841 s |
1.00 |
FNO [64, 64, 1, 4]/backward/TPU/DefaultAfterEnzyme |
0.00287951 s |
0.0028636900000000003 s |
1.01 |
FNO [64, 64, 1, 4]/backward/TPU/DisableScatterGatherPadAfterEnzyme |
0.0029170000000000003 s |
0.0028570400000000004 s |
1.02 |
VGG11 bn=true [224, 224, 3, 4]/backward/TPU/DisableScatterGatherPadBeforeEnzyme |
0.0047691 s |
0.0046275100000000005 s |
1.03 |
FNO [64, 64, 1, 4]/forward/TPU/DisablePad |
0.0011034 s |
0.00111677 s |
0.99 |
FNO [64, 64, 1, 4]/backward/TPU/XLA |
0.00318736 s |
0.0031743400000000003 s |
1.00 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisableTransposeReshapeAll |
0.00034306 s |
0.00033025000000000003 s |
1.04 |
DeepONet ([64, 1024], [1, 128])/backward/TPU/DisablePadAfterEnzyme |
0.00033275000000000004 s |
0.0003396 s |
0.98 |
This comment was automatically generated by workflow using github-action-benchmark.
1a32cf5
to
70aa857
Compare
3120869
to
7588c64
Compare
@wsmoses this should be good to go on my end. I marked the acos tests as broken for now. opened an issue upstream openxla/xla#31796 |
Uh oh!
There was an error while loading. Please reload this page.