diff --git a/src/oneelement.jl b/src/oneelement.jl index 46b2b471..4593a3f4 100644 --- a/src/oneelement.jl +++ b/src/oneelement.jl @@ -93,7 +93,58 @@ function *(A::OneElementMatrix, B::AbstractFillVector) OneElement(val, A.ind[1], size(A,1)) end -@inline function __mulonel!(y, A, x, alpha, beta) +# Special matrix types + +function *(A::OneElementMatrix, D::Diagonal) + check_matmul_sizes(A, D) + nzcol = A.ind[2] + val = if nzcol in axes(D,1) + A.val * D[nzcol, nzcol] + else + A.val * zero(eltype(D)) + end + OneElement(val, A.ind, size(A)) +end +function *(D::Diagonal, A::OneElementMatrix) + check_matmul_sizes(D, A) + nzrow = A.ind[1] + val = if nzrow in axes(D,2) + D[nzrow, nzrow] * A.val + else + zero(eltype(D)) * A.val + end + OneElement(val, A.ind, size(A)) +end + +# Inplace multiplication + +# We use this for out overloads for _mul! for OneElement because its more efficient +# due to how efficient 2 arg mul is when one or more of the args are OneElement +function __mulonel!(C, A, B, alpha, beta) + ABα = A * B * alpha + if iszero(beta) + C .= ABα + else + C .= ABα .+ C .* beta + end + return C +end +# These methods remove the ambituity in _mul!. This isn't strictly necessary, but this makes Aqua happy. +function _mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha, beta) + __mulonel!(C, A, B, alpha, beta) +end +function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha, beta) + __mulonel!(C, A, B, alpha, beta) +end + +function mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end + +@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta) αx = alpha * x.val ind1 = x.ind[1] if iszero(beta) @@ -104,19 +155,19 @@ end return y end -function _mulonel!(y, A, x::OneElementVector, alpha::Number, beta::Number) +function _mul!(y::AbstractVector, A::AbstractMatrix, x::OneElementVector, alpha, beta) check_matmul_sizes(y, A, x) - if x.ind[1] ∉ axes(x,1) # in this case x is all zeros + if iszero(getindex_value(x)) mul!(y, A, Zeros{eltype(x)}(axes(x)), alpha, beta) return y end - __mulonel!(y, A, x, alpha, beta) + __mul!(y, A, x, alpha, beta) y end -function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number) +function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, beta) check_matmul_sizes(C, A, B) - if B.ind[1] ∉ axes(B,1) || B.ind[2] ∉ axes(B,2) # in this case x is all zeros + if iszero(getindex_value(B)) mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) return C end @@ -127,24 +178,128 @@ function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number) view(C, :, B.ind[2]+1:size(C,2)) .*= beta end y = view(C, :, B.ind[2]) - __mulonel!(y, A, B, alpha, beta) + __mul!(y, A, B, alpha, beta) + C +end +function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta) + check_matmul_sizes(C, A, B) + if iszero(getindex_value(B)) + mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta) + return C + end + if iszero(beta) + C .= zero(eltype(C)) + else + view(C, :, 1:B.ind[2]-1) .*= beta + view(C, :, B.ind[2]+1:size(C,2)) .*= beta + end + ABα = A * B * alpha + nzrow, nzcol = B.ind + if iszero(beta) + C[B.ind...] = ABα[B.ind...] + else + y = view(C, :, nzcol) + y .= view(ABα, :, nzcol) .+ y .* beta + end + C +end + +function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, beta) + check_matmul_sizes(C, A, B) + if iszero(getindex_value(A)) + mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) + return C + end + if iszero(beta) + C .= zero(eltype(C)) + else + view(C, 1:A.ind[1]-1, :) .*= beta + view(C, A.ind[1]+1:size(C,1), :) .*= beta + end + y = view(C, A.ind[1], :) + ind2 = A.ind[2] + Aval = A.val + if iszero(beta) + y .= Aval .* view(B, ind2, :) .* alpha + else + y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta + end + C +end +function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta) + check_matmul_sizes(C, A, B) + if iszero(getindex_value(A)) + mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) + return C + end + if iszero(beta) + C .= zero(eltype(C)) + else + view(C, 1:A.ind[1]-1, :) .*= beta + view(C, A.ind[1]+1:size(C,1), :) .*= beta + end + ABα = A * B * alpha + nzrow, nzcol = A.ind + if iszero(beta) + C[A.ind...] = ABα[A.ind...] + else + y = view(C, nzrow, :) + y .= view(ABα, nzrow, :) .+ y .* beta + end + C +end + +function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, beta) + check_matmul_sizes(C, A, B) + if iszero(getindex_value(A)) + mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta) + return C + end + nzrow, nzcol = A.ind + if iszero(beta) + C .= zero(eltype(C)) + else + view(C, 1:nzrow-1) .*= beta + view(C, nzrow+1:size(C,1)) .*= beta + end + Aval = A.val + if iszero(beta) + C[nzrow] = Aval * B[nzcol] * alpha + else + C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta + end C end for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix})) @eval function mul!(y::StridedVector, A::$MT, x::OneElementVector, alpha::Number, beta::Number) - _mulonel!(y, A, x, alpha, beta) + _mul!(y, A, x, alpha, beta) end +end +for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}), + :Diagonal) @eval function mul!(C::StridedMatrix, A::$MT, B::OneElementMatrix, alpha::Number, beta::Number) - _mulonel!(C, A, B, alpha, beta) + _mul!(C, A, B, alpha, beta) end + @eval function mul!(C::StridedMatrix, A::OneElementMatrix, B::$MT, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) + end +end +function mul!(C::StridedVector, A::OneElementMatrix, B::StridedVector, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) end function mul!(y::AbstractVector, A::AbstractFillMatrix, x::OneElementVector, alpha::Number, beta::Number) - _mulonel!(y, A, x, alpha, beta) + _mul!(y, A, x, alpha, beta) end function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::OneElementMatrix, alpha::Number, beta::Number) - _mulonel!(C, A, B, alpha, beta) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractFillVector, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) +end +function mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number) + _mul!(C, A, B, alpha, beta) end # adjoint/transpose diff --git a/test/runtests.jl b/test/runtests.jl index be4b8fec..0c2cced9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2113,14 +2113,16 @@ end @testset "matmul" begin A = reshape(Float64[1:9;], 3, 3) + v = reshape(Float64[1:3;], 3) testinds(w::AbstractArray) = testinds(size(w)) testinds(szw::Tuple{Int}) = (szw .- 1, szw .+ 1) function testinds(szA::Tuple{Int,Int}) (szA .- 1, szA .+ (-1,0), szA .+ (0,-1), szA .+ 1, szA .+ (1,-1), szA .+ (-1,1)) end - function test_A_mul_OneElement(A, (w, w2)) - @testset for ind in testinds(w) - x = OneElement(3, ind, size(w)) + # test matvec if w is a vector, or matmat if w is a matrix + function test_mat_mul_OneElement(A, (w, w2), sz) + @testset for ind in testinds(sz) + x = OneElement(3, ind, sz) xarr = Array(x) Axarr = A * xarr Aadjxarr = A' * xarr @@ -2143,15 +2145,69 @@ end @test mul!(w2, F, x, 1.0, 1.0) ≈ Array(F) * xarr .+ 1 end end + function test_OneElementMatrix_mul_mat(A, (w, w2), sz) + @testset for ind in testinds(sz) + O = OneElement(3, ind, sz) + Oarr = Array(O) + OarrA = Oarr * A + OarrAadj = Oarr * A' + + @test O * A ≈ OarrA + @test O * A' ≈ OarrAadj + @test O * transpose(A) ≈ Oarr * transpose(A) + + @test mul!(w, O, A) ≈ OarrA + # check columnwise to ensure zero columns + @test all(((c1, c2),) -> c1 ≈ c2, zip(eachcol(w), eachcol(OarrA))) + @test mul!(w, O, A') ≈ OarrAadj + w .= 1 + @test mul!(w, O, A, 1.0, 2.0) ≈ OarrA .+ 2 + w .= 1 + @test mul!(w, O, A', 1.0, 2.0) ≈ OarrAadj .+ 2 + + F = Fill(3, size(A)) + w2 .= 1 + @test mul!(w2, O, F, 1.0, 1.0) ≈ Oarr * Array(F) .+ 1 + end + end + function test_OneElementMatrix_mul_vec(v, (w, w2), sz) + @testset for ind in testinds(sz) + O = OneElement(3, ind, sz) + Oarr = Array(O) + Oarrv = Oarr * v + + @test O * v == Oarrv + + @test mul!(w, O, v) == Oarrv + # check rowwise to ensure zero rows + @test all(((r1, r2),) -> r1 == r2, zip(eachrow(w), eachrow(Oarrv))) + w .= 1 + @test mul!(w, O, v, 1.0, 2.0) == Oarrv .+ 2 + + F = Fill(3, size(v)) + w2 .= 1 + @test mul!(w2, O, F, 1.0, 1.0) == Oarr * Array(F) .+ 1 + end + end @testset "Matrix * OneElementVector" begin w = zeros(size(A,1)) w2 = MVector{length(w)}(w) - test_A_mul_OneElement(A, (w, w2)) + test_mat_mul_OneElement(A, (w, w2), size(w)) end @testset "Matrix * OneElementMatrix" begin C = zeros(size(A)) C2 = MMatrix{size(C)...}(C) - test_A_mul_OneElement(A, (C, C2)) + test_mat_mul_OneElement(A, (C, C2), size(C)) + end + @testset "OneElementMatrix * Vector" begin + w = zeros(size(v)) + w2 = MVector{size(v)...}(v) + test_OneElementMatrix_mul_vec(v, (w, w2), size(A)) + end + @testset "OneElementMatrix * Matrix" begin + C = zeros(size(A)) + C2 = MMatrix{size(C)...}(C) + test_OneElementMatrix_mul_mat(A, (C, C2), size(A)) end @testset "OneElementMatrix * OneElement" begin @testset for ind in testinds(A) @@ -2159,10 +2215,14 @@ end v = OneElement(4, ind[2], size(A,1)) @test O * v isa OneElement @test O * v == Array(O) * Array(v) + @test mul!(ones(size(O,1)), O, v) == O * v + @test mul!(ones(size(O,1)), O, v, 2, 1) == 2 * O * v .+ 1 B = OneElement(4, ind, size(A)) @test O * B isa OneElement @test O * B == Array(O) * Array(B) + @test mul!(ones(size(O,1), size(B,2)), O, B) == O * B + @test mul!(ones(size(O,1), size(B,2)), O, B, 2, 1) == 2 * O * B .+ 1 end @test OneElement(3, (2,3), (5,4)) * OneElement(2, 2, 4) == Zeros(5) @@ -2191,6 +2251,23 @@ end B = Zeros(4) @test A * B === Zeros(5) end + @testset "Diagonal and OneElementMatrix" begin + for ind in ((2,3), (2,2), (10,10)) + O = OneElement(3, ind, (4,3)) + Oarr = Array(O) + C = zeros(size(O)) + D = Diagonal(axes(O,1)) + @test D * O == D * Oarr + @test mul!(C, D, O) == D * O + C .= 1 + @test mul!(C, D, O, 2, 2) == 2 * D * O .+ 2 + D = Diagonal(axes(O,2)) + @test O * D == Oarr * D + @test mul!(C, O, D) == O * D + C .= 1 + @test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2 + end + end end @testset "multiplication/division by a number" begin