@@ -2968,4 +2968,130 @@ end
2968
2968
]
2969
2969
end
2970
2970
2971
+ @noinline function batch_norm_inference (
2972
+ operand:: TracedRArray{T,N} ,
2973
+ scale:: Union{TracedRArray{T,1},Nothing} ,
2974
+ offset:: Union{TracedRArray{T,1},Nothing} ,
2975
+ mean:: TracedRArray{T,1} ,
2976
+ variance:: TracedRArray{T,1} ;
2977
+ epsilon,
2978
+ feature_index:: Int64 ,
2979
+ location= mlir_stacktrace (" batch_norm_inference" , @__FILE__ , @__LINE__ ),
2980
+ ) where {T,N}
2981
+ len = size (operand, feature_index)
2982
+ @assert length (mean) == length (variance) == len
2983
+
2984
+ if scale === nothing
2985
+ scale = fill (T (1 ), len; location)
2986
+ else
2987
+ @assert size (scale) == (len,)
2988
+ end
2989
+
2990
+ if offset === nothing
2991
+ offset = fill (T (0 ), len; location)
2992
+ else
2993
+ @assert size (offset) == (len,)
2994
+ end
2995
+
2996
+ return TracedRArray {T,N} (
2997
+ (),
2998
+ MLIR. IR. result (
2999
+ stablehlo. batch_norm_inference (
3000
+ operand. mlir_data,
3001
+ scale. mlir_data,
3002
+ offset. mlir_data,
3003
+ mean. mlir_data,
3004
+ variance. mlir_data;
3005
+ epsilon= Float32 (epsilon),
3006
+ feature_index= feature_index - 1 ,
3007
+ location,
3008
+ ),
3009
+ 1 ,
3010
+ ),
3011
+ size (operand),
3012
+ )
3013
+ end
3014
+
3015
+ @noinline function batch_norm_training (
3016
+ operand:: TracedRArray{T,N} ,
3017
+ scale:: Union{TracedRArray{T,1},Nothing} ,
3018
+ offset:: Union{TracedRArray{T,1},Nothing} ;
3019
+ epsilon,
3020
+ feature_index:: Int64 ,
3021
+ location= mlir_stacktrace (" batch_norm_training" , @__FILE__ , @__LINE__ ),
3022
+ ) where {T,N}
3023
+ len = size (operand, feature_index)
3024
+
3025
+ if scale === nothing
3026
+ scale = fill (T (1 ), len; location)
3027
+ else
3028
+ @assert size (scale) == (len,)
3029
+ end
3030
+
3031
+ if offset === nothing
3032
+ offset = fill (T (0 ), len; location)
3033
+ else
3034
+ @assert size (offset) == (len,)
3035
+ end
3036
+
3037
+ batch_norm_train_op = stablehlo. batch_norm_training (
3038
+ operand. mlir_data,
3039
+ scale. mlir_data,
3040
+ offset. mlir_data;
3041
+ epsilon= Float32 (epsilon),
3042
+ feature_index= feature_index - 1 ,
3043
+ location,
3044
+ )
3045
+
3046
+ return (
3047
+ TracedRArray {T,N} ((), MLIR. IR. result (batch_norm_train_op, 1 ), size (operand)),
3048
+ TracedRArray {T,1} ((), MLIR. IR. result (batch_norm_train_op, 2 ), (len,)),
3049
+ TracedRArray {T,1} ((), MLIR. IR. result (batch_norm_train_op, 3 ), (len,)),
3050
+ )
3051
+ end
3052
+
3053
+ @noinline function batch_norm_grad (
3054
+ operand:: TracedRArray{T,N} ,
3055
+ scale:: Union{TracedRArray{T,1},Nothing} ,
3056
+ mean:: TracedRArray{T,1} ,
3057
+ variance:: TracedRArray{T,1} ,
3058
+ grad_output:: TracedRArray{T,N} ;
3059
+ epsilon,
3060
+ feature_index:: Int64 ,
3061
+ location= mlir_stacktrace (" batch_norm_grad" , @__FILE__ , @__LINE__ ),
3062
+ ) where {T,N}
3063
+ len = size (operand, feature_index)
3064
+ @assert length (mean) == length (variance) == len
3065
+ @assert size (grad_output) == size (operand)
3066
+
3067
+ has_affine = scale != = nothing
3068
+
3069
+ if ! has_affine
3070
+ scale = fill (T (1 ), len; location)
3071
+ else
3072
+ @assert size (scale) == (len,)
3073
+ end
3074
+
3075
+ batch_norm_grad_op = stablehlo. batch_norm_grad (
3076
+ operand. mlir_data,
3077
+ scale. mlir_data,
3078
+ mean. mlir_data,
3079
+ variance. mlir_data,
3080
+ grad_output. mlir_data;
3081
+ epsilon= Float32 (epsilon),
3082
+ feature_index= feature_index - 1 ,
3083
+ location,
3084
+ )
3085
+
3086
+ grad_operand = TracedRArray {T,N} (
3087
+ (), MLIR. IR. result (batch_norm_grad_op, 1 ), size (operand)
3088
+ )
3089
+ grad_scale = TracedRArray {T,1} ((), MLIR. IR. result (batch_norm_grad_op, 2 ), (len,))
3090
+ grad_offset = TracedRArray {T,1} ((), MLIR. IR. result (batch_norm_grad_op, 3 ), (len,))
3091
+
3092
+ return (
3093
+ grad_operand, has_affine ? grad_scale : nothing , has_affine ? grad_offset : nothing
3094
+ )
3095
+ end
3096
+
2971
3097
end # module Ops
0 commit comments