Skip to content

Commit f2a91bf

Browse files
authored
feat: more coverage for NNlib functions (#258)
* feat: use dynamic slicing * feat: special case `gather!` for the most common cases * feat: use `@trace` to implement softmax * refactor: directly overload inplace conv routine from NNlib * refactor: overload inplace pooling layers * refactor: overload inplace batched matmul * fix: reactant needs latest reactant core * fix: temporarily avoid tracing in softmax and logsoftmax
1 parent 9d666f8 commit f2a91bf

File tree

5 files changed

+85
-60
lines changed

5 files changed

+85
-60
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ LinearAlgebra = "1.10"
3939
NNlib = "0.9"
4040
OrderedCollections = "1"
4141
Preferences = "1.4"
42-
ReactantCore = "0.1"
42+
ReactantCore = "0.1.1"
4343
Reactant_jll = "0.0.24"
4444
Scratch = "1.2"
4545
Statistics = "1.10"

ext/ReactantNNlibExt.jl

Lines changed: 73 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ module ReactantNNlibExt
33
using NNlib
44
using Reactant:
55
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
6+
using ReactantCore: @trace
67
using LinearAlgebra: LinearAlgebra, triu
78

89
for (jlop, hloop) in (
@@ -20,38 +21,46 @@ for (jlop, hloop) in (
2021
end
2122
end
2223

23-
# TODO handle non finite cases
2424
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
2525
max_ = NNlib.fast_maximum(x; dims)
26-
#if all(isfinite, max_)
27-
@fastmath out .= exp.(x .- max_)
28-
#else
29-
# _zero, _one, _inf = T(0), T(1), T(Inf)
30-
# @fastmath @. out = ifelse(isequal(max_,_inf), ifelse(isequal(x,_inf), _one, _zero), exp(x - max_))
31-
#end
26+
# XXX: Once reverse mode of if is properly supported, we can make it @trace
27+
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
28+
# one_num = Reactant.promote_to(TracedRNumber{T}, 1)
29+
# @trace if all(isfinite, max_)
30+
@. out = exp(x - max_)
31+
# else
32+
# cond = max_ .== Inf
33+
# true_pred = ifelse.(x .== Inf, one_num, zero_num)
34+
# @. out = ifelse(cond, true_pred, exp(x - max_))
35+
# end
3236
tmp = dims isa Colon ? sum(out) : sum!(max_, out)
33-
return out ./= tmp
37+
out ./= tmp
38+
return out
3439
end
3540

3641
function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
3742
max_ = NNlib.fast_maximum(x; dims)
38-
# if all(isfinite, max_)
39-
@fastmath out .= x .- max_
43+
# XXX: Once reverse mode of if is properly supported, we can make it @trace
44+
# inf_num = Reactant.promote_to(TracedRNumber{T}, Inf)
45+
# zero_num = Reactant.promote_to(TracedRNumber{T}, 0)
46+
# @trace if all(isfinite, max_)
47+
@. out = x - max_
4048
# else
41-
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
42-
# @. out = ifelse(
43-
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
44-
# )
49+
# cond = max_ .== Inf
50+
# true_pred = ifelse.(x .== Inf, zero_num, -inf_num)
51+
# @. out = ifelse(cond, true_pred, x - max_)
4552
# end
4653
@fastmath log_ = log.(sum(exp, out; dims))
47-
return out .-= log_
54+
out .-= log_
55+
return out
4856
end
4957

50-
function NNlib.conv(
51-
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
58+
function NNlib.conv!(
59+
y::TracedRArray{T,N}, x::AnyTracedRArray, W::AnyTracedRArray, cdims::DenseConvDims
5260
) where {T,N}
53-
x = materialize_traced_array(x)
54-
W = materialize_traced_array(W)
61+
# StableHLO expects matching element types
62+
x = T.(materialize_traced_array(x))
63+
W = T.(materialize_traced_array(W))
5564

5665
kernel_size = NNlib.kernel_size(cdims)
5766
padding = NNlib.padding(cdims)
@@ -77,33 +86,31 @@ function NNlib.conv(
7786
pl, pr = padding[2i - 1], padding[2i]
7887
d = dilation[i]
7988
s = stride[i]
80-
81-
(size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
89+
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
8290
end
8391
output_batch_dim = input_batch_dim
8492
output_feature_dim = input_feature_dim
8593
output_spatial_dims = input_spatial_dims
8694

87-
output_shape = (output_spatial_shapes..., size(W, kernel_output_dim), size(x, N))
88-
89-
dimension_numbers = """
90-
#stablehlo.conv<raw
91-
input_batch_dimension = $(input_batch_dim - 1),
92-
input_feature_dimension = $(input_feature_dim - 1),
93-
input_spatial_dimensions = [$(join(input_spatial_dims .- 1, ", "))],
94-
kernel_output_feature_dimension = $(kernel_output_dim - 1),
95-
kernel_input_feature_dimension = $(kernel_input_dim - 1),
96-
kernel_spatial_dimensions = [$(join(kernel_spatial_dims .- 1, ", "))],
97-
output_batch_dimension = $( output_batch_dim - 1 ),
98-
output_feature_dimension = $( output_feature_dim - 1),
99-
output_spatial_dimensions = [$(join(output_spatial_dims .- 1, ", "))],
100-
>"""
101-
dimension_numbers = parse(Reactant.MLIR.IR.Attribute, dimension_numbers)
95+
#! format: off
96+
dimension_numbers = MLIR.API.stablehloConvDimensionNumbersGet(
97+
MLIR.IR.context(),
98+
Int64(input_batch_dim - 1),
99+
Int64(input_feature_dim - 1),
100+
length(input_spatial_dims), Int64[i - 1 for i in input_spatial_dims],
101+
Int64(kernel_input_dim - 1),
102+
Int64(kernel_output_dim - 1),
103+
length(kernel_spatial_dims), Int64[i - 1 for i in kernel_spatial_dims],
104+
Int64(output_batch_dim - 1),
105+
Int64(output_feature_dim - 1),
106+
length(output_spatial_dims), Int64[i - 1 for i in output_spatial_dims],
107+
)
108+
#! format: on
102109

103110
padding = Reactant.MLIR.IR.DenseElementsAttribute(
104111
reshape(collect(padding), (num_spatial_dims, 2))
105112
)
106-
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))
113+
result_type = Reactant.MLIR.IR.TensorType(size(y), Reactant.MLIR.IR.Type(T))
107114

108115
weight = W.mlir_data
109116
if !flipkernel
@@ -126,8 +133,8 @@ function NNlib.conv(
126133
feature_group_count,
127134
batch_group_count=1,
128135
)
129-
130-
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(conv), output_shape)
136+
y.mlir_data = Reactant.MLIR.IR.result(conv)
137+
return y
131138
end
132139

133140
function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
@@ -198,27 +205,39 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
198205
return TracedRArray{T,N}((), Reactant.MLIR.IR.result(reduction), size(result_type))
199206
end
200207

201-
function NNlib.maxpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
202-
return reduce_window(
203-
Reactant.MLIR.Dialects.stablehlo.maximum, x, pdims; init=typemin(T)
204-
)
208+
function NNlib.maxpool!(
209+
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
210+
) where {T}
211+
y.mlir_data =
212+
reduce_window(
213+
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
214+
).mlir_data
215+
return y
205216
end
206217

207-
function NNlib.meanpool(x::AnyTracedRArray{T}, pdims::NNlib.PoolDims) where {T}
208-
numel = prod(NNlib.kernel_size(pdims))
209-
return reduce_window(Reactant.MLIR.Dialects.stablehlo.add, x, pdims; init=zero(T)) ./
210-
T(numel)
218+
function NNlib.meanpool!(
219+
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
220+
) where {T}
221+
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
222+
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
223+
return y
211224
end
212225

213226
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
214-
NNlib.batched_adjoint(x::AnyTracedRArray{<:Real,3}) = NNlib.batched_transpose(x)
227+
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
228+
y = permutedims(x, (2, 1, 3))
229+
conj!(y)
230+
return y
231+
end
215232

216-
function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) where {T}
233+
function NNlib.batched_mul!(
234+
res::TracedRArray{T1,3}, x::AnyTracedRArray{T2,3}, y::AnyTracedRArray{T3,3}
235+
) where {T1,T2,T3}
217236
if (size(x, 3) != size(y, 3) && size(x, 3) != 1 && size(y, 3) != 1) ||
218237
(size(x, 2) != size(y, 1))
219238
throw(
220239
DimensionMismatch(
221-
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_matmul.",
240+
lazy"size(x) = $(size(x)), size(y) = $(size(y)) inconsistent for batched_mul.",
222241
),
223242
)
224243
end
@@ -227,7 +246,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
227246

228247
B = max(size(x, 1), size(y, 1))
229248
out_shape = (B, size(x, 2), size(y, 3))
230-
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(x.mlir_data)))
249+
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))
231250

