Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "OneHotArrays"
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
version = "0.2.8"
version = "0.2.9"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
18 changes: 12 additions & 6 deletions src/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
return NNlib.gather(A, _indices(B))
end

function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
B_dim = length(_indices(parent(B)))
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
end

for wrapper in [:Adjoint, :Transpose]
@eval begin
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
Expand All @@ -31,6 +25,18 @@ for wrapper in [:Adjoint, :Transpose]

return A[onecold(b)]
end

# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix})
if size(A,2) ≠ size(B,1)
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2)) ≠ $(size(B,1))"))
end
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
end
fill!(Y, zero(eltype(Y)))
return NNlib.scatter!(+, Y, A, _indices(parent(B)))
end
end
end

Expand Down
17 changes: 14 additions & 3 deletions test/gpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,13 @@ end
y = onehotbatch(ones(3), 1:2) |> cu;
@test (repr("text/plain", y); true)

gA = rand(3, 2) |> cu;

#NOTE: this would require something that can copute gradient... we don't have that here?
#NOTE: this would require something that can compute gradient... we don't have that here?
#@test gradient(A -> sum(A * y), gA)[1] isa CuArray
end

@testset "LinearAlgebra" begin
y = onehotbatch(ones(3), 1:2) |> cu;
gA = rand(3, 2) |> cu;

# some specialized implementations call only mul! and not *, so we must ensure this works
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) ≈ gA*y
Expand All @@ -66,6 +69,14 @@ end
y = reshape(y, 3, 2)
gA = rand(2, 3) |> cu
@test_broken LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) ≈ gA*y

A = cu([1 3 5; 2 4 6; 3 6 9])
b3_dense = cu(Array(OneHotMatrix([1, 1, 2], 4)))
b3 = OneHotMatrix(cu([1, 1, 2]), 4)

d1 = fill(NaN, 3, 4) |> cu
@test mul!(d1, A, b3') == A * b3_dense'
@test mul!(d1, A, transpose(b3)) == A * transpose(b3_dense)
end

@testset "onehotbatch(::CuArray, ::UnitRange)" begin
Expand Down
24 changes: 22 additions & 2 deletions test/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,36 @@ end
c1 = fill(NaN, 3, 4)
@test mul!(c1, A, b1) == A * b1
@test c1 == A * b1

c4 = fill(NaN, 3, 6)
@test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot
@test mul!(c4, A', b4) == A' * b4
c6 = fill(NaN, 3, 4)
@test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot
@test mul!(c6, A', b6) == A' * b6

@test_throws DimensionMismatch mul!(c1, A, b2)
@test_throws DimensionMismatch mul!(c1, A, b4)
@test_throws DimensionMismatch mul!(c4, A, b1)
@test_throws DimensionMismatch mul!(zeros(10, 3), A, b1)

# note that we have separate implementations for a couple of mul! for the time being

d1 = fill(NaN, 3, 4)
@test mul!(d1, A, b3') == A * Array(b3')
@test mul!(d1, A, transpose(b3)) == A * Array(transpose(b3))

d2 = fill(NaN, 3, 6)
@test mul!(d2, A, b5') == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
@test mul!(d2, A, transpose(b5)) == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))

d3 = fill(NaN, 3, 6)
@test mul!(d3, A, b7') == A[:,[1, 2, 3, 1, 2, 3]]
@test mul!(d3, A, transpose(b7)) == A[:,[1, 2, 3, 1, 2, 3]]

d4 = fill(NaN, 4, 4)
@test_throws DimensionMismatch mul!(d4, A, b3')
@test_throws DimensionMismatch mul!(d4, A, transpose(b3))
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), b3')
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), transpose(b3))
end
Loading