diff --git a/stdlib/LinearAlgebra/src/symmetric.jl b/stdlib/LinearAlgebra/src/symmetric.jl index 38035e09b4375..b4b5a8825da8c 100644 --- a/stdlib/LinearAlgebra/src/symmetric.jl +++ b/stdlib/LinearAlgebra/src/symmetric.jl @@ -269,10 +269,34 @@ end end end +_conjugation(::Symmetric) = transpose +_conjugation(::Hermitian) = adjoint + diag(A::Symmetric) = symmetric.(diag(parent(A)), sym_uplo(A.uplo)) diag(A::Hermitian) = hermitian.(diag(parent(A)), sym_uplo(A.uplo)) -isdiag(A::HermOrSym) = isdiag(A.uplo == 'U' ? UpperTriangular(A.data) : LowerTriangular(A.data)) +function applytri(f, A::HermOrSym) + if A.uplo == 'U' + f(UpperTriangular(A.data)) + else + f(LowerTriangular(A.data)) + end +end + +function applytri(f, A::HermOrSym, B::HermOrSym) + if A.uplo == B.uplo == 'U' + f(UpperTriangular(A.data), UpperTriangular(B.data)) + elseif A.uplo == B.uplo == 'L' + f(LowerTriangular(A.data), LowerTriangular(B.data)) + elseif A.uplo == 'U' + f(UpperTriangular(A.data), UpperTriangular(_conjugation(B)(B.data))) + else # A.uplo == 'L' + f(UpperTriangular(_conjugation(A)(A.data)), UpperTriangular(B.data)) + end +end +parentof_applytri(f, args...) = applytri(parent ∘ f, args...) + +isdiag(A::HermOrSym) = applytri(isdiag, A) # For A<:Union{Symmetric,Hermitian}, similar(A[, neweltype]) should yield a matrix with the same # symmetry type, uplo flag, and underlying storage type as A. The following methods cover these cases. @@ -314,8 +338,8 @@ Hermitian{T,S}(A::Hermitian) where {T,S<:AbstractMatrix{T}} = Hermitian{T,S}(con AbstractMatrix{T}(A::Hermitian) where {T} = Hermitian(convert(AbstractMatrix{T}, A.data), sym_uplo(A.uplo)) AbstractMatrix{T}(A::Hermitian{T}) where {T} = copy(A) -copy(A::Symmetric{T,S}) where {T,S} = (B = copy(A.data); Symmetric{T,typeof(B)}(B,A.uplo)) -copy(A::Hermitian{T,S}) where {T,S} = (B = copy(A.data); Hermitian{T,typeof(B)}(B,A.uplo)) +copy(A::Symmetric) = (Symmetric(parentof_applytri(copy, A), sym_uplo(A.uplo))) +copy(A::Hermitian) = (Hermitian(parentof_applytri(copy, A), sym_uplo(A.uplo))) function copyto!(dest::Symmetric, src::Symmetric) if src.uplo == dest.uplo @@ -389,9 +413,9 @@ transpose(A::Hermitian) = Transpose(A) real(A::Symmetric{<:Real}) = A real(A::Hermitian{<:Real}) = A -real(A::Symmetric) = Symmetric(real(A.data), sym_uplo(A.uplo)) -real(A::Hermitian) = Hermitian(real(A.data), sym_uplo(A.uplo)) -imag(A::Symmetric) = Symmetric(imag(A.data), sym_uplo(A.uplo)) +real(A::Symmetric) = Symmetric(parentof_applytri(real, A), sym_uplo(A.uplo)) +real(A::Hermitian) = Hermitian(parentof_applytri(real, A), sym_uplo(A.uplo)) +imag(A::Symmetric) = Symmetric(parentof_applytri(imag, A), sym_uplo(A.uplo)) Base.copy(A::Adjoint{<:Any,<:Symmetric}) = Symmetric(copy(adjoint(A.parent.data)), ifelse(A.parent.uplo == 'U', :L, :U)) @@ -401,8 +425,9 @@ Base.copy(A::Transpose{<:Any,<:Hermitian}) = tr(A::Symmetric) = tr(A.data) # to avoid AbstractMatrix fallback (incl. allocations) tr(A::Hermitian) = real(tr(A.data)) -Base.conj(A::HermOrSym) = typeof(A)(conj(A.data), A.uplo) -Base.conj!(A::HermOrSym) = typeof(A)(conj!(A.data), A.uplo) +Base.conj(A::Symmetric) = Symmetric(parentof_applytri(conj, A), sym_uplo(A.uplo)) +Base.conj(A::Hermitian) = Hermitian(parentof_applytri(conj, A), sym_uplo(A.uplo)) +Base.conj!(A::HermOrSym) = typeof(A)(parentof_applytri(conj!, A), A.uplo) # tril/triu function tril(A::Hermitian, k::Integer=0) @@ -496,21 +521,14 @@ for (T, trans, real) in [(:Symmetric, :transpose, :identity), (:(Hermitian{<:Uni end end -(-)(A::Symmetric) = Symmetric(-A.data, sym_uplo(A.uplo)) -(-)(A::Hermitian) = Hermitian(-A.data, sym_uplo(A.uplo)) +(-)(A::Symmetric) = Symmetric(parentof_applytri(-, A), sym_uplo(A.uplo)) +(-)(A::Hermitian) = Hermitian(parentof_applytri(-, A), sym_uplo(A.uplo)) ## Addition/subtraction -for f ∈ (:+, :-), (Wrapper, conjugation) ∈ ((:Hermitian, :adjoint), (:Symmetric, :transpose)) - @eval begin - function $f(A::$Wrapper, B::$Wrapper) - if A.uplo == B.uplo - return $Wrapper($f(parent(A), parent(B)), sym_uplo(A.uplo)) - elseif A.uplo == 'U' - return $Wrapper($f(parent(A), $conjugation(parent(B))), :U) - else - return $Wrapper($f($conjugation(parent(A)), parent(B)), :U) - end - end +for f ∈ (:+, :-), Wrapper ∈ (:Hermitian, :Symmetric) + @eval function $f(A::$Wrapper, B::$Wrapper) + uplo = A.uplo == B.uplo ? sym_uplo(A.uplo) : (:U) + $Wrapper(parentof_applytri($f, A, B), uplo) end end @@ -555,12 +573,12 @@ function dot(x::AbstractVector, A::RealHermSymComplexHerm, y::AbstractVector) end # Scaling with Number -*(A::Symmetric, x::Number) = Symmetric(A.data*x, sym_uplo(A.uplo)) -*(x::Number, A::Symmetric) = Symmetric(x*A.data, sym_uplo(A.uplo)) -*(A::Hermitian, x::Real) = Hermitian(A.data*x, sym_uplo(A.uplo)) -*(x::Real, A::Hermitian) = Hermitian(x*A.data, sym_uplo(A.uplo)) -/(A::Symmetric, x::Number) = Symmetric(A.data/x, sym_uplo(A.uplo)) -/(A::Hermitian, x::Real) = Hermitian(A.data/x, sym_uplo(A.uplo)) +*(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo)) +*(x::Number, A::Symmetric) = Symmetric(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo)) +*(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y * x, A), sym_uplo(A.uplo)) +*(x::Real, A::Hermitian) = Hermitian(parentof_applytri(y -> x * y, A), sym_uplo(A.uplo)) +/(A::Symmetric, x::Number) = Symmetric(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo)) +/(A::Hermitian, x::Real) = Hermitian(parentof_applytri(y -> y/x, A), sym_uplo(A.uplo)) factorize(A::HermOrSym) = _factorize(A) function _factorize(A::HermOrSym{T}; check::Bool=true) where T diff --git a/stdlib/LinearAlgebra/test/symmetric.jl b/stdlib/LinearAlgebra/test/symmetric.jl index b3e0b7b560e7a..d3b24ccf78b0b 100644 --- a/stdlib/LinearAlgebra/test/symmetric.jl +++ b/stdlib/LinearAlgebra/test/symmetric.jl @@ -470,6 +470,42 @@ end end end +@testset "non-isbits algebra" begin + for ST in (Symmetric, Hermitian), uplo in (:L, :U) + M = Matrix{Complex{BigFloat}}(undef,2,2) + M[1,1] = rand() + M[2,2] = rand() + M[1+(uplo==:L), 1+(uplo==:U)] = rand(ComplexF64) + S = ST(M, uplo) + MS = Matrix(S) + @test real(S) == real(MS) + @test imag(S) == imag(MS) + @test conj(S) == conj(MS) + @test conj!(copy(S)) == conj(MS) + @test -S == -MS + @test S + S == MS + MS + @test S - S == MS - MS + @test S*2 == 2*S == 2*MS + @test S/2 == MS/2 + end + @testset "mixed uplo" begin + Mu = Matrix{Complex{BigFloat}}(undef,2,2) + Mu[1,1] = Mu[2,2] = 3 + Mu[1,2] = 2 + 3im + Ml = Matrix{Complex{BigFloat}}(undef,2,2) + Ml[1,1] = Ml[2,2] = 4 + Ml[2,1] = 4 + 5im + for ST in (Symmetric, Hermitian) + Su = ST(Mu, :U) + MSu = Matrix(Su) + Sl = ST(Ml, :L) + MSl = Matrix(Sl) + @test Su + Sl == Sl + Su == MSu + MSl + @test Su - Sl == -(Sl - Su) == MSu - MSl + end + end +end + # bug identified in PR #52318: dot products of quaternionic Hermitian matrices, # or any number type where conj(a)*conj(b) ≠ conj(a*b): @testset "dot Hermitian quaternion #52318" begin @@ -932,4 +968,11 @@ end end end +@testset "conj for immutable" begin + S = Symmetric(reshape((1:16)*im, 4, 4)) + @test conj(S) == conj(Array(S)) + H = Hermitian(reshape((1:16)*im, 4, 4)) + @test conj(H) == conj(Array(H)) +end + end # module TestSymmetric