Skip to content

Commit e6f3dc6

Browse files
committed
refactor: use Ops instead of direct stablehlo calls
1 parent ca1f1be commit e6f3dc6

File tree

6 files changed

+77
-252
lines changed

6 files changed

+77
-252
lines changed

src/ControlFlow.jl

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,9 @@ function get_region_removing_missing_values(compiled_fn, insertions)
142142
return_op = MLIR.IR.terminator(block)
143143
for (i, rt) in insertions
144144
if rt isa TracedRNumber
145-
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, ()))
146-
op = MLIR.Dialects.stablehlo.constant(; value=attr)
145+
op = Ops.constant(Array{eltype(rt)}(undef, ()))
147146
elseif rt isa TracedRArray
148-
attr = MLIR.IR.DenseElementsAttribute(Array{eltype(rt)}(undef, size(rt)))
149-
op = MLIR.Dialects.stablehlo.constant(; value=attr)
147+
op = Ops.constant(Array{eltype(rt)}(undef, size(rt)))
150148
else
151149
error("Unknown type $(typeof(rt))")
152150
end

src/Interpreter.jl

Lines changed: 3 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -233,17 +233,7 @@ function push_acts!(ad_inputs, x::BatchDuplicated, path, reverse)
233233
predims = size(x.val)
234234
cval = MLIR.IR.result(
235235
MLIR.Dialects.stablehlo.concatenate(
236-
[
237-
MLIR.IR.result(
238-
MLIR.Dialects.stablehlo.reshape(
239-
v.mlir_data;
240-
result_0=MLIR.IR.TensorType(
241-
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
242-
),
243-
),
244-
) for v in x.dval
245-
];
246-
dimension=Int64(0),
236+
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
247237
),
248238
)
249239
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
@@ -258,17 +248,7 @@ function push_acts!(ad_inputs, x::BatchDuplicatedNoNeed, path, reverse)
258248
predims = size(x.val)
259249
cval = MLIR.IR.result(
260250
MLIR.Dialects.stablehlo.concatenate(
261-
[
262-
MLIR.IR.result(
263-
MLIR.Dialects.stablehlo.reshape(
264-
v.mlir_data;
265-
result_0=MLIR.IR.TensorType(
266-
Int64[1, predims...], eltype(MLIR.IR.type(v.mlir_data))
267-
),
268-
),
269-
) for v in x.dval
270-
];
271-
dimension=Int64(0),
251+
[Ops.reshape(v, Int64[1, predims...]) for v in x.dval]; dimension=Int64(0)
272252
),
273253
)
274254
tval = TracedRArray{ET,length(predims) + 1}((), cval, (length(x.dval), predims...))
@@ -502,22 +482,12 @@ function overload_autodiff(
502482
for i in 1:width
503483
sz = size(a)
504484
starts = Int64[i]
505-
strides = Int64[1]
506485
limits = Int64[i]
507486
for v in sz
508487
push!(starts, 0)
509488
push!(limits, v)
510-
push!(strides, 1)
511489
end
512-
sval = MLIR.IR.result(
513-
MLIR.Dialects.stablehlo.slice(
514-
sval;
515-
start_indices=starts,
516-
limit_indices=limits,
517-
stride_indices=strides,
518-
),
519-
1,
520-
)
490+
sval = Ops.slice(sval, starts, limits)
521491
set!(dresult[i], path[2:end], sval)
522492
end
523493
end

src/Ops.jl

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ function constant(
2828
end
2929

3030
function constant(x::ConcreteRArray; kwargs...)
31-
return stablehlo.constant(convert(Array, x); kwargs...)
31+
return stablehlo.constant(Base.convert(Array, x); kwargs...)
3232
end
3333

3434
function constant(
@@ -42,7 +42,9 @@ function constant(
4242
x::ConcreteRNumber{T}; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
4343
) where {T}
4444
output = mlir_type(TracedRArray{T,0}, ())
45-
value = MLIR.IR.DenseElementsAttribute(fill(MLIR.IR.Attribute(convert(T, x)), output))
45+
value = MLIR.IR.DenseElementsAttribute(
46+
fill(MLIR.IR.Attribute(Base.convert(T, x)), output)
47+
)
4648
res = MLIR.IR.result(stablehlo.constant(; output, value, location))
4749
return TracedRNumber{T,N}((), res)
4850
end
@@ -1033,7 +1035,7 @@ function compare(
10331035
end
10341036

10351037
res = MLIR.IR.result(
1036-
MLIR.Dialects.stablehlo.compare(
1038+
stablehlo.compare(
10371039
lhs.mlir_data,
10381040
rhs.mlir_data;
10391041
comparison_direction=MLIR.API.stablehloComparisonDirectionAttrGet(
@@ -1048,4 +1050,35 @@ function compare(
10481050
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
10491051
end
10501052

1053+
# eltype conversion
1054+
function convert(
1055+
::Type{TracedRArray{T,N}},
1056+
x::TracedRArray;
1057+
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
1058+
) where {T,N}
1059+
@assert N == ndims(x)
1060+
return TracedRArray{T,N}(
1061+
(),
1062+
MLIR.IR.result(
1063+
stablehlo.convert(
1064+
x.mlir_data; result=mlir_type(TracedRArray{T,N}, size(x)), location
1065+
),
1066+
),
1067+
size(x),
1068+
)
1069+
end
1070+
1071+
function convert(
1072+
::Type{TracedRNumber{T}},
1073+
x::TracedRNumber;
1074+
location=mlir_stacktrace("convert", @__FILE__, @__LINE__),
1075+
) where {T}
1076+
return TracedRNumber{T}(
1077+
(),
1078+
MLIR.IR.result(
1079+
stablehlo.convert(x.mlir_data; result=mlir_type(TracedRNumber{T}), location)
1080+
),
1081+
)
1082+
end
1083+
10511084
end

src/TracedRArray.jl

Lines changed: 15 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -36,35 +36,16 @@ ReactantCore.is_traced(::TracedRArray) = true
3636

3737
new_traced_value(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), nothing, size(A))
3838

39+
TracedRArray{T,N}(rhs::TracedRArray{T,N}) where {T,N} = rhs
3940
function TracedRArray{T,N}(rhs::TracedRArray{T0,N}) where {T,T0,N}
40-
if T == T0
41-
return rhs
42-
else
43-
return TracedRArray{T,N}(
44-
(),
45-
MLIR.IR.result(
46-
MLIR.Dialects.stablehlo.convert(
47-
rhs.mlir_data; result=mlir_type(TracedRArray{T,N}, size(rhs))
48-
),
49-
1,
50-
),
51-
size(rhs),
52-
)
53-
end
41+
return Ops.convert(TracedRArray{T,N}, rhs)
5442
end
5543

5644
function TracedRArray{T,N}(rhs::WrappedTracedRArray{T0,N}) where {T0,T,N}
5745
return TracedRArray{T,N}(materialize_traced_array(rhs))
5846
end
5947

60-
function TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N}
61-
attr = MLIR.IR.DenseElementsAttribute(collect(rhs))
62-
return TracedRArray{T,N}(
63-
TracedRArray{T0,length(size(rhs))}(
64-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(rhs)
65-
),
66-
)
67-
end
48+
TracedRArray{T,N}(rhs::AbstractArray{T0,N}) where {T0,T,N} = Ops.constant(collect(rhs))
6849

6950
materialize_traced_array(x::TracedRArray) = x
7051
materialize_traced_array(x::WrappedTracedRArray) = x[axes(x)...]
@@ -164,7 +145,6 @@ function Base.getindex(
164145
),
165146
1,
166147
)
167-
168148
return TracedRNumber{T}((), res2)
169149
end
170150

@@ -254,9 +234,7 @@ Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data,
254234

255235
# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
256236
function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N}
257-
attr = MLIR.IR.DenseElementsAttribute(zeros(T, dims))
258-
res = MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1)
259-
return TracedRArray{T,N}((), res, dims)
237+
return Ops.constant(zeros(T, dims))
260238
end
261239

262240
function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}}
@@ -266,69 +244,20 @@ function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOC
266244
end
267245

268246
function Base.permutedims(A::AnyTracedRArray{T,N}, perm) where {T,N}
269-
return TracedRArray{T,N}(
270-
(),
271-
MLIR.IR.result(
272-
MLIR.Dialects.stablehlo.transpose(
273-
get_mlir_data(A);
274-
permutation=MLIR.IR.DenseArrayAttribute([Int64(i - 1) for i in perm]),
275-
),
276-
1,
277-
),
278-
Tuple(size(A, i) for i in perm),
279-
)
247+
return Ops.transpose(materialize_traced_array(A), perm)
280248
end
281249

282-
Base.conj(A::TracedRArray) = A
283-
function Base.conj(A::TracedRArray{T,N}) where {T<:Complex,N}
284-
return TracedRArray{T,N}(
285-
(),
286-
MLIR.IR.result(
287-
MLIR.Dialects.chlo.conj(
288-
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
289-
),
290-
1,
291-
),
292-
size(A),
293-
)
294-
end
295-
296-
Base.conj!(A::TracedRArray) = A
297-
function Base.conj!(A::TracedRArray{T,N}) where {T<:Complex,N}
298-
A.mlir_data = MLIR.IR.result(
299-
MLIR.Dialects.chlo.conj(A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))),
300-
1,
301-
)
250+
Base.conj!(A::AnyTracedRArray) = A
251+
function Base.conj!(A::AnyTracedRArray{<:Complex})
252+
set_mlir_data!(A, Ops.conj(materialize_traced_array(A)).mlir_data)
302253
return A
303254
end
304255

