Skip to content

Commit 241fd14

Browse files
authored
feat: indexing using traced values (#434)
* feat: indexing using traced values * feat: implement repeat inner * feat: support scalar linear indexing + tests * fix: regression in cartesian index support * Update src/TracedRArray.jl
1 parent 7d2b898 commit 241fd14

File tree

4 files changed

+173
-45
lines changed

4 files changed

+173
-45
lines changed

ext/ReactantNNlibExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
323323
This case is not optimized and will be slow." maxlog = 1
324324
dims = NNlib.scatter_dims(src, dst, idxs)
325325
colons = ntuple(Returns(Colon()), dims)
326-
start_sizes = ntuple(i -> size(src, i), dims)
326+
start_sizes = ntuple(Base.Fix1(size, src), dims)
327327
results = map(CartesianIndices(idxs)) do k
328328
res = @allowscalar src[colons..., Tuple(idxs[k])...]
329329
res isa TracedRNumber && (res = TracedUtils.broadcast_to_size(res, (1,)))

src/TracedRArray.jl

Lines changed: 102 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -58,35 +58,84 @@ end
5858

5959
Base.getindex(a::TracedRArray{T,0}) where {T} = TracedRNumber{T}((), a.mlir_data)
6060

61+
function generate_index_list(i1, is...)
62+
list = reshape(i1, :, 1) .- 1
63+
for i in is
64+
i = reshape(i, :, 1)
65+
lorig = size(list, 1)
66+
list = repeat(list, size(i, 1), 1)
67+
i = repeat(i; inner=(lorig, 1)) .- 1
68+
list = hcat(list, i)
69+
end
70+
return list
71+
end
72+
73+
function scalar_index_to_cartesian(idx::AbstractVector{T}, sz::NTuple{N,Int}) where {T,N}
74+
idx = idx .- 1
75+
idxs = materialize_traced_array(reshape(idx .% T(sz[1]), :, 1))
76+
idx = idx T(sz[1])
77+
for i in 2:N
78+
idxs = hcat(idxs, idx .% T(sz[i]))
79+
idx = idx T(sz[i])
80+
end
81+
return idxs
82+
end
83+
84+
function Base.getindex(
85+
a::TracedRArray{T,N}, indices::Union{Int,TracedRNumber{Int}}
86+
) where {T,N}
87+
if indices isa Int
88+
indices = TracedUtils.promote_to(TracedRNumber{Int}, indices)
89+
end
90+
indices = TracedUtils.broadcast_to_size(indices, (1,))
91+
return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))[1]
92+
end
93+
94+
function Base.getindex(a::TracedRArray{T,N}, indices) where {T,N}
95+
if !(indices isa TracedRArray)
96+
indices = TracedUtils.promote_to(TracedRArray{Int,1}, collect(indices))
97+
end
98+
return Ops.gather_getindex(a, scalar_index_to_cartesian(indices, size(a)))
99+
end
100+
101+
Base.getindex(a::TracedRArray{T,N}, ::Colon) where {T,N} = materialize_traced_array(vec(a))
102+
103+
function Base.getindex(a::TracedRArray{T,N}, indices::CartesianIndex{N}) where {T,N}
104+
indices =
105+
materialize_traced_array(
106+
reshape(
107+
TracedUtils.promote_to(TracedRArray{Int,1}, vcat(Tuple(indices)...)), 1, N
108+
),
109+
) .- 1
110+
return Ops.gather_getindex(a, indices)[1]
111+
end
112+
61113
function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
62114
indices = map(enumerate(indices)) do (idx, i)
63115
i isa Colon && return 1:size(a, idx)
64116
i isa CartesianIndex && return Tuple(i)
65117
return i
66118
end
67119

68-
non_contiguous_getindex = false
120+
use_gather_getindex = false
69121
for idxs in indices
70122
idxs isa Number && continue
123+
if idxs isa Reactant.TracedType
124+
use_gather_getindex = true
125+
break
126+
end
71127
contiguous = all(isone, diff(idxs))
72-
# XXX: We want to throw error even for dynamic indexing
73128
if typeof(contiguous) <: Bool && !contiguous
74-
non_contiguous_getindex = true
129+
use_gather_getindex = true
75130
break
76131
end
77132
end
78133

79-
if non_contiguous_getindex
80-
indices_tuples = collect(Iterators.product(indices...))
81-
indices = Matrix{Int}(
82-
undef, (length(indices_tuples), length(first(indices_tuples)))
83-
)
84-
for (i, idx) in enumerate(indices_tuples)
85-
indices[i, :] .= idx .- 1
86-
end
87-
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
88-
res = Ops.gather_getindex(a, indices)
89-
return Ops.reshape(res, size(indices_tuples)...)
134+
if use_gather_getindex
135+
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
136+
indices_list = generate_index_list(indices_list...)
137+
res = Ops.gather_getindex(a, indices_list)
138+
return Ops.reshape(res, length.(indices)...)
90139
end
91140

92141
start_indices = map(indices) do i
@@ -99,7 +148,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
99148

