Skip to content

Commit fe3a212

Browse files
PbelliveKristofferC
authored andcommitted
Fix dispatch of SparseMatrixCSC*Diagonal multiplication (#29045)
* Fix type signature of mul! methods for multiplying SparseMatrixCSCs with Diagonal matrices. Type signature for diagonal matrices was wrong, causing fallback to generic Matmul. * Add SparseMatrixCSC*Diagonal dispatch test * Fix trailing whitespace * Don't copy with deepcopy (cherry picked from commit 8d99356)
1 parent 90c8718 commit fe3a212

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

stdlib/SparseArrays/src/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,7 +968,7 @@ function copyinds!(C::SparseMatrixCSC, A::SparseMatrixCSC)
968968
end
969969

970970
# multiply by diagonal matrix as vector
971-
function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{<:Vector})
971+
function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{T, <:Vector}) where T
972972
m, n = size(A)
973973
b = D.diag
974974
(n==length(b) && size(A)==size(C)) || throw(DimensionMismatch())
@@ -982,7 +982,7 @@ function mul!(C::SparseMatrixCSC, A::SparseMatrixCSC, D::Diagonal{<:Vector})
982982
C
983983
end
984984

985-
function mul!(C::SparseMatrixCSC, D::Diagonal{<:Vector}, A::SparseMatrixCSC)
985+
function mul!(C::SparseMatrixCSC, D::Diagonal{T, <:Vector}, A::SparseMatrixCSC) where T
986986
m, n = size(A)
987987
b = D.diag
988988
(m==length(b) && size(A)==size(C)) || throw(DimensionMismatch())

stdlib/SparseArrays/test/sparse.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using LinearAlgebra
88
using Base.Printf: @printf
99
using Random
1010
using Test: guardseed
11+
using InteractiveUtils: @which
1112

1213
@testset "issparse" begin
1314
@test issparse(sparse(fill(1,5,5)))
@@ -2295,4 +2296,15 @@ end
22952296
@test typeof(a) === typeof(na)
22962297
end
22972298

2299+
#PR #29045
2300+
@testset "Issue #28934" begin
2301+
A = sprand(5,5,0.5)
2302+
D = Diagonal(rand(5))
2303+
C = copy(A)
2304+
m1 = @which mul!(C,A,D)
2305+
m2 = @which mul!(C,D,A)
2306+
@test m1.module == SparseArrays
2307+
@test m2.module == SparseArrays
2308+
end
2309+
22982310
end # module

0 commit comments

Comments
 (0)