Skip to content

Commit a531d55

Browse files
allow NNlib v0.9 (#38)
* allow NNlib v0.9 * fix test
1 parent 469b192 commit a531d55

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1212

1313
[compat]
1414
Adapt = "3.0"
15-
CUDA = "3.8"
15+
CUDA = "4"
1616
ChainRulesCore = "1.13"
1717
Compat = "4.2"
1818
GPUArraysCore = "0.1.0"
19-
NNlib = "0.8"
19+
NNlib = "0.8, 0.9"
2020
Zygote = "0.6.35"
2121
julia = "1.6"
2222

test/gpu.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,11 @@ end
2323
@test (repr("text/plain", y); true)
2424

2525
gA = rand(3, 2) |> cu;
26-
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
26+
if VERSION >= v"1.9" && CUDA.functional()
27+
@test gradient(A -> sum(A * y), gA)[1] isa CuArray
28+
else
29+
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
30+
end
2731
end
2832

2933
@testset "onehotbatch(::CuArray, ::UnitRange)" begin

0 commit comments

Comments
 (0)