Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 166 additions & 11 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,58 @@ function *(A::OneElementMatrix, B::AbstractFillVector)
OneElement(val, A.ind[1], size(A,1))
end

@inline function __mulonel!(y, A, x, alpha, beta)
# Special matrix types

function *(A::OneElementMatrix, D::Diagonal)
check_matmul_sizes(A, D)
nzcol = A.ind[2]
val = if nzcol in axes(D,1)
A.val * D[nzcol, nzcol]
else
A.val * zero(eltype(D))
end
OneElement(val, A.ind, size(A))
end
function *(D::Diagonal, A::OneElementMatrix)
check_matmul_sizes(D, A)
nzrow = A.ind[1]
val = if nzrow in axes(D,2)
D[nzrow, nzrow] * A.val
else
zero(eltype(D)) * A.val
end
OneElement(val, A.ind, size(A))
end

# Inplace multiplication

# We use this for out overloads for _mul! for OneElement because its more efficient
# due to how efficient 2 arg mul is when one or more of the args are OneElement
function __mulonel!(C, A, B, alpha, beta)
ABα = A * B * alpha
if iszero(beta)
C .= ABα
else
C .= ABα .+ C .* beta
end
return C
end
# These methods remove the ambituity in _mul!. This isn't strictly necessary, but this makes Aqua happy.
function _mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha, beta)
__mulonel!(C, A, B, alpha, beta)
end
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha, beta)
__mulonel!(C, A, B, alpha, beta)
end

function mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta)
αx = alpha * x.val
ind1 = x.ind[1]
if iszero(beta)
Expand All @@ -104,19 +155,19 @@ end
return y
end

function _mulonel!(y, A, x::OneElementVector, alpha::Number, beta::Number)
function _mul!(y::AbstractVector, A::AbstractMatrix, x::OneElementVector, alpha, beta)
check_matmul_sizes(y, A, x)
if x.ind[1] ∉ axes(x,1) # in this case x is all zeros
if iszero(getindex_value(x))
mul!(y, A, Zeros{eltype(x)}(axes(x)), alpha, beta)
return y
end
Comment on lines +160 to 163
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am dubious as to if this check is worth the potential type instability.
I feel like such a OneElement where that element is also zero is almost never constructed.

Copy link
Member Author

@jishnub jishnub Feb 2, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not type-unstable, though. It just calls a different method that returns the same vector, and should be type-stable as well. This is a guard against the current implementation allowing the indices to not lie within the axes (which probably should be disallowed).

E.g.:

julia> @report_opt mul!(zeros(2), ones(2,2), OneElement(1,2))
No errors detected

__mulonel!(y, A, x, alpha, beta)
__mul!(y, A, x, alpha, beta)
y
end

function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if B.ind[1] ∉ axes(B,1) || B.ind[2] ∉ axes(B,2) # in this case x is all zeros
if iszero(getindex_value(B))
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
Expand All @@ -127,24 +178,128 @@ function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
end
y = view(C, :, B.ind[2])
__mulonel!(y, A, B, alpha, beta)
__mul!(y, A, B, alpha, beta)
C
end
function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(B))
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
return C
end
Comment on lines +186 to +189
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

per above I am dubious

if iszero(beta)
C .= zero(eltype(C))
else
view(C, :, 1:B.ind[2]-1) .*= beta
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = B.ind
if iszero(beta)
C[B.ind...] = ABα[B.ind...]
else
y = view(C, :, nzcol)
y .= view(ABα, :, nzcol) .+ y .* beta
end
C
end

function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
y = view(C, A.ind[1], :)
ind2 = A.ind[2]
Aval = A.val
if iszero(beta)
y .= Aval .* view(B, ind2, :) .* alpha
else
y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta
end
C
end
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:A.ind[1]-1, :) .*= beta
view(C, A.ind[1]+1:size(C,1), :) .*= beta
end
ABα = A * B * alpha
nzrow, nzcol = A.ind
if iszero(beta)
C[A.ind...] = ABα[A.ind...]
else
y = view(C, nzrow, :)
y .= view(ABα, nzrow, :) .+ y .* beta
end
C
end

function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, beta)
check_matmul_sizes(C, A, B)
if iszero(getindex_value(A))
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
return C
end
nzrow, nzcol = A.ind
if iszero(beta)
C .= zero(eltype(C))
else
view(C, 1:nzrow-1) .*= beta
view(C, nzrow+1:size(C,1)) .*= beta
end
Aval = A.val
if iszero(beta)
C[nzrow] = Aval * B[nzcol] * alpha
else
C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
end
C
end

