Skip to content

Commit 3a0710d

Browse files
authored
feat: add dynamic_update_slice_const_prop pass + tril + triu (#334)
* feat: add dynamic_update_slice_const_prop pass * refactor: move linear algebra overloads to a different file * feat: add triu and tril impl * refactor: minimize batch_op * feat: add Ops.compare * refactor: use ops in base dispatches * refactor: move linear algebra tests * fix: tril defn and inplace ops * test: add inplace tests
1 parent 1bb0000 commit 3a0710d

File tree

9 files changed

+183
-119
lines changed

9 files changed

+183
-119
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,21 +298,6 @@ function NNlib.pad_constant(
298298
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
299299
end
300300

301-
function NNlib.make_causal_mask(x::AnyTracedRArray; dims::Int=2)
302-
len = size(x, dims)
303-
# directly generating booleans were causing an incorrect constant attribute generation
304-
# but the optimized IR removes the type case so we are probably ok
305-
mask = MLIR.IR.DenseElementsAttribute(collect(triu(fill(1, (len, len)))))
306-
return Reactant.promote_to(
307-
TracedRArray{Bool,2},
308-
TracedRArray{Int,2}(
309-
(),
310-
MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=mask), 1),
311-
(len, len),
312-
),
313-
)
314-
end
315-
316301
# XXX: reevaluate this manual optimization once
317302
# https://github.com/EnzymeAD/Enzyme-JAX/issues/164 is handled
318303
function NNlib.gather!(

src/Compiler.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,7 @@ const opt_passes::String = join(
245245
"pad_dot_general<1>(1)",
246246
"if_inline<1>",
247247
"if_to_select<1>",
248+
"dynamic_update_slice_const_prop",
248249
],
249250
';',
250251
) *

src/Ops.jl

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,4 +1014,36 @@ function select(
10141014
return TracedRNumber{T}((), res)
10151015
end
10161016

1017+
# comparison
1018+
function compare(
1019+
lhs::Union{TracedRArray{T},TracedRNumber{T}},
1020+
rhs::Union{TracedRArray{T},TracedRNumber{T}};
1021+
comparison_direction::String,
1022+
compare_type=nothing,
1023+
location=mlir_stacktrace("compare", @__FILE__, @__LINE__),
1024+
) where {T}
1025+
@assert comparison_direction in ("EQ", "NE", "GE", "GT", "LE", "LT")
1026+
@assert size(lhs) == size(rhs)
1027+
if lhs isa TracedRNumber
1028+
@assert rhs isa TracedRNumber
1029+
else
1030+
@assert rhs isa TracedRArray
1031+
end
1032+
1033+
res = MLIR.IR.result(
1034+
MLIR.Dialects.stablehlo.compare(
1035+
lhs.mlir_data,
1036+
rhs.mlir_data;
1037+
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
1038+
MLIR.IR.context(), comparison_direction
1039+
),
1040+
compare_type,
1041+
location,
1042+
),
1043+
1,
1044+
)
1045+
lhs isa TracedRNumber && return TracedRNumber{Bool}((), res)
1046+
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
1047+
end
1048+
10171049
end

src/Reactant.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,11 @@ include("utils.jl")
9898
include("ConcreteRArray.jl")
9999
include("TracedRNumber.jl")
100100
include("TracedRArray.jl")
101+
101102
include("Ops.jl")
102103

104+
include("linear_algebra.jl")
105+
103106
const TracedType = Union{TracedRArray,TracedRNumber,MissingTracedValue}
104107

105108
include("ControlFlow.jl")

src/TracedRArray.jl

Lines changed: 0 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -415,95 +415,6 @@ for (jlop, hloop, hlocomp, merge) in
415415
end
416416
end
417417

