Skip to content

Commit 036f700

Browse files
committed
simple setindex
1 parent 4e2d74a commit 036f700

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

src/array.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,15 @@ function Base.copyto!(dst::AbstractArray{T,N}, src::OneHotArray{<:Any,<:Any,N,<:
8686
end
8787
function Base.copyto!(dst::Array{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AnyGPUArray}) where {T,N}
8888
copyto!(dst, adapt(Array, src))
89+
90+
@inline function Base.setindex!(x::OneHotArray{<:Any, N}, v, i::Integer, I::Vararg{Integer, N}) where N
91+
@boundscheck checkbounds(x, i, I...)
92+
if Bool(v)
93+
@inbounds x.indices[I...] = i
94+
elseif x.indices[I...] == i
95+
@inbounds x.indices[I...] = 0
96+
end
97+
x
8998
end
9099

91100
function Base.showarg(io::IO, x::OneHotArray, toplevel)
@@ -104,9 +113,9 @@ end
104113
# copy CuArray versions back before trying to print them:
105114
for fun in (:show, :print_array) # print_array is used by 3-arg show
106115
@eval begin
107-
Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} =
116+
Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} =
108117
Base.$fun(io, adapt(Array, X))
109-
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} =
118+
Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} =
110119
Base.$fun(io, adapt(Array, X))
111120
end
112121
end

0 commit comments

Comments
 (0)