532
532
limit_indices = limit_indices
533
533
rsize = limit_indices .- start_indices
534
534
@assert all (rsize .> 0 ) " Invalid slice dimensions"
535
- strides = isnothing (strides) ? [ 1 , size (x)[ 1 : ( end - 1 )] . .. ] : strides
535
+ strides = isnothing (strides) ? ones (Int64, N) : strides
536
536
res = MLIR. IR. result (
537
537
stablehlo. slice (
538
538
x. mlir_data;
@@ -1699,7 +1699,8 @@ instead.
1699
1699
@noinline function scatter_setindex (
1700
1700
dest:: TracedRArray{T,N} ,
1701
1701
scatter_indices:: TracedRArray{Int64,2} ,
1702
- updates:: TracedRArray{T2,1} ,
1702
+ updates:: TracedRArray{T2,1} ;
1703
+ location= mlir_stacktrace (" scatter_setindex" , @__FILE__ , @__LINE__ ),
1703
1704
) where {T,N,T2}
1704
1705
@assert length (updates) == size (scatter_indices, 1 )
1705
1706
@assert size (scatter_indices, 2 ) == N
@@ -1716,14 +1717,39 @@ instead.
1716
1717
push! (block, return_op)
1717
1718
pushfirst! (update_computation, block)
1718
1719
1719
- # ! format: off
1720
- update_window_dims = Int64[]
1721
- inserted_window_dims = collect (Int64, 0 : (N - 1 ))
1722
- input_batching_dims = Int64[]
1723
- scatter_indices_batching_dims = Int64[]
1724
- scatter_dims_to_operand_dims = collect (Int64, 0 : (N - 1 ))
1725
- index_vector_dim = Int64 (1 )
1720
+ return scatter (
1721
+ [dest],
1722
+ scatter_indices,
1723
+ [updates];
1724
+ update_computation,
1725
+ update_window_dims= Int64[],
1726
+ inserted_window_dims= collect (Int64, 0 : (N - 1 )),
1727
+ input_batching_dims= Int64[],
1728
+ scatter_indices_batching_dims= Int64[],
1729
+ scatter_dims_to_operand_dims= collect (Int64, 0 : (N - 1 )),
1730
+ index_vector_dim= Int64 (1 ),
1731
+ location,
1732
+ )[1 ]
1733
+ end
1734
+
1735
+ @noinline function scatter (
1736
+ dest:: Vector{TracedRArray{T,N}} ,
1737
+ scatter_indices:: TracedRArray{Int64,2} ,
1738
+ updates:: Vector{<:TracedRArray{T}} ;
1739
+ update_computation:: MLIR.IR.Region ,
1740
+ update_window_dims:: Vector{Int64} ,
1741
+ inserted_window_dims:: Vector{Int64} ,
1742
+ input_batching_dims:: Vector{Int64} ,
1743
+ scatter_indices_batching_dims:: Vector{Int64} ,
1744
+ scatter_dims_to_operand_dims:: Vector{Int64} ,
1745
+ index_vector_dim:: Int64 ,
1746
+ location= mlir_stacktrace (" scatter" , @__FILE__ , @__LINE__ ),
1747
+ ) where {T,N}
1748
+ scatter_indices = subtract (
1749
+ scatter_indices, fill (Int64 (1 ), size (scatter_indices)); location
1750
+ )
1726
1751
1752
+ # ! format: off
1727
1753
scatter_dimension_numbers = MLIR. API. stablehloScatterDimensionNumbersGet (
1728
1754
MLIR. IR. context (),
1729
1755
length (update_window_dims), update_window_dims,
@@ -1735,21 +1761,22 @@ instead.
1735
1761
)
1736
1762
# ! format: on
1737
1763
1738
- return TracedRArray {T,N} (
1739
- (),
1740
- MLIR. IR. result (
1741
- MLIR. Dialects. stablehlo. scatter (
1742
- [dest. mlir_data],
1743
- scatter_indices. mlir_data,
1744
- [updates. mlir_data];
1745
- result_0= [mlir_type (TracedRArray{T,N}, size (dest))],
1746
- update_computation,
1747
- scatter_dimension_numbers,
1748
- ),
1749
- 1 ,
1750
- ),
1751
- size (dest),
1764
+ dest_values = [d. mlir_data for d in dest]
1765
+ update_values = [u. mlir_data for u in updates]
1766
+ scatter_op = stablehlo. scatter (
1767
+ dest_values,
1768
+ scatter_indices. mlir_data,
1769
+ update_values;
1770
+ update_computation,
1771
+ scatter_dimension_numbers,
1772
+ result_0= [mlir_type (TracedRArray{T,N}, size (d)) for d in dest],
1773
+ location,
1752
1774
)
1775
+
1776
+ return [
1777
+ TracedRArray {T,N} ((), MLIR. IR. result (scatter_op, i), size (dest[i])) for
1778
+ i in eachindex (dest)
1779
+ ]
1753
1780
end
1754
1781
1755
1782
"""
@@ -1760,7 +1787,9 @@ specified by `gather_indices`. If the indices are contiguous it is recommended t
1760
1787
use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1761
1788
"""
1762
1789
@noinline function gather_getindex (
1763
- src:: TracedRArray{T,N} , gather_indices:: TracedRArray{Int64,2}
1790
+ src:: TracedRArray{T,N} ,
1791
+ gather_indices:: TracedRArray{Int64,2} ;
1792
+ location= mlir_stacktrace (" gather_getindex" , @__FILE__ , @__LINE__ ),
1764
1793
) where {T,N}
1765
1794
@assert size (gather_indices, 2 ) == N
1766
1795
@@ -1770,14 +1799,42 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1770
1799
)
1771
1800
end
1772
1801
1773
- # ! format: off
1774
- offset_dims = Int64[1 ]
1775
- collapsed_slice_dims = collect (Int64, 0 : (N - 2 ))
1776
- operand_batching_dims = Int64[]
1777
- start_indices_batching_dims = Int64[]
1778
- start_index_map = collect (Int64, 0 : (N - 1 ))
1779
- index_vector_dim = Int64 (1 )
1802
+ return reshape (
1803
+ gather (
1804
+ src,
1805
+ gather_indices;
1806
+ offset_dims= Int64[1 ],
1807
+ collapsed_slice_dims= collect (Int64, 0 : (N - 2 )),
1808
+ operand_batching_dims= Int64[],
1809
+ start_indices_batching_dims= Int64[],
1810
+ start_index_map= collect (Int64, 0 : (N - 1 )),
1811
+ index_vector_dim= Int64 (1 ),
1812
+ slice_sizes= ones (Int64, N),
1813
+ indices_are_sorted= false ,
1814
+ location,
1815
+ ),
1816
+ [size (gather_indices, 1 )],
1817
+ )
1818
+ end
1819
+
1820
+ @noinline function gather (
1821
+ src:: TracedRArray{T,N} ,
1822
+ gather_indices:: TracedRArray{Int64,2} ;
1823
+ offset_dims:: Vector{Int64} ,
1824
+ collapsed_slice_dims:: Vector{Int64} ,
1825
+ operand_batching_dims:: Vector{Int64} ,
1826
+ start_indices_batching_dims:: Vector{Int64} ,
1827
+ start_index_map:: Vector{Int64} ,
1828
+ index_vector_dim:: Int64 ,
1829
+ slice_sizes:: Vector{Int64} ,
1830
+ indices_are_sorted:: Bool = false ,
1831
+ location= mlir_stacktrace (" gather" , @__FILE__ , @__LINE__ ),
1832
+ ) where {T,N}
1833
+ gather_indices = subtract (
1834
+ gather_indices, fill (Int64 (1 ), size (gather_indices)); location
1835
+ )
1780
1836
1837
+ # ! format: off
1781
1838
dimension_numbers = MLIR. API. stablehloGatherDimensionNumbersGet (
1782
1839
MLIR. IR. context (),
1783
1840
Int64 (length (offset_dims)), offset_dims,
@@ -1789,20 +1846,18 @@ use [`MLIR.Dialects.stablehlo.dynamic_slice`](@ref) instead.
1789
1846
)
1790
1847
# ! format: on
1791
1848
1792
- return reshape (
1793
- TracedRArray {T} (
1794
- MLIR. IR. result (
1795
- MLIR. Dialects. stablehlo. gather (
1796
- src. mlir_data,
1797
- gather_indices. mlir_data;
1798
- dimension_numbers,
1799
- slice_sizes= Base. fill (Int64 (1 ), N),
1800
- indices_are_sorted= false ,
1801
- ),
1802
- 1 ,
1849
+ return TracedRArray {T} (
1850
+ MLIR. IR. result (
1851
+ MLIR. Dialects. stablehlo. gather (
1852
+ src. mlir_data,
1853
+ gather_indices. mlir_data;
1854
+ dimension_numbers,
1855
+ slice_sizes,
1856
+ indices_are_sorted,
1857
+ location,
1803
1858
),
1859
+ 1 ,
1804
1860
),
1805
- size (gather_indices, 1 ),
1806
1861
)
1807
1862
end
1808
1863
0 commit comments