Skip to content

Commit 9d6cc5d

Browse files
authored
fix: add GPU backend as default if available (#272)
* fix: add GPU backend as default if available * test: fix gpu tests
1 parent 9e8eec0 commit 9d6cc5d

File tree

3 files changed

+25
-16
lines changed

3 files changed

+25
-16
lines changed

.buildkite/pipeline.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ steps:
77
group:
88
- core
99
- neural_networks
10+
- integration
1011
plugins:
1112
- JuliaCI/julia#v1:
1213
version: "{{matrix.version}}"

src/XLA.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ function __init__()
124124
try
125125
gpu = GPUClient()
126126
backends["gpu"] = gpu
127+
default_backend[] = gpu
127128
catch e
128129
println(stdout, e)
129130
end

test/linear_algebra.jl

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,27 +46,34 @@ function mul_with_view3(A, x)
4646
end
4747

4848
@testset begin
49-
A = Reactant.to_rarray(rand(4, 4))
50-
x = Reactant.to_rarray(rand(4, 2))
51-
b = Reactant.to_rarray(rand(4))
49+
A = rand(4, 4)
50+
x = rand(4, 2)
51+
b = rand(4)
5252

53-
@test @jit(muladd2(A, x, b)) muladd2(A, x, b)
54-
@test @jit(muladd_5arg(A, x, b)) muladd2(A, x, b)
55-
@test @jit(muladd_5arg2(A, x, b)) 2 .* A * x .+ b
53+
A_ra = Reactant.to_rarray(A)
54+
x_ra = Reactant.to_rarray(x)
55+
b_ra = Reactant.to_rarray(b)
5656

57-
@test @jit(mul_with_view1(A, x)) mul_with_view1(A, x)
57+
@test @jit(muladd2(A_ra, x_ra, b_ra)) muladd2(A, x, b)
58+
@test @jit(muladd_5arg(A_ra, x_ra, b_ra)) muladd2(A, x, b)
59+
@test @jit(muladd_5arg2(A_ra, x_ra, b_ra)) 2 .* A * x .+ b
5860

59-
x2 = Reactant.to_rarray(rand(4))
60-
@test @jit(mul_with_view2(A, x2)) mul_with_view2(A, x2)
61-
@test @jit(mul_with_view3(A, x2)) mul_with_view3(A, x2)
61+
@test @jit(mul_with_view1(A_ra, x_ra)) mul_with_view1(A, x)
62+
63+
x2 = rand(4)
64+
x2_ra = Reactant.to_rarray(x2)
65+
66+
@test @jit(mul_with_view2(A_ra, x2_ra)) mul_with_view2(A, x2)
67+
@test @jit(mul_with_view3(A_ra, x2_ra)) mul_with_view3(A, x2)
6268

6369
# Mixed Precision
64-
x = Reactant.to_rarray(rand(Float32, 4, 2))
70+
x = rand(Float32, 4, 2)
71+
x_ra = Reactant.to_rarray(x)
6572

66-
@test @jit(muladd2(A, x, b)) muladd2(A, x, b)
67-
@test @jit(muladd_5arg(A, x, b)) muladd2(A, x, b)
73+
@test @jit(muladd2(A_ra, x_ra, b_ra)) muladd2(A, x, b)
74+
@test @jit(muladd_5arg(A_ra, x_ra, b_ra)) muladd2(A, x, b)
6875

69-
C = similar(A, Float32, size(A, 1), size(x, 2))
70-
@jit(mul!(C, A, x))
71-
@test C A * x
76+
C_ra = similar(A_ra, Float32, size(A, 1), size(x, 2))
77+
@jit(mul!(C_ra, A_ra, x_ra))
78+
@test C_ra A * x
7279
end

0 commit comments

Comments
 (0)