@@ -61,13 +61,17 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
61
61
end
62
62
63
63
# copy CuArray versions back before trying to print them:
64
- Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:CuArray} ) where {T, L, N, var"N+1" } =
65
- Base. print_array (io, adapt (Array, X))
66
- Base. print_array (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:CuArray}} ) where {T, L, N, var"N+1" } =
67
- Base. print_array (io, adapt (Array, X))
64
+ for fun in (:show , :print_array ) # print_array is used by 3-arg show
65
+ @eval begin
66
+ Base.$ fun (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray} ) where {T, L, N, var"N+1" } =
67
+ Base.$ fun (io, adapt (Array, X))
68
+ Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}} ) where {T, L, N} =
69
+ Base.$ fun (io, adapt (Array, X))
70
+ end
71
+ end
68
72
69
- _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N , <:Union{Integer, AbstractArray}} ) where N = Array{Bool, N }
70
- _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, N , <:CuArray } ) where N = CuArray {Bool, N }
73
+ _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, var"N+1" , <:Union{Integer, AbstractArray}} ) where { var"N+1" } = Array{Bool, var"N+1" }
74
+ _onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, var"N+1" , <:AbstractGPUArray } ) where { var"N+1" } = AbstractGPUArray {Bool, var"N+1" }
71
75
72
76
function Base. cat (x:: OneHotLike{<:Any, L} , xs:: OneHotLike{<:Any, L} ...; dims:: Int ) where L
73
77
if isone (dims) || any (x -> ! _isonehot (x), (x, xs... ))
@@ -90,7 +94,13 @@ MLUtils.batch(xs::AbstractArray{<:OneHotVector{<:Any, L}}) where L = OneHotMatri
90
94
91
95
Adapt. adapt_structure (T, x:: OneHotArray{<:Any, L} ) where L = OneHotArray (adapt (T, _indices (x)), L)
92
96
93
- Base. BroadcastStyle (:: Type{<:OneHotArray{<: Any, <: Any, <: Any, N, <: CuArray}} ) where N = CUDA. CuArrayStyle {N} ()
97
+ function Base. BroadcastStyle (:: Type{<:OneHotArray{<: Any, <: Any, <: Any, var"N+1", T}} ) where {var"N+1" , T <: AbstractGPUArray }
98
+ # We want CuArrayStyle{N+1}(). There's an AbstractGPUArrayStyle but it doesn't do what we need.
99
+ S = Base. BroadcastStyle (T)
100
+ # S has dim N not N+1. The following hack to fix it relies on the arraystyle having N as its first type parameter, which
101
+ # isn't guaranteed, but there are not so many GPU broadcasting styles in the wild. (Far fewer than there are array wrappers.)
102
+ (typeof (S). name. wrapper){var"N+1" }()
103
+ end
94
104
95
105
Base. map (f, x:: OneHotLike ) = Base. broadcast (f, x)
96
106
0 commit comments