for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}))
@eval function mul!(y::StridedVector, A::$MT, x::OneElementVector, alpha::Number, beta::Number)
_mulonel!(y, A, x, alpha, beta)
_mul!(y, A, x, alpha, beta)
end
end
for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}),
:Diagonal)
@eval function mul!(C::StridedMatrix, A::$MT, B::OneElementMatrix, alpha::Number, beta::Number)
_mulonel!(C, A, B, alpha, beta)
_mul!(C, A, B, alpha, beta)
end
@eval function mul!(C::StridedMatrix, A::OneElementMatrix, B::$MT, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
end
function mul!(C::StridedVector, A::OneElementMatrix, B::StridedVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

function mul!(y::AbstractVector, A::AbstractFillMatrix, x::OneElementVector, alpha::Number, beta::Number)
_mulonel!(y, A, x, alpha, beta)
_mul!(y, A, x, alpha, beta)
end
function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
_mulonel!(C, A, B, alpha, beta)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractFillVector, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end
function mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
_mul!(C, A, B, alpha, beta)
end

# adjoint/transpose
Expand Down
87 changes: 82 additions & 5 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2113,14 +2113,16 @@ end

@testset "matmul" begin
A = reshape(Float64[1:9;], 3, 3)
v = reshape(Float64[1:3;], 3)
testinds(w::AbstractArray) = testinds(size(w))
testinds(szw::Tuple{Int}) = (szw .- 1, szw .+ 1)
function testinds(szA::Tuple{Int,Int})
(szA .- 1, szA .+ (-1,0), szA .+ (0,-1), szA .+ 1, szA .+ (1,-1), szA .+ (-1,1))
end
function test_A_mul_OneElement(A, (w, w2))
@testset for ind in testinds(w)
x = OneElement(3, ind, size(w))
# test matvec if w is a vector, or matmat if w is a matrix
function test_mat_mul_OneElement(A, (w, w2), sz)
@testset for ind in testinds(sz)
x = OneElement(3, ind, sz)
xarr = Array(x)
Axarr = A * xarr
Aadjxarr = A' * xarr
Expand All @@ -2143,26 +2145,84 @@ end
@test mul!(w2, F, x, 1.0, 1.0) ≈ Array(F) * xarr .+ 1
end
end
function test_OneElementMatrix_mul_mat(A, (w, w2), sz)
@testset for ind in testinds(sz)
O = OneElement(3, ind, sz)
Oarr = Array(O)
OarrA = Oarr * A
OarrAadj = Oarr * A'

@test O * A ≈ OarrA
@test O * A' ≈ OarrAadj
@test O * transpose(A) ≈ Oarr * transpose(A)

@test mul!(w, O, A) ≈ OarrA
# check columnwise to ensure zero columns
@test all(((c1, c2),) -> c1 ≈ c2, zip(eachcol(w), eachcol(OarrA)))
@test mul!(w, O, A') ≈ OarrAadj
w .= 1
@test mul!(w, O, A, 1.0, 2.0) ≈ OarrA .+ 2
w .= 1
@test mul!(w, O, A', 1.0, 2.0) ≈ OarrAadj .+ 2

F = Fill(3, size(A))
w2 .= 1
@test mul!(w2, O, F, 1.0, 1.0) ≈ Oarr * Array(F) .+ 1
end
end
function test_OneElementMatrix_mul_vec(v, (w, w2), sz)
@testset for ind in testinds(sz)
O = OneElement(3, ind, sz)
Oarr = Array(O)
Oarrv = Oarr * v

@test O * v == Oarrv

@test mul!(w, O, v) == Oarrv
# check rowwise to ensure zero rows
@test all(((r1, r2),) -> r1 == r2, zip(eachrow(w), eachrow(Oarrv)))
w .= 1
@test mul!(w, O, v, 1.0, 2.0) == Oarrv .+ 2

F = Fill(3, size(v))
w2 .= 1
@test mul!(w2, O, F, 1.0, 1.0) == Oarr * Array(F) .+ 1
end
end
@testset "Matrix * OneElementVector" begin
w = zeros(size(A,1))
w2 = MVector{length(w)}(w)
test_A_mul_OneElement(A, (w, w2))
test_mat_mul_OneElement(A, (w, w2), size(w))
end
@testset "Matrix * OneElementMatrix" begin
C = zeros(size(A))
C2 = MMatrix{size(C)...}(C)
test_A_mul_OneElement(A, (C, C2))
test_mat_mul_OneElement(A, (C, C2), size(C))
end
@testset "OneElementMatrix * Vector" begin
w = zeros(size(v))
w2 = MVector{size(v)...}(v)
test_OneElementMatrix_mul_vec(v, (w, w2), size(A))
end
@testset "OneElementMatrix * Matrix" begin
C = zeros(size(A))
C2 = MMatrix{size(C)...}(C)
test_OneElementMatrix_mul_mat(A, (C, C2), size(A))
end
@testset "OneElementMatrix * OneElement" begin
@testset for ind in testinds(A)
O = OneElement(3, ind, size(A))
v = OneElement(4, ind[2], size(A,1))
@test O * v isa OneElement
@test O * v == Array(O) * Array(v)
@test mul!(ones(size(O,1)), O, v) == O * v
@test mul!(ones(size(O,1)), O, v, 2, 1) == 2 * O * v .+ 1

B = OneElement(4, ind, size(A))
@test O * B isa OneElement
@test O * B == Array(O) * Array(B)
@test mul!(ones(size(O,1), size(B,2)), O, B) == O * B
@test mul!(ones(size(O,1), size(B,2)), O, B, 2, 1) == 2 * O * B .+ 1
end

@test OneElement(3, (2,3), (5,4)) * OneElement(2, 2, 4) == Zeros(5)
Expand Down Expand Up @@ -2191,6 +2251,23 @@ end
B = Zeros(4)
@test A * B === Zeros(5)
end
@testset "Diagonal and OneElementMatrix" begin
for ind in ((2,3), (2,2), (10,10))
O = OneElement(3, ind, (4,3))
Oarr = Array(O)
C = zeros(size(O))
D = Diagonal(axes(O,1))
@test D * O == D * Oarr
@test mul!(C, D, O) == D * O
C .= 1
@test mul!(C, D, O, 2, 2) == 2 * D * O .+ 2
D = Diagonal(axes(O,2))
@test O * D == Oarr * D
@test mul!(C, O, D) == O * D
C .= 1
@test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2
end
end
end

@testset "multiplication/division by a number" begin
Expand Down