@@ -61,10 +61,14 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
61
61
end
62
62
63
63
# copy CuArray versions back before trying to print them:
64
- Base. print_array (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray} ) where {T, L, N, var"N+1" } =
65
- Base. print_array (io, adapt (Array, X))
66
- Base. print_array (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}} ) where {T, L, N, var"N+1" } =
67
- Base. print_array (io, adapt (Array, X))
64
+ for fun in (:show , :print_array ) # print_array is used by 3-arg show
65
+ @eval begin
66
+ Base.$ fun (io:: IO , X:: OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray} ) where {T, L, N, var"N+1" } =
67
+ Base.$ fun (io, adapt (Array, X))
68
+ Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}} ) where {T, L, N} =
69
+ Base.$ fun (io, adapt (Array, X))
70
+ end
71
+ end
68
72
69
73
_onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}} ) where {var"N+1" } = Array{Bool, var"N+1" }
70
74
_onehot_bool_type (:: OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray} ) where {var"N+1" } = AbstractGPUArray{Bool, var"N+1" }
0 commit comments