232251
if size(x, 1) != size(y, 1)
233252
if size(x, 1) == 1
@@ -244,7 +263,7 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
244263
prec = MLIR.IR.Attribute(
245264
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
246265
)
247-
res = TracedRArray{T,3}(
266+
tmp = TracedRArray{T1,3}(
248267
(),
249268
MLIR.IR.result(
250269
MLIR.Dialects.stablehlo.dot_general(
@@ -258,7 +277,8 @@ function NNlib.batched_mul(x::AnyTracedRArray{T,3}, y::AnyTracedRArray{T,3}) whe
258277
),
259278
size(resty),
260279
)
261-
return permutedims(res, (2, 3, 1))
280+
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
281+
return res
262282
end
263283

264284
function NNlib.pad_constant(

lib/ReactantCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ReactantCore"
22
uuid = "a3311ec8-5e00-46d5-b541-4f83e724a433"
33
authors = ["William Moses <[email protected]>", "Valentin Churavy <[email protected]>", "Sergio Sánchez Ramírez <[email protected]>", "Paul Berg <[email protected]>", "Avik Pal <[email protected]>"]
4-
version = "0.1.0"
4+
version = "0.1.1"
55

66
[deps]
77
ExpressionExplorer = "21656369-7473-754a-2065-74616d696c43"

lib/ReactantCore/src/ReactantCore.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ end
1515

1616
MissingTracedValue() = MissingTracedValue(())
1717

18+
const SPECIAL_SYMBOLS = [
19+
:(:), :nothing, :missing, :Inf, :Inf16, :Inf32, :Inf64, :Base, :Core
20+
]
21+
1822
# Code generation
1923
"""
2024
@trace <expr>
@@ -79,7 +83,7 @@ You need to ensure that all branches have the same type.
7983
8084
### Certain Symbols are Reserved
8185
82-
Symbols like `nothing`, `missing` and `:` are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
86+
Symbols like $(SPECIAL_SYMBOLS) are not allowed as variables in `@trace` expressions. While certain cases might work but these are not guaranteed to work. For
8387
example, the following will not work:
8488
8589
```julia
@@ -299,6 +303,4 @@ function error_if_return(expr)
299303
end
300304
end
301305

302-
const SPECIAL_SYMBOLS = [:(:), :nothing, :missing]
303-
304306
end

src/TracedRArray.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,8 +538,10 @@ function Base.mapreduce(
538538
dims = [dims]
539539
end
540540

541+
op_in_T = Core.Compiler.return_type(f, Tuple{T})
542+
541543
if isnothing(init)
542-
init = Base.reduce_empty(Base.BottomRF(op), Core.Compiler.return_type(f, Tuple{T}))
544+
init = Base.reduce_empty(Base.BottomRF(op), op_in_T)
543545
else
544546
init = init::T
545547
end
@@ -561,7 +563,8 @@ function Base.mapreduce(
561563
fnbody = MLIR.IR.Block(in_tys, [MLIR.IR.Location() for arg in in_tys])
562564

563565
args = (
564-
TracedRNumber{T}((), MLIR.IR.argument(fnbody, i)) for (i, ty) in enumerate(in_tys)
566+
TracedRNumber{op_in_T}((), MLIR.IR.argument(fnbody, i)) for
567+
(i, ty) in enumerate(in_tys)
565568
)
566569

567570
res = MLIR.IR.block!(fnbody) do

0 commit comments

Comments
 (0)