Skip to content

Commit 85429bd

Browse files
jishnubLilithHafner
authored andcommitted
Use mul! in Diagonal*Matrix (JuliaLang#42321)
1 parent 565f6d0 commit 85429bd

File tree

5 files changed

+171
-61
lines changed

5 files changed

+171
-61
lines changed

stdlib/LinearAlgebra/src/LinearAlgebra.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@ import Base: USE_BLAS64, abs, acos, acosh, acot, acoth, acsc, acsch, adjoint, as
1212
asin, asinh, atan, atanh, axes, big, broadcast, ceil, cis, conj, convert, copy, copyto!, cos,
1313
cosh, cot, coth, csc, csch, eltype, exp, fill!, floor, getindex, hcat,
1414
getproperty, imag, inv, isapprox, isequal, isone, iszero, IndexStyle, kron, kron!, length, log, map, ndims,
15-
oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
15+
one, oneunit, parent, power_by_squaring, print_matrix, promote_rule, real, round, sec, sech,
1616
setindex!, show, similar, sin, sincos, sinh, size, sqrt,
17-
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec
17+
strides, stride, tan, tanh, transpose, trunc, typed_hcat, vec, zero
1818
using Base: IndexLinear, promote_eltype, promote_op, promote_typeof,
1919
@propagate_inbounds, @pure, reduce, typed_hvcat, typed_vcat, require_one_based_indexing,
2020
splat

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 85 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
198198
Diagonal(Base.literal_pow.(^, D.diag, valp)) # for speed
199199
Base.literal_pow(::typeof(^), D::Diagonal, ::Val{-1}) = inv(D) # for disambiguation
200200

201+
function _muldiag_size_check(A, B)
202+
nA = size(A, 2)
203+
mB = size(B, 1)
204+
@noinline throw_dimerr(::AbstractMatrix, nA, mB) = throw(DimensionMismatch("second dimension of A, $nA, does not match first dimension of B, $mB"))
205+
@noinline throw_dimerr(::AbstractVector, nA, mB) = throw(DimensionMismatch("second dimension of D, $nA, does not match length of V, $mB"))
206+
nA == mB || throw_dimerr(B, nA, mB)
207+
return nothing
208+
end
209+
# the output matrix should have the same size as the non-diagonal input matrix or vector
210+
@noinline throw_dimerr(szC, szA) = throw(DimensionMismatch("output matrix has size: $szC, but should have size $szA"))
211+
_size_check_out(C, ::Diagonal, A) = _size_check_out(C, A)
212+
_size_check_out(C, A, ::Diagonal) = _size_check_out(C, A)
213+
_size_check_out(C, A::Diagonal, ::Diagonal) = _size_check_out(C, A)
214+
function _size_check_out(C, A)
215+
szA = size(A)
216+
szC = size(C)
217+
szA == szC || throw_dimerr(szC, szA)
218+
return nothing
219+
end
220+
function _muldiag_size_check(C, A, B)
221+
_muldiag_size_check(A, B)
222+
_size_check_out(C, A, B)
223+
end
224+
201225
function (*)(Da::Diagonal, Db::Diagonal)
202-
nDa, mDb = size(Da, 2), size(Db, 1)
203-
if nDa != mDb
204-
throw(DimensionMismatch("second dimension of Da, $nDa, does not match first dimension of Db, $mDb"))
205-
end
226+
_muldiag_size_check(Da, Db)
206227
return Diagonal(Da.diag .* Db.diag)
207228
end
208229

209230
function (*)(D::Diagonal, V::AbstractVector)
210-
nD = size(D, 2)
211-
if nD != length(V)
212-
throw(DimensionMismatch("second dimension of D, $nD, does not match length of V, $(length(V))"))
213-
end
231+
_muldiag_size_check(D, V)
214232
return D.diag .* V
215233
end
216234

@@ -220,29 +238,12 @@ end
220238
lmul!(D, copy_oftype(B, promote_op(*, eltype(B), eltype(D.diag))))
221239

222240
(*)(A::AbstractMatrix, D::Diagonal) =
223-
rmul!(copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))), D)
241+
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), A, D)
224242
(*)(D::Diagonal, A::AbstractMatrix) =
225-
lmul!(D, copy_similar(A, promote_op(*, eltype(A), eltype(D.diag))))
243+
mul!(similar(A, promote_op(*, eltype(A), eltype(D.diag)), size(A)), D, A)
226244

