Skip to content

Commit 4307da6

Browse files
committed
refactored to purge name space of SparseArrays
1 parent 70c8b73 commit 4307da6

File tree

1 file changed

+88
-93
lines changed

1 file changed

+88
-93
lines changed

stdlib/SparseArrays/src/linalg.jl

Lines changed: 88 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -235,20 +235,11 @@ function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S
235235
end
236236

237237
## triangular sparse handling
238-
abstract type UnitDiagonal end
239-
struct UnitDiagonalYes <:UnitDiagonal end
240-
struct UnitDiagonalNo <:UnitDiagonal end
241-
abstract type AdjointElement end
242-
struct AdjointElementYes <:AdjointElement end
243-
struct AdjointElementNo <:AdjointElement end
244-
245-
possible_adjoint(adj::AdjointElementYes, a ) = adjoint(a)
246-
possible_adjoint(adj::AdjointElementNo, a ) = a
247-
AdjointElement(::Adjoint) = AdjointElementYes()
248-
AdjointElement(::Any) = AdjointElementNo()
249-
UnitDiagonal(::UnitUpperTriangular) = UnitDiagonalYes()
250-
UnitDiagonal(::UnitLowerTriangular) = UnitDiagonalYes()
251-
UnitDiagonal(::Any) = UnitDiagonalNo()
238+
239+
possible_adjoint(adj::Bool, a::Real ) = a
240+
possible_adjoint(adj::Bool, a ) = adj ? adjoint(a) : a
241+
242+
const UnitDiagonalTriangular = Union{UnitUpperTriangular,UnitLowerTriangular}
252243

253244
const LowerTriangularPlain{T} = Union{
254245
LowerTriangular{T,<:SparseMatrixCSCUnion{T}},
@@ -280,8 +271,21 @@ const TriangularSparse{T} = Union{
280271
LowerTriangularSparse{T}, UpperTriangularSparse{T}} where T
281272

282273
## triangular multipliers
283-
function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
284-
# forward substitution for UpperTriangular SparseCSC matrices
274+
function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
275+
@assert !has_offset_axes(A, B)
276+
nrowB, ncolB = size(B, 1), size(B, 2)
277+
ncol = LinearAlgebra.checksquare(A)
278+
if nrowB != ncol
279+
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
280+
end
281+
_lmul!(A, B)
282+
end
283+
284+
# forward multiplication for UpperTriangular SparseCSC matrices
285+
function _lmul!(U::UpperTriangularPlain, B::StridedVecOrMat)
286+
A = U.data
287+
unit = U isa UnitDiagonalTriangular
288+
285289
nrowB, ncolB = size(B, 1), size(B, 2)
286290
aa = getnzval(A)
287291
ja = getrowval(A)
@@ -292,7 +296,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
292296
for j = 1:nrowB
293297
i1 = ia[j]
294298
i2 = ia[j + 1] - 1
295-
done = unit isa UnitDiagonalYes
299+
done = unit
296300

297301
bj = B[joff + j]
298302
for ii = i1:i2
@@ -301,7 +305,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
301305
if jai < j
302306
B[joff + jai] += aii * bj
303307
elseif jai == j
304-
if unit isa UnitDiagonalNo
308+
if !unit
305309
B[joff + j] *= aii
306310
done = true
307311
end
@@ -318,8 +322,11 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
318322
B
319323
end
320324

321-
function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
322-
# backward substitution for LowerTriangular SparseCSC matrices
325+
# backward multiplication for LowerTriangular SparseCSC matrices
326+
function _lmul!(L::LowerTriangularPlain, B::StridedVecOrMat)
327+
A = L.data
328+
unit = L isa UnitDiagonalTriangular
329+
323330
nrowB, ncolB = size(B, 1), size(B, 2)
324331
aa = getnzval(A)
325332
ja = getrowval(A)
@@ -330,7 +337,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
330337
for j = nrowB:-1:1
331338
i1 = ia[j]
332339
i2 = ia[j + 1] - 1
333-
done = unit isa UnitDiagonalYes
340+
done = unit
334341

335342
bj = B[joff + j]
336343
for ii = i2:-1:i1
@@ -339,7 +346,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
339346
if jai > j
340347
B[joff + jai] += aii * bj
341348
elseif jai == j
342-
if unit isa UnitDiagonalNo
349+
if !unit
343350
B[joff + j] *= aii
344351
done = true
345352
end
@@ -356,8 +363,12 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
356363
B
357364
end
358365

359-
function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
360-
# forward substitution for adjoint and transpose of LowerTriangular CSC matrices
366+
# forward multiplication for adjoint and transpose of LowerTriangular CSC matrices
367+
function _lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat)
368+
A = U.parent.data
369+
unit = U.parent isa UnitDiagonalTriangular
370+
adj = U isa Adjoint
371+
361372
nrowB, ncolB = size(B, 1), size(B, 2)
362373
aa = getnzval(A)
363374
ja = getrowval(A)
@@ -370,7 +381,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
370381
i1 = ia[j]
371382
i2 = ia[j + 1] - 1
372383
akku = Z
373-
j0 = unit isa UnitDiagonalNo ? j : j + 1
384+
j0 = !unit ? j : j + 1
374385

