Skip to content

Commit 430737b

Browse files
committed
some more tests
1 parent 4e2d74a commit 430737b

File tree

4 files changed

+43
-9
lines changed

4 files changed

+43
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.8"
3+
version = "0.2.9"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/linalg.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
1010
return NNlib.gather(A, _indices(B))
1111
end
1212

13-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
14-
B_dim = length(_indices(parent(B)))
15-
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
16-
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
17-
end
18-
1913
for wrapper in [:Adjoint, :Transpose]
2014
@eval begin
2115
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
@@ -31,6 +25,18 @@ for wrapper in [:Adjoint, :Transpose]
3125

3226
return A[onecold(b)]
3327
end
28+
29+
# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
30+
function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix})
31+
if size(A,2) size(B,1)
32+
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
33+
end
34+
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
35+
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
36+
end
37+
fill!(Y, zero(eltype(Y)))
38+
return NNlib.scatter!(+, Y, A, _indices(parent(B)))
39+
end
3440
end
3541
end
3642

test/gpu.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,14 @@ end
4242
@test Array{Float32}(cx) == Array{Float32}(x) == collect(x)
4343
@test convert(AbstractArray{Float32}, cx) isa CuArray{Float32}
4444
@test collect(convert(AbstractArray{Float32}, cx)) == collect(x)
45+
46+
A = cu([1 3 5; 2 4 6; 3 6 9])
47+
b3_dense = cu(Array(OneHotMatrix([1, 1, 2], 4)))
48+
b3 = OneHotMatrix(cu([1, 1, 2]), 4)
49+
50+
d1 = fill(NaN, 3, 4) |> cu
51+
@test mul!(d1, A, b3') == A * b3_dense'
52+
@test mul!(d1, A, transpose(b3)) == A * transpose(b3_dense)
4553
end
4654

4755
@testset "onehot gpu" begin

test/linalg.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,36 @@ end
5959
c1 = fill(NaN, 3, 4)
6060
@test mul!(c1, A, b1) == A * b1
6161
@test c1 == A * b1
62-
62+
6363
c4 = fill(NaN, 3, 6)
6464
@test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot
6565
@test mul!(c4, A', b4) == A' * b4
6666
c6 = fill(NaN, 3, 4)
6767
@test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot
6868
@test mul!(c6, A', b6) == A' * b6
69-
69+
7070
@test_throws DimensionMismatch mul!(c1, A, b2)
7171
@test_throws DimensionMismatch mul!(c1, A, b4)
7272
@test_throws DimensionMismatch mul!(c4, A, b1)
7373
@test_throws DimensionMismatch mul!(zeros(10, 3), A, b1)
74+
75+
# note that we have separate implementations for a couple of mul! for the time being
76+
77+
d1 = fill(NaN, 3, 4)
78+
@test mul!(d1, A, b3') == A * Array(b3')
79+
@test mul!(d1, A, transpose(b3)) == A * Array(transpose(b3))
80+
81+
d2 = fill(NaN, 3, 6)
82+
@test mul!(d2, A, b5') == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
83+
@test mul!(d2, A, transpose(b5)) == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
84+
85+
d3 = fill(NaN, 3, 6)
86+
@test mul!(d3, A, b7') == A[:,[1, 2, 3, 1, 2, 3]]
87+
@test mul!(d3, A, transpose(b7)) == A[:,[1, 2, 3, 1, 2, 3]]
88+
89+
d4 = fill(NaN, 4, 4)
90+
@test_throws DimensionMismatch mul!(d4, A, b3')
91+
@test_throws DimensionMismatch mul!(d4, A, transpose(b3))
92+
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), b3')
93+
@test_throws DimensionMismatch mul!(d1, fill(1, (4,4)), transpose(b3))
7494
end

0 commit comments

Comments
 (0)