227-
function rmul!(A::AbstractMatrix, D::Diagonal)
228-
require_one_based_indexing(A)
229-
nA, nD = size(A, 2), length(D.diag)
230-
if nA != nD
231-
throw(DimensionMismatch("second dimension of A, $nA, does not match the first of D, $nD"))
232-
end
233-
A .= A .* permutedims(D.diag)
234-
return A
235-
end
236-
237-
function lmul!(D::Diagonal, B::AbstractVecOrMat)
238-
require_one_based_indexing(B)
239-
nB, nD = size(B, 1), length(D.diag)
240-
if nB != nD
241-
throw(DimensionMismatch("second dimension of D, $nD, does not match the first of B, $nB"))
242-
end
243-
B .= D.diag .* B
244-
return B
245-
end
245+
rmul!(A::AbstractMatrix, D::Diagonal) = mul!(A, A, D)
246+
lmul!(D::Diagonal, B::AbstractVecOrMat) = mul!(B, D, B)
246247

247248
rmul!(A::Union{LowerTriangular,UpperTriangular}, D::Diagonal) = typeof(A)(rmul!(A.data, D))
248249
function rmul!(A::UnitLowerTriangular, D::Diagonal)
@@ -306,37 +307,66 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
306307
lmul!(D, At)
307308
end
308309

309-
rmul!(A::Diagonal, B::Diagonal) = Diagonal(A.diag .*= B.diag)
310-
lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)
310+
@inline function __muldiag!(out, D::Diagonal, B, alpha, beta)
311+
if iszero(beta)
312+
out .= (D.diag .* B) .*ₛ alpha
313+
else
314+
out .= (D.diag .* B) .*ₛ alpha .+ out .* beta
315+
end
316+
return out
317+
end
318+
319+
@inline function __muldiag!(out, A, D::Diagonal, alpha, beta)
320+
if iszero(beta)
321+
out .= (A .* permutedims(D.diag)) .*ₛ alpha
322+
else
323+
out .= (A .* permutedims(D.diag)) .*ₛ alpha .+ out .* beta
324+
end
325+
return out
326+
end
327+
328+
@inline function __muldiag!(out::Diagonal, D1::Diagonal, D2::Diagonal, alpha, beta)
329+
if iszero(beta)
330+
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha
331+
else
332+
out.diag .= (D1.diag .* D2.diag) .*ₛ alpha .+ out.diag .* beta
333+
end
334+
return out
335+
end
336+
337+
# only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
338+
@inline __muldiag!(out, D1::Diagonal, D2::Diagonal, alpha, beta) =
339+
mul!(out, D1, D2, alpha, beta)
340+
341+
@inline function _muldiag!(out, A, B, alpha, beta)
342+
_muldiag_size_check(out, A, B)
343+
__muldiag!(out, A, B, alpha, beta)
344+
return out
345+
end
311346

