@@ -86,6 +86,15 @@ function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,<:Any,N,<:
86
86
end
87
87
function Base. copyto! (dst:: Array{T,N} , src:: OneHotArray{<:Any,<:Any,N,<:AnyGPUArray} ) where {T,N}
88
88
copyto! (dst, adapt (Array, src))
89
+
90
+ @inline function Base. setindex! (x:: OneHotArray{<:Any, N} , v, i:: Integer , I:: Vararg{Integer, N} ) where N
91
+ @boundscheck checkbounds (x, i, I... )
92
+ if Bool (v)
93
+ @inbounds x. indices[I... ] = i
94
+ elseif x. indices[I... ] == i
95
+ @inbounds x. indices[I... ] = 0
96
+ end
97
+ x
89
98
end
90
99
91
100
function Base. showarg (io:: IO , x:: OneHotArray , toplevel)
104
113
# copy CuArray versions back before trying to print them:
105
114
for fun in (:show , :print_array ) # print_array is used by 3-arg show
106
115
@eval begin
107
- Base.$ fun (io:: IO , X:: OneHotLike{T, N, var"N+1", <:AbstractGPUArray} ) where {T, N, var"N+1" } =
116
+ Base.$ fun (io:: IO , X:: OneHotLike{T, N, var"N+1", <:AbstractGPUArray} ) where {T, N, var"N+1" } =
108
117
Base.$ fun (io, adapt (Array, X))
109
- Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}} ) where {T, N} =
118
+ Base.$ fun (io:: IO , X:: LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}} ) where {T, N} =
110
119
Base.$ fun (io, adapt (Array, X))
111
120
end
112
121
end
0 commit comments