Skip to content

Commit 3ee51a6

Browse files
authored
fix: better MLIR codegen for simple getindex/setindex (#1200)
* fix: emit slice instead of gather * fix: emit dus instead of scatter * fix: diag indexing * fix: indexing * fix: promotion * fix: ops * fix: index * chore: run formatter * fix: diag indices * fix: we don't need diag specialization * fix: contiguous check * fix: generalize subarray indexing behavior * fix: try fixing ambiguity * fix: more bad dispatches * fix: fixup * fix: some ambiguities * fix: more cleanup * fix: materialize subarray * fix: more fixes * Update src/TracedRArray.jl * fix: remove unwanted overlay * Update test/indexing.jl * fix: fill! compile
1 parent d8281bb commit 3ee51a6

File tree

8 files changed

+280
-204
lines changed

8 files changed

+280
-204
lines changed

src/ConcreteRArray.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -547,25 +547,22 @@ function Base.fill!(a::ConcretePJRTArray{T,N}, val) where {T,N}
547547
return a
548548
end
549549

550-
idxs = ntuple(Returns(Colon()), N)
551-
fn = compile(mysetindex!, (a, val, idxs...))
552-
fn(a, val, idxs...)
550+
fn = compile(fill!, (a, val))
551+
fn(a, val)
553552
return a
554553
end
555554

556555
function Base.fill!(a::ConcreteIFRTArray{T,N}, val) where {T,N}
557556
isempty(a) && throw("Cannot setindex! to empty buffer")
558557

559-
idxs = ntuple(Returns(Colon()), N)
560-
fn = compile(mysetindex!, (a, val, idxs...))
561-
fn(a, val, idxs...)
558+
fn = compile(fill!, (a, val))
559+
fn(a, val)
562560
return a
563561
end
564562

565563
function Base.fill!(x::Union{AnyConcreteIFRTArray,AnyConcretePJRTArray}, val)
566-
idxs = ntuple(Returns(Colon()), ndims(x))
567-
fn = compile(mysetindex!, (x, val, idxs...))
568-
fn(x, val, idxs...)
564+
fn = compile(fill!, (x, val))
565+
fn(x, val)
569566
return x
570567
end
571568

src/Ops.jl

Lines changed: 98 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ end
532532
limit_indices = limit_indices
533533
rsize = limit_indices .- start_indices
534534
@assert all(rsize .> 0) "Invalid slice dimensions"
535-
strides = isnothing(strides) ? [1, size(x)[1:(end - 1)]...] : strides
535+
strides = isnothing(strides) ? ones(Int64, N) : strides
536536
res = MLIR.IR.result(
537537
stablehlo.slice(
538538
x.mlir_data;
@@ -1699,7 +1699,8 @@ instead.
16991699
@noinline function scatter_setindex(
17001700
dest::TracedRArray{T,N},
17011701
scatter_indices::TracedRArray{Int64,2},
1702-
updates::TracedRArray{T2,1},
1702+
updates::TracedRArray{T2,1};
1703+
location=mlir_stacktrace("scatter_setindex", @__FILE__, @__LINE__),
17031704
) where {T,N,T2}
17041705
@assert length(updates) == size(scatter_indices, 1)
17051706
@assert size(scatter_indices, 2) == N
@@ -1716,14 +1717,39 @@ instead.
17161717
push!(block, return_op)
17171718
pushfirst!(update_computation, block)
17181719

1719-
#! format: off
1720-
update_window_dims = Int64[]
1721-
inserted_window_dims = collect(Int64, 0:(N - 1))
1722-
input_batching_dims = Int64[]
1723-
scatter_indices_batching_dims = Int64[]
1724-
scatter_dims_to_operand_dims = collect(Int64, 0:(N - 1))
1725-
index_vector_dim = Int64(1)
1720+
return scatter(
1721+
[dest],
1722+
scatter_indices,
1723+
[updates];
1724+
update_computation,
1725+
update_window_dims=Int64[],
1726+
inserted_window_dims=collect(Int64, 0:(N - 1)),
1727+
input_batching_dims=Int64[],
1728+
scatter_indices_batching_dims=Int64[],
1729+
scatter_dims_to_operand_dims=collect(Int64, 0:(N - 1)),
1730+
index_vector_dim=Int64(1),
1731+
location,
1732+
)[1]
1733+
end
1734+
1735+
@noinline function scatter(
1736+
dest::Vector{TracedRArray{T,N}},
1737+
scatter_indices::TracedRArray{Int64,2},
1738+
updates::Vector{<:TracedRArray{T}};
1739+
update_computation::MLIR.IR.Region,
1740+
update_window_dims::Vector{Int64},
1741+
inserted_window_dims::Vector{Int64},
1742+
input_batching_dims::Vector{Int64},
1743+
scatter_indices_batching_dims::Vector{Int64},
1744+
scatter_dims_to_operand_dims::Vector{Int64},
1745+
index_vector_dim::Int64,
1746+
location=mlir_stacktrace("scatter", @__FILE__, @__LINE__),
1747+
) where {T,N}
1748+
scatter_indices = subtract(
1749+
scatter_indices, fill(Int64(1), size(scatter_indices)); location
1750+
)
17261751

1752+
#! format: off
17271753
scatter_dimension_numbers = MLIR.API.stablehloScatterDimensionNumbersGet(
17281754
MLIR.IR.context(),
17291755
length(update_window_dims), update_window_dims,
@@ -1735,21 +1761,22 @@ instead.
17351761
)
17361762
#! format: on
17371763

1738-
return TracedRArray{T,N}(
1739-
(),
1740-
MLIR.IR.result(
1741-
MLIR.Dialects.stablehlo.scatter(
1742-
[dest.mlir_data],
1743-
scatter_indices.mlir_data,
1744-
[updates.mlir_data];
1745-
result_0=[mlir_type(TracedRArray{T,N}, size(dest))],
1746-
update_computation,
1747-
scatter_dimension_numbers,
1748-
),
1749-
1,
1750-
),
1751-
size(dest),
1764+
dest_values = [d.mlir_data for d in dest]
1765+
update_values = [u.mlir_data for u in updates]
1766+
scatter_op = stablehlo.scatter(
1767+
dest_values,
1768+
scatter_indices.mlir_data,
1769+
update_values;
1770+
update_computation,
1771+
scatter_dimension_numbers,
1772+
result_0=[mlir_type(TracedRArray{T,N}, size(d)) for d in dest],
1773+
location,
17521774
)
1775+
1776+
return [
1777+
TracedRArray{T,N}((), MLIR.IR.result(scatter_op, i), size(dest[i])) for
1778+
i in eachindex(dest)
1779+
]
17531780
end
17541781

17551782
"""
@@ -1760,7 +1787,9 @@ specified by `gather_indices`. If the indices are contiguous it is recommended t
17601787
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
17611788
"""
17621789
@noinline function gather_getindex(
1763-
src::TracedRArray{T,N}, gather_indices::TracedRArray{Int64,2}
1790+
src::TracedRArray{T,N},
1791+
gather_indices::TracedRArray{Int64,2};
1792+
location=mlir_stacktrace("gather_getindex", @__FILE__, @__LINE__),
17641793
) where {T,N}
17651794
@assert size(gather_indices, 2) == N
17661795

@@ -1770,14 +1799,42 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
17701799
)
17711800
end
17721801

1773-
#! format: off
1774-
offset_dims = Int64[1]
1775-
collapsed_slice_dims = collect(Int64, 0:(N - 2))
1776-
operand_batching_dims = Int64[]
1777-
start_indices_batching_dims = Int64[]
1778-
start_index_map = collect(Int64, 0:(N - 1))
1779-
index_vector_dim = Int64(1)
1802+
return reshape(
1803+
gather(
1804+
src,
1805+
gather_indices;
1806+
offset_dims=Int64[1],
1807+
collapsed_slice_dims=collect(Int64, 0:(N - 2)),
1808+
operand_batching_dims=Int64[],
1809+
start_indices_batching_dims=Int64[],
1810+
start_index_map=collect(Int64, 0:(N - 1)),
1811+
index_vector_dim=Int64(1),
1812+
slice_sizes=ones(Int64, N),
1813+
indices_are_sorted=false,
1814+
location,
1815+
),
1816+
[size(gather_indices, 1)],
1817+
)
1818+
end
1819+
1820+
@noinline function gather(
1821+
src::TracedRArray{T,N},
1822+
gather_indices::TracedRArray{Int64,2};
1823+
offset_dims::Vector{Int64},
1824+
collapsed_slice_dims::Vector{Int64},
1825+
operand_batching_dims::Vector{Int64},
1826+
start_indices_batching_dims::Vector{Int64},
1827+
start_index_map::Vector{Int64},
1828+
index_vector_dim::Int64,
1829+
slice_sizes::Vector{Int64},
1830+
indices_are_sorted::Bool=false,
1831+
location=mlir_stacktrace("gather", @__FILE__, @__LINE__),
1832+
) where {T,N}
1833+
gather_indices = subtract(
1834+
gather_indices, fill(Int64(1), size(gather_indices)); location
1835+
)
17801836

1837+
#! format: off
17811838
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
17821839
MLIR.IR.context(),
17831840
Int64(length(offset_dims)), offset_dims,
@@ -1789,20 +1846,18 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
17891846
)
17901847
#! format: on
17911848

1792-
return reshape(
1793-
TracedRArray{T}(
1794-
MLIR.IR.result(
1795-
MLIR.Dialects.stablehlo.gather(
1796-
src.mlir_data,
1797-
gather_indices.mlir_data;
1798-
dimension_numbers,
1799-
slice_sizes=Base.fill(Int64(1), N),
1800-
indices_are_sorted=false,
1801-
),
1802-
1,
1849+
return TracedRArray{T}(
1850+
MLIR.IR.result(
1851+
MLIR.Dialects.stablehlo.gather(
1852+
src.mlir_data,
1853+
gather_indices.mlir_data;
1854+
dimension_numbers,
1855+
slice_sizes,
1856+
indices_are_sorted,
1857+
location,
18031858
),
1859+
1,
18041860
),
1805-
size(gather_indices, 1),
18061861
)
18071862
end
18081863

src/Overlay.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,6 @@
88
return f
99
end
1010

11-
@reactant_overlay @noinline function Base.setindex!(
12-
a::AnyTracedRArray{T,N}, v, indices::Vararg{Any,N}
13-
) where {T,N}
14-
ancestor_indices = TracedUtils.get_ancestor_indices(a, indices...)
15-
(Base.inferencebarrier(setindex!))(Reactant.ancestor(a), v, ancestor_indices...)
16-
return a
17-
end
18-
1911
# Enzyme.jl overlays
2012
@reactant_overlay @noinline function Enzyme.autodiff_deferred(
2113
rmode::Enzyme.Mode, f::FA, rt::Type{A}, args::Vararg{Annotation,Nargs}

0 commit comments

Comments
 (0)