312347
# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
313-
@inline mul!(out::AbstractVector, A::Diagonal, in::AbstractVector, alpha::Number, beta::Number) =
314-
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
315-
@inline mul!(out::AbstractMatrix, A::Diagonal, in::AbstractMatrix, alpha::Number, beta::Number) =
316-
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
317-
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Adjoint{<:Any,<:AbstractVecOrMat},
318-
alpha::Number, beta::Number) =
319-
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
320-
@inline mul!(out::AbstractMatrix, A::Diagonal, in::Transpose{<:Any,<:AbstractVecOrMat},
321-
alpha::Number, beta::Number) =
322-
out .= (A.diag .* in) .*ₛ alpha .+ out .*ₛ beta
323-
324-
@inline mul!(out::AbstractMatrix, in::AbstractMatrix, A::Diagonal, alpha::Number, beta::Number) =
325-
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
326-
@inline mul!(out::AbstractMatrix, in::Adjoint{<:Any,<:AbstractVecOrMat}, A::Diagonal,
327-
alpha::Number, beta::Number) =
328-
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
329-
@inline mul!(out::AbstractMatrix, in::Transpose{<:Any,<:AbstractVecOrMat}, A::Diagonal,
330-
alpha::Number, beta::Number) =
331-
out .= (in .* permutedims(A.diag)) .*ₛ alpha .+ out .*ₛ beta
348+
@inline mul!(out::AbstractVector, D::Diagonal, V::AbstractVector, alpha::Number, beta::Number) =
349+
_muldiag!(out, D, V, alpha, beta)
350+
@inline mul!(out::AbstractMatrix, D::Diagonal, B::AbstractMatrix, alpha::Number, beta::Number) =
351+
_muldiag!(out, D, B, alpha, beta)
352+
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Adjoint{<:Any,<:AbstractVecOrMat},
353+
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)
354+
@inline mul!(out::AbstractMatrix, D::Diagonal, B::Transpose{<:Any,<:AbstractVecOrMat},
355+
alpha::Number, beta::Number) = _muldiag!(out, D, B, alpha, beta)
356+
357+
@inline mul!(out::AbstractMatrix, A::AbstractMatrix, D::Diagonal, alpha::Number, beta::Number) =
358+
_muldiag!(out, A, D, alpha, beta)
359+
@inline mul!(out::AbstractMatrix, A::Adjoint{<:Any,<:AbstractVecOrMat}, D::Diagonal,
360+
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
361+
@inline mul!(out::AbstractMatrix, A::Transpose{<:Any,<:AbstractVecOrMat}, D::Diagonal,
362+
alpha::Number, beta::Number) = _muldiag!(out, A, D, alpha, beta)
363+
@inline mul!(C::Diagonal, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number) =
364+
_muldiag!(C, Da, Db, alpha, beta)
332365

333366
function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta::Number)
334-
mA = size(Da, 1)
335-
mB = size(Db, 1)
336-
mA == mB || throw(DimensionMismatch("A has dimensions ($mA,$mA) but B has dimensions ($mB,$mB)"))
337-
mC, nC = size(C)
338-
mC == nC == mA || throw(DimensionMismatch("output matrix has size: ($mC,$nC), but should have size ($mA,$mA)"))
367+
_muldiag_size_check(C, Da, Db)
339368
require_one_based_indexing(C)
369+
mA = size(Da, 1)
340370
da = Da.diag
341371
db = Db.diag
342372
_rmul_or_fill!(C, beta)

stdlib/LinearAlgebra/src/generic.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
# inside this function.
99
function *end
1010
Broadcast.broadcasted(::typeof(*ₛ), out, beta) =
11-
iszero(beta::Number) ? false : broadcasted(*, out, beta)
11+
iszero(beta::Number) ? false :
12+
isone(beta::Number) ? broadcasted(identity, out) : broadcasted(*, out, beta)
1213

