@@ -141,16 +141,23 @@ function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {
141
141
end
142
142
143
143
function _getindex_linear (a:: TracedRArray{T,N} , indices:: AbstractArray ) where {T,N}
144
+ if ! (indices isa Reactant. TracedType) && __contiguous_indices (vec (indices))
145
+ a_flat = materialize_traced_array (vec (a))
146
+ indices_flat = vec (indices)
147
+ return Ops. reshape (
148
+ Ops. dynamic_slice (a_flat, [first (indices_flat)], [length (indices_flat)]),
149
+ collect (size (indices)),
150
+ )
151
+ end
152
+
144
153
if ! (indices isa TracedRArray)
145
154
indices = collect (indices)
146
155
eltype (indices) <: CartesianIndex && (indices = LinearIndices (size (a))[indices])
147
156
indices = TracedUtils. promote_to (TracedRArray{Int,ndims (indices)}, indices)
148
157
end
149
- return materialize_traced_array (
150
- reshape (
151
- Ops. gather_getindex (a, scalar_index_to_cartesian (vec (indices), size (a))),
152
- size (indices),
153
- ),
158
+ return Ops. reshape (
159
+ Ops. gather_getindex (a, scalar_index_to_cartesian (vec (indices), size (a))),
160
+ collect (size (indices)),
154
161
)
155
162
end
156
163
@@ -279,34 +286,12 @@ end
279
286
function Base. setindex! (
280
287
a:: TracedRArray{T,N} , v, index:: Union{Int,TracedRNumber{Int}}
281
288
) where {T,N}
282
- _setindex_scalar! (a, v, index)
283
- return a
289
+ return _setindex_scalar! (a, v, index)
284
290
end
285
-
286
- # Avoid ambiguity
287
291
function Base. setindex! (
288
292
a:: TracedRArray{T,1} , v, index:: Union{Int,TracedRNumber{Int}}
289
293
) where {T}
290
- _setindex_scalar! (a, v, index)
291
- return a
292
- end
293
-
294
- function _setindex_unitrange! (a, v, indices)
295
- originalsz = size (a)
296
- flattened = Ops. reshape (a, [prod (originalsz)])
297
- result = Ops. dynamic_update_slice (flattened, v[begin : length (indices)], [first (indices)])
298
- result = Ops. reshape (result, collect (originalsz))
299
- set_mlir_data! (a, get_mlir_data (result))
300
- return a
301
- end
302
-
303
- function Base. setindex! (
304
- a:: Reactant.TracedRArray{T,N} , v, indices:: UnitRange{Int}
305
- ) where {T,N}
306
- return _setindex_unitrange! (a, v, indices)
307
- end
308
- function Base. setindex! (a:: Reactant.TracedRArray{T,1} , v, indices:: UnitRange{Int} ) where {T}
309
- return _setindex_unitrange! (a, v, indices)
294
+ return _setindex_scalar! (a, v, index)
310
295
end
311
296
312
297
function Base. setindex! (a:: TracedRArray{T,N} , v, index:: CartesianIndex{N} ) where {T,N}
@@ -324,7 +309,22 @@ function Base.setindex!(a::TracedRArray{T,N}, v, index::CartesianIndex{N}) where
324
309
return a
325
310
end
326
311
327
- function Base. setindex! (a:: TracedRArray{T,N} , v, indices) where {T,N}
312
+ function _setindex_linear! (a:: TracedRArray{T,N} , v, indices:: AbstractArray ) where {T,N}
313
+ if ! (indices isa Reactant. TracedType) && __contiguous_indices (vec (indices))
314
+ res = Ops. reshape (
315
+ Ops. dynamic_update_slice (
316
+ materialize_traced_array (vec (a)),
317
+ TracedUtils. broadcast_to_size (
318
+ TracedUtils. promote_to (TracedRArray{T,1 }, vec (v)), (length (indices),)
319
+ ),
320
+ [first (indices)],
321
+ ),
322
+ collect (size (a)),
323
+ )
324
+ set_mlir_data! (a, get_mlir_data (res))
325
+ return a
326
+ end
327
+
328
328
if ! (indices isa TracedRArray)
329
329
indices = collect (indices)
330
330
eltype (indices) <: CartesianIndex && (indices = LinearIndices (size (a))[indices])
@@ -339,6 +339,13 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N}
339
339
return a
340
340
end
341
341
342
+ function Base. setindex! (a:: TracedRArray{T,N} , v, indices:: AbstractArray ) where {T,N}
343
+ return _setindex_linear! (a, v, indices)
344
+ end
345
+ function Base. setindex! (a:: TracedRArray{T,1} , v, indices:: AbstractArray ) where {T,N}
346
+ return _setindex_linear! (a, v, indices)
347
+ end
348
+
342
349
function Base. setindex! (a:: TracedRArray{T,N} , v, indices:: Vararg{Any,N} ) where {T,N}
343
350
if (N == 1 ) && (indices isa Colon)
344
351
# Remove ambiguity from the previous
0 commit comments