@@ -86,6 +86,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
86
86
return data
87
87
# XLA.from_row_major(data)
88
88
end
89
+ Base. Array (x:: ConcreteRArray ) = convert (Array, x)
89
90
90
91
function synchronize (x:: Union{ConcreteRArray,ConcreteRNumber} )
91
92
XLA. synced_buffer (x. data)
@@ -145,6 +146,20 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
145
146
end
146
147
end
147
148
149
+ function Base. isapprox (x:: ConcreteRArray , y:: AbstractArray ; kwargs... )
150
+ return Base. isapprox (convert (Array, x), convert (Array, y); kwargs... )
151
+ end
152
+ function Base. isapprox (x:: AbstractArray , y:: ConcreteRArray ; kwargs... )
153
+ return Base. isapprox (convert (Array, x), convert (Array, y); kwargs... )
154
+ end
155
+ function Base. isapprox (x:: ConcreteRArray , y:: ConcreteRArray ; kwargs... )
156
+ return Base. isapprox (convert (Array, x), convert (Array, y); kwargs... )
157
+ end
158
+
159
+ Base.:(== )(x:: ConcreteRArray , y:: AbstractArray ) = convert (Array, x) == convert (Array, y)
160
+ Base.:(== )(x:: AbstractArray , y:: ConcreteRArray ) = convert (Array, x) == convert (Array, y)
161
+ Base.:(== )(x:: ConcreteRArray , y:: ConcreteRArray ) = convert (Array, x) == convert (Array, y)
162
+
148
163
function Base. show (io:: IO , X:: ConcreteRScalar{T} ) where {T}
149
164
if X. data == XLA. AsyncEmptyBuffer
150
165
println (io, " <Empty buffer>" )
@@ -171,12 +186,11 @@ function Base.show(io::IO, X::ConcreteRArray)
171
186
return print (io, " $(typeof (X)) ($(str) )" )
172
187
end
173
188
174
- const getindex_warned = Ref (false )
175
189
function Base. getindex (a:: ConcreteRArray{T} , args:: Vararg{Int,N} ) where {T,N}
176
190
if a. data == XLA. AsyncEmptyBuffer
177
191
throw (" Cannot getindex from empty buffer" )
178
192
end
179
- # error("""Scalar indexing is disallowed.""")
193
+
180
194
XLA. await (a. data)
181
195
if XLA. BufferOnCPU (a. data. buffer)
182
196
buf = a. data. buffer
@@ -193,16 +207,8 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
193
207
return unsafe_load (ptr, start)
194
208
end
195
209
end
196
- if ! getindex_warned[]
197
- @warn (
198
- """ Performing scalar get-indexing on task $(current_task ()) .
199
- Invocation resulted in scalar indexing of a ConcreteRArray.
200
- This is typically caused by calling an iterating implementation of a method.
201
- Such implementations *do not* execute on device, but very slowly on the CPU,
202
- and require expensive copies and synchronization each time and therefore should be avoided."""
203
- )
204
- getindex_warned[] = true
205
- end
210
+
211
+ GPUArraysCore. assertscalar (" getindex(::ConcreteRArray, ::Vararg{Int, N})" )
206
212
return convert (Array, a)[args... ]
207
213
end
208
214
@@ -211,12 +217,11 @@ function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
211
217
return nothing
212
218
end
213
219
214
- const setindex_warned = Ref (false )
215
-
216
220
function Base. setindex! (a:: ConcreteRArray{T} , v, args:: Vararg{Int,N} ) where {T,N}
217
221
if a. data == XLA. AsyncEmptyBuffer
218
222
throw (" Cannot setindex! to empty buffer" )
219
223
end
224
+
220
225
XLA. await (a. data)
221
226
if XLA. BufferOnCPU (a. data. buffer)
222
227
buf = a. data. buffer
@@ -234,19 +239,8 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
234
239
end
235
240
return a
236
241
end
237
- if ! setindex_warned[]
238
- @warn (
239
- """ Performing scalar set-indexing on task $(current_task ()) .
240
- Invocation resulted in scalar indexing of a ConcreteRArray.
241
- This is typically caused by calling an iterating implementation of a method.
242
- Such implementations *do not* execute on device, but very slowly on the CPU,
243
- and require expensive copies and synchronization each time and therefore should be avoided.
244
-
245
- This error message will only be printed for the first invocation for brevity.
246
- """
247
- )
248
- setindex_warned[] = true
249
- end
242
+
243
+ GPUArraysCore. assertscalar (" setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})" )
250
244
fn = Reactant. compile (mysetindex!, (a, v, args... ))
251
245
fn (a, v, args... )
252
246
return a
0 commit comments