Skip to content

Commit b9b4dfa

Browse files
authored
Reduce generic matrix*vector latency (#56289)
```julia julia> using LinearAlgebra julia> A = rand(Int,4,4); x = rand(Int,4); y = similar(x); julia> @time mul!(y, A, x, 2, 2); 0.330489 seconds (792.22 k allocations: 41.519 MiB, 8.75% gc time, 99.99% compilation time) # master 0.134212 seconds (339.89 k allocations: 17.103 MiB, 15.23% gc time, 99.98% compilation time) # This PR ``` Main changes: - `generic_matvecmul!` and `_generic_matvecmul!` now accept `alpha` and `beta` arguments instead of `MulAddMul(alpha, beta)`. The methods that accept a `MulAddMul(alpha, beta)` are also retained for backward compatibility, but these now forward `alpha` and `beta`, instead of the other way around. - Narrow the scope of the `@stable_muladdmul` applications. We now construct the `MulAddMul(alpha, beta)` object only where it is needed in a function call, and we annotate the call site with `@stable_muladdmul`. This leads to smaller branches. - Create a new internal function with methods for the `'N'`, `'T'` and `'C'` cases, so that firstly, there's less code duplication, and secondly, the `_generic_matvecmul!` method is now simple enough to enable constant propagation. This eliminates the unnecessary branches, and only the one that is taken is compiled. Together, this reduces the TTFX substantially.
1 parent 005608a commit b9b4dfa

File tree

2 files changed

+61
-67
lines changed

2 files changed

+61
-67
lines changed

stdlib/LinearAlgebra/src/matmul.jl

Lines changed: 60 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -78,27 +78,20 @@ _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
7878
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
7979
alpha::Number, beta::Number) where {T<:BlasFloat} =
8080
gemv!(y, tA, A, x, alpha, beta)
81-
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
82-
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
83-
gemv!(y, tA, A, x, _add.alpha, _add.beta)
81+
8482
# Real (possibly transposed) matrix times complex vector.
8583
# Multiply the matrix with the real and imaginary parts separately
8684
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
8785
alpha::Number, beta::Number) where {T<:BlasReal} =
8886
gemv!(y, tA, A, x, alpha, beta)
89-
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
90-
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
91-
gemv!(y, tA, A, x, _add.alpha, _add.beta)
87+
9288
# Complex matrix times real vector.
9389
# Reinterpret the matrix as a real matrix and do real matvec computation.
9490
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
9591
# but that check is included in gemv! anyway
9692
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
9793
alpha::Number, beta::Number) where {T<:BlasReal} =
9894
gemv!(y, tA, A, x, alpha, beta)
99-
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
100-
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
101-
gemv!(y, tA, A, x, _add.alpha, _add.beta)
10295

10396
# Vector-Matrix multiplication
10497
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
@@ -539,9 +532,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
539532
if tA_uc in ('S', 'H')
540533
# re-wrap again and use plain ('N') matvec mul algorithm,
541534
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
542-
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
535+
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
543536
else
544-
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
537+
return _generic_matvecmul!(y, tA, A, x, α, β)
545538
end
546539
end
547540

@@ -564,7 +557,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
564557
return y
565558
else
566559
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
567-
return @stable_muladdmul _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
560+
return _generic_matvecmul!(y, ta, Anew, x, α, β)
568561
end
569562
end
570563

@@ -591,9 +584,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
591584
elseif tA_uc in ('S', 'H')
592585
# re-wrap again and use plain ('N') matvec mul algorithm,
593586
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
594-
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
587+
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
595588
else
596-
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
589+
return _generic_matvecmul!(y, tA, A, x, α, β)
597590
end
598591
end
599592

@@ -825,82 +818,83 @@ end
825818
# NOTE: the generic version is also called as fallback for
826819
# strides != 1 cases
827820

828-
Base.@constprop :aggressive generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) =
829-
@stable_muladdmul generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta))
821+
# legacy method, retained for backward compatibility
822+
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
823+
generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
830824
@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
831-
_add::MulAddMul = MulAddMul())
825+
alpha::Number, beta::Number)
832826
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
833827
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
834-
return _generic_matvecmul!(C, ta, Anew, B, _add)
828+
return _generic_matvecmul!(C, ta, Anew, B, alpha, beta)
835829
end
836830

