Skip to content

Commit 5a60501

Browse files
authored
feat: inherit scalar indexing functionality from GPUArraysCore (#268)
* feat: inherit scalar indexing functionality from GPUArraysCore * chore: run formatter * fix: always warn inside tracing unless opt-out * chore: reexport @allowscalar * feat: add isapprox for array types * fix: test fixes for scalar indexing * fix: allow scalar indexing in gather
1 parent 05bd81f commit 5a60501

14 files changed

+92
-85
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
99
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
1010
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
11+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1112
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
1213
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1314
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -35,6 +36,7 @@ ArrayInterface = "7.10"
3536
CEnum = "0.4, 0.5"
3637
Downloads = "1.6"
3738
Enzyme = "0.13"
39+
GPUArraysCore = "0.1, 0.2"
3840
LinearAlgebra = "1.10"
3941
NNlib = "0.9"
4042
OrderedCollections = "1"

ext/ReactantNNlibExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module ReactantNNlibExt
22

33
using NNlib
4+
using GPUArraysCore: @allowscalar
45
using Reactant:
56
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
67
using ReactantCore: @trace
@@ -367,7 +368,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
367368
colons = ntuple(Returns(Colon()), dims)
368369
start_sizes = ntuple(i -> size(src, i), dims)
369370
results = map(CartesianIndices(idxs)) do k
370-
res = src[colons..., Tuple(idxs[k])...]
371+
res = @allowscalar src[colons..., Tuple(idxs[k])...]
371372
res isa TracedRNumber && (res = Reactant.broadcast_to_size(res, (1,)))
372373
return reshape(res, start_sizes..., :)
373374
end

src/ConcreteRArray.jl

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ function Base.convert(::Type{T}, X::ConcreteRArray{ElType,N}) where {T<:Array,El
8686
return data
8787
# XLA.from_row_major(data)
8888
end
89+
Base.Array(x::ConcreteRArray) = convert(Array, x)
8990

9091
function synchronize(x::Union{ConcreteRArray,ConcreteRNumber})
9192
XLA.synced_buffer(x.data)
@@ -145,6 +146,20 @@ for T in (ConcreteRNumber, ConcreteRArray{<:Any,0})
145146
end
146147
end
147148

149+
function Base.isapprox(x::ConcreteRArray, y::AbstractArray; kwargs...)
150+
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
151+
end
152+
function Base.isapprox(x::AbstractArray, y::ConcreteRArray; kwargs...)
153+
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
154+
end
155+
function Base.isapprox(x::ConcreteRArray, y::ConcreteRArray; kwargs...)
156+
return Base.isapprox(convert(Array, x), convert(Array, y); kwargs...)
157+
end
158+
159+
Base.:(==)(x::ConcreteRArray, y::AbstractArray) = convert(Array, x) == convert(Array, y)
160+
Base.:(==)(x::AbstractArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y)
161+
Base.:(==)(x::ConcreteRArray, y::ConcreteRArray) = convert(Array, x) == convert(Array, y)
162+
148163
function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
149164
if X.data == XLA.AsyncEmptyBuffer
150165
println(io, "<Empty buffer>")
@@ -171,12 +186,11 @@ function Base.show(io::IO, X::ConcreteRArray)
171186
return print(io, "$(typeof(X))($(str))")
172187
end
173188

174-
const getindex_warned = Ref(false)
175189
function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
176190
if a.data == XLA.AsyncEmptyBuffer
177191
throw("Cannot getindex from empty buffer")
178192
end
179-
# error("""Scalar indexing is disallowed.""")
193+
180194
XLA.await(a.data)
181195
if XLA.BufferOnCPU(a.data.buffer)
182196
buf = a.data.buffer
@@ -193,16 +207,8 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
193207
return unsafe_load(ptr, start)
194208
end
195209
end
196-
if !getindex_warned[]
197-
@warn(
198-
"""Performing scalar get-indexing on task $(current_task()).
199-
Invocation resulted in scalar indexing of a ConcreteRArray.
200-
This is typically caused by calling an iterating implementation of a method.
201-
Such implementations *do not* execute on device, but very slowly on the CPU,
202-
and require expensive copies and synchronization each time and therefore should be avoided."""
203-
)
204-
getindex_warned[] = true
205-
end
210+
211+
GPUArraysCore.assertscalar("getindex(::ConcreteRArray, ::Vararg{Int, N})")
206212
return convert(Array, a)[args...]
207213
end
208214

@@ -211,12 +217,11 @@ function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
211217
return nothing
212218
end
213219

214-
const setindex_warned = Ref(false)
215-
216220
function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N}
217221
if a.data == XLA.AsyncEmptyBuffer
218222
throw("Cannot setindex! to empty buffer")
219223
end
224+
220225
XLA.await(a.data)
221226
if XLA.BufferOnCPU(a.data.buffer)
222227
buf = a.data.buffer
@@ -234,19 +239,8 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
234239
end
235240
return a
236241
end
237-
if !setindex_warned[]
238-
@warn(
239-
"""Performing scalar set-indexing on task $(current_task()).
240-
Invocation resulted in scalar indexing of a ConcreteRArray.
241-
This is typically caused by calling an iterating implementation of a method.
242-
Such implementations *do not* execute on device, but very slowly on the CPU,
243-
and require expensive copies and synchronization each time and therefore should be avoided.
244-
245-
This error message will only be printed for the first invocation for brevity.
246-
"""
247-
)
248-
setindex_warned[] = true
249-
end
242+
243+
GPUArraysCore.assertscalar("setindex!(::ConcreteRArray, ::Any, ::Vararg{Int, N})")
250244
fn = Reactant.compile(mysetindex!, (a, v, args...))
251245
fn(a, v, args...)
252246
return a

src/Reactant.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ using ReactantCore: ReactantCore, @trace, MissingTracedValue
44

55
using LinearAlgebra: LinearAlgebra
66
using Adapt: Adapt, WrappedArray
7+
using GPUArraysCore: GPUArraysCore, @allowscalar, allowscalar # keep this import to allow users to do `Reactant.allowscalar(false)`
8+
9+
export @allowscalar # re-exported from GPUArraysCore
710

811
# auxiliary types and functions
912
include("OrderedIdDict.jl")
@@ -114,8 +117,7 @@ function set_default_backend(backend::XLA.Client)
114117
end
115118

116119
function set_default_backend(backend::String)
117-
backend = XLA.backends[backend]
118-
return XLA.default_backend[] = backend
120+
return set_default_backend(XLA.backends[backend])
119121
end
120122

121123
end # module

src/TracedRArray.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,7 @@ end
5959
function Base.getindex(
6060
a::TracedRArray{T,N}, index::Vararg{Union{Int,TracedRNumber{Int}},N}
6161
) where {T,N}
62-
@warn(
63-
"""Performing scalar indexing on task $(current_task()).
64-
Invocation resulted in scalar indexing of a TracedRArray.
65-
This is typically caused by calling an iterating implementation of a method.
66-
Such implementations *do not* execute on device, but very slowly on the CPU,
67-
and require expensive copies and synchronization each time and therefore should be avoided."""
68-
)
62+
GPUArraysCore.assertscalar("getindex(::TracedRArray, ::Vararg{Int, N})")
6963

7064
start_indices = [promote_to(TracedRNumber{Int}, i - 1).mlir_data for i in index]
7165
slice_sizes = [Int64(1) for _ in index]

src/XLA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ function __init__()
130130
end
131131
end
132132
end
133+
133134
return nothing
134135
end
135136

test/basic.jl

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ using Test
33
using Enzyme
44
using Statistics
55

6-
# Reactant.set_default_backend("gpu")
7-
86
fastmax(x::AbstractArray{T}) where {T} = reduce(max, x; dims=1, init=float(T)(-Inf))
97

108
using InteractiveUtils
@@ -16,7 +14,7 @@ using InteractiveUtils
1614

1715
a = Reactant.ConcreteRArray(x)
1816

19-
c_res = sum(a)
17+
c_res = @allowscalar sum(a)
2018
@test c_res r_res
2119

2220
@test @jit(sum(a)) r_res
@@ -29,7 +27,7 @@ end
2927

3028
a = Reactant.ConcreteRArray(x)
3129

32-
c_res = fastmax(a)
30+
c_res = @allowscalar fastmax(a)
3331
@test c_res r_res
3432

3533
@test @jit(fastmax(a)) r_res
@@ -45,7 +43,7 @@ sinexpbc(x) = sinexp.(x)
4543

4644
a = Reactant.ConcreteRArray(x)
4745

48-
c_res = sinexpbc(a)
46+
c_res = @allowscalar sinexpbc(a)
4947
@test c_res r_res
5048

5149
@test @jit(sinexpbc(a)) r_res
@@ -427,10 +425,10 @@ end
427425
@test y y_ra
428426

429427
x_ra_array = Array(x_ra)
430-
@test all(iszero, x_ra_array[1, :])
431-
@test all(iszero, x_ra_array[2, :])
432-
@test all(isone, x_ra_array[3, :])
433-
@test all(isone, x_ra_array[4, :])
428+
@test @allowscalar all(iszero, x_ra_array[1, :])
429+
@test @allowscalar all(iszero, x_ra_array[2, :])
430+
@test @allowscalar all(isone, x_ra_array[3, :])
431+
@test @allowscalar all(isone, x_ra_array[4, :])
434432
end
435433

436434
tuple_byref(x) = (; a=(; b=x))
@@ -504,14 +502,14 @@ end
504502

505503
f2 = @compile f1(x_ra)
506504
res2 = f2(Reactant.to_rarray((5, [3.14]); track_numbers=(Number,)))
507-
@test only(res2) 5 * 3.14
505+
@test @allowscalar(only(res2)) 5 * 3.14
508506
@test res2 isa ConcreteRArray
509507

510508
x_ra = Reactant.to_rarray(x)
511509

512510
f3 = @compile f1(x_ra)
513511
res3 = f3(Reactant.to_rarray((5, [3.14])))
514-
@test only(res3) only(f1(x))
512+
@test @allowscalar(only(res3)) only(f1(x))
515513
@test res3 isa ConcreteRArray
516514
end
517515
end
@@ -544,18 +542,22 @@ end
544542
x_ra = Reactant.to_rarray(x)
545543

546544
y = @jit(clamp!(x_ra, 0.0, 0.25))
547-
@test maximum(y) 0.25
548-
@test minimum(y) 0.0
549-
@test maximum(x_ra) == maximum(y)
550-
@test minimum(x_ra) == minimum(y)
545+
@allowscalar begin
546+
@test maximum(y) 0.25
547+
@test minimum(y) 0.0
548+
@test maximum(x_ra) == maximum(y)
549+
@test minimum(x_ra) == minimum(y)
550+
end
551551

552552
x = randn(2, 3)
553553
x_ra = Reactant.to_rarray(x)
554554

555555
y = @jit(clamp.(x_ra, 0.0, 0.25))
556-
@test maximum(y) 0.25
557-
@test minimum(y) 0.0
558-
@test x_ra x
556+
@allowscalar begin
557+
@test maximum(y) 0.25
558+
@test minimum(y) 0.0
559+
@test x_ra x
560+
end
559561
end
560562

561563
@testset "dynamic indexing" begin
@@ -565,6 +567,8 @@ end
565567
idx = [1, 2, 3]
566568
idx_ra = Reactant.to_rarray(idx)
567569

568-
y = @jit(getindex(x_ra, idx_ra, :))
570+
fn(x, idx) = @allowscalar x[idx, :]
571+
572+
y = @jit(fn(x_ra, idx_ra))
569573
@test y x[idx, :]
570574
end

test/closure.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ using Reactant
44
muler(x) = y -> x * y
55

66
@testset "closure" begin
7-
x = Reactant.ConcreteRArray(ones(2, 2))
8-
y = Reactant.ConcreteRArray(ones(2, 2))
7+
x = ones(2, 2)
8+
y = ones(2, 2)
9+
x_ra = Reactant.ConcreteRArray(x)
10+
y_ra = Reactant.ConcreteRArray(y)
911

10-
f = muler(x)
11-
@test @jit(f(y)) x * y
12+
f = muler(x_ra)
13+
@test @jit(f(y_ra)) x * y
1214
end

test/compile.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,22 +16,25 @@ Base.sum(x::NamedTuple{(:a,),Tuple{T}}) where {T<:Reactant.TracedRArray} = (; a=
1616
end
1717

1818
@testset "world-age" begin
19-
a = Reactant.ConcreteRArray(ones(2, 10))
20-
b = Reactant.ConcreteRArray(ones(10, 2))
19+
a = ones(2, 10)
20+
b = ones(10, 2)
21+
a_ra = Reactant.ConcreteRArray(a)
22+
b_ra = Reactant.ConcreteRArray(b)
2123

22-
fworld(x, y) = @jit(*(x, y))
24+
fworld(x, y) = @jit(x * y)
2325

24-
@test fworld(a, b) ones(2, 2) * 10
26+
@test fworld(a_ra, b_ra) ones(2, 2) * 10
2527
end
2628

2729
@testset "type casting & optimized out returns" begin
28-
a = Reactant.ConcreteRArray(rand(2, 10))
30+
a = ones(2, 10)
31+
a_ra = Reactant.ConcreteRArray(a)
2932

3033
ftype1(x) = Float64.(x)
3134
ftype2(x) = Float32.(x)
3235

33-
y1 = @jit ftype1(a)
34-
y2 = @jit ftype2(a)
36+
y1 = @jit ftype1(a_ra)
37+
y2 = @jit ftype2(a_ra)
3538

3639
@test y1 isa Reactant.ConcreteRArray{Float64,2}
3740
@test y2 isa Reactant.ConcreteRArray{Float32,2}

test/complex.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ end
9999
end
100100

101101
@testset "complex reduction" begin
102-
x_ra = Reactant.ConcreteRArray(randn(ComplexF32, 10, 10))
103-
@test @jit(sum(abs2, x_ra)) sum(abs2, x_ra)
102+
x = randn(ComplexF32, 10, 10)
103+
x_ra = Reactant.ConcreteRArray(x)
104+
@test @jit(sum(abs2, x_ra)) sum(abs2, x)
104105
end

0 commit comments

Comments
 (0)