diff --git a/ext/FillArraysPDMatsExt.jl b/ext/FillArraysPDMatsExt.jl index d5a0892d..0f98f775 100644 --- a/ext/FillArraysPDMatsExt.jl +++ b/ext/FillArraysPDMatsExt.jl @@ -3,10 +3,16 @@ module FillArraysPDMatsExt import FillArrays import FillArrays.LinearAlgebra import PDMats +using FillArrays: mult_zeros, AbstractZeros +using PDMats: ScalMat function PDMats.AbstractPDMat(a::LinearAlgebra.Diagonal{T,<:FillArrays.AbstractFill{T,1}}) where {T<:Real} dim = size(a, 1) - return PDMats.ScalMat(dim, FillArrays.getindex_value(a.diag)) + return ScalMat(dim, FillArrays.getindex_value(a.diag)) end +Base.:*(a::ScalMat, b::AbstractZeros{T, 1} where T) = mult_zeros(a, b) +Base.:*(a::ScalMat, b::AbstractZeros{T, 2} where T) = mult_zeros(a, b) +Base.:*(a::AbstractZeros{T, 2} where T, b::ScalMat) = mult_zeros(a, b) # This is implemented in case ScalMat implements right multiplication + end # module diff --git a/test/runtests.jl b/test/runtests.jl index 088089d4..32cd2ed7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2990,6 +2990,11 @@ end @test a.dim == length(diag) @test a.value == first(diag) end + a = ScalMat(4, 1.) + for zero in (Zeros(4), Zeros(4,4)) + @test a * zero === zero + end + @test Zeros(4,4) * a === Zeros(4,4) end @testset "isbanded/isdiag" begin