Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,10 @@ wrapper_char(A::Hermitian) = WrapperChar('H', A.uplo == 'U')
wrapper_char(A::Hermitian{<:Real}) = WrapperChar('S', A.uplo == 'U')
wrapper_char(A::Symmetric) = WrapperChar('S', A.uplo == 'U')

wrapper_char_NTC(A::AbstractArray) = uppercase(wrapper_char(A)) == 'N'
wrapper_char_NTC(A::Union{StridedArray, Adjoint, Transpose}) = true
wrapper_char_NTC(A::Union{Symmetric, Hermitian}) = false

Base.@constprop :aggressive function wrap(A::AbstractVecOrMat, tA::AbstractChar)
# merge the result of this before return, so that we can type-assert the return such
# that even if the tmerge is inaccurate, inference can still identify that the
Expand Down
78 changes: 50 additions & 28 deletions stdlib/LinearAlgebra/src/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -293,15 +293,24 @@ true
@inline mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) = _mul!(C, A, B, α, β)
# Add a level of indirection and specialize _mul! to avoid ambiguities in mul!
@inline _mul!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat, α::Number, β::Number) =
generic_matmatmul!(
generic_matmatmul_wrapper!(
C,
wrapper_char(A),
wrapper_char(B),
_unwrap(A),
_unwrap(B),
α, β
α, β,
Val(wrapper_char_NTC(A) & wrapper_char_NTC(B))
)

# this indirection allows is to specialize on the types of the wrappers of A and B to some extent,
# even though the wrappers are stripped off in mul!
# By default, we ignore the wrapper info and forward the arguments to generic_matmatmul!
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C, tA, tB, A, B, α, β, @nospecialize(val))
generic_matmatmul!(C, tA, tB, A, B, α, β)
end


"""
rmul!(A, B)

Expand Down Expand Up @@ -368,9 +377,9 @@ julia> lmul!(F.Q, B)
"""
lmul!(A, B)

# THE one big BLAS dispatch
Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:BlasFloat}
# THE one big BLAS dispatch. This is split into two methods to improve latency
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{true}) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
Expand All @@ -389,19 +398,37 @@ Base.@constprop :aggressive function generic_matmatmul!(C::StridedMatrix{T}, tA,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, α, β)
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
if tA_uc == 'T' && tB_uc == 'N' && A === B
return syrk_wrapper!(C, 'T', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'T' && A === B
return syrk_wrapper!(C, 'N', A, α, β)
elseif tA_uc == 'C' && tB_uc == 'N' && A === B
return herk_wrapper!(C, 'C', A, α, β)
elseif tA_uc == 'N' && tB_uc == 'C' && A === B
return herk_wrapper!(C, 'N', A, α, β)
else
return gemm_wrapper!(C, tA, tB, A, B, α, β)
end
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedMatrix{T}, tA, tB, A::StridedVecOrMat{T}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasFloat}
mA, nA = lapack_size(tA, A)
mB, nB = lapack_size(tB, B)
if any(iszero, size(A)) || any(iszero, size(B)) || iszero(α)
if size(C) != (mA, nB)
throw(DimensionMismatch(lazy"C has dimensions $(size(C)), should have ($mA,$nB)"))
end
return _rmul_or_fill!(C, β)
end
if size(C) == size(A) == size(B) == (2,2)
return matmul2x2!(C, tA, tB, A, B, α, β)
end
if size(C) == size(A) == size(B) == (3,3)
return matmul3x3!(C, tA, tB, A, B, α, β)
end
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
alpha, beta = promote(α, β, zero(T))
if alpha isa Union{Bool,T} && beta isa Union{Bool,T}
if tA_uc == 'S' && tB_uc == 'N'
Expand All @@ -421,18 +448,13 @@ Base.@constprop :aggressive generic_matmatmul!(C::StridedMatrix{T}, tA, tB, A::S
_add::MulAddMul = MulAddMul()) where {T<:BlasFloat} =
generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)

# Complex matrix times (transposed) real matrix. Reinterpret the first matrix to real for efficiency.
Base.@constprop :aggressive function generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number) where {T<:BlasReal}
# We convert the chars to uppercase to potentially unwrap a WrapperChar,
# and extract the char corresponding to the wrapper type
tA_uc, tB_uc = uppercase(tA), uppercase(tB)
# the map in all ensures constprop by acting on tA and tB individually, instead of looping over them.
if all(map(in(('N', 'T', 'C')), (tA_uc, tB_uc)))
gemm_wrapper!(C, tA, tB, A, B, α, β)
else
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
end
function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{true}) where {T<:BlasReal}
gemm_wrapper!(C, tA, tB, A, B, α, β)
end
Base.@constprop :aggressive function generic_matmatmul_wrapper!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
α::Number, β::Number, ::Val{false}) where {T<:BlasReal}
_generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), MulAddMul(α, β))
end
# legacy method
Base.@constprop :aggressive generic_matmatmul!(C::StridedVecOrMat{Complex{T}}, tA, tB, A::StridedVecOrMat{Complex{T}}, B::StridedVecOrMat{T},
Expand Down