100149
x = TracedRArray{T,N}((), res, Tuple(length.(indices)))
101150
ddims = findall(Base.Fix2(isa, Integer), indices)
102-
isempty(ddims) || return dropdims(x; dims=Tuple(ddims))
151+
isempty(ddims) || return materialize_traced_array(dropdims(x; dims=Tuple(ddims)))
103152
return x
104153
end
105154

@@ -119,27 +168,24 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
119168
return i
120169
end
121170

122-
non_contiguous_setindex = false
171+
use_scatter_setindex = false
123172
for idxs in indices
124173
idxs isa Number && continue
174+
if idxs isa Reactant.TracedType
175+
use_scatter_setindex = true
176+
break
177+
end
125178
contiguous = all(isone, diff(idxs))
126-
# XXX: We want to throw error even for dynamic indexing
127179
if typeof(contiguous) <: Bool && !contiguous
128-
non_contiguous_setindex = true
180+
use_scatter_setindex = true
129181
break
130182
end
131183
end
132184

133-
if non_contiguous_setindex
134-
indices_tuples = collect(Iterators.product(indices...))
135-
indices = Matrix{Int}(
136-
undef, (length(indices_tuples), length(first(indices_tuples)))
137-
)
138-
for (i, idx) in enumerate(indices_tuples)
139-
indices[i, :] .= idx .- 1
140-
end
141-
indices = TracedUtils.promote_to(TracedRArray{Int,2}, indices)
142-
res = Ops.scatter_setindex(a, indices, Ops.reshape(v, length(v)))
185+
if use_scatter_setindex
186+
indices_list = map(Base.Fix1(TracedUtils.promote_to, TracedRArray{Int,1}), indices)
187+
indices_list = generate_index_list(indices_list...)
188+
res = Ops.scatter_setindex(a, indices_list, Ops.reshape(v, length(v)))
143189
a.mlir_data = res.mlir_data
144190
return v
145191
end
@@ -512,15 +558,16 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x)
512558
Base.any(f::Function, x::AnyTracedRArray) = mapreduce(f, |, x)
513559

514560
# outer repeat
515-
# Overridden because we don't need to further recur into the definitions here
516-
function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,M}
561+
function Base._RepeatInnerOuter.repeat_outer(
562+
x::AnyTracedRArray{T,N}, counts::NTuple{M,Int}
563+
) where {T,N,M}
517564
P = max(N, M) # potentially padded
518565

519566
# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
520567
interleaved_size = ones(Int, 2P)
521568
interleaved_size[1:2:(2N)] .= size(x)
522569

523-
x_interleaved = reshape(x, interleaved_size...)
570+
x_interleaved = reshape(materialize_traced_array(x), interleaved_size...)
524571

525572
# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
526573
broadcast_target_size = interleaved_size
@@ -531,9 +578,31 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,
531578
# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
532579
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))
533580

534-
x_final = reshape(x_broadcasted, final_size...)
581+
return materialize_traced_array(reshape(x_broadcasted, final_size...))
582+
end
583+
584+
# inner repeat
585+
function Base._RepeatInnerOuter.repeat_inner(
586+
x::AnyTracedRArray{T,N}, counts::NTuple{M,Int}
587+
) where {T,N,M}
588+
P = max(N, M) # potentially padded
589+
590+
# (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP)
591+
interleaved_size = ones(Int, 2P)
592+
interleaved_size[2:2:(2N)] .= size(x)
593+
594+
x_interleaved = reshape(materialize_traced_array(x), interleaved_size...)
595+
596+
# (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP)
597+
broadcast_target_size = interleaved_size
598+
broadcast_target_size[1:2:(2N)] .= counts
599+
600+
x_broadcasted = TracedUtils.broadcast_to_size(x_interleaved, broadcast_target_size)
601+
602+
# (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP)
603+
final_size = vec(prod(reshape(broadcast_target_size, 2, :); dims=1))
535604

536-
return x_final
605+
return materialize_traced_array(reshape(x_broadcasted, final_size...))
537606
end
538607

539608
end