837-
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
838-
_add::MulAddMul = MulAddMul())
839-
require_one_based_indexing(C, A, B)
840-
@assert tA in ('N', 'T', 'C')
841-
mB = length(B)
842-
mA, nA = lapack_size(tA, A)
843-
if mB != nA
844-
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
845-
end
846-
if mA != length(C)
847-
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
848-
end
849-
831+
# legacy method, retained for backward compatibility
832+
_generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
833+
_generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
834+
function __generic_matvecmul!(f::F, C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
835+
alpha::Number, beta::Number) where {F}
850836
Astride = size(A, 1)
851-
852837
@inbounds begin
853-
if tA == 'T' # fastest case
854-
if nA == 0
855-
for k = 1:mA
856-
_modify!(_add, false, C, k)
857-
end
858-
else
859-
for k = 1:mA
860-
aoffs = (k-1)*Astride
861-
firstterm = transpose(A[aoffs + 1])*B[1]
862-
s = zero(firstterm + firstterm)
863-
for i = 1:nA
864-
s += transpose(A[aoffs+i]) * B[i]
865-
end
866-
_modify!(_add, s, C, k)
867-
end
868-
end
869-
elseif tA == 'C'
870-
if nA == 0
871-
for k = 1:mA
872-
_modify!(_add, false, C, k)
838+
if length(B) == 0
839+
for k = eachindex(C)
840+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k)
873841
end
874842
else
875-
for k = 1:mA
843+
for k = eachindex(C)
876844
aoffs = (k-1)*Astride
877-
firstterm = A[aoffs + 1]'B[1]
845+
firstterm = f(A[aoffs + 1]) * B[1]
878846
s = zero(firstterm + firstterm)
879-
for i = 1:nA
880-
s += A[aoffs + i]'B[i]
847+
for i = eachindex(B)
848+
s += f(A[aoffs+i]) * B[i]
881849
end
882-
_modify!(_add, s, C, k)
850+
@stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k)
883851
end
884852
end
885-
else # tA == 'N'
886-
for i = 1:mA
887-
if !iszero(_add.beta)
888-
C[i] *= _add.beta
889-
elseif mB == 0
853+
end
854+
end
855+
function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
856+
alpha::Number, beta::Number)
857+
Astride = size(A, 1)
858+
@inbounds begin
859+
for i = eachindex(C)
860+
if !iszero(beta)
861+
C[i] *= beta
862+
elseif length(B) == 0
890863
C[i] = false
891864
else
892865
C[i] = zero(A[i]*B[1] + A[i]*B[1])
893866
end
894867
end
895-
for k = 1:mB
868+
for k = eachindex(B)
896869
aoffs = (k-1)*Astride
897-
b = _add(B[k])
898-
for i = 1:mA
870+
b = @stable_muladdmul MulAddMul(alpha,beta)(B[k])
871+
for i = eachindex(C)
899872
C[i] += A[aoffs + i] * b
900873
end
901874
end
902875
end
903-
end # @inbounds
876+
return C
877+
end
878+
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
879+
alpha::Number, beta::Number)
880+
require_one_based_indexing(C, A, B)
881+
@assert tA in ('N', 'T', 'C')
882+
mB = length(B)
883+
mA, nA = lapack_size(tA, A)
884+
if mB != nA
885+
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
886+
end
887+
if mA != length(C)
888+
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
889+
end
890+
891+
if tA == 'T' # fastest case
892+
__generic_matvecmul!(transpose, C, A, B, alpha, beta)
893+
elseif tA == 'C'
894+
__generic_matvecmul!(adjoint, C, A, B, alpha, beta)
895+
else # tA == 'N'
896+
__generic_matvecmul!(identity, C, A, B, alpha, beta)
897+
end
904898
C
905899
end
906900

stdlib/LinearAlgebra/src/triangular.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1066,7 +1066,7 @@ for TC in (:AbstractVector, :AbstractMatrix)
10661066
if isone(alpha) && iszero(beta)
10671067
return _trimul!(C, A, B)
10681068
else
1069-
return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta))
1069+
return _generic_matvecmul!(C, 'N', A, B, alpha, beta)
10701070
end
10711071
end
10721072
end

0 commit comments

Comments
 (0)