diff --git a/Project.toml b/Project.toml index 90e50c4..6bf6f1e 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OneHotArrays" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.8" +version = "0.2.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -17,7 +17,7 @@ ChainRulesCore = "1.13" Compat = "4.2" GPUArraysCore = "0.1, 0.2" NNlib = "0.8, 0.9" -Zygote = "0.6.35" +Zygote = "0.6.77, 0.7.7" julia = "1.6" [extras] diff --git a/src/array.jl b/src/array.jl index 45c91e9..2365c04 100644 --- a/src/array.jl +++ b/src/array.jl @@ -88,6 +88,20 @@ function Base.copyto!(dst::Array{T,N}, src::OneHotArray{<:Any,<:Any,N,<:AnyGPUAr copyto!(dst, adapt(Array, src)) end +@inline function Base.setindex!(x::OneHotArray{<:Any, N}, v, i::Integer, I::Vararg{Integer, N}) where N + @boundscheck checkbounds(x, i, I...) + if Bool(v) + @inbounds x.indices[I...] = i + elseif x.indices[I...] == i + # writing 0, at position of the 1 => move the 1 down if possible + i == x.nlabels && throw(ArgumentError("`setindex!` here would leave the `OneHotArray` without a hot one (in this column)")) + @inbounds x.indices[I...] = i+1 + else + # writing 0, where it's already 0 => do nothing + end + x +end + function Base.showarg(io::IO, x::OneHotArray, toplevel) print(io, ndims(x) == 1 ? "OneHotVector(" : ndims(x) == 2 ? "OneHotMatrix(" : "OneHotArray(") Base.showarg(io, x.indices, false) @@ -104,9 +118,9 @@ end # copy CuArray versions back before trying to print them: for fun in (:show, :print_array) # print_array is used by 3-arg show @eval begin - Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} = + Base.$fun(io::IO, X::OneHotLike{T, N, var"N+1", <:AbstractGPUArray}) where {T, N, var"N+1"} = Base.$fun(io, adapt(Array, X)) - Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} = + Base.$fun(io::IO, X::LinearAlgebra.AdjOrTrans{Bool, <:OneHotLike{T, N, <:Any, <:AbstractGPUArray}}) where {T, N} = Base.$fun(io, adapt(Array, X)) end end diff --git a/test/array.jl b/test/array.jl index 02db456..45424d6 100644 --- a/test/array.jl +++ b/test/array.jl @@ -49,6 +49,31 @@ end @test_throws BoundsError oa[:, :] end +@testset "Writing" begin + x = onehotbatch([1,2,3], 0:4) + x[1] = 1 + @test x[:, 1] == [1, 0, 0, 0, 0] + x[2, 1] = 1.0 + @test x[:, 1] == [0, 1, 0, 0, 0] + x[2] = 0 + @test x[:, 1] == [0, 0, 1, 0, 0] # writing 0 pushes 1 down + + y = onehotbatch([4,0,1], 0:4) + @test copyto!(x, y) == y + @test x[:, 1] == [0, 0, 0, 0, 1] + x .= 1 + @test x[end, :] == [1, 1, 1] + x .= y + @test x[:, 3] == [0, 1, 0, 0, 0] + + @test_throws ArgumentError y[5,1] = 0 # can't push 1 off the end + @test sum(y) == 3 # has not been corrupted before error + + @test_throws BoundsError x[6,1] = 0 + @test_throws BoundsError x[16] = 0 + @test_throws InexactError x[2] = 1.5 +end + @testset "Concatenating" begin # vector cat @test hcat(ov, ov) == OneHotMatrix(vcat(ov.indices, ov.indices), 10)