Skip to content

Commit 7eb4d33

Browse files
dkarraschantoine-levitt
authored andcommitted
Improve performance of svd and eigen of Diagonals (JuliaLang#43856)
1 parent 4da06f7 commit 7eb4d33

File tree

2 files changed

+52
-10
lines changed

2 files changed

+52
-10
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -675,20 +675,38 @@ function eigen(D::Diagonal; permute::Bool=true, scale::Bool=true, sortby::Union{
675675
if any(!isfinite, D.diag)
676676
throw(ArgumentError("matrix contains Infs or NaNs"))
677677
end
678-
Eigen(sorteig!(eigvals(D), eigvecs(D), sortby)...)
678+
Td = Base.promote_op(/, eltype(D), eltype(D))
679+
λ = eigvals(D)
680+
if !isnothing(sortby)
681+
p = sortperm(λ; alg=QuickSort, by=sortby)
682+
λ = λ[p] # make a copy, otherwise this permutes D.diag
683+
evecs = zeros(Td, size(D))
684+
@inbounds for i in eachindex(p)
685+
evecs[p[i],i] = one(Td)
686+
end
687+
else
688+
evecs = Matrix{Td}(I, size(D))
689+
end
690+
Eigen(λ, evecs)
679691
end
680692

681693
#Singular system
682694
svdvals(D::Diagonal{<:Number}) = sort!(abs.(D.diag), rev = true)
683695
svdvals(D::Diagonal) = [svdvals(v) for v in D.diag]
684-
function svd(D::Diagonal{T}) where T<:Number
685-
S = abs.(D.diag)
686-
piv = sortperm(S, rev = true)
687-
U = Diagonal(D.diag ./ S)
688-
Up = hcat([U[:,i] for i = 1:length(D.diag)][piv]...)
689-
V = Diagonal(fill!(similar(D.diag), one(T)))
690-
Vp = hcat([V[:,i] for i = 1:length(D.diag)][piv]...)
691-
return SVD(Up, S[piv], copy(Vp'))
696+
function svd(D::Diagonal{T}) where {T<:Number}
697+
d = D.diag
698+
s = abs.(d)
699+
piv = sortperm(s, rev = true)
700+
S = s[piv]
701+
Td = typeof(oneunit(T)/oneunit(T))
702+
U = zeros(Td, size(D))
703+
Vt = copy(U)
704+
for i in 1:length(d)
705+
j = piv[i]
706+
U[j,i] = d[j] / S[i]
707+
Vt[i,j] = one(Td)
708+
end
709+
return SVD(U, S, Vt)
692710
end
693711

694712
# disambiguation methods: * and / of Diagonal and Adj/Trans AbsVec

stdlib/LinearAlgebra/test/diagonal.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@ module TestDiagonal
55
using Test, LinearAlgebra, Random
66
using LinearAlgebra: BlasFloat, BlasComplex
77

8+
const BASE_TEST_PATH = joinpath(Sys.BINDIR, "..", "share", "julia", "test")
9+
isdefined(Main, :Furlongs) || @eval Main include(joinpath($(BASE_TEST_PATH), "testhelpers", "Furlongs.jl"))
10+
using .Main.Furlongs
11+
812
n=12 #Size of matrix problem to test
913
Random.seed!(1)
1014

@@ -344,8 +348,12 @@ Random.seed!(1)
344348

345349
@testset "Eigensystem" begin
346350
eigD = eigen(D)
347-
@test Diagonal(eigD.values) D
351+
@test Diagonal(eigD.values) == D
348352
@test eigD.vectors == Matrix(I, size(D))
353+
eigsortD = eigen(D, sortby=LinearAlgebra.eigsortby)
354+
@test eigsortD.values !== D.diag
355+
@test eigsortD.values == sort(D.diag, by=LinearAlgebra.eigsortby)
356+
@test Matrix(eigsortD) == D
349357
end
350358

351359
@testset "ldiv" begin
@@ -411,6 +419,22 @@ Random.seed!(1)
411419
@test svd(D).V == V
412420
end
413421

422+
@testset "svd/eigen with Diagonal{Furlong}" begin
423+
Du = Furlong.(D)
424+
@test Du isa Diagonal{<:Furlong{1}}
425+
F = svd(Du)
426+
U, s, V = F
427+
@test map(x -> x.val, Matrix(F)) map(x -> x.val, Du)
428+
@test svdvals(Du) == s
429+
@test U isa AbstractMatrix{<:Furlong{0}}
430+
@test V isa AbstractMatrix{<:Furlong{0}}
431+
@test s isa AbstractVector{<:Furlong{1}}
432+
E = eigen(Du)
433+
vals, vecs = E
434+
@test Matrix(E) == Du
435+
@test vals isa AbstractVector{<:Furlong{1}}
436+
@test vecs isa AbstractMatrix{<:Furlong{0}}
437+
end
414438
end
415439

416440
@testset "rdiv! (#40887)" begin

0 commit comments

Comments
 (0)