1314
"""
1415
MulAddMul(alpha, beta)

stdlib/LinearAlgebra/src/special.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,16 @@ function fill!(A::Union{Diagonal,Bidiagonal,Tridiagonal,SymTridiagonal}, x)
308308
not be filled with $x, since some of its entries are constrained."))
309309
end
310310

311-
one(A::Diagonal{T}) where T = Diagonal(fill!(similar(A.diag, typeof(one(T))), one(T)))
311+
one(D::Diagonal) = Diagonal(one.(D.diag))
312312
one(A::Bidiagonal{T}) where T = Bidiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))), A.uplo)
313313
one(A::Tridiagonal{T}) where T = Tridiagonal(fill!(similar(A.du, typeof(one(T))), zero(one(T))), fill!(similar(A.d, typeof(one(T))), one(T)), fill!(similar(A.dl, typeof(one(T))), zero(one(T))))
314314
one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(one(T))), one(T)), fill!(similar(A.ev, typeof(one(T))), zero(one(T))))
315315
# equals and approx equals methods for structured matrices
316316
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl
317317

318+
zero(D::Diagonal) = Diagonal(zero.(D.diag))
319+
oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag))
320+
318321
# SymTridiagonal and Bidiagonal have the same field names
319322
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
320323
==(B::Bidiagonal, A::Diagonal) = A == B

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 78 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,41 @@ let D1 = Diagonal(rand(5)), D2 = Diagonal(rand(5))
578578
@test LinearAlgebra.lmul!(adjoint(D1),copy(D2)) == adjoint(D1)*D2
579579
end
580580

581+
@testset "multiplication of a Diagonal with a Matrix" begin
582+
A = collect(reshape(1:8, 4, 2));
583+
B = BigFloat.(A);
584+
DL = Diagonal(collect(axes(A, 1)));
585+
DR = Diagonal(Float16.(collect(axes(A, 2))));
586+
587+
@test DL * A == collect(DL) * A
588+
@test A * DR == A * collect(DR)
589+
@test DL * B == collect(DL) * B
590+
@test B * DR == B * collect(DR)
591+
592+
A = reshape([ones(2,2), ones(2,2)*2, ones(2,2)*3, ones(2,2)*4], 2, 2)
593+
Ac = collect(A)
594+
D = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
595+
Dc = collect(D)
596+
@test A * D == Ac * Dc
597+
@test D * A == Dc * Ac
598+
@test D * D == Dc * Dc
599+
600+
AS = similar(A)
601+
mul!(AS, A, D, true, false)
602+
@test AS == A * D
603+
604+
D2 = similar(D)
605+
mul!(D2, D, D)
606+
@test D2 == D * D
607+
608+
D2[diagind(D2)] .= D[diagind(D)]
609+
lmul!(D, D2)
610+
@test D2 == D * D
611+
D2[diagind(D2)] .= D[diagind(D)]
612+
rmul!(D2, D)
613+
@test D2 == D * D
614+
end
615+
581616
@testset "multiplication of QR Q-factor and Diagonal (#16615 spot test)" begin
582617
D = Diagonal(randn(5))
583618
Q = qr(randn(5, 5)).Q
@@ -686,12 +721,35 @@ end
686721
xt = transpose(x)
687722
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
688723
D = Diagonal(A)
689-
@test x'*D == x'*A == copy(x')*D == copy(x')*A
690-
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
724+
@test x'*D == x'*A == collect(x')*D == collect(x')*A
725+
@test xt*D == xt*A == collect(xt)*D == collect(xt)*A
726+
outadjxD = similar(x'*D); outtrxD = similar(xt*D);
727+
mul!(outadjxD, x', D)
728+
@test outadjxD == x'*D
729+
mul!(outtrxD, xt, D)
730+
@test outtrxD == xt*D
731+
732+
D1 = Diagonal([[1 2; 3 4]])
733+
@test D1 * x' == D1 * collect(x') == collect(D1) * collect(x')
734+
@test D1 * xt == D1 * collect(xt) == collect(D1) * collect(xt)
735+
outD1adjx = similar(D1 * x'); outD1trx = similar(D1 * xt);
736+
mul!(outadjxD, D1, x')
737+
@test outadjxD == D1*x'
738+
mul!(outtrxD, D1, xt)
739+
@test outtrxD == D1*xt
740+
691741
y = [x, x]
692742
yt = transpose(y)
693743
@test y'*D*y == (y'*D)*y == (y'*A)*y
694744
@test yt*D*y == (yt*D)*y == (yt*A)*y
745+
outadjyD = similar(y'*D); outtryD = similar(yt*D);
746+
outadjyD2 = similar(collect(y'*D)); outtryD2 = similar(collect(yt*D));
747+
mul!(outadjyD, y', D)
748+
mul!(outadjyD2, y', D)
749+
@test outadjyD == outadjyD2 == y'*D
750+
mul!(outtryD, yt, D)
751+
mul!(outtryD2, yt, D)
752+
@test outtryD == outtryD2 == yt*D
695753
end
696754

697755
@testset "Multiplication of single element Diagonal (#36746, #40726)" begin
@@ -826,4 +884,22 @@ end
826884
@test \(x, B) == /(B, x)
827885
end
828886

887+
@testset "zero and one" begin
888+
D1 = Diagonal(rand(3))
889+
@test D1 + zero(D1) == D1
890+
@test D1 * one(D1) == D1
891+
@test D1 * oneunit(D1) == D1
892+
@test oneunit(D1) isa typeof(D1)
893+
D2 = Diagonal([collect(reshape(1:4, 2, 2)), collect(reshape(5:8, 2, 2))])
894+
@test D2 + zero(D2) == D2
895+
@test D2 * one(D2) == D2
896+
@test D2 * oneunit(D2) == D2
897+
@test oneunit(D2) isa typeof(D2)
898+
D3 = Diagonal([D2, D2]);
899+
@test D3 + zero(D3) == D3
900+
@test D3 * one(D3) == D3
901+
@test D3 * oneunit(D3) == D3
902+
@test oneunit(D3) isa typeof(D3)
903+
end
904+
829905
end # module TestDiagonal

0 commit comments

Comments
 (0)