418-
function LinearAlgebra.mul!(
419-
@nospecialize(C::TracedRArray{T1,1}),
420-
@nospecialize(A::AnyTracedRArray{T2,2}),
421-
@nospecialize(B::AnyTracedRArray{T3,1}),
422-
α::Number=true,
423-
β::Number=false,
424-
) where {T1,T2,T3}
425-
# TODO: The reshape operations are not getting optimized, we should directly call dot_general
426-
rC = reshape(C, :, 1)
427-
LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β)
428-
C.mlir_data = get_mlir_data(vec(rC))
429-
return C
430-
end
431-
432-
function LinearAlgebra.mul!(
433-
@nospecialize(C::TracedRArray{T1,2}),
434-
@nospecialize(A::AnyTracedRArray{T2,2}),
435-
@nospecialize(B::AnyTracedRArray{T3,1}),
436-
α::Number=true,
437-
β::Number=false,
438-
) where {T1,T2,T3}
439-
LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β)
440-
return C
441-
end
442-
443-
function LinearAlgebra.mul!(
444-
@nospecialize(C::TracedRArray{T1,2}),
445-
@nospecialize(A::AnyTracedRArray{T2,2}),
446-
@nospecialize(B::AnyTracedRArray{T3,2}),
447-
α::Number=true,
448-
β::Number=false,
449-
) where {T1,T2,T3}
450-
if size(C) != (size(A, 1), size(B, 2))
451-
throw(
452-
DimensionMismatch(
453-
"C has size $(size(C)), A has size $(size(A)), B has size $(size(B))"
454-
),
455-
)
456-
end
457-
if size(A, 2) != size(B, 1)
458-
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
459-
end
460-
resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1))
461-
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
462-
MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0]
463-
)
464-
prec = MLIR.IR.Attribute(
465-
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
466-
)
467-
precar = MLIR.IR.Attribute([prec, prec])
468-
res = MLIR.IR.result(
469-
MLIR.Dialects.stablehlo.dot_general(
470-
get_mlir_data(A),
471-
get_mlir_data(B);
472-
result_0=resty,
473-
dot_dimension_numbers=dot_dimension_numbers,
474-
precision_config=precar,
475-
),
476-
1,
477-
)
478-
if iszero(β)
479-
if isone(α)
480-
C.mlir_data = res
481-
else
482-
C.mlir_data = MLIR.IR.result(
483-
MLIR.Dialects.stablehlo.multiply(
484-
res, broadcast_to_size(T1(α), size(C)).mlir_data
485-
),
486-
1,
487-
)
488-
end
489-
else
490-
α_res = MLIR.IR.result(
491-
MLIR.Dialects.stablehlo.multiply(
492-
res, broadcast_to_size(T1(α), size(C)).mlir_data
493-
),
494-
1,
495-
)
496-
β_C = MLIR.IR.result(
497-
MLIR.Dialects.stablehlo.multiply(
498-
C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data
499-
),
500-
1,
501-
)
502-
C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1)
503-
end
504-
return C
505-
end
506-
507418
function Enzyme.Compiler.active_reg_inner(
508419
::Type{TracedRArray{T,N}},
509420
seen::ST,

src/TracedRNumber.jl

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -151,19 +151,7 @@ for (jlop, hloop, hlocomp) in (
151151
function $(jlop)(
152152
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
153153
) where {T}
154-
return TracedRNumber{Bool}(
155-
(),
156-
MLIR.IR.result(
157-
MLIR.Dialects.stablehlo.$(hloop)(
158-
lhs.mlir_data,
159-
rhs.mlir_data;
160-
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
161-
MLIR.IR.context(), $hlocomp
162-
),
163-
),
164-
1,
165-
),
166-
)
154+
return Ops.compare(lhs, rhs; comparison_direction=$(hlocomp))
167155
end
168156

169157
function $(jlop)(@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs)) where {T}

