Skip to content

Commit 621c3aa

Browse files
authored
Merge pull request #9 from mcabbott/gpuarrays
RFC: use GPUArrays, also for testing
2 parents 376f6ed + cea26a4 commit 621c3aa

File tree

6 files changed

+104
-12
lines changed

6 files changed

+104
-12
lines changed

Project.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ version = "0.1.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7-
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
87
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
8+
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
99
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1010
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1111
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
@@ -14,11 +14,17 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1414
Adapt = "3.0"
1515
CUDA = "3.8"
1616
ChainRulesCore = "1.13"
17+
GPUArrays = "8.2.1"
1718
MLUtils = "0.2"
1819
NNlib = "0.8"
20+
Zygote = "0.6.35"
21+
julia = "1.6"
1922

2023
[extras]
2124
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
25+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
26+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
27+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2228

2329
[targets]
24-
test = ["Test"]
30+
test = ["Test", "CUDA", "Random", "Zygote"]

src/OneHotArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ module OneHotArrays
22

33
using Adapt
44
using ChainRulesCore
5-
using CUDA
5+
using GPUArrays
66
using LinearAlgebra
77
using MLUtils
88
using NNlib

src/array.jl

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,17 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
6161
end
6262

6363
# copy CuArray versions back before trying to print them:
64-
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
65-
Base.print_array(io, adapt(Array, X))
66-
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
67-
Base.print_array(io, adapt(Array, X))
64+
for fun in (:show, :print_array) # print_array is used by 3-arg show
65+
@eval begin
66+
Base.$fun(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
67+
Base.$fun(io, adapt(Array, X))
68+
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}}) where {T, L, N} =
69+
Base.$fun(io, adapt(Array, X))
70+
end
71+
end
6872

69-
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
70-
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
73+
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
74+
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}
7175

7276
function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
7377
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
@@ -90,7 +94,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri
9094

9195
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
9296

93-
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
97+
function Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}}) where {var"N+1", T <: AbstractGPUArray}
98+
# We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
99+
S = Base.BroadcastStyle(T)
100+
# S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
101+
# isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
102+
(typeof(S).name.wrapper){var"N+1"}()
103+
end
94104

95105
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
96106

src/onehot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ nonzero elements.
5757
If one of the inputs in `xs` is not found in `labels`, that column is `onehot(default, labels)`
5858
if `default` is given, else an error.
5959
60-
If `xs` has more dimensions, `M = ndims(xs) > 1`, then the result is an
61-
`AbstractArray{Bool, M+1}` which is one-hot along the first dimension,
60+
If `xs` has more dimensions, `N = ndims(xs) > 1`, then the result is an
61+
`AbstractArray{Bool, N+1}` which is one-hot along the first dimension,
6262
i.e. `result[:, k...] == onehot(xs[k...], labels)`.
6363
6464
Note that `xs` can be any iterable, such as a string. And that using a tuple

test/gpu.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
2+
# Tests from Flux, probably not the optimal testset organisation!
3+
4+
@testset "CUDA" begin
5+
x = randn(5, 5)
6+
cx = cu(x)
7+
@test cx isa CuArray
8+
9+
@test_broken onecold(cu([1.0, 2.0, 3.0])) == 3 # scalar indexing error?
10+
11+
x = onehotbatch([1, 2, 3], 1:3)
12+
cx = cu(x)
13+
@test cx isa OneHotMatrix && cx.indices isa CuArray
14+
@test (cx .+ 1) isa CuArray
15+
16+
xs = rand(5, 5)
17+
ys = onehotbatch(1:5,1:5)
18+
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
19+
end
20+
21+
@testset "onehot gpu" begin
22+
y = onehotbatch(ones(3), 1:2) |> cu;
23+
@test (repr("text/plain", y); true)
24+
25+
gA = rand(3, 2) |> cu;
26+
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote?
27+
end
28+
29+
@testset "onecold gpu" begin
30+
y = onehotbatch(ones(3), 1:10) |> cu;
31+
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
32+
@test onecold(y) isa CuArray
33+
@test y[3,:] isa CuArray
34+
@test onecold(y, l) == ['a', 'a', 'a']
35+
end
36+
37+
@testset "onehot forward map to broadcast" begin
38+
oa = OneHotArray(rand(1:10, 5, 5), 10) |> cu
39+
@test all(map(identity, oa) .== oa)
40+
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
41+
end
42+
43+
@testset "show gpu" begin
44+
x = onehotbatch([1, 2, 3], 1:3)
45+
cx = cu(x)
46+
# 3-arg show
47+
@test contains(repr("text/plain", cx), "1 ⋅ ⋅")
48+
@test contains(repr("text/plain", cx), string(typeof(cx.indices)))
49+
# 2-arg show, https://github.com/FluxML/Flux.jl/issues/1905
50+
@test repr(cx) == "Bool[1 0 0; 0 1 0; 0 0 1]"
51+
end

test/runtests.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,28 @@ end
1212
@testset "Linear Algebra" begin
1313
include("linalg.jl")
1414
end
15+
16+
using Zygote
17+
import CUDA
18+
if CUDA.functional()
19+
using CUDA # exports CuArray, etc
20+
@info "starting CUDA tests"
21+
else
22+
@info "CUDA not functional, testing via GPUArrays"
23+
using GPUArrays
24+
GPUArrays.allowscalar(false)
25+
26+
# GPUArrays provides a fake GPU array, for testing
27+
jl_file = normpath(joinpath(pathof(GPUArrays), "..", "..", "test", "jlarray.jl"))
28+
using Random # loaded within jl_file
29+
include(jl_file)
30+
using .JLArrays
31+
cu = jl
32+
CuArray{T,N} = JLArray{T,N}
33+
end
34+
35+
@test cu(rand(3)) .+ 1 isa CuArray
36+
37+
@testset "GPUArrays" begin
38+
include("gpu.jl")
39+
end

0 commit comments

Comments
 (0)