Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 27 additions & 41 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,15 @@ SymTridiagonal(A::AbstractTriangular) = SymTridiagonal(Tridiagonal(A))
Tridiagonal(A::AbstractTriangular) =
isbanded(A, -1, 1) ? Tridiagonal(diag(A, -1), diag(A, 0), diag(A, 1)) : # is tridiagonal
throw(ArgumentError("matrix cannot be represented as Tridiagonal"))

UpperTriangular(A::Bidiagonal) =
A.uplo == 'U' ? UpperTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as UpperTriangular"))
LowerTriangular(A::Bidiagonal) =
A.uplo == 'L' ? LowerTriangular{eltype(A), typeof(A)}(A) :
throw(ArgumentError("matrix cannot be represented as LowerTriangular"))

const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular}
const PossibleTriangularMatrix = Union{Diagonal, Bidiagonal, AbstractTriangular}

convert(T::Type{<:Diagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
convert(T::Type{<:SymTridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m : T(m)
Expand All @@ -67,6 +73,9 @@ convert(T::Type{<:Tridiagonal}, m::ConvertibleSpecialMatrix) = m isa T ? m :
convert(T::Type{<:LowerTriangular}, m::Union{LowerTriangular,UnitLowerTriangular}) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::Union{UpperTriangular,UnitUpperTriangular}) = m isa T ? m : T(m)

convert(T::Type{<:LowerTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)
convert(T::Type{<:UpperTriangular}, m::PossibleTriangularMatrix) = m isa T ? m : T(m)

# Constructs two method definitions taking into account (assumed) commutativity
# e.g. @commutative f(x::S, y::T) where {S,T} = x+y is the same is defining
# f(x::S, y::T) where {S,T} = x+y
Expand All @@ -80,51 +89,28 @@ macro commutative(myexpr)
end

for op in (:+, :-)
SpecialMatrices = [:Diagonal, :Bidiagonal, :Tridiagonal, :Matrix]
for (idx, matrixtype1) in enumerate(SpecialMatrices) # matrixtype1 is the sparser matrix type
for matrixtype2 in SpecialMatrices[idx+1:end] # matrixtype2 is the denser matrix type
@eval begin # TODO quite a few of these conversions are NOT defined
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
end
end
end

for matrixtype1 in (:SymTridiagonal,) # matrixtype1 is the sparser matrix type
for matrixtype2 in (:Tridiagonal, :Matrix) # matrixtype2 is the denser matrix type
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
end
end
end

for matrixtype1 in (:Diagonal, :Bidiagonal) # matrixtype1 is the sparser matrix type
for matrixtype2 in (:SymTridiagonal,) # matrixtype2 is the denser matrix type
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(convert(($matrixtype2), A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, convert(($matrixtype2), B))
for (matrixtype, uplo, converttype) in ((:UpperTriangular, 'U', :UpperTriangular),
(:UnitUpperTriangular, 'U', :UpperTriangular),
(:LowerTriangular, 'L', :LowerTriangular),
(:UnitLowerTriangular, 'L', :LowerTriangular))
@eval begin
function ($op)(A::$matrixtype, B::Bidiagonal)
if B.uplo == $uplo
($op)(A, convert($converttype, B))
else
($op).(A, B)
end
end
end
end

for matrixtype1 in (:Diagonal,)
for (matrixtype2,matrixtype3) in ((:UpperTriangular,:UpperTriangular),
(:UnitUpperTriangular,:UpperTriangular),
(:LowerTriangular,:LowerTriangular),
(:UnitLowerTriangular,:LowerTriangular))
@eval begin
($op)(A::($matrixtype1), B::($matrixtype2)) = ($op)(($matrixtype3)(A), B)
($op)(A::($matrixtype2), B::($matrixtype1)) = ($op)(A, ($matrixtype3)(B))
function ($op)(A::Bidiagonal, B::$matrixtype)
if A.uplo == $uplo
($op)(convert($converttype, A), B)
else
($op).(A, B)
end
end
end
end
for matrixtype in (:SymTridiagonal,:Tridiagonal,:Bidiagonal,:Matrix)
@eval begin
($op)(A::AbstractTriangular, B::($matrixtype)) = ($op)(copyto!(similar(parent(A)), A), B)
($op)(A::($matrixtype), B::AbstractTriangular) = ($op)(A, copyto!(similar(parent(B)), B))
end
end
end

rmul!(A::AbstractTriangular, adjB::Adjoint{<:Any,<:Union{QRCompactWYQ,QRPackedQ}}) =
Expand Down
18 changes: 18 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,24 @@ end
@test Matrix(convert(Spectype,A) - D) ≈ Matrix(A - D)
end
end

UpTri = UpperTriangular(rand(20,20))
LoTri = LowerTriangular(rand(20,20))
Diag = Diagonal(rand(20,20))
Tridiag = Tridiagonal(rand(20, 20))
UpBi = Bidiagonal(rand(20,20), :U)
LoBi = Bidiagonal(rand(20,20), :L)
Sym = SymTridiagonal(rand(20), rand(19))
Dense = rand(20, 20)
mats = [UpTri, LoTri, Diag, Tridiag, UpBi, LoBi, Sym, Dense]

for op in (+, -)
for A in mats
for B in mats
@test (op)(A, B) ≈ (op)(Matrix(A), Matrix(B)) ≈ Matrix((op)(A, B))
end
end
end
end

@testset "Triangular Types and QR" begin
Expand Down