Skip to content

Commit 8b85adb

Browse files
authored
Merge pull request #3 from TLipede/add-cuda-zygote
2 parents 04a70f8 + c69d4fc commit 8b85adb

File tree

2 files changed

+21
-1
lines changed

2 files changed

+21
-1
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@ version = "0.1.0"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
7+
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
8+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
79
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
810
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
911
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1012

1113
[compat]
1214
Adapt = "3.0"
15+
ChainRulesCore = "1.13"
16+
CUDA = "3.8"
1317
MLUtils = "0.2"
1418
NNlib = "0.8"
1519

src/OneHotArrays.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module OneHotArrays
22

33
using Adapt
4+
using ChainRulesCore
5+
using CUDA
46
using LinearAlgebra
57
using MLUtils
68
using NNlib
@@ -21,7 +23,6 @@ OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1,
2123
OneHotArray(indices::T, L::Integer) where {T<:Integer} = OneHotArray{T, L, 0, 1, T}(indices)
2224
OneHotArray(indices::I, L::Integer) where {T, N, I<:AbstractArray{T, N}} = OneHotArray{T, L, N, N+1, I}(indices)
2325

24-
2526
_indices(x::OneHotArray) = x.indices
2627
_indices(x::Base.ReshapedArray{<: Any, <: Any, <: OneHotArray}) =
2728
reshape(parent(x).indices, x.dims[2:end])
@@ -68,7 +69,14 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
6869
x[i,j] ? s : _isonehot(x) ? Base.replace_with_centered_mark(s) : s
6970
end
7071

72+
# copy CuArray versions back before trying to print them:
73+
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:CuArray}) where {T, L, N, var"N+1"} =
74+
Base.print_array(io, adapt(Array, X))
75+
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}}) where {T, L, N, var"N+1"} =
76+
Base.print_array(io, adapt(Array, X))
77+
7178
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}}) where N = Array{Bool, N}
79+
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray}) where N = CuArray{Bool, N}
7280

7381
function Base.cat(x::OneHotLike{<:Any, L}, xs::OneHotLike{<:Any, L}...; dims::Int) where L
7482
if isone(dims) || any(x -> !_isonehot(x), (x, xs...))
@@ -91,6 +99,8 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri
9199

92100
Adapt.adapt_structure(T, x::OneHotArray{<:Any, L}) where L = OneHotArray(adapt(T, _indices(x)), L)
93101

102+
Base.BroadcastStyle(::Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}}) where N = CUDA.CuArrayStyle{N}()
103+
94104
Base.map(f, x::OneHotLike) = Base.broadcast(f, x)
95105

96106
Base.argmax(x::OneHotLike; dims = Colon()) =
@@ -228,6 +238,12 @@ function _fast_argmax(x::OneHotLike)
228238
end
229239
end
230240

241+
ChainRulesCore.@non_differentiable onehot(::Any...)
242+
ChainRulesCore.@non_differentiable onehotbatch(::Any...)
243+
ChainRulesCore.@non_differentiable onecold(::Any...)
244+
245+
ChainRulesCore.@non_differentiable (::Type{<:OneHotArray})(indices::Any, L::Integer)
246+
231247
function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, L}) where L
232248
_isonehot(B) || return invoke(*, Tuple{AbstractMatrix, AbstractMatrix}, A, B)
233249
size(A, 2) == L || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $L"))

0 commit comments

Comments
 (0)