305-
Base.real(A::TracedRArray) = A
306-
function Base.real(A::TracedRArray{Complex{T},N}) where {T,N}
307-
return TracedRArray{T,N}(
308-
(),
309-
MLIR.IR.result(
310-
MLIR.Dialects.stablehlo.real(
311-
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
312-
),
313-
1,
314-
),
315-
size(A),
316-
)
317-
end
256+
Base.real(A::AnyTracedRArray) = A
257+
Base.real(A::AnyTracedRArray{<:Complex}) = Ops.real(materialize_traced_array(A))
318258

319-
Base.imag(A::TracedRArray) = zero(A)
320-
function Base.imag(A::TracedRArray{Complex{T},N}) where {T,N}
321-
return TracedRArray{T,N}(
322-
(),
323-
MLIR.IR.result(
324-
MLIR.Dialects.stablehlo.imag(
325-
A.mlir_data; result=mlir_type(TracedRArray{T,N}, size(A))
326-
),
327-
1,
328-
),
329-
size(A),
330-
)
331-
end
259+
Base.imag(A::AnyTracedRArray) = zero(A)
260+
Base.imag(A::AnyTracedRArray{<:Complex}) = Ops.imag(materialize_traced_array(A))
332261

333262
promote_to(::Type{TracedRArray{T,N}}, rhs) where {T,N} = TracedRArray{T,N}(rhs)
334263