375386
# loop through column j of A - only structural non-zeros
376387
for ii = i2:-1:i1
@@ -382,7 +393,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
382393
break
383394
end
384395
end
385-
if unit isa UnitDiagonalYes
396+
if unit
386397
akku += B[joff + j]
387398
end
388399
B[joff + j] = akku
@@ -392,8 +403,12 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
392403
B
393404
end
394405

395-
function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
396-
# multiply with adjoint and transpose of LowerTriangular CSC matrices
406+
# backward multiplication with adjoint and transpose of LowerTriangular CSC matrices
407+
function _lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat)
408+
A = L.parent.data
409+
unit = L.parent isa UnitDiagonalTriangular
410+
adj = L isa Adjoint
411+
397412
nrowB, ncolB = size(B, 1), size(B, 2)
398413
aa = getnzval(A)
399414
ja = getrowval(A)
@@ -406,7 +421,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
406421
i1 = ia[j]
407422
i2 = ia[j + 1] - 1
408423
akku = Z
409-
j0 = unit isa UnitDiagonalNo ? j : j - 1
424+
j0 = !unit ? j : j - 1
410425

411426
# loop through column j of A - only structural non-zeros
412427
for ii = i1:i2
@@ -418,7 +433,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
418433
break
419434
end
420435
end
421-
if unit isa UnitDiagonalYes
436+
if unit
422437
akku += B[joff + j]
423438
end
424439
B[joff + j] = akku
@@ -428,31 +443,22 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
428443
B
429444
end
430445

431-
function lmul!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
446+
## triangular solvers
447+
function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
432448
@assert !has_offset_axes(A, B)
433449
nrowB, ncolB = size(B, 1), size(B, 2)
434450
ncol = LinearAlgebra.checksquare(A)
435451
if nrowB != ncol
436452
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
437453
end
438-
_lmul!(A, B)
454+
_ldiv!(A, B)
439455
end
440456

441-
_lmul!(L::LowerTriangularPlain, B::StridedVecOrMat) =
442-
bwdTriMul!(L.data, B, UnitDiagonal(L))
443-
444-
_lmul!(L::LowerTriangularWrapped, B::StridedVecOrMat) =
445-
_bwdTriMul!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L))
446-
447-
_lmul!(U::UpperTriangularPlain, B::StridedVecOrMat) =
448-
fwdTriMul!(U.data, B, UnitDiagonal(U))
449-
450-
_lmul!(U::UpperTriangularWrapped, B::StridedVecOrMat) =
451-
_fwdTriMul!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U))
452-
453-
## triangular solvers
454-
function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
455457
# forward substitution for LowerTriangular CSC matrices
458+
function _ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat)
459+
A = L.data
460+
unit = L isa UnitDiagonalTriangular
461+
456462
nrowB, ncolB = size(B, 1), size(B, 2)
457463
aa = getnzval(A)
458464
ja = getrowval(A)
@@ -472,12 +478,12 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
472478
bj = B[joff + j]
473479
# check for zero pivot and divide with pivot
474480
if jai == j
475-
if unit isa UnitDiagonalNo
481+
if !unit
476482
bj /= aa[ii]
477483
B[joff + j] = bj
478484
end
479485
ii += 1
480-
elseif unit isa UnitDiagonalNo
486+
elseif !unit
481487
throw(LinearAlgebra.SingularException(j))
482488
end
483489

@@ -491,8 +497,11 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
491497
B
492498
end
493499

494-
function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal)
495500
# backward substitution for UpperTriangular CSC matrices
501+
function _ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat)
502+
A = U.data
503+
unit = U isa UnitDiagonalTriangular
504+
496505
nrowB, ncolB = size(B, 1), size(B, 2)
497506
aa = getnzval(A)
498507
ja = getrowval(A)
@@ -512,12 +521,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
512521
bj = B[joff + j]
513522
# check for zero pivot and divide with pivot
514523
if jai == j
515-
if unit isa UnitDiagonalNo
524+
if !unit
516525
bj /= aa[ii]
517526
B[joff + j] = bj
518527
end
519528
ii -= 1
520-
elseif unit isa UnitDiagonalNo
529+
elseif !unit
521530
throw(LinearAlgebra.SingularException(j))
522531
end
523532

