diff --git a/Project.toml b/Project.toml index 90e50c4..ba2df07 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "OneHotArrays" uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f" -version = "0.2.8" +version = "0.2.9" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/linalg.jl b/src/linalg.jl index 71100b1..03a302d 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -10,12 +10,6 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1}) return NNlib.gather(A, _indices(B)) end -function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix}) - B_dim = length(_indices(parent(B))) - size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim")) - return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2))) -end - for wrapper in [:Adjoint, :Transpose] @eval begin function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T @@ -31,6 +25,18 @@ for wrapper in [:Adjoint, :Transpose] return A[onecold(b)] end + + # note that the fill! is the same thing done by NNlib.scatter so it is not more expensive + function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix}) + if size(A,2) ≠ size(B,1) + throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2)) ≠ $(size(B,1))")) + end + if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2)) + throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))")) + end + fill!(Y, zero(eltype(Y))) + return NNlib.scatter!(+, Y, A, _indices(parent(B))) + end end end diff --git a/test/gpu.jl b/test/gpu.jl index 0a80610..5eb0915 100644 --- a/test/gpu.jl +++ b/test/gpu.jl @@ -48,10 +48,13 @@ end y = onehotbatch(ones(3), 1:2) |> cu; @test (repr("text/plain", y); true) - gA = rand(3, 2) |> cu; - - #NOTE: this would require something that can copute gradient... we don't have that here? + #NOTE: this would require something that can compute gradient... we don't have that here? #@test gradient(A -> sum(A * y), gA)[1] isa CuArray +end + +@testset "LinearAlgebra" begin + y = onehotbatch(ones(3), 1:2) |> cu; + gA = rand(3, 2) |> cu; # some specialized implementations call only mul! and not *, so we must ensure this works @test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) ≈ gA*y @@ -66,6 +69,14 @@ end y = reshape(y, 3, 2) gA = rand(2, 3) |> cu @test_broken LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) ≈ gA*y + + A = cu([1 3 5; 2 4 6; 3 6 9]) + b3_dense = cu(Array(OneHotMatrix([1, 1, 2], 4))) + b3 = OneHotMatrix(cu([1, 1, 2]), 4) + + d1 = fill(NaN, 3, 4) |> cu + @test mul!(d1, A, b3') == A * b3_dense' + @test mul!(d1, A, transpose(b3)) == A * transpose(b3_dense) end @testset "onehotbatch(::CuArray, ::UnitRange)" begin diff --git a/test/linalg.jl b/test/linalg.jl index e34ac51..f020f8b 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -59,16 +59,36 @@ end c1 = fill(NaN, 3, 4) @test mul!(c1, A, b1) == A * b1 @test c1 == A * b1 - + c4 = fill(NaN, 3, 6) @test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot @test mul!(c4, A', b4) == A' * b4 c6 = fill(NaN, 3, 4) @test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot @test mul!(c6, A', b6) == A' * b6 - + @test_throws DimensionMismatch mul!(c1, A, b2) @test_throws DimensionMismatch mul!(c1, A, b4) @test_throws DimensionMismatch mul!(c4, A, b1) @test_throws DimensionMismatch mul!(zeros(10, 3), A, b1) + + # note that we have separate implementations for a couple of mul! for the time being + + d1 = fill(NaN, 3, 4) + @test mul!(d1, A, b3') == A * Array(b3') + @test mul!(d1, A, transpose(b3)) == A * Array(transpose(b3)) + + d2 = fill(NaN, 3, 6) + @test mul!(d2, A, b5') == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3)) + @test mul!(d2, A, transpose(b5)) == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3)) + + d3 = fill(NaN, 3, 6) + @test mul!(d3, A, b7') == A[:,[1, 2, 3, 1, 2, 3]] + @test mul!(d3, A, transpose(b7)) == A[:,[1, 2, 3, 1, 2, 3]] + + d4 = fill(NaN, 4, 4) + @test_throws DimensionMismatch mul!(d4, A, b3') + @test_throws DimensionMismatch mul!(d4, A, transpose(b3)) + @test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), b3') + @test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), transpose(b3)) end