Skip to content

Commit e9b7a72

Browse files
make similar return empty tensors. (#632)
* segfaults * fill op * add fill specializations * remove leftover print + formatting * * add location kwarg to fill * support fill of complex * fix constant op * `constant(fill(` -> `fill(` * formatting * add fallback to base fill for tracedrnumber * Update src/Ops.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 0ddeab2 commit e9b7a72

File tree

4 files changed

+90
-23
lines changed

4 files changed

+90
-23
lines changed

src/Ops.jl

Lines changed: 81 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,10 +80,81 @@ end
8080
x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
8181
) where {T<:Number}
8282
x isa TracedRNumber && return x
83-
res = constant(fill(x); location)
83+
res = fill(x; location)
8484
return TracedRNumber{T}((), res.mlir_data)
8585
end
8686

87+
function fill(
88+
v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
89+
)
90+
return fill(v, dims; location)
91+
end
92+
function fill(
93+
v,
94+
dims::NTuple{N,Union{Integer,Base.OneTo}};
95+
location=mlir_stacktrace("fill", @__FILE__, @__LINE__),
96+
) where {N}
97+
return fill(v, map(Base.to_dim, dims); location)
98+
end
99+
function fill(
100+
v, dims::NTuple{N,Integer}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
101+
) where {N}
102+
return fill(v, collect(dims); location)
103+
end
104+
function fill(v, ::Tuple{}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__))
105+
return fill(v, Int[]; location)
106+
end
107+
108+
function fill(number::TracedRNumber{T}, shape::Vector{Int}; location) where {T}
109+
return Base.fill(number, Tuple(shape))
110+
end
111+
112+
for (T, mlir_func) in (
113+
(Bool, :mlirDenseElementsAttrBoolSplatGet),
114+
(UInt8, :mlirDenseElementsAttrUInt8SplatGet),
115+
(Int8, :mlirDenseElementsAttrInt8SplatGet),
116+
(UInt32, :mlirDenseElementsAttrUInt32SplatGet),
117+
(Int32, :mlirDenseElementsAttrInt32SplatGet),
118+
(UInt64, :mlirDenseElementsAttrUInt64SplatGet),
119+
(Int64, :mlirDenseElementsAttrInt64SplatGet),
120+
(Float32, :mlirDenseElementsAttrFloatSplatGet),
121+
(Float64, :mlirDenseElementsAttrDoubleSplatGet),
122+
)
123+
@eval begin
124+
@noinline function fill(
125+
number::$T,
126+
shape::Vector{Int};
127+
location=mlir_stacktrace("fill", @__FILE__, @__LINE__),
128+
)
129+
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type($T); location=location)
130+
131+
splatattr = MLIR.API.$mlir_func(tt, number)
132+
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
133+
cst = MLIR.IR.result(cst_op)
134+
ta = TracedRArray{$T,length(shape)}((), cst, shape)
135+
return ta
136+
end
137+
end
138+
end
139+
140+
_fill_element_attr(x) = MLIR.IR.Attribute(x)
141+
function _fill_element_attr(x::Complex)
142+
return MLIR.IR.Attribute([
143+
MLIR.IR.Attribute(Base.real(x)), MLIR.IR.Attribute(Base.imag(x))
144+
])
145+
end
146+
147+
@noinline function fill(
148+
element::T, shape::Vector{Int}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)
149+
) where {T}
150+
tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T))
151+
splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element))
152+
cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location)
153+
cst = MLIR.IR.result(cst_op)
154+
ta = TracedRArray{T,length(shape)}((), cst, shape)
155+
return ta
156+
end
157+
87158
# unary elementwise ops
88159
for (dialect, op) in [
89160
(:stablehlo, :abs),
@@ -350,9 +421,9 @@ end
350421
@noinline function pad(
351422
x::TracedRArray{T,N},
352423
padding_value::TracedRNumber{T};
353-
low=fill(0, N),
354-
high=fill(0, N),
355-
interior=fill(0, N),
424+
low=Base.fill(0, N),
425+
high=Base.fill(0, N),
426+
interior=Base.fill(0, N),
356427
location=mlir_stacktrace("pad", @__FILE__, @__LINE__),
357428
) where {T,N}
358429
rsize = size(x) .+ low .+ high .+ max.(size(x) .- 1, 0) .* interior
@@ -1056,7 +1127,7 @@ end
10561127
op = chlo.top_k(x.mlir_data; values, indices, k, location)
10571128
indices = add(
10581129
TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
1059-
constant(fill(Int32(1), Tuple(rsize))),
1130+
fill(Int32(1), Tuple(rsize)),
10601131
) # return the 1-indexed index
10611132
indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally
10621133
values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize)
@@ -1160,7 +1231,7 @@ end
11601231
(; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
11611232
output = divide(
11621233
convert(TracedRArray{T,ndims(output)}, output),
1163-
constant(fill(T(typemax(uT)), Tuple(shape)); location),
1234+
fill(T(typemax(uT)), Tuple(shape); location),
11641235
)
11651236
return (; output_state, output)
11661237
end
@@ -1200,11 +1271,11 @@ fields:
12001271
rand_uniform = res.output
12011272
seed = res.output_state
12021273
scaled_uniform = subtract(
1203-
multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))),
1204-
constant(fill(T(1), size(rand_uniform))),
1274+
multiply(rand_uniform, fill(T(2), size(rand_uniform))),
1275+
fill(T(1), size(rand_uniform)),
12051276
)
12061277
probit = erf_inv(scaled_uniform)
1207-
rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform))))
1278+
rand_normal = multiply(probit, fill(Base.sqrt(T(2)), size(rand_uniform)))
12081279
return (; output_state=seed, output=rand_normal)
12091280
end
12101281

