Skip to content

Commit f0e9a2e

Browse files
committed
I was using the wrong dimension checks for mul!
1 parent 4e2d74a commit f0e9a2e

File tree

3 files changed

+29
-9
lines changed

3 files changed

+29
-9
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "OneHotArrays"
22
uuid = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3-
version = "0.2.8"
3+
version = "0.2.9"
44

55
[deps]
66
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/linalg.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,6 @@ function Base.:(*)(A::AbstractMatrix, B::OneHotLike{<:Any, 1})
1010
return NNlib.gather(A, _indices(B))
1111
end
1212

13-
function Base.:(*)(A::AbstractMatrix, B::Adjoint{Bool, <:OneHotMatrix})
14-
B_dim = length(_indices(parent(B)))
15-
size(A, 2) == B_dim || throw(DimensionMismatch("Matrix column must correspond with OneHot size: $(size(A, 2)) != $B_dim"))
16-
return NNlib.scatter(+, A, _indices(parent(B)), dstsize=(size(A,1), size(B,2)))
17-
end
18-
1913
for wrapper in [:Adjoint, :Transpose]
2014
@eval begin
2115
function Base.:*(A::$wrapper{<:Any, <:AbstractMatrix{T}}, b::OneHotVector) where T
@@ -31,6 +25,18 @@ for wrapper in [:Adjoint, :Transpose]
3125

3226
return A[onecold(b)]
3327
end
28+
29+
# note that the fill! is the same thing done by NNlib.scatter so it is not more expensive
30+
function LinearAlgebra.mul!(Y::AbstractMatrix, A::AbstractMatrix, B::$wrapper{Bool,<:OneHotMatrix})
31+
if size(A,2) size(B,1)
32+
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
33+
end
34+
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
35+
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
36+
end
37+
fill!(Y, zero(eltype(Y)))
38+
return NNlib.scatter!(+, Y, A, _indices(parent(B)))
39+
end
3440
end
3541
end
3642

test/linalg.jl

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,30 @@ end
5959
c1 = fill(NaN, 3, 4)
6060
@test mul!(c1, A, b1) == A * b1
6161
@test c1 == A * b1
62-
62+
6363
c4 = fill(NaN, 3, 6)
6464
@test mul!(c4, A, b4) == A * b4 # b4 is reshaped but still one-hot
6565
@test mul!(c4, A', b4) == A' * b4
6666
c6 = fill(NaN, 3, 4)
6767
@test mul!(c6, A, b6) == A * b6 # b4 is reshaped and not one-hot
6868
@test mul!(c6, A', b6) == A' * b6
69-
69+
7070
@test_throws DimensionMismatch mul!(c1, A, b2)
7171
@test_throws DimensionMismatch mul!(c1, A, b4)
7272
@test_throws DimensionMismatch mul!(c4, A, b1)
7373
@test_throws DimensionMismatch mul!(zeros(10, 3), A, b1)
74+
75+
# note that we have separate implementations for a couple of mul! for the time being
76+
77+
d1 = fill(NaN, 3, 4)
78+
@test mul!(d1, A, b3') == A * Array(b3')
79+
@test mul!(d1, A, transpose(b3)) == A * Array(transpose(b3))
80+
81+
d2 = fill(NaN, 3, 6)
82+
@test mul!(d2, A, b5') == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
83+
@test mul!(d2, A, transpose(b5)) == hcat(A[:,[1, 2, 3, 3]], A[:,1]+A[:,2], zeros(Int64, 3))
84+
85+
d3 = fill(NaN, 3, 6)
86+
@test mul!(d3, A, b7') == A[:,[1, 2, 3, 1, 2, 3]]
87+
@test mul!(d3, A, transpose(b7)) == A[:,[1, 2, 3, 1, 2, 3]]
7488
end

0 commit comments

Comments
 (0)