1
1
module OneHotArrays
2
2
3
3
using Adapt
4
+ using ChainRulesCore
5
+ using CUDA
4
6
using LinearAlgebra
5
7
using MLUtils
6
8
using NNlib
@@ -21,7 +23,6 @@ OneHotArray{T, L, N, I}(indices) where {T, L, N, I} = OneHotArray{T, L, N, N+1,
21
23
OneHotArray (indices:: T , L:: Integer ) where {T<: Integer } = OneHotArray {T, L, 0, 1, T} (indices)
22
24
OneHotArray (indices:: I , L:: Integer ) where {T, N, I<: AbstractArray{T, N} } = OneHotArray {T, L, N, N+1, I} (indices)
23
25
24
-
25
26
_indices (x:: OneHotArray ) = x. indices
26
27
_indices (x:: Base.ReshapedArray{<: Any, <: Any, <: OneHotArray} ) =
27
28
reshape (parent (x). indices, x. dims[2 : end ])
@@ -68,7 +69,14 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
68
69
x[i,j] ? s : _isonehot (x) ? Base. replace_with_centered_mark (s) : s
69
70
end
70
71
72
+ # copy CuArray versions back before trying to print them:
73
+ Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:CuArray} ) where {T, L, N, var"N+1" } =
74
+ Base. print_array (io, adapt (Array, X))
75
+ Base. print_array (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}} ) where {T, L, N, var"N+1" } =
76
+ Base. print_array (io, adapt (Array, X))
77
+
71
78
_onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N, <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N}
79
+ _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N, <:CuArray} ) where N = CuArray{Bool, N}
72
80
73
81
function Base. cat (x:: OneHotLike{<:Any, L} , xs:: OneHotLike{<:Any, L} ...; dims:: Int ) where L
74
82
if isone (dims) || any (x -> ! _isonehot (x), (x, xs... ))
@@ -91,6 +99,8 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri
91
99
92
100
Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x)), L)
93
101
102
+ Base. BroadcastStyle (:: Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
103
+
94
104
Base. map (f, x:: OneHotLike ) = Base. broadcast (f, x)
95
105
96
106
Base. argmax (x:: OneHotLike ; dims = Colon ()) =
@@ -228,6 +238,12 @@ function _fast_argmax(x::OneHotLike)
228
238
end
229
239
end
230
240
241
+ ChainRulesCore. @non_differentiable onehot (:: Any... )
242
+ ChainRulesCore. @non_differentiable onehotbatch (:: Any... )
243
+ ChainRulesCore. @non_differentiable onecold (:: Any... )
244
+
245
+ ChainRulesCore. @non_differentiable (:: Type{<:OneHotArray} )(indices:: Any , L:: Integer )
246
+
231
247
function Base.:(* )(A:: AbstractMatrix , B:: OneHotLike{<:Any, L} ) where L
232
248
_isonehot (B) || return invoke (* , Tuple{AbstractMatrix, AbstractMatrix}, A, B)
233
249
size (A, 2 ) == L || throw (DimensionMismatch (" Matrix column must correspond with OneHot size: $(size (A, 2 )) != $L " ))
0 commit comments