From 8db640078329ba09884c32bfd29418732595e51b Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 22 Oct 2024 16:30:10 +0530 Subject: [PATCH 1/4] Reduce generic matrix*vector latency --- stdlib/LinearAlgebra/src/matmul.jl | 47 +++++++++++--------------- stdlib/LinearAlgebra/src/triangular.jl | 2 +- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index f64422fd9cb8a..632fd404404ea 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -78,17 +78,13 @@ _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector, generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, alpha::Number, beta::Number) where {T<:BlasFloat} = gemv!(y, tA, A, x, alpha, beta) -generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T}, - _add::MulAddMul = MulAddMul()) where {T<:BlasFloat} = - gemv!(y, tA, A, x, _add.alpha, _add.beta) + # Real (possibly transposed) matrix times complex vector. # Multiply the matrix with the real and imaginary parts separately generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, alpha::Number, beta::Number) where {T<:BlasReal} = gemv!(y, tA, A, x, alpha, beta) -generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}}, - _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = - gemv!(y, tA, A, x, _add.alpha, _add.beta) + # Complex matrix times real vector. # Reinterpret the matrix as a real matrix and do real matvec computation. # works only in cooperation with BLAS when A is untransposed (tA == 'N') @@ -96,9 +92,6 @@ generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::S generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, alpha::Number, beta::Number) where {T<:BlasReal} = gemv!(y, tA, A, x, alpha, beta) -generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T}, - _add::MulAddMul = MulAddMul()) where {T<:BlasReal} = - gemv!(y, tA, A, x, _add.alpha, _add.beta) # Vector-Matrix multiplication (*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')' @@ -539,9 +532,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar if tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β) else - return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, tA, A, x, α, β) end end @@ -564,7 +557,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs return y else Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) - return @stable_muladdmul _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, ta, Anew, x, α, β) end end @@ -591,9 +584,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs elseif tA_uc in ('S', 'H') # re-wrap again and use plain ('N') matvec mul algorithm, # because _generic_matvecmul! can't handle the HermOrSym cases specifically - return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β)) + return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β) else - return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β)) + return _generic_matvecmul!(y, tA, A, x, α, β) end end @@ -825,17 +818,17 @@ end # NOTE: the generic version is also called as fallback for # strides != 1 cases -Base.@constprop :aggressive generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) = - @stable_muladdmul generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta)) -@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul = MulAddMul()) +generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) = + generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) +@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, + alpha::Number, beta::Number) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) - return _generic_matvecmul!(C, ta, Anew, B, _add) + return _generic_matvecmul!(C, ta, Anew, B, alpha, beta) end function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, - _add::MulAddMul = MulAddMul()) + alpha::Number, beta::Number) require_one_based_indexing(C, A, B) @assert tA in ('N', 'T', 'C') mB = length(B) @@ -853,7 +846,7 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst if tA == 'T' # fastest case if nA == 0 for k = 1:mA - _modify!(_add, false, C, k) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k) end else for k = 1:mA @@ -863,13 +856,13 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst for i = 1:nA s += transpose(A[aoffs+i]) * B[i] end - _modify!(_add, s, C, k) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k) end end elseif tA == 'C' if nA == 0 for k = 1:mA - _modify!(_add, false, C, k) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k) end else for k = 1:mA @@ -879,13 +872,13 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst for i = 1:nA s += A[aoffs + i]'B[i] end - _modify!(_add, s, C, k) + @stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k) end end else # tA == 'N' for i = 1:mA - if !iszero(_add.beta) - C[i] *= _add.beta + if !iszero(beta) + C[i] *= beta elseif mB == 0 C[i] = false else @@ -894,7 +887,7 @@ function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::Abst end for k = 1:mB aoffs = (k-1)*Astride - b = _add(B[k]) + b = @stable_muladdmul MulAddMul(alpha,beta)(B[k]) for i = 1:mA C[i] += A[aoffs + i] * b end diff --git a/stdlib/LinearAlgebra/src/triangular.jl b/stdlib/LinearAlgebra/src/triangular.jl index d6994f4b4dd58..1a7d04115c97d 100644 --- a/stdlib/LinearAlgebra/src/triangular.jl +++ b/stdlib/LinearAlgebra/src/triangular.jl @@ -1066,7 +1066,7 @@ for TC in (:AbstractVector, :AbstractMatrix) if isone(alpha) && iszero(beta) return _trimul!(C, A, B) else - return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta)) + return _generic_matvecmul!(C, 'N', A, B, alpha, beta) end end end From 05d5b3779d544fc6567c1169622a49696a79b77e Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 22 Oct 2024 16:42:10 +0530 Subject: [PATCH 2/4] Trim whitespace --- stdlib/LinearAlgebra/src/matmul.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index 632fd404404ea..ad1e5080f8216 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -820,7 +820,7 @@ end generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) = generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) -@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, +@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA) From 00c2e6fdba533d909a17ae172fce0513dd119a5f Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Tue, 22 Oct 2024 21:07:20 +0530 Subject: [PATCH 3/4] Restore _generic_matvecmul! method that accepts a MulAddMul --- stdlib/LinearAlgebra/src/matmul.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index ad1e5080f8216..bb790afa6d183 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -818,6 +818,7 @@ end # NOTE: the generic version is also called as fallback for # strides != 1 cases +# legacy method, retained for backward compatibility generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) = generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) @inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, @@ -827,6 +828,9 @@ generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector return _generic_matvecmul!(C, ta, Anew, B, alpha, beta) end +# legacy method, retained for backward compatibility +_generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) = + _generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) require_one_based_indexing(C, A, B) From ff1a06ca7819a5d04303c6a1d1a93aca8dd2fc57 Mon Sep 17 00:00:00 2001 From: Jishnu Bhattacharya Date: Wed, 23 Oct 2024 13:27:48 +0530 Subject: [PATCH 4/4] Split branches into separate functions --- stdlib/LinearAlgebra/src/matmul.jl | 83 ++++++++++++++---------------- 1 file changed, 40 insertions(+), 43 deletions(-) diff --git a/stdlib/LinearAlgebra/src/matmul.jl b/stdlib/LinearAlgebra/src/matmul.jl index bb790afa6d183..a8205a1dde808 100644 --- a/stdlib/LinearAlgebra/src/matmul.jl +++ b/stdlib/LinearAlgebra/src/matmul.jl @@ -831,73 +831,70 @@ end # legacy method, retained for backward compatibility _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) = _generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta) -function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, - alpha::Number, beta::Number) - require_one_based_indexing(C, A, B) - @assert tA in ('N', 'T', 'C') - mB = length(B) - mA, nA = lapack_size(tA, A) - if mB != nA - throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB")) - end - if mA != length(C) - throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA")) - end - +function __generic_matvecmul!(f::F, C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector, + alpha::Number, beta::Number) where {F} Astride = size(A, 1) - @inbounds begin - if tA == 'T' # fastest case - if nA == 0 - for k = 1:mA - @stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k) - end - else - for k = 1:mA - aoffs = (k-1)*Astride - firstterm = transpose(A[aoffs + 1])*B[1] - s = zero(firstterm + firstterm) - for i = 1:nA - s += transpose(A[aoffs+i]) * B[i] - end - @stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k) - end - end - elseif tA == 'C' - if nA == 0 - for k = 1:mA + if length(B) == 0 + for k = eachindex(C) @stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k) end else - for k = 1:mA + for k = eachindex(C) aoffs = (k-1)*Astride - firstterm = A[aoffs + 1]'B[1] + firstterm = f(A[aoffs + 1]) * B[1] s = zero(firstterm + firstterm) - for i = 1:nA - s += A[aoffs + i]'B[i] + for i = eachindex(B) + s += f(A[aoffs+i]) * B[i] end @stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k) end end - else # tA == 'N' - for i = 1:mA + end +end +function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector, + alpha::Number, beta::Number) + Astride = size(A, 1) + @inbounds begin + for i = eachindex(C) if !iszero(beta) C[i] *= beta - elseif mB == 0 + elseif length(B) == 0 C[i] = false else C[i] = zero(A[i]*B[1] + A[i]*B[1]) end end - for k = 1:mB + for k = eachindex(B) aoffs = (k-1)*Astride b = @stable_muladdmul MulAddMul(alpha,beta)(B[k]) - for i = 1:mA + for i = eachindex(C) C[i] += A[aoffs + i] * b end end end - end # @inbounds + return C +end +function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, + alpha::Number, beta::Number) + require_one_based_indexing(C, A, B) + @assert tA in ('N', 'T', 'C') + mB = length(B) + mA, nA = lapack_size(tA, A) + if mB != nA + throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB")) + end + if mA != length(C) + throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA")) + end + + if tA == 'T' # fastest case + __generic_matvecmul!(transpose, C, A, B, alpha, beta) + elseif tA == 'C' + __generic_matvecmul!(adjoint, C, A, B, alpha, beta) + else # tA == 'N' + __generic_matvecmul!(identity, C, A, B, alpha, beta) + end C end