Skip to content

Commit cea26a4

Browse files
committed
fix 2-arg show
1 parent 4fb6fe9 commit cea26a4

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

src/array.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,14 @@ function Base.replace_in_print_matrix(x::OneHotLike, i::Integer, j::Integer, s::
6161
end
6262

6363
# copy CuArray versions back before trying to print them:
64-
Base.print_array(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
65-
Base.print_array(io, adapt(Array, X))
66-
Base.print_array(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}}) where {T, L, N, var"N+1"} =
67-
Base.print_array(io, adapt(Array, X))
64+
for fun in (:show, :print_array) # print_array is used by 3-arg show
65+
@eval begin
66+
Base.$fun(io::IO, X::OneHotLike{T, L, N, var"N+1", <:AbstractGPUArray}) where {T, L, N, var"N+1"} =
67+
Base.$fun(io, adapt(Array, X))
68+
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, L, N, <:Any, <:AbstractGPUArray}}) where {T, L, N} =
69+
Base.$fun(io, adapt(Array, X))
70+
end
71+
end
6872

6973
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:Union{Integer, AbstractArray}}) where {var"N+1"} = Array{Bool, var"N+1"}
7074
_onehot_bool_type(::OneHotLike{<:Any, <:Any, <:Any, var"N+1", <:AbstractGPUArray}) where {var"N+1"} = AbstractGPUArray{Bool, var"N+1"}

test/gpu.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,13 @@ end
3939
@test all(map(identity, oa) .== oa)
4040
@test all(map(x -> 2 * x, oa) .== 2 .* oa)
4141
end
42+
43+
@testset "show gpu" begin
44+
x = onehotbatch([1, 2, 3], 1:3)
45+
cx = cu(x)
46+
# 3-arg show
47+
@test contains(repr("text/plain", cx), "1 ⋅ ⋅")
48+
@test contains(repr("text/plain", cx), string(typeof(cx.indices)))
49+
# 2-arg show, https://github.com/FluxML/Flux.jl/issues/1905
50+
@test repr(cx) == "Bool[1 0 0; 0 1 0; 0 0 1]"
51+
end

0 commit comments

Comments
 (0)