diff --git a/stdlib/LinearAlgebra/src/diagonal.jl b/stdlib/LinearAlgebra/src/diagonal.jl index 1e28d6d040ca0..d849640a351f1 100644 --- a/stdlib/LinearAlgebra/src/diagonal.jl +++ b/stdlib/LinearAlgebra/src/diagonal.jl @@ -316,10 +316,11 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta return C end +#TODO: many of /, \ related function has no size check and singular check +(/)(A::AbstractVecOrMat, D::Diagonal) = + rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D) (/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag) -ldiv!(x::AbstractArray, A::Diagonal, b::AbstractArray) = (x .= A.diag .\ b) - function rdiv!(A::AbstractMatrix, D::Diagonal) require_one_based_indexing(A) dd = D.diag @@ -339,8 +340,30 @@ function rdiv!(A::AbstractMatrix, D::Diagonal) A end -(/)(A::AbstractArray, D::Diagonal) = - rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D) +(\)(D::Diagonal, A::AbstractMatrix) = + ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A)) +(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b +(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag) + +ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b) + +function ldiv!(D::Diagonal, B::AbstractVecOrMat) + m, n = size(B, 1), size(B, 2) + if m != length(D.diag) + throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows")) + end + (m == 0 || n == 0) && return B + for j = 1:n + for i = 1:m + di = D.diag[i] + if di == 0 + throw(SingularException(i)) + end + B[i,j] = di \ B[i,j] + end + end + return B +end # (l/r)mul!, l/rdiv!, *, / and \ Optimization for AbstractTriangular. # These functions are generally more efficient if we calculate the whole data field. @@ -469,30 +492,6 @@ for f in (:exp, :cis, :log, :sqrt, @eval $f(D::Diagonal) = Diagonal($f.(D.diag)) end -(\)(D::Diagonal, A::AbstractMatrix) = - ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A)) - -(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b -(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag) - -function ldiv!(D::Diagonal, B::AbstractVecOrMat) - m, n = size(B, 1), size(B, 2) - if m != length(D.diag) - throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows")) - end - (m == 0 || n == 0) && return B - for j = 1:n - for i = 1:m - di = D.diag[i] - if di == 0 - throw(SingularException(i)) - end - B[i,j] = di \ B[i,j] - end - end - return B -end - function inv(D::Diagonal{T}) where T Di = similar(D.diag, typeof(inv(zero(T)))) for i = 1:length(D.diag) @@ -572,7 +571,6 @@ function _mapreduce_prod(f, x, D::Diagonal, y) end end - function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true) info = 0 for (i, di) in enumerate(A.diag)