src/linear_algebra.jl

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
function LinearAlgebra.mul!(
2+
@nospecialize(C::TracedRArray{T1,1}),
3+
@nospecialize(A::AnyTracedRArray{T2,2}),
4+
@nospecialize(B::AnyTracedRArray{T3,1}),
5+
α::Number=true,
6+
β::Number=false,
7+
) where {T1,T2,T3}
8+
# TODO: The reshape operations are not getting optimized, we should directly call dot_general
9+
rC = reshape(C, :, 1)
10+
LinearAlgebra.mul!(rC, A, reshape(B, :, 1), α, β)
11+
C.mlir_data = get_mlir_data(vec(rC))
12+
return C
13+
end
14+
15+
function LinearAlgebra.mul!(
16+
@nospecialize(C::TracedRArray{T1,2}),
17+
@nospecialize(A::AnyTracedRArray{T2,2}),
18+
@nospecialize(B::AnyTracedRArray{T3,1}),
19+
α::Number=true,
20+
β::Number=false,
21+
) where {T1,T2,T3}
22+
LinearAlgebra.mul!(C, A, reshape(B, :, 1), α, β)
23+
return C
24+
end
25+
26+
function LinearAlgebra.mul!(
27+
@nospecialize(C::TracedRArray{T1,2}),
28+
@nospecialize(A::AnyTracedRArray{T2,2}),
29+
@nospecialize(B::AnyTracedRArray{T3,2}),
30+
α::Number=true,
31+
β::Number=false,
32+
) where {T1,T2,T3}
33+
if size(C) != (size(A, 1), size(B, 2))
34+
throw(
35+
DimensionMismatch(
36+
"C has size $(size(C)), A has size $(size(A)), B has size $(size(B))"
37+
),
38+
)
39+
end
40+
if size(A, 2) != size(B, 1)
41+
throw(DimensionMismatch("A has size $(size(A)), B has size $(size(B))"))
42+
end
43+
resty = MLIR.IR.TensorType(size(C), MLIR.IR.Type(T1))
44+
dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
45+
MLIR.IR.context(), 0, [], 0, [], 1, [1], 1, [0]
46+
)
47+
prec = MLIR.IR.Attribute(
48+
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
49+
)
50+
precar = MLIR.IR.Attribute([prec, prec])
51+
res = MLIR.IR.result(
52+
MLIR.Dialects.stablehlo.dot_general(
53+
get_mlir_data(A),
54+
get_mlir_data(B);
55+
result_0=resty,
56+
dot_dimension_numbers=dot_dimension_numbers,
57+
precision_config=precar,
58+
),
59+
1,
60+
)
61+
if iszero(β)
62+
if isone(α)
63+
C.mlir_data = res
64+
else
65+
C.mlir_data = MLIR.IR.result(
66+
MLIR.Dialects.stablehlo.multiply(
67+
res, broadcast_to_size(T1(α), size(C)).mlir_data
68+
),
69+
1,
70+
)
71+
end
72+
else
73+
α_res = MLIR.IR.result(
74+
MLIR.Dialects.stablehlo.multiply(
75+
res, broadcast_to_size(T1(α), size(C)).mlir_data
76+
),
77+
1,
78+
)
79+
β_C = MLIR.IR.result(
80+
MLIR.Dialects.stablehlo.multiply(
81+
C.mlir_data, broadcast_to_size(T1(β), size(C)).mlir_data
82+
),
83+
1,
84+
)
85+
C.mlir_data = MLIR.IR.result(MLIR.Dialects.stablehlo.add(α_res, β_C), 1)
86+
end
87+
return C
88+
end
89+
90+
function LinearAlgebra.triu!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
91+
iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1)
92+
iota_2 = Ops.subtract(
93+
Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X))
94+
)
95+
idxs = Ops.compare(iota_1, iota_2; comparison_direction="LE")
96+
X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data
97+
return X
98+
end
99+
100+
function LinearAlgebra.tril!(@nospecialize(X::TracedRArray{T,2}), k::Integer) where {T}
101+
iota_1 = Ops.iota(Int64, [size(X)...]; iota_dimension=1)
102+
iota_2 = Ops.subtract(
103+
Ops.iota(Int64, [size(X)...]; iota_dimension=2), broadcast_to_size(k, size(X))
104+
)
105+
idxs = Ops.compare(iota_1, iota_2; comparison_direction="GE")
106+
X.mlir_data = Ops.select(idxs, X, zero(X)).mlir_data
107+
return X
108+
end

test/linear_algebra.jl renamed to test/integration/linear_algebra.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function mul_with_view3(A, x)
4545
return C
4646
end
4747

48-
@testset begin
48+
@testset "Matrix Multiplication" begin
4949
A = rand(4, 4)
5050
x = rand(4, 2)
5151
b = rand(4)
@@ -77,3 +77,39 @@ end
7777
@jit(mul!(C_ra, A_ra, x_ra))
7878
@test C_ra A * x
7979
end
80+
81+
@testset "triu & tril" begin
82+
A = rand(4, 6)
83+
A_ra = Reactant.to_rarray(A)
84+
85+
@test @jit(triu(A_ra)) triu(A)
86+
@test @jit(tril(A_ra)) tril(A)
87+
@test @jit(triu(A_ra, 2)) triu(A, 2)
88+
@test @jit(tril(A_ra, 2)) tril(A, 2)
89+
@test @jit(triu(A_ra, -1)) triu(A, -1)
90+
@test @jit(tril(A_ra, -1)) tril(A, -1)
91+
92+
A_ra = Reactant.to_rarray(A)
93+
@jit(triu!(A_ra))
94+
@test A_ra triu(A)
95+
96+
A_ra = Reactant.to_rarray(A)
97+
@jit(tril!(A_ra))
98+
@test A_ra tril(A)
99+
100+
A_ra = Reactant.to_rarray(A)
101+
@jit(triu!(A_ra, 2))
102+
@test A_ra triu(A, 2)
103+
104+
A_ra = Reactant.to_rarray(A)
105+
@jit(tril!(A_ra, 2))
106+
@test A_ra tril(A, 2)
107+
108+
A_ra = Reactant.to_rarray(A)
109+
@jit(triu!(A_ra, -1))
110+
@test A_ra triu(A, -1)
111+
112+
A_ra = Reactant.to_rarray(A)
113+
@jit(tril!(A_ra, -1))
114+
@test A_ra tril(A, -1)
115+
end

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,10 @@ const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))
5656
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
5757
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
5858
@safetestset "Control Flow" include("control_flow.jl")
59-
@safetestset "Linear Algebra" include("linear_algebra.jl")
6059
end
6160

6261
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
62+
@safetestset "Linear Algebra" include("integration/linear_algebra.jl")
6363
@safetestset "AbstractFFTs" include("integration/fft.jl")
6464
end
6565

0 commit comments

Comments
 (0)