Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 60 additions & 66 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,20 @@ _mul!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector,
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasFloat} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{T}, tA, A::StridedVecOrMat{T}, x::StridedVector{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Real (possibly transposed) matrix times complex vector.
# Multiply the matrix with the real and imaginary parts separately
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{T}, x::StridedVector{Complex{T}},
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Complex matrix times real vector.
# Reinterpret the matrix as a real matrix and do real matvec computation.
# works only in cooperation with BLAS when A is untransposed (tA == 'N')
# but that check is included in gemv! anyway
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
alpha::Number, beta::Number) where {T<:BlasReal} =
gemv!(y, tA, A, x, alpha, beta)
generic_matvecmul!(y::StridedVector{Complex{T}}, tA, A::StridedVecOrMat{Complex{T}}, x::StridedVector{T},
_add::MulAddMul = MulAddMul()) where {T<:BlasReal} =
gemv!(y, tA, A, x, _add.alpha, _add.beta)

# Vector-Matrix multiplication
(*)(x::AdjointAbsVec, A::AbstractMatrix) = (A'*x')'
Expand Down Expand Up @@ -539,9 +532,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{T}, tA::AbstractChar
if tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
else
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, α, β)
end
end

Expand All @@ -564,7 +557,7 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
return y
else
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return @stable_muladdmul _generic_matvecmul!(y, ta, Anew, x, MulAddMul(α, β))
return _generic_matvecmul!(y, ta, Anew, x, α, β)
end
end

Expand All @@ -591,9 +584,9 @@ Base.@constprop :aggressive function gemv!(y::StridedVector{Complex{T}}, tA::Abs
elseif tA_uc in ('S', 'H')
# re-wrap again and use plain ('N') matvec mul algorithm,
# because _generic_matvecmul! can't handle the HermOrSym cases specifically
return @stable_muladdmul _generic_matvecmul!(y, 'N', wrap(A, tA), x, MulAddMul(α, β))
return _generic_matvecmul!(y, 'N', wrap(A, tA), x, α, β)
else
return @stable_muladdmul _generic_matvecmul!(y, tA, A, x, MulAddMul(α, β))
return _generic_matvecmul!(y, tA, A, x, α, β)
end
end

Expand Down Expand Up @@ -825,82 +818,83 @@ end
# NOTE: the generic version is also called as fallback for
# strides != 1 cases

Base.@constprop :aggressive generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, alpha::Number, beta::Number) =
@stable_muladdmul generic_matvecmul!(C, tA, A, B, MulAddMul(alpha, beta))
# legacy method, retained for backward compatibility
generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
@inline function generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
alpha::Number, beta::Number)
tA_uc = uppercase(tA) # potentially convert a WrapperChar to a Char
Anew, ta = tA_uc in ('S', 'H') ? (wrap(A, tA), oftype(tA, 'N')) : (A, tA)
return _generic_matvecmul!(C, ta, Anew, B, _add)
return _generic_matvecmul!(C, ta, Anew, B, alpha, beta)
end

function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
_add::MulAddMul = MulAddMul())
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

# legacy method, retained for backward compatibility
_generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector, _add::MulAddMul = MulAddMul()) =
_generic_matvecmul!(C, tA, A, B, _add.alpha, _add.beta)
function __generic_matvecmul!(f::F, C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number) where {F}
Astride = size(A, 1)

@inbounds begin
if tA == 'T' # fastest case
if nA == 0
for k = 1:mA
_modify!(_add, false, C, k)
end
else
for k = 1:mA
aoffs = (k-1)*Astride
firstterm = transpose(A[aoffs + 1])*B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += transpose(A[aoffs+i]) * B[i]
end
_modify!(_add, s, C, k)
end
end
elseif tA == 'C'
if nA == 0
for k = 1:mA
_modify!(_add, false, C, k)
if length(B) == 0
for k = eachindex(C)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), false, C, k)
end
else
for k = 1:mA
for k = eachindex(C)
aoffs = (k-1)*Astride
firstterm = A[aoffs + 1]'B[1]
firstterm = f(A[aoffs + 1]) * B[1]
s = zero(firstterm + firstterm)
for i = 1:nA
s += A[aoffs + i]'B[i]
for i = eachindex(B)
s += f(A[aoffs+i]) * B[i]
end
_modify!(_add, s, C, k)
@stable_muladdmul _modify!(MulAddMul(alpha,beta), s, C, k)
end
end
else # tA == 'N'
for i = 1:mA
if !iszero(_add.beta)
C[i] *= _add.beta
elseif mB == 0
end
end
function __generic_matvecmul!(::typeof(identity), C::AbstractVector, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
Astride = size(A, 1)
@inbounds begin
for i = eachindex(C)
if !iszero(beta)
C[i] *= beta
elseif length(B) == 0
C[i] = false
else
C[i] = zero(A[i]*B[1] + A[i]*B[1])
end
end
for k = 1:mB
for k = eachindex(B)
aoffs = (k-1)*Astride
b = _add(B[k])
for i = 1:mA
b = @stable_muladdmul MulAddMul(alpha,beta)(B[k])
for i = eachindex(C)
C[i] += A[aoffs + i] * b
end
end
end
end # @inbounds
return C
end
function _generic_matvecmul!(C::AbstractVector, tA, A::AbstractVecOrMat, B::AbstractVector,
alpha::Number, beta::Number)
require_one_based_indexing(C, A, B)
@assert tA in ('N', 'T', 'C')
mB = length(B)
mA, nA = lapack_size(tA, A)
if mB != nA
throw(DimensionMismatch(lazy"matrix A has dimensions ($mA,$nA), vector B has length $mB"))
end
if mA != length(C)
throw(DimensionMismatch(lazy"result C has length $(length(C)), needs length $mA"))
end

if tA == 'T' # fastest case
__generic_matvecmul!(transpose, C, A, B, alpha, beta)
elseif tA == 'C'
__generic_matvecmul!(adjoint, C, A, B, alpha, beta)
else # tA == 'N'
__generic_matvecmul!(identity, C, A, B, alpha, beta)
end
C
end

Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1066,7 +1066,7 @@ for TC in (:AbstractVector, :AbstractMatrix)
if isone(alpha) && iszero(beta)
return _trimul!(C, A, B)
else
return @stable_muladdmul generic_matvecmul!(C, 'N', A, B, MulAddMul(alpha, beta))
return _generic_matvecmul!(C, 'N', A, B, alpha, beta)
end
end
end
Expand Down