@@ -14,8 +14,9 @@ using ..Reactant:
14
14
MLIR,
15
15
ancestor,
16
16
unwrapped_eltype
17
+ using .. TracedUtils: TracedUtils, get_mlir_data, set_mlir_data!, materialize_traced_array
18
+
17
19
using ReactantCore: ReactantCore
18
- using .. TracedUtils: TracedUtils, materialize_traced_array
19
20
using GPUArraysCore: GPUArraysCore
20
21
21
22
ReactantCore. is_traced (:: TracedRArray ) = true
@@ -55,25 +56,37 @@ function Base.getindex(
55
56
return TracedRNumber {T} ((), res2)
56
57
end
57
58
58
- function Base. getindex (a:: TracedRArray{T,0} ) where {T}
59
- return TracedRNumber {T} ((), a. mlir_data)
60
- end
59
+ Base. getindex (a:: TracedRArray{T,0} ) where {T} = TracedRNumber {T} ((), a. mlir_data)
61
60
62
- # XXX : We want to support https://github.com/EnzymeAD/Reactant.jl/issues/242 eventually
63
61
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
64
62
indices = map (enumerate (indices)) do (idx, i)
65
63
i isa Colon && return 1 : size (a, idx)
66
64
i isa CartesianIndex && return Tuple (i)
67
65
return i
68
66
end
69
67
70
- foreach (indices) do idxs
71
- idxs isa Number && return nothing
68
+ non_contiguous_getindex = false
69
+ for idxs in indices
70
+ idxs isa Number && continue
72
71
contiguous = all (isone, diff (idxs))
73
72
# XXX : We want to throw error even for dynamic indexing
74
- if typeof (a) <: Bool
75
- contiguous || error (" non-contiguous indexing is not supported" )
73
+ if typeof (contiguous) <: Bool && ! contiguous
74
+ non_contiguous_getindex = true
75
+ break
76
+ end
77
+ end
78
+
79
+ if non_contiguous_getindex
80
+ indices_tuples = collect (Iterators. product (indices... ))
81
+ indices = Matrix {Int} (
82
+ undef, (length (indices_tuples), length (first (indices_tuples)))
83
+ )
84
+ for (i, idx) in enumerate (indices_tuples)
85
+ indices[i, :] .= idx .- 1
76
86
end
87
+ indices = TracedUtils. promote_to (TracedRArray{Int,2 }, indices)
88
+ res = Ops. gather_getindex (a, indices)
89
+ return Ops. reshape (res, size (indices_tuples)... )
77
90
end
78
91
79
92
start_indices = map (indices) do i
@@ -99,16 +112,41 @@ function Base.getindex(a::WrappedTracedRArray, indices...)
99
112
return getindex (ancestor (a), TracedUtils. get_ancestor_indices (a, indices... )... )
100
113
end
101
114
102
- function Base. setindex! (
103
- a:: TracedRArray{T,N} ,
104
- v,
105
- indices:: Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N} ,
106
- ) where {T,N}
115
+ function Base. setindex! (a:: TracedRArray{T,N} , v, indices:: Vararg{Any,N} ) where {T,N}
107
116
indices = map (enumerate (indices)) do (idx, i)
108
- i isa Int ? (i: i) : (i isa Colon ? (1 : size (a, idx)) : i)
117
+ i isa Colon && return 1 : size (a, idx)
118
+ i isa CartesianIndex && return Tuple (i)
119
+ return i
120
+ end
121
+
122
+ non_contiguous_setindex = false
123
+ for idxs in indices
124
+ idxs isa Number && continue
125
+ contiguous = all (isone, diff (idxs))
126
+ # XXX : We want to throw error even for dynamic indexing
127
+ if typeof (contiguous) <: Bool && ! contiguous
128
+ non_contiguous_setindex = true
129
+ break
130
+ end
131
+ end
132
+
133
+ if non_contiguous_setindex
134
+ indices_tuples = collect (Iterators. product (indices... ))
135
+ indices = Matrix {Int} (
136
+ undef, (length (indices_tuples), length (first (indices_tuples)))
137
+ )
138
+ for (i, idx) in enumerate (indices_tuples)
139
+ indices[i, :] .= idx .- 1
140
+ end
141
+ indices = TracedUtils. promote_to (TracedRArray{Int,2 }, indices)
142
+ res = Ops. scatter_setindex (a, indices, Ops. reshape (v, length (v)))
143
+ a. mlir_data = res. mlir_data
144
+ return v
109
145
end
146
+
110
147
v = TracedUtils. broadcast_to_size (v, length .(indices))
111
148
v = TracedUtils. promote_to (TracedRArray{T,N}, v)
149
+
112
150
indices = [
113
151
(
114
152
TracedUtils. promote_to (TracedRNumber{Int}, i isa Colon ? 1 : first (i)) - 1
@@ -124,11 +162,7 @@ function Base.setindex!(
124
162
return v
125
163
end
126
164
127
- function Base. setindex! (
128
- a:: AnyTracedRArray{T,N} ,
129
- v,
130
- indices:: Vararg{Union{Base.AbstractUnitRange,Colon,Int,TracedRNumber{Int}},N} ,
131
- ) where {T,N}
165
+ function Base. setindex! (a:: AnyTracedRArray{T,N} , v, indices:: Vararg{Any,N} ) where {T,N}
132
166
ancestor_indices = TracedUtils. get_ancestor_indices (a, indices... )
133
167
setindex! (ancestor (a), v, ancestor_indices... )
134
168
return a
0 commit comments