Skip to content

Commit 599d487

Browse files
committed
fix: more cleanup
1 parent e54ccad commit 599d487

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

src/TracedRArray.jl

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -141,16 +141,23 @@ function Base.getindex(a::TracedRArray{T,1}, indices::CartesianIndex{1}) where {
141141
end
142142

143143
function _getindex_linear(a::TracedRArray{T,N}, indices::AbstractArray) where {T,N}
144+
if !(indices isa Reactant.TracedType) && __contiguous_indices(vec(indices))
145+
a_flat = materialize_traced_array(vec(a))
146+
indices_flat = vec(indices)
147+
return Ops.reshape(
148+
Ops.dynamic_slice(a_flat, [first(indices_flat)], [length(indices_flat)]),
149+
collect(size(indices)),
150+
)
151+
end
152+
144153
if !(indices isa TracedRArray)
145154
indices = collect(indices)
146155
eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices])
147156
indices = TracedUtils.promote_to(TracedRArray{Int,ndims(indices)}, indices)
148157
end
149-
return materialize_traced_array(
150-
reshape(
151-
Ops.gather_getindex(a, scalar_index_to_cartesian(vec(indices), size(a))),
152-
size(indices),
153-
),
158+
return Ops.reshape(
159+
Ops.gather_getindex(a, scalar_index_to_cartesian(vec(indices), size(a))),
160+
collect(size(indices)),
154161
)
155162
end
156163

@@ -279,34 +286,12 @@ end
279286
function Base.setindex!(
280287
a::TracedRArray{T,N}, v, index::Union{Int,TracedRNumber{Int}}
281288
) where {T,N}
282-
_setindex_scalar!(a, v, index)
283-
return a
289+
return _setindex_scalar!(a, v, index)
284290
end
285-
286-
# Avoid ambiguity
287291
function Base.setindex!(
288292
a::TracedRArray{T,1}, v, index::Union{Int,TracedRNumber{Int}}
289293
) where {T}
290-
_setindex_scalar!(a, v, index)
291-
return a
292-
end
293-
294-
function _setindex_unitrange!(a, v, indices)
295-
originalsz = size(a)
296-
flattened = Ops.reshape(a, [prod(originalsz)])
297-
result = Ops.dynamic_update_slice(flattened, v[begin:length(indices)], [first(indices)])
298-
result = Ops.reshape(result, collect(originalsz))
299-
set_mlir_data!(a, get_mlir_data(result))
300-
return a
301-
end
302-
303-
function Base.setindex!(
304-
a::Reactant.TracedRArray{T,N}, v, indices::UnitRange{Int}
305-
) where {T,N}
306-
return _setindex_unitrange!(a, v, indices)
307-
end
308-
function Base.setindex!(a::Reactant.TracedRArray{T,1}, v, indices::UnitRange{Int}) where {T}
309-
return _setindex_unitrange!(a, v, indices)
294+
return _setindex_scalar!(a, v, index)
310295
end
311296

312297
function Base.setindex!(a::TracedRArray{T,N}, v, index::CartesianIndex{N}) where {T,N}
@@ -324,7 +309,22 @@ function Base.setindex!(a::TracedRArray{T,N}, v, index::CartesianIndex{N}) where
324309
return a
325310
end
326311

327-
function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N}
312+
function _setindex_linear!(a::TracedRArray{T,N}, v, indices::AbstractArray) where {T,N}
313+
if !(indices isa Reactant.TracedType) && __contiguous_indices(vec(indices))
314+
res = Ops.reshape(
315+
Ops.dynamic_update_slice(
316+
materialize_traced_array(vec(a)),
317+
TracedUtils.broadcast_to_size(
318+
TracedUtils.promote_to(TracedRArray{T,1}, vec(v)), (length(indices),)
319+
),
320+
[first(indices)],
321+
),
322+
collect(size(a)),
323+
)
324+
set_mlir_data!(a, get_mlir_data(res))
325+
return a
326+
end
327+
328328
if !(indices isa TracedRArray)
329329
indices = collect(indices)
330330
eltype(indices) <: CartesianIndex && (indices = LinearIndices(size(a))[indices])
@@ -339,6 +339,13 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices) where {T,N}
339339
return a
340340
end
341341

342+
function Base.setindex!(a::TracedRArray{T,N}, v, indices::AbstractArray) where {T,N}
343+
return _setindex_linear!(a, v, indices)
344+
end
345+
function Base.setindex!(a::TracedRArray{T,1}, v, indices::AbstractArray) where {T,N}
346+
return _setindex_linear!(a, v, indices)
347+
end
348+
342349
function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {T,N}
343350
if (N == 1) && (indices isa Colon)
344351
# Remove ambiguity from the previous

0 commit comments

Comments
 (0)