@@ -3,6 +3,7 @@ module ReactantNNlibExt
3
3
using NNlib
4
4
using Reactant:
5
5
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
6
+ using ReactantCore: @trace
6
7
using LinearAlgebra: LinearAlgebra, triu
7
8
8
9
for (jlop, hloop) in (
@@ -20,38 +21,46 @@ for (jlop, hloop) in (
20
21
end
21
22
end
22
23
23
- # TODO handle non finite cases
24
24
function NNlib. softmax! (out:: TracedRArray{T,N} , x:: AbstractArray ; dims= 1 ) where {T,N}
25
25
max_ = NNlib. fast_maximum (x; dims)
26
- # if all(isfinite, max_)
27
- @fastmath out .= exp .(x .- max_)
28
- # else
29
- # _zero, _one, _inf = T(0), T(1), T(Inf)
30
- # @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
31
- # end
26
+ # XXX : Once reverse mode of if is properly supported, we can make it @trace
27
+ # zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
28
+ # one_num = Reactant.promote_to(TracedRNumber{T}, 1)
29
+ # @trace if all(isfinite, max_)
30
+ @. out = exp (x - max_)
31
+ # else
32
+ # cond = max_ .== Inf
33
+ # true_pred = ifelse.(x .== Inf, one_num, zero_num)
34
+ # @. out = ifelse(cond, true_pred, exp(x - max_))
35
+ # end
32
36
tmp = dims isa Colon ? sum (out) : sum! (max_, out)
33
- return out ./= tmp
37
+ out ./= tmp
38
+ return out
34
39
end
35
40
36
41
function NNlib. logsoftmax! (out:: TracedRArray{T} , x:: AbstractArray ; dims= 1 ) where {T}
37
42
max_ = NNlib. fast_maximum (x; dims)
38
- # if all(isfinite, max_)
39
- @fastmath out .= x .- max_
43
+ # XXX : Once reverse mode of if is properly supported, we can make it @trace
44
+ # inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
45
+ # zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
46
+ # @trace if all(isfinite, max_)
47
+ @. out = x - max_
40
48
# else
41
- # _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
42
- # @. out = ifelse(
43
- # isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
44
- # )
49
+ # cond = max_ .== Inf
50
+ # true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
51
+ # @. out = ifelse(cond, true_pred, x - max_)
45
52
# end
46
53
@fastmath log_ = log .(sum (exp, out; dims))
47
- return out .- = log_
54
+ out .- = log_
55
+ return out
48
56
end
49
57
50
- function NNlib. conv (
51
- x :: AnyTracedRArray {T,N} , W:: AnyTracedRArray{T} , cdims:: DenseConvDims
58
+ function NNlib. conv! (
59
+ y :: TracedRArray {T,N} , x :: AnyTracedRArray , W:: AnyTracedRArray , cdims:: DenseConvDims
52
60
) where {T,N}
53
- x = materialize_traced_array (x)
54
- W = materialize_traced_array (W)
61
+ # StableHLO expects matching element types
62
+ x = T .(materialize_traced_array (x))
63
+ W = T .(materialize_traced_array (W))
55
64
56
65
kernel_size = NNlib. kernel_size (cdims)
57
66
padding = NNlib. padding (cdims)
@@ -77,33 +86,31 @@ function NNlib.conv(
77
86
pl, pr = padding[2 i - 1 ], padding[2 i]
78
87
d = dilation[i]
79
88
s = stride[i]
80
-
81
- (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
89
+ return (size (x, i) + pl + pr - d * (K - 1 ) - 1 ) ÷ s + 1
82
90
end
83
91
output_batch_dim = input_batch_dim
84
92
output_feature_dim = input_feature_dim
85
93
output_spatial_dims = input_spatial_dims
86
94
87
- output_shape = (output_spatial_shapes... , size (W, kernel_output_dim), size (x, N))
88
-
89
- dimension_numbers = """
90
- #stablehlo.conv<raw
91
- input_batch_dimension = $(input_batch_dim - 1 ) ,
92
- input_feature_dimension = $(input_feature_dim - 1 ) ,
93
- input_spatial_dimensions = [$(join (input_spatial_dims .- 1 , " , " )) ],
94
- kernel_output_feature_dimension = $(kernel_output_dim - 1 ) ,
95
- kernel_input_feature_dimension = $(kernel_input_dim - 1 ) ,
96
- kernel_spatial_dimensions = [$(join (kernel_spatial_dims .- 1 , " , " )) ],
97
- output_batch_dimension = $( output_batch_dim - 1 ) ,
98
- output_feature_dimension = $( output_feature_dim - 1 ) ,
99
- output_spatial_dimensions = [$(join (output_spatial_dims .- 1 , " , " )) ],
100
- >"""
101
- dimension_numbers = parse (Reactant. MLIR. IR. Attribute, dimension_numbers)
95
+ # ! format: off
96
+ dimension_numbers = MLIR. API. stablehloConvDimensionNumbersGet (
97
+ MLIR. IR. context (),
98
+ Int64 (input_batch_dim - 1 ),
99
+ Int64 (input_feature_dim - 1 ),
100
+ length (input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
101
+ Int64 (kernel_input_dim - 1 ),
102
+ Int64 (kernel_output_dim - 1 ),
103
+ length (kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
104
+ Int64 (output_batch_dim - 1 ),
105
+ Int64 (output_feature_dim - 1 ),
106
+ length (output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
107
+ )
108
+ # ! format: on
102
109
103
110
padding = Reactant. MLIR. IR. DenseElementsAttribute (
104
111
reshape (collect (padding), (num_spatial_dims, 2 ))
105
112
)
106
- result_type = Reactant. MLIR. IR. TensorType (output_shape , Reactant. MLIR. IR. Type (T))
113
+ result_type = Reactant. MLIR. IR. TensorType (size (y) , Reactant. MLIR. IR. Type (T))
107
114
108
115
weight = W. mlir_data
109
116
if ! flipkernel
@@ -126,8 +133,8 @@ function NNlib.conv(
126
133
feature_group_count,
127
134
batch_group_count= 1 ,
128
135
)
129
-
130
- return TracedRArray {T,N} ((), Reactant . MLIR . IR . result (conv), output_shape)
136
+ y . mlir_data = Reactant . MLIR . IR . result (conv)
137
+ return y
131
138
end
132
139
133
140
function reduce_window (f, x:: AnyTracedRArray{T,N} , pdims; init) where {T,N}
@@ -198,27 +205,39 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
198
205
return TracedRArray {T,N} ((), Reactant. MLIR. IR. result (reduction), size (result_type))
199
206
end
200
207
201
- function NNlib. maxpool (x:: AnyTracedRArray{T} , pdims:: NNlib.PoolDims ) where {T}
202
- return reduce_window (
203
- Reactant. MLIR. Dialects. stablehlo. maximum, x, pdims; init= typemin (T)
204
- )
208
+ function NNlib. maxpool! (
209
+ y:: TracedRArray{T} , x:: AnyTracedRArray , pdims:: NNlib.PoolDims
210
+ ) where {T}
211
+ y. mlir_data =
212
+ reduce_window (
213
+ Reactant. MLIR. Dialects. stablehlo. maximum, T .(x), pdims; init= typemin (T)
214
+ ). mlir_data
215
+ return y
205
216
end
206
217
207
- function NNlib. meanpool (x:: AnyTracedRArray{T} , pdims:: NNlib.PoolDims ) where {T}
208
- numel = prod (NNlib. kernel_size (pdims))
209
- return reduce_window (Reactant. MLIR. Dialects. stablehlo. add, x, pdims; init= zero (T)) ./
210
- T (numel)
218
+ function NNlib. meanpool! (
219
+ y:: TracedRArray{T} , x:: AnyTracedRArray , pdims:: NNlib.PoolDims
220
+ ) where {T}
221
+ res = reduce_window (Reactant. MLIR. Dialects. stablehlo. add, T .(x), pdims; init= zero (T))
222
+ y. mlir_data = (res ./ T (prod (NNlib. kernel_size (pdims)))). mlir_data
223
+ return y
211
224
end
212
225
213
226
NNlib. batched_transpose (x:: AnyTracedRArray{T,3} ) where {T} = permutedims (x, (2 , 1 , 3 ))
214
- NNlib. batched_adjoint (x:: AnyTracedRArray{<:Real,3} ) = NNlib. batched_transpose (x)
227
+ function NNlib. batched_adjoint (x:: AnyTracedRArray{T,3} ) where {T}
228
+ y = permutedims (x, (2 , 1 , 3 ))
229
+ conj! (y)
230
+ return y
231
+ end
215
232
216
- function NNlib. batched_mul (x:: AnyTracedRArray{T,3} , y:: AnyTracedRArray{T,3} ) where {T}
233
+ function NNlib. batched_mul! (
234
+ res:: TracedRArray{T1,3} , x:: AnyTracedRArray{T2,3} , y:: AnyTracedRArray{T3,3}
235
+ ) where {T1,T2,T3}
217
236
if (size (x, 3 ) != size (y, 3 ) && size (x, 3 ) != 1 && size (y, 3 ) != 1 ) ||
218
237
(size (x, 2 ) != size (y, 1 ))
219
238
throw (
220
239
DimensionMismatch (
221
- lazy " size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul ." ,
240
+ lazy " size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul ." ,
222
241
),
223
242
)
224
243
end
@@ -227,7 +246,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
227
246
228
247
B = max (size (x, 1 ), size (y, 1 ))
229
248
out_shape = (B, size (x, 2 ), size (y, 3 ))
230
- resty = MLIR. IR. TensorType (out_shape, eltype (MLIR. IR. type (x . mlir_data)))
249
+ resty = MLIR. IR. TensorType (out_shape, eltype (MLIR. IR. type (res . mlir_data)))
231
250
232
251
if size (x, 1 ) != size (y, 1 )
233
252
if size (x, 1 ) == 1
@@ -244,7 +263,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
244
263
prec = MLIR. IR. Attribute (
245
264
MLIR. API. stablehloPrecisionAttrGet (MLIR. IR. context (), " DEFAULT" )
246
265
)
247
- res = TracedRArray {T ,3} (
266
+ tmp = TracedRArray {T1 ,3} (
248
267
(),
249
268
MLIR. IR. result (
250
269
MLIR. Dialects. stablehlo. dot_general (
@@ -258,7 +277,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
258
277
),
259
278
size (resty),
260
279
)
261
- return permutedims (res, (2 , 3 , 1 ))
280
+ res. mlir_data = permutedims (tmp, (2 , 3 , 1 )). mlir_data
281
+ return res
262
282
end
263
283
264
284
function NNlib. pad_constant (
0 commit comments