@@ -235,20 +235,11 @@ function dot(A::SparseMatrixCSC{T1,S1},B::SparseMatrixCSC{T2,S2}) where {T1,T2,S
235235end
236236
237237# # triangular sparse handling
238- abstract type UnitDiagonal end
239- struct UnitDiagonalYes <: UnitDiagonal end
240- struct UnitDiagonalNo <: UnitDiagonal end
241- abstract type AdjointElement end
242- struct AdjointElementYes <: AdjointElement end
243- struct AdjointElementNo <: AdjointElement end
244-
245- possible_adjoint (adj:: AdjointElementYes , a ) = adjoint (a)
246- possible_adjoint (adj:: AdjointElementNo , a ) = a
247- AdjointElement (:: Adjoint ) = AdjointElementYes ()
248- AdjointElement (:: Any ) = AdjointElementNo ()
249- UnitDiagonal (:: UnitUpperTriangular ) = UnitDiagonalYes ()
250- UnitDiagonal (:: UnitLowerTriangular ) = UnitDiagonalYes ()
251- UnitDiagonal (:: Any ) = UnitDiagonalNo ()
238+
239+ possible_adjoint (adj:: Bool , a:: Real ) = a
240+ possible_adjoint (adj:: Bool , a ) = adj ? adjoint (a) : a
241+
242+ const UnitDiagonalTriangular = Union{UnitUpperTriangular,UnitLowerTriangular}
252243
253244const LowerTriangularPlain{T} = Union{
254245 LowerTriangular{T,<: SparseMatrixCSCUnion{T} },
@@ -280,8 +271,21 @@ const TriangularSparse{T} = Union{
280271 LowerTriangularSparse{T}, UpperTriangularSparse{T}} where T
281272
282273# # triangular multipliers
283- function fwdTriMul! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal )
284- # forward substitution for UpperTriangular SparseCSC matrices
274+ function lmul! (A:: TriangularSparse{T} , B:: StridedVecOrMat{T} ) where T
275+ @assert ! has_offset_axes (A, B)
276+ nrowB, ncolB = size (B, 1 ), size (B, 2 )
277+ ncol = LinearAlgebra. checksquare (A)
278+ if nrowB != ncol
279+ throw (DimensionMismatch (" A is $(ncol) columns and B has $(nrowB) rows" ))
280+ end
281+ _lmul! (A, B)
282+ end
283+
284+ # forward multiplication for UpperTriangular SparseCSC matrices
285+ function _lmul! (U:: UpperTriangularPlain , B:: StridedVecOrMat )
286+ A = U. data
287+ unit = U isa UnitDiagonalTriangular
288+
285289 nrowB, ncolB = size (B, 1 ), size (B, 2 )
286290 aa = getnzval (A)
287291 ja = getrowval (A)
@@ -292,7 +296,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
292296 for j = 1 : nrowB
293297 i1 = ia[j]
294298 i2 = ia[j + 1 ] - 1
295- done = unit isa UnitDiagonalYes
299+ done = unit
296300
297301 bj = B[joff + j]
298302 for ii = i1: i2
@@ -301,7 +305,7 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
301305 if jai < j
302306 B[joff + jai] += aii * bj
303307 elseif jai == j
304- if unit isa UnitDiagonalNo
308+ if ! unit
305309 B[joff + j] *= aii
306310 done = true
307311 end
@@ -318,8 +322,11 @@ function fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
318322 B
319323end
320324
321- function bwdTriMul! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal )
322- # backward substitution for LowerTriangular SparseCSC matrices
325+ # backward multiplication for LowerTriangular SparseCSC matrices
326+ function _lmul! (L:: LowerTriangularPlain , B:: StridedVecOrMat )
327+ A = L. data
328+ unit = L isa UnitDiagonalTriangular
329+
323330 nrowB, ncolB = size (B, 1 ), size (B, 2 )
324331 aa = getnzval (A)
325332 ja = getrowval (A)
@@ -330,7 +337,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
330337 for j = nrowB: - 1 : 1
331338 i1 = ia[j]
332339 i2 = ia[j + 1 ] - 1
333- done = unit isa UnitDiagonalYes
340+ done = unit
334341
335342 bj = B[joff + j]
336343 for ii = i2: - 1 : i1
@@ -339,7 +346,7 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
339346 if jai > j
340347 B[joff + jai] += aii * bj
341348 elseif jai == j
342- if unit isa UnitDiagonalNo
349+ if ! unit
343350 B[joff + j] *= aii
344351 done = true
345352 end
@@ -356,8 +363,12 @@ function bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDiag
356363 B
357364end
358365
359- function _fwdTriMul! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal , adj:: AdjointElement )
360- # forward substitution for adjoint and transpose of LowerTriangular CSC matrices
366+ # forward multiplication for adjoint and transpose of LowerTriangular CSC matrices
367+ function _lmul! (U:: UpperTriangularWrapped , B:: StridedVecOrMat )
368+ A = U. parent. data
369+ unit = U. parent isa UnitDiagonalTriangular
370+ adj = U isa Adjoint
371+
361372 nrowB, ncolB = size (B, 1 ), size (B, 2 )
362373 aa = getnzval (A)
363374 ja = getrowval (A)
@@ -370,7 +381,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
370381 i1 = ia[j]
371382 i2 = ia[j + 1 ] - 1
372383 akku = Z
373- j0 = unit isa UnitDiagonalNo ? j : j + 1
384+ j0 = ! unit ? j : j + 1
374385
375386 # loop through column j of A - only structural non-zeros
376387 for ii = i2: - 1 : i1
@@ -382,7 +393,7 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
382393 break
383394 end
384395 end
385- if unit isa UnitDiagonalYes
396+ if unit
386397 akku += B[joff + j]
387398 end
388399 B[joff + j] = akku
@@ -392,8 +403,12 @@ function _fwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
392403 B
393404end
394405
395- function _bwdTriMul! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal , adj:: AdjointElement )
396- # multiply with adjoint and transpose of LowerTriangular CSC matrices
406+ # backward multiplication with adjoint and transpose of LowerTriangular CSC matrices
407+ function _lmul! (L:: LowerTriangularWrapped , B:: StridedVecOrMat )
408+ A = L. parent. data
409+ unit = L. parent isa UnitDiagonalTriangular
410+ adj = L isa Adjoint
411+
397412 nrowB, ncolB = size (B, 1 ), size (B, 2 )
398413 aa = getnzval (A)
399414 ja = getrowval (A)
@@ -406,7 +421,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
406421 i1 = ia[j]
407422 i2 = ia[j + 1 ] - 1
408423 akku = Z
409- j0 = unit isa UnitDiagonalNo ? j : j - 1
424+ j0 = ! unit ? j : j - 1
410425
411426 # loop through column j of A - only structural non-zeros
412427 for ii = i1: i2
@@ -418,7 +433,7 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
418433 break
419434 end
420435 end
421- if unit isa UnitDiagonalYes
436+ if unit
422437 akku += B[joff + j]
423438 end
424439 B[joff + j] = akku
@@ -428,31 +443,22 @@ function _bwdTriMul!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDia
428443 B
429444end
430445
431- function lmul! (A:: TriangularSparse{T} , B:: StridedVecOrMat{T} ) where T
446+ # # triangular solvers
447+ function ldiv! (A:: TriangularSparse{T} , B:: StridedVecOrMat{T} ) where T
432448 @assert ! has_offset_axes (A, B)
433449 nrowB, ncolB = size (B, 1 ), size (B, 2 )
434450 ncol = LinearAlgebra. checksquare (A)
435451 if nrowB != ncol
436452 throw (DimensionMismatch (" A is $(ncol) columns and B has $(nrowB) rows" ))
437453 end
438- _lmul ! (A, B)
454+ _ldiv ! (A, B)
439455end
440456
441- _lmul! (L:: LowerTriangularPlain , B:: StridedVecOrMat ) =
442- bwdTriMul! (L. data, B, UnitDiagonal (L))
443-
444- _lmul! (L:: LowerTriangularWrapped , B:: StridedVecOrMat ) =
445- _bwdTriMul! (L. parent. data, B, UnitDiagonal (L. parent), AdjointElement (L))
446-
447- _lmul! (U:: UpperTriangularPlain , B:: StridedVecOrMat ) =
448- fwdTriMul! (U. data, B, UnitDiagonal (U))
449-
450- _lmul! (U:: UpperTriangularWrapped , B:: StridedVecOrMat ) =
451- _fwdTriMul! (U. parent. data, B, UnitDiagonal (U. parent), AdjointElement (U))
452-
453- # # triangular solvers
454- function fwdTriSolve! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal )
455457# forward substitution for LowerTriangular CSC matrices
458+ function _ldiv! (L:: LowerTriangularPlain , B:: StridedVecOrMat )
459+ A = L. data
460+ unit = L isa UnitDiagonalTriangular
461+
456462 nrowB, ncolB = size (B, 1 ), size (B, 2 )
457463 aa = getnzval (A)
458464 ja = getrowval (A)
@@ -472,12 +478,12 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
472478 bj = B[joff + j]
473479 # check for zero pivot and divide with pivot
474480 if jai == j
475- if unit isa UnitDiagonalNo
481+ if ! unit
476482 bj /= aa[ii]
477483 B[joff + j] = bj
478484 end
479485 ii += 1
480- elseif unit isa UnitDiagonalNo
486+ elseif ! unit
481487 throw (LinearAlgebra. SingularException (j))
482488 end
483489
@@ -491,8 +497,11 @@ function fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
491497 B
492498end
493499
494- function bwdTriSolve! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal )
495500# backward substitution for UpperTriangular CSC matrices
501+ function _ldiv! (U:: UpperTriangularPlain , B:: StridedVecOrMat )
502+ A = U. data
503+ unit = U isa UnitDiagonalTriangular
504+
496505 nrowB, ncolB = size (B, 1 ), size (B, 2 )
497506 aa = getnzval (A)
498507 ja = getrowval (A)
@@ -512,12 +521,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
512521 bj = B[joff + j]
513522 # check for zero pivot and divide with pivot
514523 if jai == j
515- if unit isa UnitDiagonalNo
524+ if ! unit
516525 bj /= aa[ii]
517526 B[joff + j] = bj
518527 end
519528 ii -= 1
520- elseif unit isa UnitDiagonalNo
529+ elseif ! unit
521530 throw (LinearAlgebra. SingularException (j))
522531 end
523532
@@ -531,8 +540,12 @@ function bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitDi
531540 B
532541end
533542
534- function _fwdTriSolve! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal , adj:: AdjointElement )
535543# forward substitution for adjoint and transpose of UpperTriangular CSC matrices
544+ function _ldiv! (L:: LowerTriangularWrapped , B:: StridedVecOrMat )
545+ A = L. parent. data
546+ unit = L. parent isa UnitDiagonalTriangular
547+ adj = L isa Adjoint
548+
536549 nrowB, ncolB = size (B, 1 ), size (B, 2 )
537550 aa = getnzval (A)
538551 ja = getrowval (A)
@@ -547,14 +560,14 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
547560 done = false
548561
549562 # loop through column j of A - only structural non-zeros
550- for ip = i1: i2
551- i = ja[ip ]
552- if i < j
553- aai = possible_adjoint (adj, aa[ip ])
554- akku -= B[joff + i ] * aai
555- elseif i == j
556- if unit isa UnitDiagonalNo
557- aai = possible_adjoint (adj, aa[ip ])
563+ for ii = i1: i2
564+ jai = ja[ii ]
565+ if jai < j
566+ aai = possible_adjoint (adj, aa[ii ])
567+ akku -= B[joff + jai ] * aai
568+ elseif jai == j
569+ if ! unit
570+ aai = possible_adjoint (adj, aa[ii ])
558571 akku /= aai
559572 end
560573 done = true
@@ -563,7 +576,7 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
563576 break
564577 end
565578 end
566- if ! done && unit isa UnitDiagonalNo
579+ if ! done && ! unit
567580 throw (LinearAlgebra. SingularException (j))
568581 end
569582 B[joff + j] = akku
@@ -573,8 +586,12 @@ function _fwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
573586 B
574587end
575588
576- function _bwdTriSolve! (A:: SparseMatrixCSCUnion , B:: AbstractVecOrMat , unit:: UnitDiagonal , adj:: AdjointElement )
577589# backward substitution for adjoint and transpose of LowerTriangular CSC matrices
590+ function _ldiv! (U:: UpperTriangularWrapped , B:: StridedVecOrMat )
591+ A = U. parent. data
592+ unit = U. parent isa UnitDiagonalTriangular
593+ adj = U isa Adjoint
594+
578595 nrowB, ncolB = size (B, 1 ), size (B, 2 )
579596 aa = getnzval (A)
580597 ja = getrowval (A)
@@ -589,14 +606,14 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
589606 done = false
590607
591608 # loop through column j of A - only structural non-zeros
592- for ip = i2: - 1 : i1
593- i = ja[ip ]
594- if i > j
595- aai = possible_adjoint (adj, aa[ip ])
596- akku -= B[joff + i ] * aai
597- elseif i == j
598- if unit isa UnitDiagonalNo
599- aai = possible_adjoint (adj, aa[ip ])
609+ for ii = i2: - 1 : i1
610+ jai = ja[ii ]
611+ if jai > j
612+ aai = possible_adjoint (adj, aa[ii ])
613+ akku -= B[joff + jai ] * aai
614+ elseif jai == j
615+ if ! unit
616+ aai = possible_adjoint (adj, aa[ii ])
600617 akku /= aai
601618 end
602619 done = true
@@ -605,7 +622,7 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
605622 break
606623 end
607624 end
608- if ! done && unit isa UnitDiagonalNo
625+ if ! done && ! unit
609626 throw (LinearAlgebra. SingularException (j))
610627 end
611628 B[joff + j] = akku
@@ -615,28 +632,6 @@ function _bwdTriSolve!(A::SparseMatrixCSCUnion, B::AbstractVecOrMat, unit::UnitD
615632 B
616633end
617634
618- function ldiv! (A:: TriangularSparse{T} , B:: StridedVecOrMat{T} ) where T
619- @assert ! has_offset_axes (A, B)
620- nrowB, ncolB = size (B, 1 ), size (B, 2 )
621- ncol = LinearAlgebra. checksquare (A)
622- if nrowB != ncol
623- throw (DimensionMismatch (" A is $(ncol) columns and B has $(nrowB) rows" ))
624- end
625- _ldiv! (A, B)
626- end
627-
628- _ldiv! (L:: LowerTriangularPlain , B:: StridedVecOrMat ) =
629- fwdTriSolve! (L. data, B, UnitDiagonal (L))
630-
631- _ldiv! (L:: LowerTriangularWrapped , B:: StridedVecOrMat ) =
632- _fwdTriSolve! (L. parent. data, B, UnitDiagonal (L. parent), AdjointElement (L))
633-
634- _ldiv! (U:: UpperTriangularPlain , B:: StridedVecOrMat ) =
635- bwdTriSolve! (U. data, B, UnitDiagonal (U))
636-
637- _ldiv! (U:: UpperTriangularWrapped , B:: StridedVecOrMat ) =
638- _bwdTriSolve! (U. parent. data, B, UnitDiagonal (U. parent), AdjointElement (U))
639-
640635(\ )(L:: TriangularSparse , B:: SparseMatrixCSC ) = ldiv! (L, Array (B))
641636(* )(L:: TriangularSparse , B:: SparseMatrixCSC ) = lmul! (L, Array (B))
642637
0 commit comments