Skip to content

Commit 0ed3903

Browse files
dkarraschKristofferC
authored andcommitted
minor fixes in multiplication with Diagonals (#31443)
* minor fixes in multiplication with Diagonals * correct rmul!(A,D), revert changes in AdjTrans(x)*D * [r/l]mul!: replace conj by adjoint, add transpose * add tests * fix typo * relax some tests, added more tests * simplify tests, strict equality (cherry picked from commit a93185f)
1 parent 57d5e64 commit 0ed3903

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,7 @@ end
168168

169169
function rmul!(A::AbstractMatrix, D::Diagonal)
170170
@assert !has_offset_axes(A)
171-
A .= A .* transpose(D.diag)
171+
A .= A .* permutedims(D.diag)
172172
return A
173173
end
174174

@@ -256,20 +256,20 @@ lmul!(A::Diagonal, B::Diagonal) = Diagonal(B.diag .= A.diag .* B.diag)
256256

257257
function lmul!(adjA::Adjoint{<:Any,<:Diagonal}, B::AbstractMatrix)
258258
A = adjA.parent
259-
return lmul!(conj(A.diag), B)
259+
return lmul!(adjoint(A), B)
260260
end
261261
function lmul!(transA::Transpose{<:Any,<:Diagonal}, B::AbstractMatrix)
262262
A = transA.parent
263-
return lmul!(A.diag, B)
263+
return lmul!(transpose(A), B)
264264
end
265265

266266
function rmul!(A::AbstractMatrix, adjB::Adjoint{<:Any,<:Diagonal})
267267
B = adjB.parent
268-
return rmul!(A, conj(B.diag))
268+
return rmul!(A, adjoint(B))
269269
end
270270
function rmul!(A::AbstractMatrix, transB::Transpose{<:Any,<:Diagonal})
271271
B = transB.parent
272-
return rmul!(A, B.diag)
272+
return rmul!(A, transpose(B))
273273
end
274274

275275
# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
@@ -508,10 +508,9 @@ end
508508
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal) = Adjoint(map((t,s) -> t'*s, D.diag, parent(x)))
509509
*(x::Adjoint{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
510510
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
511-
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map(*, D.diag, parent(x)))
511+
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal) = Transpose(map((t,s) -> transpose(t)*s, D.diag, parent(x)))
512512
*(x::Transpose{<:Any,<:AbstractVector}, D::Diagonal, y::AbstractVector) =
513513
mapreduce(t -> t[1]*t[2]*t[3], +, zip(x, D.diag, y))
514-
# TODO: these methods will yield row matrices, rather than adjoint/transpose vectors
515514

516515
function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
517516
info = 0

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -441,10 +441,20 @@ end
441441
fullBB = copyto!(Matrix{Matrix{T}}(undef, 2, 2), BB)
442442
for (transform1, transform2) in ((identity, identity),
443443
(identity, adjoint ), (adjoint, identity ), (adjoint, adjoint ),
444-
(identity, transpose), (transpose, identity ), (transpose, transpose) )
444+
(identity, transpose), (transpose, identity ), (transpose, transpose),
445+
(identity, Adjoint ), (Adjoint, identity ), (Adjoint, Adjoint ),
446+
(identity, Transpose), (Transpose, identity ), (Transpose, Transpose))
445447
@test *(transform1(D), transform2(B))::typeof(D) *(transform1(Matrix(D)), transform2(Matrix(B))) atol=2 * eps()
446448
@test *(transform1(DD), transform2(BB))::typeof(DD) == *(transform1(fullDD), transform2(fullBB))
447449
end
450+
M = randn(T, 5, 5)
451+
MM = [randn(T, 2, 2) for _ in 1:2, _ in 1:2]
452+
for transform in (identity, adjoint, transpose, Adjoint, Transpose)
453+
@test lmul!(transform(D), copy(M)) == *(transform(Matrix(D)), M)
454+
@test rmul!(copy(M), transform(D)) == *(M, transform(Matrix(D)))
455+
@test lmul!(transform(DD), copy(MM)) == *(transform(fullDD), MM)
456+
@test rmul!(copy(MM), transform(DD)) == *(MM, transform(fullDD))
457+
end
448458
end
449459
end
450460

@@ -454,10 +464,16 @@ end
454464
end
455465

456466
@testset "Multiplication with Adjoint and Transpose vectors (#26863)" begin
457-
x = rand(5)
458-
D = Diagonal(rand(5))
459-
@test x'*D*x == (x'*D)*x == (x'*Array(D))*x
460-
@test Transpose(x)*D*x == (Transpose(x)*D)*x == (Transpose(x)*Array(D))*x
467+
x = collect(1:2)
468+
xt = transpose(x)
469+
A = reshape([[1 2; 3 4], zeros(Int,2,2), zeros(Int, 2, 2), [5 6; 7 8]], 2, 2)
470+
D = Diagonal(A)
471+
@test x'*D == x'*A == copy(x')*D == copy(x')*A
472+
@test xt*D == xt*A == copy(xt)*D == copy(xt)*A
473+
y = [x, x]
474+
yt = transpose(y)
475+
@test y'*D*y == (y'*D)*y == (y'*A)*y
476+
@test yt*D*y == (yt*D)*y == (yt*A)*y
461477
end
462478

463479
@testset "Triangular division by Diagonal #27989" begin

0 commit comments

Comments
 (0)