Skip to content

Commit 32ac4c5

Browse files
authored
fix: setindex with traced indices (#1197)
1 parent 832a6da commit 32ac4c5

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/TracedRArray.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,9 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
377377
v = TracedUtils.promote_to(TracedRArray{T,N}, v)
378378
else
379379
v = TracedUtils.promote_to(TracedRArray{T,ndims(v)}, v)
380-
non_integer_indices = [!(idx isa Integer) for idx in indices]
380+
non_integer_indices = [
381+
!(idx isa Union{Integer,TracedRNumber{<:Integer}}) for idx in indices
382+
]
381383
broadcast_dims = findall(non_integer_indices)
382384
if length(broadcast_dims) == N
383385
v = TracedUtils.broadcast_to_size(v, length.(indices))

test/control_flow.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -836,3 +836,19 @@ end
836836
recursivenottraced.x[3].x = recursivenottraced
837837
@test !Reactant.ReactantCore.is_traced(recursivenottraced)
838838
end
839+
840+
function loop_batched(x)
841+
y = similar(x)
842+
@trace for i in 1:size(x, 1)
843+
y[i, :] = x[i, :] .+ 1
844+
y[i, :] = y[i, :] .^ 2
845+
end
846+
return y
847+
end
848+
849+
@testset "setindex: batched" begin
850+
x = rand(1024, 128)
851+
x_ra = Reactant.to_rarray(x)
852+
853+
@test @jit(loop_batched(x_ra)) loop_batched(x)
854+
end

0 commit comments

Comments
 (0)