@@ -177,8 +177,14 @@ function fill(v, ::Tuple{}; location=mlir_stacktrace("fill", @__FILE__, @__LINE_
177
177
return fill (v, Int[]; location)
178
178
end
179
179
180
- function fill (number:: TracedRNumber{T} , shape:: Vector{Int} ; location) where {T}
181
- return Base. fill (number, Tuple (shape))
180
+ function fill (
181
+ number:: TracedRNumber{T} ,
182
+ shape:: Vector{Int} ;
183
+ location= mlir_stacktrace (" fill" , @__FILE__ , @__LINE__ ),
184
+ ) where {T}
185
+ return broadcast_in_dim (
186
+ TracedRArray {T,0} ((), number. mlir_data, ()), Int64[], shape; location
187
+ )
182
188
end
183
189
184
190
for (T, mlir_func) in (
@@ -2746,4 +2752,95 @@ end
2746
2752
return TracedRArray {T,ndims(res)} ((), res, size (res))
2747
2753
end
2748
2754
2755
+ # Currently this is very simplistic and doesn't linearize/delinearize and supports only
2756
+ # a single argument (similar to how Julia's mapslices works)
2757
+ @noinline function batch (
2758
+ f:: F ,
2759
+ A:: TracedRArray{T,N} ,
2760
+ dims:: Vector{Int} ;
2761
+ location= mlir_stacktrace (" batch" , @__FILE__ , @__LINE__ ),
2762
+ ) where {F,T,N}
2763
+ sort! (dims)
2764
+
2765
+ # First we permute and make sure the batch dims are at the beginning
2766
+ batch_dims = Int64[i for i in 1 : N if i ∉ dims]
2767
+ permutation = zeros (Int64, N)
2768
+ for (i, d) in enumerate (batch_dims)
2769
+ permutation[i] = d
2770
+ end
2771
+ for (i, d) in enumerate (dims)
2772
+ permutation[i + length (batch_dims)] = d
2773
+ end
2774
+
2775
+ A = Ops. transpose (A, permutation; location)
2776
+
2777
+ sample_input = fill (T (0 ), [size (A, i) for i in (length (batch_dims) + 1 ): N]; location)
2778
+ # TODO : detect and forbid internal mutations
2779
+ mlir_fn_res = Reactant. TracedUtils. make_mlir_fn (
2780
+ f,
2781
+ (sample_input,),
2782
+ (),
2783
+ " unbatched_" * string (f),
2784
+ false ;
2785
+ args_in_result= :none ,
2786
+ do_transpose= false ,
2787
+ )
2788
+
2789
+ @assert ! mlir_fn_res. fnwrapped " Currently we don't support batching closures."
2790
+
2791
+ func = mlir_fn_res. f
2792
+ @assert MLIR. IR. nregions (func) == 1
2793
+
2794
+ result = only (mlir_fn_res. linear_results)
2795
+ batch_shape = [size (A, i) for i in 1 : length (batch_dims)]
2796
+
2797
+ if result isa TracedRArray
2798
+ @assert ndims (result) == ndims (sample_input)
2799
+ output_type = MLIR. IR. TensorType (
2800
+ vcat (batch_shape, collect (Int64, size (result))),
2801
+ MLIR. IR. Type (unwrapped_eltype (result)),
2802
+ )
2803
+ elseif result isa TracedRNumber
2804
+ output_type = MLIR. IR. TensorType (
2805
+ batch_shape, MLIR. IR. Type (unwrapped_eltype (result))
2806
+ )
2807
+ else
2808
+ error (" Unsupported result type $(typeof (result)) " )
2809
+ end
2810
+
2811
+ batched_result = batch ([A], [output_type], batch_shape; fn= func, location)[1 ]
2812
+
2813
+ if result isa TracedRNumber
2814
+ batched_result = Ops. reshape (
2815
+ batched_result, vcat (batch_shape, ones (Int64, ndims (sample_input))); location
2816
+ )
2817
+ end
2818
+
2819
+ return Ops. transpose (batched_result, invperm (permutation); location)
2820
+ end
2821
+
2822
+ @noinline function batch (
2823
+ inputs:: Vector{<:Union{<:TracedRArray,<:MLIR.IR.Value}} ,
2824
+ output_types:: Vector{<:MLIR.IR.Type} ,
2825
+ batch_shape:: Vector{Int64} ;
2826
+ fn,
2827
+ location= mlir_stacktrace (" batch" , @__FILE__ , @__LINE__ ),
2828
+ )
2829
+ op = MLIR. Dialects. enzyme. batch (
2830
+ [i isa TracedRArray ? i. mlir_data : i for i in inputs];
2831
+ outputs= output_types,
2832
+ fn= MLIR. IR. FlatSymbolRefAttribute (
2833
+ String (Reactant. TracedUtils. get_attribute_by_name (fn, " sym_name" ))
2834
+ ),
2835
+ batch_shape= MLIR. IR. DenseArrayAttribute (batch_shape),
2836
+ location,
2837
+ )
2838
+
2839
+ return [
2840
+ TracedRArray {MLIR.IR.julia_type(eltype(out_type)),ndims(out_type)} (
2841
+ (), MLIR. IR. result (op, i), size (out_type)
2842
+ ) for (i, out_type) in enumerate (output_types)
2843
+ ]
2844
+ end
2845
+
2749
2846
end # module Ops
0 commit comments