@@ -531,8 +540,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
531540
B
532541
end
533542

534-
function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
535543
# forward substitution for adjoint and transpose of UpperTriangular CSC matrices
544+
function _ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat)
545+
A = L.parent.data
546+
unit = L.parent isa UnitDiagonalTriangular
547+
adj = L isa Adjoint
548+
536549
nrowB, ncolB = size(B, 1), size(B, 2)
537550
aa = getnzval(A)
538551
ja = getrowval(A)
@@ -547,14 +560,14 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
547560
done = false
548561

549562
# loop through column j of A - only structural non-zeros
550-
for ip = i1:i2
551-
i = ja[ip]
552-
if i < j
553-
aai = possible_adjoint(adj, aa[ip])
554-
akku -= B[joff + i] * aai
555-
elseif i == j
556-
if unit isa UnitDiagonalNo
557-
aai = possible_adjoint(adj, aa[ip])
563+
for ii = i1:i2
564+
jai = ja[ii]
565+
if jai < j
566+
aai = possible_adjoint(adj, aa[ii])
567+
akku -= B[joff + jai] * aai
568+
elseif jai == j
569+
if !unit
570+
aai = possible_adjoint(adj, aa[ii])
558571
akku /= aai
559572
end
560573
done = true
@@ -563,7 +576,7 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
563576
break
564577
end
565578
end
566-
if !done && unit isa UnitDiagonalNo
579+
if !done && !unit
567580
throw(LinearAlgebra.SingularException(j))
568581
end
569582
B[joff + j] = akku
@@ -573,8 +586,12 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
573586
B
574587
end
575588

576-
function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiagonal, adj::AdjointElement)
577589
# backward substitution for adjoint and transpose of LowerTriangular CSC matrices
590+
function _ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat)
591+
A = U.parent.data
592+
unit = U.parent isa UnitDiagonalTriangular
593+
adj = U isa Adjoint
594+
578595
nrowB, ncolB = size(B, 1), size(B, 2)
579596
aa = getnzval(A)
580597
ja = getrowval(A)
@@ -589,14 +606,14 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
589606
done = false
590607

591608
# loop through column j of A - only structural non-zeros
592-
for ip = i2:-1:i1
593-
i = ja[ip]
594-
if i > j
595-
aai = possible_adjoint(adj, aa[ip])
596-
akku -= B[joff + i] * aai
597-
elseif i == j
598-
if unit isa UnitDiagonalNo
599-
aai = possible_adjoint(adj, aa[ip])
609+
for ii = i2:-1:i1
610+
jai = ja[ii]
611+
if jai > j
612+
aai = possible_adjoint(adj, aa[ii])
613+
akku -= B[joff + jai] * aai
614+
elseif jai == j
615+
if !unit
616+
aai = possible_adjoint(adj, aa[ii])
600617
akku /= aai
601618
end
602619
done = true
@@ -605,7 +622,7 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
605622
break
606623
end
607624
end
608-
if !done && unit isa UnitDiagonalNo
625+
if !done && !unit
609626
throw(LinearAlgebra.SingularException(j))
610627
end
611628
B[joff + j] = akku
@@ -615,28 +632,6 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
615632
B
616633
end
617634

618-
function ldiv!(A::TriangularSparse{T}, B::StridedVecOrMat{T}) where T
619-
@assert !has_offset_axes(A, B)
620-
nrowB, ncolB = size(B, 1), size(B, 2)
621-
ncol = LinearAlgebra.checksquare(A)
622-
if nrowB != ncol
623-
throw(DimensionMismatch("A is $(ncol) columns and B has $(nrowB) rows"))
624-
end
625-
_ldiv!(A, B)
626-
end
627-
628-
_ldiv!(L::LowerTriangularPlain, B::StridedVecOrMat) =
629-
fwdTriSolve!(L.data, B, UnitDiagonal(L))
630-
631-
_ldiv!(L::LowerTriangularWrapped, B::StridedVecOrMat) =
632-
_fwdTriSolve!(L.parent.data, B, UnitDiagonal(L.parent), AdjointElement(L))
633-
634-
_ldiv!(U::UpperTriangularPlain, B::StridedVecOrMat) =
635-
bwdTriSolve!(U.data, B, UnitDiagonal(U))
636-
637-
_ldiv!(U::UpperTriangularWrapped, B::StridedVecOrMat) =
638-
_bwdTriSolve!(U.parent.data, B, UnitDiagonal(U.parent), AdjointElement(U))
639-
640635
(\)(L::TriangularSparse, B::SparseMatrixCSC) = ldiv!(L, Array(B))
641636
(*)(L::TriangularSparse, B::SparseMatrixCSC) = lmul!(L, Array(B))
642637

0 commit comments

Comments
 (0)