|
80 | 80 | x::T; location=mlir_stacktrace("constant", @__FILE__, @__LINE__)
|
81 | 81 | ) where {T<:Number}
|
82 | 82 | x isa TracedRNumber && return x
|
83 |
| - res = constant(fill(x); location) |
| 83 | + res = fill(x; location) |
84 | 84 | return TracedRNumber{T}((), res.mlir_data)
|
85 | 85 | end
|
86 | 86 |
|
| 87 | +function fill( |
| 88 | + v, dims::Base.DimOrInd...; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) |
| 89 | +) |
| 90 | + return fill(v, dims; location) |
| 91 | +end |
| 92 | +function fill( |
| 93 | + v, |
| 94 | + dims::NTuple{N,Union{Integer,Base.OneTo}}; |
| 95 | + location=mlir_stacktrace("fill", @__FILE__, @__LINE__), |
| 96 | +) where {N} |
| 97 | + return fill(v, map(Base.to_dim, dims); location) |
| 98 | +end |
| 99 | +function fill( |
| 100 | + v, dims::NTuple{N,Integer}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) |
| 101 | +) where {N} |
| 102 | + return fill(v, collect(dims); location) |
| 103 | +end |
| 104 | +function fill(v, ::Tuple{}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__)) |
| 105 | + return fill(v, Int[]; location) |
| 106 | +end |
| 107 | + |
| 108 | +function fill(number::TracedRNumber{T}, shape::Vector{Int}; location) where {T} |
| 109 | + return Base.fill(number, Tuple(shape)) |
| 110 | +end |
| 111 | + |
| 112 | +for (T, mlir_func) in ( |
| 113 | + (Bool, :mlirDenseElementsAttrBoolSplatGet), |
| 114 | + (UInt8, :mlirDenseElementsAttrUInt8SplatGet), |
| 115 | + (Int8, :mlirDenseElementsAttrInt8SplatGet), |
| 116 | + (UInt32, :mlirDenseElementsAttrUInt32SplatGet), |
| 117 | + (Int32, :mlirDenseElementsAttrInt32SplatGet), |
| 118 | + (UInt64, :mlirDenseElementsAttrUInt64SplatGet), |
| 119 | + (Int64, :mlirDenseElementsAttrInt64SplatGet), |
| 120 | + (Float32, :mlirDenseElementsAttrFloatSplatGet), |
| 121 | + (Float64, :mlirDenseElementsAttrDoubleSplatGet), |
| 122 | +) |
| 123 | + @eval begin |
| 124 | + @noinline function fill( |
| 125 | + number::$T, |
| 126 | + shape::Vector{Int}; |
| 127 | + location=mlir_stacktrace("fill", @__FILE__, @__LINE__), |
| 128 | + ) |
| 129 | + tt = MLIR.IR.TensorType(shape, MLIR.IR.Type($T); location=location) |
| 130 | + |
| 131 | + splatattr = MLIR.API.$mlir_func(tt, number) |
| 132 | + cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) |
| 133 | + cst = MLIR.IR.result(cst_op) |
| 134 | + ta = TracedRArray{$T,length(shape)}((), cst, shape) |
| 135 | + return ta |
| 136 | + end |
| 137 | + end |
| 138 | +end |
| 139 | + |
| 140 | +_fill_element_attr(x) = MLIR.IR.Attribute(x) |
| 141 | +function _fill_element_attr(x::Complex) |
| 142 | + return MLIR.IR.Attribute([ |
| 143 | + MLIR.IR.Attribute(Base.real(x)), MLIR.IR.Attribute(Base.imag(x)) |
| 144 | + ]) |
| 145 | +end |
| 146 | + |
| 147 | +@noinline function fill( |
| 148 | + element::T, shape::Vector{Int}; location=mlir_stacktrace("fill", @__FILE__, @__LINE__) |
| 149 | +) where {T} |
| 150 | + tt = MLIR.IR.TensorType(shape, MLIR.IR.Type(T)) |
| 151 | + splatattr = MLIR.API.mlirDenseElementsAttrSplatGet(tt, _fill_element_attr(element)) |
| 152 | + cst_op = stablehlo.constant(; output=tt, value=splatattr, location=location) |
| 153 | + cst = MLIR.IR.result(cst_op) |
| 154 | + ta = TracedRArray{T,length(shape)}((), cst, shape) |
| 155 | + return ta |
| 156 | +end |
| 157 | + |
87 | 158 | # unary elementwise ops
|
88 | 159 | for (dialect, op) in [
|
89 | 160 | (:stablehlo, :abs),
|
|
350 | 421 | @noinline function pad(
|
351 | 422 | x::TracedRArray{T,N},
|
352 | 423 | padding_value::TracedRNumber{T};
|
353 |
| - low=fill(0, N), |
354 |
| - high=fill(0, N), |
355 |
| - interior=fill(0, N), |
| 424 | + low=Base.fill(0, N), |
| 425 | + high=Base.fill(0, N), |
| 426 | + interior=Base.fill(0, N), |
356 | 427 | location=mlir_stacktrace("pad", @__FILE__, @__LINE__),
|
357 | 428 | ) where {T,N}
|
358 | 429 | rsize = size(x) .+ low .+ high .+ max.(size(x) .- 1, 0) .* interior
|
@@ -1056,7 +1127,7 @@ end
|
1056 | 1127 | op = chlo.top_k(x.mlir_data; values, indices, k, location)
|
1057 | 1128 | indices = add(
|
1058 | 1129 | TracedRArray{Int32,N}((), MLIR.IR.result(op, 2), rsize),
|
1059 |
| - constant(fill(Int32(1), Tuple(rsize))), |
| 1130 | + fill(Int32(1), Tuple(rsize)), |
1060 | 1131 | ) # return the 1-indexed index
|
1061 | 1132 | indices = convert(TracedRArray{Int64,N}, indices) # julia indexes with Int64 generally
|
1062 | 1133 | values = TracedRArray{T,N}((), MLIR.IR.result(op, 1), rsize)
|
@@ -1160,7 +1231,7 @@ end
|
1160 | 1231 | (; output_state, output) = rng_bit_generator(uT, seed, shape; algorithm, location)
|
1161 | 1232 | output = divide(
|
1162 | 1233 | convert(TracedRArray{T,ndims(output)}, output),
|
1163 |
| - constant(fill(T(typemax(uT)), Tuple(shape)); location), |
| 1234 | + fill(T(typemax(uT)), Tuple(shape); location), |
1164 | 1235 | )
|
1165 | 1236 | return (; output_state, output)
|
1166 | 1237 | end
|
@@ -1200,11 +1271,11 @@ fields:
|
1200 | 1271 | rand_uniform = res.output
|
1201 | 1272 | seed = res.output_state
|
1202 | 1273 | scaled_uniform = subtract(
|
1203 |
| - multiply(rand_uniform, constant(fill(T(2), size(rand_uniform)))), |
1204 |
| - constant(fill(T(1), size(rand_uniform))), |
| 1274 | + multiply(rand_uniform, fill(T(2), size(rand_uniform))), |
| 1275 | + fill(T(1), size(rand_uniform)), |
1205 | 1276 | )
|
1206 | 1277 | probit = erf_inv(scaled_uniform)
|
1207 |
| - rand_normal = multiply(probit, constant(fill(Base.sqrt(T(2)), size(rand_uniform)))) |
| 1278 | + rand_normal = multiply(probit, fill(Base.sqrt(T(2)), size(rand_uniform))) |
1208 | 1279 | return (; output_state=seed, output=rand_normal)
|
1209 | 1280 | end
|
1210 | 1281 |
|
@@ -1570,7 +1641,7 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
|
1570 | 1641 | src.mlir_data,
|
1571 | 1642 | gather_indices.mlir_data;
|
1572 | 1643 | dimension_numbers,
|
1573 |
| - slice_sizes=fill(Int64(1), N), |
| 1644 | + slice_sizes=Base.fill(Int64(1), N), |
1574 | 1645 | indices_are_sorted=false,
|
1575 | 1646 | ),
|
1576 | 1647 | 1,
|
|
0 commit comments