@@ -521,13 +450,7 @@ function Base.mapreduce(
521450
redT = eltype(MLIR.IR.julia_type(MLIR.IR.type(red)))
522451

523452
if dims != (:)
524-
red = MLIR.IR.result(
525-
MLIR.Dialects.stablehlo.reshape(
526-
red; result_0=MLIR.IR.TensorType(toonedims, eltype(MLIR.IR.type(red)))
527-
),
528-
1,
529-
)
530-
red = TracedRArray{redT,length(toonedims)}((), red, (toonedims...,))
453+
red = Ops.reshape(red, toonedims...)
531454
else
532455
if length(outdims) == 0
533456
red = TracedRNumber{redT}((), red)
@@ -633,27 +556,14 @@ function Base.copyto!(dest::TracedRArray{T,N}, src::TracedRArray{T,N}) where {T,
633556
return dest
634557
end
635558

636-
function broadcast_to_size(arg::AbstractArray, rsize)
637-
attr = MLIR.IR.DenseElementsAttribute(arg)
638-
len = ndims(arg)
639-
@assert typeof(len) == Int
640-
arg = TracedRArray{eltype(arg),len}(
641-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), size(arg)
642-
)
643-
return broadcast_to_size(arg, rsize)
644-
end
559+
broadcast_to_size(arg::AbstractArray, rsize) = broadcast_to_size(Ops.constant(arg), rsize)
645560

646561
function broadcast_to_size(arg::Base.RefValue, rsize)
647562
# XXX: don't we want to expand here to rsize?
648563
return arg
649564
end
650565

651-
function broadcast_to_size(arg::T, rsize) where {T<:Number}
652-
attr = MLIR.IR.DenseElementsAttribute(Base.fill(arg, Tuple(rsize)))
653-
return TracedRArray{T,length(rsize)}(
654-
(), MLIR.IR.result(MLIR.Dialects.stablehlo.constant(; value=attr), 1), rsize
655-
)
656-
end
566+
broadcast_to_size(arg::Number, rsize) = Ops.constant(Base.fill(arg, Tuple(rsize)))
657567

658568
function broadcast_to_size(arg::TracedRNumber, rsize)
659569
length(rsize) == 0 && return arg

0 commit comments

Comments
 (0)