58
58
59
59
Base. getindex (a:: TracedRArray{T,0} ) where {T} = TracedRNumber {T} ((), a. mlir_data)
60
60
61
+ function generate_index_list (i1, is... )
62
+ list = reshape (i1, :, 1 ) .- 1
63
+ for i in is
64
+ i = reshape (i, :, 1 )
65
+ lorig = size (list, 1 )
66
+ list = repeat (list, size (i, 1 ), 1 )
67
+ i = repeat (i; inner= (lorig, 1 )) .- 1
68
+ list = hcat (list, i)
69
+ end
70
+ return list
71
+ end
72
+
73
+ function scalar_index_to_cartesian (idx:: AbstractVector{T} , sz:: NTuple{N,Int} ) where {T,N}
74
+ idx = idx .- 1
75
+ idxs = materialize_traced_array (reshape (idx .% T (sz[1 ]), :, 1 ))
76
+ idx = idx .÷ T (sz[1 ])
77
+ for i in 2 : N
78
+ idxs = hcat (idxs, idx .% T (sz[i]))
79
+ idx = idx .÷ T (sz[i])
80
+ end
81
+ return idxs
82
+ end
83
+
84
+ function Base. getindex (
85
+ a:: TracedRArray{T,N} , indices:: Union{Int,TracedRNumber{Int}}
86
+ ) where {T,N}
87
+ if indices isa Int
88
+ indices = TracedUtils. promote_to (TracedRNumber{Int}, indices)
89
+ end
90
+ indices = TracedUtils. broadcast_to_size (indices, (1 ,))
91
+ return Ops. gather_getindex (a, scalar_index_to_cartesian (indices, size (a)))[1 ]
92
+ end
93
+
94
+ function Base. getindex (a:: TracedRArray{T,N} , indices) where {T,N}
95
+ if ! (indices isa TracedRArray)
96
+ indices = TracedUtils. promote_to (TracedRArray{Int,1 }, collect (indices))
97
+ end
98
+ return Ops. gather_getindex (a, scalar_index_to_cartesian (indices, size (a)))
99
+ end
100
+
101
+ Base. getindex (a:: TracedRArray{T,N} , :: Colon ) where {T,N} = materialize_traced_array (vec (a))
102
+
103
+ function Base. getindex (a:: TracedRArray{T,N} , indices:: CartesianIndex{N} ) where {T,N}
104
+ indices =
105
+ materialize_traced_array (
106
+ reshape (
107
+ TracedUtils. promote_to (TracedRArray{Int,1 }, vcat (Tuple (indices)... )), 1 , N
108
+ ),
109
+ ) .- 1
110
+ return Ops. gather_getindex (a, indices)[1 ]
111
+ end
112
+
61
113
function Base. getindex (a:: TracedRArray{T,N} , indices:: Vararg{Any,N} ) where {T,N}
62
114
indices = map (enumerate (indices)) do (idx, i)
63
115
i isa Colon && return 1 : size (a, idx)
64
116
i isa CartesianIndex && return Tuple (i)
65
117
return i
66
118
end
67
119
68
- non_contiguous_getindex = false
120
+ use_gather_getindex = false
69
121
for idxs in indices
70
122
idxs isa Number && continue
123
+ if idxs isa Reactant. TracedType
124
+ use_gather_getindex = true
125
+ break
126
+ end
71
127
contiguous = all (isone, diff (idxs))
72
- # XXX : We want to throw error even for dynamic indexing
73
128
if typeof (contiguous) <: Bool && ! contiguous
74
- non_contiguous_getindex = true
129
+ use_gather_getindex = true
75
130
break
76
131
end
77
132
end
78
133
79
- if non_contiguous_getindex
80
- indices_tuples = collect (Iterators. product (indices... ))
81
- indices = Matrix {Int} (
82
- undef, (length (indices_tuples), length (first (indices_tuples)))
83
- )
84
- for (i, idx) in enumerate (indices_tuples)
85
- indices[i, :] .= idx .- 1
86
- end
87
- indices = TracedUtils. promote_to (TracedRArray{Int,2 }, indices)
88
- res = Ops. gather_getindex (a, indices)
89
- return Ops. reshape (res, size (indices_tuples)... )
134
+ if use_gather_getindex
135
+ indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), indices)
136
+ indices_list = generate_index_list (indices_list... )
137
+ res = Ops. gather_getindex (a, indices_list)
138
+ return Ops. reshape (res, length .(indices)... )
90
139
end
91
140
92
141
start_indices = map (indices) do i
@@ -99,7 +148,7 @@ function Base.getindex(a::TracedRArray{T,N}, indices::Vararg{Any,N}) where {T,N}
99
148
100
149
x = TracedRArray {T,N} ((), res, Tuple (length .(indices)))
101
150
ddims = findall (Base. Fix2 (isa, Integer), indices)
102
- isempty (ddims) || return dropdims (x; dims= Tuple (ddims))
151
+ isempty (ddims) || return materialize_traced_array ( dropdims (x; dims= Tuple (ddims) ))
103
152
return x
104
153
end
105
154
@@ -119,27 +168,24 @@ function Base.setindex!(a::TracedRArray{T,N}, v, indices::Vararg{Any,N}) where {
119
168
return i
120
169
end
121
170
122
- non_contiguous_setindex = false
171
+ use_scatter_setindex = false
123
172
for idxs in indices
124
173
idxs isa Number && continue
174
+ if idxs isa Reactant. TracedType
175
+ use_scatter_setindex = true
176
+ break
177
+ end
125
178
contiguous = all (isone, diff (idxs))
126
- # XXX : We want to throw error even for dynamic indexing
127
179
if typeof (contiguous) <: Bool && ! contiguous
128
- non_contiguous_setindex = true
180
+ use_scatter_setindex = true
129
181
break
130
182
end
131
183
end
132
184
133
- if non_contiguous_setindex
134
- indices_tuples = collect (Iterators. product (indices... ))
135
- indices = Matrix {Int} (
136
- undef, (length (indices_tuples), length (first (indices_tuples)))
137
- )
138
- for (i, idx) in enumerate (indices_tuples)
139
- indices[i, :] .= idx .- 1
140
- end
141
- indices = TracedUtils. promote_to (TracedRArray{Int,2 }, indices)
142
- res = Ops. scatter_setindex (a, indices, Ops. reshape (v, length (v)))
185
+ if use_scatter_setindex
186
+ indices_list = map (Base. Fix1 (TracedUtils. promote_to, TracedRArray{Int,1 }), indices)
187
+ indices_list = generate_index_list (indices_list... )
188
+ res = Ops. scatter_setindex (a, indices_list, Ops. reshape (v, length (v)))
143
189
a. mlir_data = res. mlir_data
144
190
return v
145
191
end
@@ -512,15 +558,16 @@ Base.all(f::Function, x::AnyTracedRArray) = mapreduce(f, &, x)
512
558
Base. any (f:: Function , x:: AnyTracedRArray ) = mapreduce (f, | , x)
513
559
514
560
# outer repeat
515
- # Overridden because we don't need to further recur into the definitions here
516
- function Base. repeat (x:: AnyTracedRArray{T,N} , counts:: Vararg{Int,M} ) where {T,N,M}
561
+ function Base. _RepeatInnerOuter. repeat_outer (
562
+ x:: AnyTracedRArray{T,N} , counts:: NTuple{M,Int}
563
+ ) where {T,N,M}
517
564
P = max (N, M) # potentially padded
518
565
519
566
# (d1, d2, ..., dP) -> (d1, 1, d2, 1, ..., dP, 1)
520
567
interleaved_size = ones (Int, 2 P)
521
568
interleaved_size[1 : 2 : (2 N)] .= size (x)
522
569
523
- x_interleaved = reshape (x , interleaved_size... )
570
+ x_interleaved = reshape (materialize_traced_array (x) , interleaved_size... )
524
571
525
572
# (d1, 1, d2, 1, ..., dP, 1) -> (d1, r1, d2, r2, ..., dP, rP)
526
573
broadcast_target_size = interleaved_size
@@ -531,9 +578,31 @@ function Base.repeat(x::AnyTracedRArray{T,N}, counts::Vararg{Int,M}) where {T,N,
531
578
# (d1, r1, d2, r2, ..., dP, rP) -> (d1*r1, d2*r2, ..., dP*rP)
532
579
final_size = vec (prod (reshape (broadcast_target_size, 2 , :); dims= 1 ))
533
580
534
- x_final = reshape (x_broadcasted, final_size... )
581
+ return materialize_traced_array (reshape (x_broadcasted, final_size... ))
582
+ end
583
+
584
+ # inner repeat
585
+ function Base. _RepeatInnerOuter. repeat_inner (
586
+ x:: AnyTracedRArray{T,N} , counts:: NTuple{M,Int}
587
+ ) where {T,N,M}
588
+ P = max (N, M) # potentially padded
589
+
590
+ # (d1, d2, ..., dP) -> (1, d1, 1, d2, 1, ..., 1, dP)
591
+ interleaved_size = ones (Int, 2 P)
592
+ interleaved_size[2 : 2 : (2 N)] .= size (x)
593
+
594
+ x_interleaved = reshape (materialize_traced_array (x), interleaved_size... )
595
+
596
+ # (1, d1, 1, d2, 1, ..., 1, dP) -> (r1, d1, r2, d2, ..., rP, dP)
597
+ broadcast_target_size = interleaved_size
598
+ broadcast_target_size[1 : 2 : (2 N)] .= counts
599
+
600
+ x_broadcasted = TracedUtils. broadcast_to_size (x_interleaved, broadcast_target_size)
601
+
602
+ # (r1, d1, r2, d2, ..., rP, dP) -> (d1*r1, d2*r2, ..., dP*rP)
603
+ final_size = vec (prod (reshape (broadcast_target_size, 2 , :); dims= 1 ))
535
604
536
- return x_final
605
+ return materialize_traced_array ( reshape (x_broadcasted, final_size ... ))
537
606
end
538
607
539
608
end
0 commit comments