Skip to content

Commit 01ae5d9

Browse files
authored
Merge pull request #10 from N5N3/N5N3-patch-1
replace all AbstractArray with AbstractVecOrMat
2 parents eb4b99a + 7e0be45 commit 01ae5d9

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

stdlib/LinearAlgebra/src/diagonal.jl

Lines changed: 27 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,11 @@ function mul!(C::AbstractMatrix, Da::Diagonal, Db::Diagonal, alpha::Number, beta
316316
return C
317317
end
318318

319+
#TODO: many of /, \ related function has no size check and singular check
320+
(/)(A::AbstractVecOrMat, D::Diagonal) =
321+
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
319322
(/)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag ./ Db.diag)
320323

321-
ldiv!(x::AbstractArray, A::Diagonal, b::AbstractArray) = (x .= A.diag .\ b)
322-
323324
function rdiv!(A::AbstractMatrix, D::Diagonal)
324325
require_one_based_indexing(A)
325326
dd = D.diag
@@ -339,8 +340,30 @@ function rdiv!(A::AbstractMatrix, D::Diagonal)
339340
A
340341
end
341342

342-
(/)(A::AbstractArray, D::Diagonal) =
343-
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)
343+
(\)(D::Diagonal, A::AbstractMatrix) =
344+
ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A))
345+
(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
346+
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)
347+
348+
ldiv!(x::AbstractVecOrMat, A::Diagonal, b::AbstractVecOrMat) = (x .= A.diag .\ b)
349+
350+
function ldiv!(D::Diagonal, B::AbstractVecOrMat)
351+
m, n = size(B, 1), size(B, 2)
352+
if m != length(D.diag)
353+
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
354+
end
355+
(m == 0 || n == 0) && return B
356+
for j = 1:n
357+
for i = 1:m
358+
di = D.diag[i]
359+
if di == 0
360+
throw(SingularException(i))
361+
end
362+
B[i,j] = di \ B[i,j]
363+
end
364+
end
365+
return B
366+
end
344367

345368
# (l/r)mul!, l/rdiv!, *, / and \ Optimization for AbstractTriangular.
346369
# These functions are generally more efficient if we calculate the whole data field.
@@ -469,30 +492,6 @@ for f in (:exp, :cis, :log, :sqrt,
469492
@eval $f(D::Diagonal) = Diagonal($f.(D.diag))
470493
end
471494

472-
(\)(D::Diagonal, A::AbstractMatrix) =
473-
ldiv!(D, (typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A))
474-
475-
(\)(D::Diagonal, b::AbstractVector) = D.diag .\ b
476-
(\)(Da::Diagonal, Db::Diagonal) = Diagonal(Da.diag .\ Db.diag)
477-
478-
function ldiv!(D::Diagonal, B::AbstractVecOrMat)
479-
m, n = size(B, 1), size(B, 2)
480-
if m != length(D.diag)
481-
throw(DimensionMismatch("diagonal matrix is $(length(D.diag)) by $(length(D.diag)) but right hand side has $m rows"))
482-
end
483-
(m == 0 || n == 0) && return B
484-
for j = 1:n
485-
for i = 1:m
486-
di = D.diag[i]
487-
if di == 0
488-
throw(SingularException(i))
489-
end
490-
B[i,j] = di \ B[i,j]
491-
end
492-
end
493-
return B
494-
end
495-
496495
function inv(D::Diagonal{T}) where T
497496
Di = similar(D.diag, typeof(inv(zero(T))))
498497
for i = 1:length(D.diag)
@@ -572,7 +571,6 @@ function _mapreduce_prod(f, x, D::Diagonal, y)
572571
end
573572
end
574573

575-
576574
function cholesky!(A::Diagonal, ::Val{false} = Val(false); check::Bool = true)
577575
info = 0
578576
for (i, di) in enumerate(A.diag)

0 commit comments

Comments
 (0)