@@ -78,27 +78,20 @@ _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
7878generic_matvecmul! (y:: StridedVector{T} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{T} ,
7979 alpha:: Number , beta:: Number ) where {T<: BlasFloat } =
8080 gemv! (y, tA, A, x, alpha, beta)
81- generic_matvecmul! (y:: StridedVector{T} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{T} ,
82- _add:: MulAddMul = MulAddMul ()) where {T<: BlasFloat } =
83- gemv! (y, tA, A, x, _add. alpha, _add. beta)
81+
8482# Real (possibly transposed) matrix times complex vector.
8583# Multiply the matrix with the real and imaginary parts separately
8684generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{Complex{T}} ,
8785 alpha:: Number , beta:: Number ) where {T<: BlasReal } =
8886 gemv! (y, tA, A, x, alpha, beta)
89- generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{T} , x:: StridedVector{Complex{T}} ,
90- _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
91- gemv! (y, tA, A, x, _add. alpha, _add. beta)
87+
9288# Complex matrix times real vector.
9389# Reinterpret the matrix as a real matrix and do real matvec computation.
9490# works only in cooperation with BLAS when A is untransposed (tA == 'N')
9591# but that check is included in gemv! anyway
9692generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{Complex{T}} , x:: StridedVector{T} ,
9793 alpha:: Number , beta:: Number ) where {T<: BlasReal } =
9894 gemv! (y, tA, A, x, alpha, beta)
99- generic_matvecmul! (y:: StridedVector{Complex{T}} , tA, A:: StridedVecOrMat{Complex{T}} , x:: StridedVector{T} ,
100- _add:: MulAddMul = MulAddMul ()) where {T<: BlasReal } =
101- gemv! (y, tA, A, x, _add. alpha, _add. beta)
10295
10396# Vector-Matrix multiplication
10497(* )(x:: AdjointAbsVec , A:: AbstractMatrix ) = (A' * x' )'
@@ -539,9 +532,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
539532 if tA_uc in (' S' , ' H' )
540533 # re-wrap again and use plain ('N') matvec mul algorithm,
541534 # because _generic_matvecmul! can't handle the HermOrSym cases specifically
542- return @stable_muladdmul _generic_matvecmul! (y, ' N' , wrap (A, tA), x, MulAddMul ( α, β) )
535+ return _generic_matvecmul! (y, ' N' , wrap (A, tA), x, α, β)
543536 else
544- return @stable_muladdmul _generic_matvecmul! (y, tA, A, x, MulAddMul ( α, β) )
537+ return _generic_matvecmul! (y, tA, A, x, α, β)
545538 end
546539end
547540
@@ -564,7 +557,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
564557 return y
565558 else
566559 Anew, ta = tA_uc in (' S' , ' H' ) ? (wrap (A, tA), oftype (tA, ' N' )) : (A, tA)
567- return @stable_muladdmul _generic_matvecmul! (y, ta, Anew, x, MulAddMul ( α, β) )
560+ return _generic_matvecmul! (y, ta, Anew, x, α, β)
568561 end
569562end
570563
@@ -591,9 +584,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
591584 elseif tA_uc in (' S' , ' H' )
592585 # re-wrap again and use plain ('N') matvec mul algorithm,
593586 # because _generic_matvecmul! can't handle the HermOrSym cases specifically
594- return @stable_muladdmul _generic_matvecmul! (y, ' N' , wrap (A, tA), x, MulAddMul ( α, β) )
587+ return _generic_matvecmul! (y, ' N' , wrap (A, tA), x, α, β)
595588 else
596- return @stable_muladdmul _generic_matvecmul! (y, tA, A, x, MulAddMul ( α, β) )
589+ return _generic_matvecmul! (y, tA, A, x, α, β)
597590 end
598591end
599592
@@ -825,82 +818,83 @@ end
825818# NOTE: the generic version is also called as fallback for
826819# strides != 1 cases
827820
828- Base. @constprop :aggressive generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector , alpha:: Number , beta:: Number ) =
829- @stable_muladdmul generic_matvecmul! (C, tA, A, B, MulAddMul (alpha, beta))
821+ # legacy method, retained for backward compatibility
822+ generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector , _add:: MulAddMul = MulAddMul ()) =
823+ generic_matvecmul! (C, tA, A, B, _add. alpha, _add. beta)
830824@inline function generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
831- _add :: MulAddMul = MulAddMul () )
825+ alpha :: Number , beta :: Number )
832826 tA_uc = uppercase (tA) # potentially convert a WrapperChar to a Char
833827 Anew, ta = tA_uc in (' S' , ' H' ) ? (wrap (A, tA), oftype (tA, ' N' )) : (A, tA)
834- return _generic_matvecmul! (C, ta, Anew, B, _add )
828+ return _generic_matvecmul! (C, ta, Anew, B, alpha, beta )
835829end
836830
837- function _generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
838- _add:: MulAddMul = MulAddMul ())
839- require_one_based_indexing (C, A, B)
840- @assert tA in (' N' , ' T' , ' C' )
841- mB = length (B)
842- mA, nA = lapack_size (tA, A)
843- if mB != nA
844- throw (DimensionMismatch (lazy " matrix A has dimensions ($mA,$nA), vector B has length $mB" ))
845- end
846- if mA != length (C)
847- throw (DimensionMismatch (lazy " result C has length $(length(C)), needs length $mA" ))
848- end
849-
831+ # legacy method, retained for backward compatibility
832+ _generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector , _add:: MulAddMul = MulAddMul ()) =
833+ _generic_matvecmul! (C, tA, A, B, _add. alpha, _add. beta)
834+ function __generic_matvecmul! (f:: F , C:: AbstractVector , A:: AbstractVecOrMat , B:: AbstractVector ,
835+ alpha:: Number , beta:: Number ) where {F}
850836 Astride = size (A, 1 )
851-
852837 @inbounds begin
853- if tA == ' T' # fastest case
854- if nA == 0
855- for k = 1 : mA
856- _modify! (_add, false , C, k)
857- end
858- else
859- for k = 1 : mA
860- aoffs = (k- 1 )* Astride
861- firstterm = transpose (A[aoffs + 1 ])* B[1 ]
862- s = zero (firstterm + firstterm)
863- for i = 1 : nA
864- s += transpose (A[aoffs+ i]) * B[i]
865- end
866- _modify! (_add, s, C, k)
867- end
868- end
869- elseif tA == ' C'
870- if nA == 0
871- for k = 1 : mA
872- _modify! (_add, false , C, k)
838+ if length (B) == 0
839+ for k = eachindex (C)
840+ @stable_muladdmul _modify! (MulAddMul (alpha,beta), false , C, k)
873841 end
874842 else
875- for k = 1 : mA
843+ for k = eachindex (C)
876844 aoffs = (k- 1 )* Astride
877- firstterm = A[aoffs + 1 ]' B[1 ]
845+ firstterm = f ( A[aoffs + 1 ]) * B[1 ]
878846 s = zero (firstterm + firstterm)
879- for i = 1 : nA
880- s += A[aoffs + i] ' B[i]
847+ for i = eachindex (B)
848+ s += f ( A[aoffs+ i]) * B[i]
881849 end
882- _modify! (_add , s, C, k)
850+ @stable_muladdmul _modify! (MulAddMul (alpha,beta) , s, C, k)
883851 end
884852 end
885- else # tA == 'N'
886- for i = 1 : mA
887- if ! iszero (_add. beta)
888- C[i] *= _add. beta
889- elseif mB == 0
853+ end
854+ end
855+ function __generic_matvecmul! (:: typeof (identity), C:: AbstractVector , A:: AbstractVecOrMat , B:: AbstractVector ,
856+ alpha:: Number , beta:: Number )
857+ Astride = size (A, 1 )
858+ @inbounds begin
859+ for i = eachindex (C)
860+ if ! iszero (beta)
861+ C[i] *= beta
862+ elseif length (B) == 0
890863 C[i] = false
891864 else
892865 C[i] = zero (A[i]* B[1 ] + A[i]* B[1 ])
893866 end
894867 end
895- for k = 1 : mB
868+ for k = eachindex (B)
896869 aoffs = (k- 1 )* Astride
897- b = _add (B[k])
898- for i = 1 : mA
870+ b = @stable_muladdmul MulAddMul (alpha,beta) (B[k])
871+ for i = eachindex (C)
899872 C[i] += A[aoffs + i] * b
900873 end
901874 end
902875 end
903- end # @inbounds
876+ return C
877+ end
878+ function _generic_matvecmul! (C:: AbstractVector , tA, A:: AbstractVecOrMat , B:: AbstractVector ,
879+ alpha:: Number , beta:: Number )
880+ require_one_based_indexing (C, A, B)
881+ @assert tA in (' N' , ' T' , ' C' )
882+ mB = length (B)
883+ mA, nA = lapack_size (tA, A)
884+ if mB != nA
885+ throw (DimensionMismatch (lazy " matrix A has dimensions ($mA,$nA), vector B has length $mB" ))
886+ end
887+ if mA != length (C)
888+ throw (DimensionMismatch (lazy " result C has length $(length(C)), needs length $mA" ))
889+ end
890+
891+ if tA == ' T' # fastest case
892+ __generic_matvecmul! (transpose, C, A, B, alpha, beta)
893+ elseif tA == ' C'
894+ __generic_matvecmul! (adjoint, C, A, B, alpha, beta)
895+ else # tA == 'N'
896+ __generic_matvecmul! (identity, C, A, B, alpha, beta)
897+ end
904898 C
905899end
906900
0 commit comments