@@ -1570,7 +1641,7 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
15701641
src.mlir_data,
15711642
gather_indices.mlir_data;
15721643
dimension_numbers,
1573-
slice_sizes=fill(Int64(1), N),
1644+
slice_sizes=Base.fill(Int64(1), N),
15741645
indices_are_sorted=false,
15751646
),
15761647
1,

src/TracedRArray.jl

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -373,9 +373,8 @@ Base.collect(x::TracedRArray) = copy(x) # XXX: Is this correct?
373373

374374
Base.copy(A::TracedRArray{T,N}) where {T,N} = TracedRArray{T,N}((), A.mlir_data, size(A))
375375

376-
# TODO is there a way to create an unitialized `tensor`? does it show an advantage? maybe `fill`?
377376
function Base.similar(::TracedRArray, ::Type{T}, dims::Dims{N}) where {T,N}
378-
return Ops.constant(zeros(unwrapped_eltype(T), dims))
377+
return Ops.fill(zero(unwrapped_eltype(T)), dims)
379378
end
380379

381380
function Base.show(io::IOty, X::TracedRArray{T,N}) where {T,N,IOty<:Union{IO,IOContext}}
@@ -998,12 +997,12 @@ function Base.findmin(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
998997
# Compute linear indices
999998
strds = strides(x)
1000999
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
1001-
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
1002-
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
1000+
iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices)))
1001+
linear_indices = Ops.fill(Int64(1), size(indices))
10031002
for d in eachindex(iotas)
10041003
linear_indices = Ops.add(
10051004
linear_indices,
1006-
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
1005+
Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))),
10071006
)
10081007
end
10091008

@@ -1027,12 +1026,12 @@ function Base.findmax(f, x::AnyTracedRArray; dims::Union{Integer,Nothing}=nothin
10271026
# Compute linear indices
10281027
strds = strides(x)
10291028
iotas = [Ops.iota(Int64, [size(indices)...]; iota_dimension=i) for i in 1:ndims(x)]
1030-
iotas[dims] = Ops.subtract(indices, Ops.constant(fill(Int64(1), size(indices))))
1031-
linear_indices = Ops.constant(fill(Int64(1), size(indices)))
1029+
iotas[dims] = Ops.subtract(indices, Ops.fill(Int64(1), size(indices)))
1030+
linear_indices = Ops.fill(Int64(1), size(indices))
10321031
for d in eachindex(iotas)
10331032
linear_indices = Ops.add(
10341033
linear_indices,
1035-
Ops.multiply(iotas[d], Ops.constant(fill(Int64(strds[d]), size(iotas[d])))),
1034+
Ops.multiply(iotas[d], Ops.fill(Int64(strds[d]), size(iotas[d]))),
10361035
)
10371036
end
10381037

src/TracedRNumber.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,7 @@ function TracedUtils.promote_to(::Type{TracedRNumber{T}}, rhs) where {T}
7878
TracedRNumber{Reactant.unwrapped_eltype(rhs)}((), rhs.mlir_data),
7979
)
8080
end
81-
rhs isa Number &&
82-
return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(fill(T(rhs))))
81+
rhs isa Number && return TracedUtils.promote_to(TracedRNumber{T}, Ops.fill(T(rhs)))
8382
return TracedUtils.promote_to(TracedRNumber{T}, Ops.constant(collect(rhs)))
8483
end
8584

src/stdlibs/LinearAlgebra.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,9 +302,7 @@ function LinearAlgebra._diagm(
302302
MLIR.IR.result(MLIR.Dialects.stablehlo.concatenate(concat_inputs; dimension=0), 1),
303303
(size(scatter_indices, 1),),
304304
)
305-
return Ops.scatter_setindex(
306-
Ops.constant(fill(zero(T), (m, n))), scatter_indices, values
307-
)
305+
return Ops.scatter_setindex(Ops.fill(zero(T), (m, n)), scatter_indices, values)
308306
end
309307

310308
# Common Utilities

0 commit comments

Comments
 (0)