src/TracedRNumber.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ for (jlop, hloop) in (
8484
(:(Base.:*), :multiply),
8585
(:(Base.:/), :divide),
8686
(:(Base.:^), :power),
87+
(:(Base.mod), :remainder),
88+
(:(Base.rem), :remainder),
8789
)
8890
@eval function $(jlop)(
8991
@nospecialize(lhs::TracedRNumber{T}), @nospecialize(rhs::TracedRNumber{T})
@@ -92,6 +94,10 @@ for (jlop, hloop) in (
9294
end
9395
end
9496

97+
function Base.div(@nospecialize(lhs::TracedRNumber{T}), rhs) where {T<:Integer}
98+
return Ops.divide(lhs, TracedUtils.promote_to(TracedRNumber{T}, rhs))
99+
end
100+
95101
function Base.div(
96102
@nospecialize(lhs::TracedRNumber{T}), rhs, ::typeof(RoundDown)
97103
) where {T<:Integer}

test/basic.jl

Lines changed: 64 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -365,12 +365,23 @@ end
365365
end
366366

367367
@testset "repeat" begin
368+
fn_inner(x, counts) = repeat(x; inner=counts)
369+
368370
@testset for (size, counts) in Iterators.product(
369371
[(2,), (2, 3), (2, 3, 4), (2, 3, 4, 5)],
370372
[(), (1,), (2,), (2, 1), (1, 2), (2, 2), (2, 2, 2), (1, 1, 1, 1, 1)],
371373
)
372374
x = rand(size...)
373-
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
375+
376+
@testset "outer repeat" begin
377+
@test (@jit repeat(Reactant.to_rarray(x), counts...)) == repeat(x, counts...)
378+
end
379+
380+
length(counts) < length(size) && continue
381+
382+
@testset "inner repeat" begin
383+
@test (@jit fn_inner(Reactant.to_rarray(x), counts)) == fn_inner(x, counts)
384+
end
374385
end
375386
end
376387

@@ -751,11 +762,11 @@ end
751762
x = rand(4, 2)
752763
x_ra = Reactant.to_rarray(x)
753764

754-
non_contiguous_indexing1(x) = x[[1, 3, 2], :]
755-
non_contiguous_indexing2(x) = x[:, [1, 2, 2]]
765+
non_contiguous_indexing3(x) = x[[1, 3, 2], :]
766+
non_contiguous_indexing4(x) = x[:, [1, 2, 2]]
756767

757-
@test @jit(non_contiguous_indexing1(x_ra)) non_contiguous_indexing1(x)
758-
@test @jit(non_contiguous_indexing2(x_ra)) non_contiguous_indexing2(x)
768+
@test @jit(non_contiguous_indexing3(x_ra)) non_contiguous_indexing3(x)
769+
@test @jit(non_contiguous_indexing4(x_ra)) non_contiguous_indexing4(x)
759770

760771
x = rand(4, 4, 3)
761772
x_ra = Reactant.to_rarray(x)
@@ -777,17 +788,59 @@ end
777788
x = rand(4, 2)
778789
x_ra = Reactant.to_rarray(x)
779790

780-
non_contiguous_indexing1!(x) = x[[1, 3, 2], :] .= 2
781-
non_contiguous_indexing2!(x) = x[:, [1, 2, 2]] .= 2
791+
non_contiguous_indexing3!(x) = x[[1, 3, 2], :] .= 2
792+
non_contiguous_indexing4!(x) = x[:, [1, 2, 2]] .= 2
782793

783-
@jit(non_contiguous_indexing1!(x_ra))
784-
non_contiguous_indexing1!(x)
794+
@jit(non_contiguous_indexing3!(x_ra))
795+
non_contiguous_indexing3!(x)
785796
@test x_ra x
786797

787798
x = rand(4, 2)
788799
x_ra = Reactant.to_rarray(x)
789800

790-
@jit(non_contiguous_indexing2!(x_ra))
791-
non_contiguous_indexing2!(x)
801+
@jit(non_contiguous_indexing4!(x_ra))
802+
non_contiguous_indexing4!(x)
792803
@test x_ra x
793804
end
805+
806+
@testset "indexing with traced arrays" begin
807+
x = rand(4, 4, 3)
808+
idx1 = [1, 3, 2]
809+
idx3 = [1, 2, 1, 3]
810+
811+
x_ra = Reactant.to_rarray(x)
812+
idx1_ra = Reactant.to_rarray(idx1)
813+
idx3_ra = Reactant.to_rarray(idx3)
814+
815+
getindex1(x, idx1) = x[idx1, :, :]
816+
getindex2(x, idx1) = x[:, idx1, :]
817+
getindex3(x, idx3) = x[:, :, idx3]
818+
getindex4(x, idx1, idx3) = x[idx1, :, idx3]
819+
820+
@test @jit(getindex1(x_ra, idx1_ra)) getindex1(x, idx1)
821+
@test @jit(getindex2(x_ra, idx1_ra)) getindex2(x, idx1)
822+
@test @jit(getindex3(x_ra, idx3_ra)) getindex3(x, idx3)
823+
@test @jit(getindex4(x_ra, idx1_ra, idx3_ra)) getindex4(x, idx1, idx3)
824+
end
825+
826+
@testset "linear indexing" begin
827+
x = rand(4, 4, 3)
828+
x_ra = Reactant.to_rarray(x)
829+
830+
getindex_linear_scalar(x, idx) = @allowscalar x[idx]
831+
832+
@testset for i in 1:length(x)
833+
@test @jit(getindex_linear_scalar(x_ra, i)) getindex_linear_scalar(x, i)
834+
@test @jit(
835+
getindex_linear_scalar(x_ra, Reactant.to_rarray(i; track_numbers=(Number,)))
836+
) getindex_linear_scalar(x, i)
837+
end
838+
839+
idx = rand(1:length(x), 8)
840+
idx_ra = Reactant.to_rarray(idx)
841+
842+
getindex_linear_vector(x, idx) = x[idx]
843+
844+
@test @jit(getindex_linear_vector(x_ra, idx_ra)) getindex_linear_vector(x, idx)
845+
@test @jit(getindex_linear_vector(x_ra, idx)) getindex_linear_vector(x, idx)
846+
end

0 commit comments

Comments
 (0)