Skip to content

Commit 99278fa

Browse files
jishnuboxinabox
andauthored
Add missing matrix multiplication methods involving OneElement (#347)
* Add missing matrix multiplication methods involving OneElement * multiplications with Diagonal * Add suggested comment to `__muloneel!` Co-authored-by: Frames White <[email protected]> --------- Co-authored-by: Frames White <[email protected]>
1 parent cff6c44 commit 99278fa

File tree

2 files changed

+248
-16
lines changed

2 files changed

+248
-16
lines changed

src/oneelement.jl

Lines changed: 166 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,58 @@ function *(A::OneElementMatrix, B::AbstractFillVector)
9393
OneElement(val, A.ind[1], size(A,1))
9494
end
9595

96-
@inline function __mulonel!(y, A, x, alpha, beta)
96+
# Special matrix types
97+
98+
function *(A::OneElementMatrix, D::Diagonal)
99+
check_matmul_sizes(A, D)
100+
nzcol = A.ind[2]
101+
val = if nzcol in axes(D,1)
102+
A.val * D[nzcol, nzcol]
103+
else
104+
A.val * zero(eltype(D))
105+
end
106+
OneElement(val, A.ind, size(A))
107+
end
108+
function *(D::Diagonal, A::OneElementMatrix)
109+
check_matmul_sizes(D, A)
110+
nzrow = A.ind[1]
111+
val = if nzrow in axes(D,2)
112+
D[nzrow, nzrow] * A.val
113+
else
114+
zero(eltype(D)) * A.val
115+
end
116+
OneElement(val, A.ind, size(A))
117+
end
118+
119+
# Inplace multiplication
120+
121+
# We use this for out overloads for _mul! for OneElement because its more efficient
122+
# due to how efficient 2 arg mul is when one or more of the args are OneElement
123+
function __mulonel!(C, A, B, alpha, beta)
124+
ABα = A * B * alpha
125+
if iszero(beta)
126+
C .= ABα
127+
else
128+
C .= ABα .+ C .* beta
129+
end
130+
return C
131+
end
132+
# These methods remove the ambituity in _mul!. This isn't strictly necessary, but this makes Aqua happy.
133+
function _mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha, beta)
134+
__mulonel!(C, A, B, alpha, beta)
135+
end
136+
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha, beta)
137+
__mulonel!(C, A, B, alpha, beta)
138+
end
139+
140+
function mul!(C::AbstractMatrix, A::OneElementMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
141+
_mul!(C, A, B, alpha, beta)
142+
end
143+
function mul!(C::AbstractVector, A::OneElementMatrix, B::OneElementVector, alpha::Number, beta::Number)
144+
_mul!(C, A, B, alpha, beta)
145+
end
146+
147+
@inline function __mul!(y, A::AbstractMatrix, x::OneElement, alpha, beta)
97148
αx = alpha * x.val
98149
ind1 = x.ind[1]
99150
if iszero(beta)
@@ -104,19 +155,19 @@ end
104155
return y
105156
end
106157

107-
function _mulonel!(y, A, x::OneElementVector, alpha::Number, beta::Number)
158+
function _mul!(y::AbstractVector, A::AbstractMatrix, x::OneElementVector, alpha, beta)
108159
check_matmul_sizes(y, A, x)
109-
if x.ind[1] axes(x,1) # in this case x is all zeros
160+
if iszero(getindex_value(x))
110161
mul!(y, A, Zeros{eltype(x)}(axes(x)), alpha, beta)
111162
return y
112163
end
113-
__mulonel!(y, A, x, alpha, beta)
164+
__mul!(y, A, x, alpha, beta)
114165
y
115166
end
116167

117-
function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
168+
function _mul!(C::AbstractMatrix, A::AbstractMatrix, B::OneElementMatrix, alpha, beta)
118169
check_matmul_sizes(C, A, B)
119-
if B.ind[1] axes(B,1) || B.ind[2] axes(B,2) # in this case x is all zeros
170+
if iszero(getindex_value(B))
120171
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
121172
return C
122173
end
@@ -127,24 +178,128 @@ function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
127178
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
128179
end
129180
y = view(C, :, B.ind[2])
130-
__mulonel!(y, A, B, alpha, beta)
181+
__mul!(y, A, B, alpha, beta)
182+
C
183+
end
184+
function _mul!(C::AbstractMatrix, A::Diagonal, B::OneElementMatrix, alpha, beta)
185+
check_matmul_sizes(C, A, B)
186+
if iszero(getindex_value(B))
187+
mul!(C, A, Zeros{eltype(B)}(axes(B)), alpha, beta)
188+
return C
189+
end
190+
if iszero(beta)
191+
C .= zero(eltype(C))
192+
else
193+
view(C, :, 1:B.ind[2]-1) .*= beta
194+
view(C, :, B.ind[2]+1:size(C,2)) .*= beta
195+
end
196+
ABα = A * B * alpha
197+
nzrow, nzcol = B.ind
198+
if iszero(beta)
199+
C[B.ind...] = ABα[B.ind...]
200+
else
201+
y = view(C, :, nzcol)
202+
y .= view(ABα, :, nzcol) .+ y .* beta
203+
end
204+
C
205+
end
206+
207+
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractMatrix, alpha, beta)
208+
check_matmul_sizes(C, A, B)
209+
if iszero(getindex_value(A))
210+
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
211+
return C
212+
end
213+
if iszero(beta)
214+
C .= zero(eltype(C))
215+
else
216+
view(C, 1:A.ind[1]-1, :) .*= beta
217+
view(C, A.ind[1]+1:size(C,1), :) .*= beta
218+
end
219+
y = view(C, A.ind[1], :)
220+
ind2 = A.ind[2]
221+
Aval = A.val
222+
if iszero(beta)
223+
y .= Aval .* view(B, ind2, :) .* alpha
224+
else
225+
y .= Aval .* view(B, ind2, :) .* alpha .+ y .* beta
226+
end
227+
C
228+
end
229+
function _mul!(C::AbstractMatrix, A::OneElementMatrix, B::Diagonal, alpha, beta)
230+
check_matmul_sizes(C, A, B)
231+
if iszero(getindex_value(A))
232+
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
233+
return C
234+
end
235+
if iszero(beta)
236+
C .= zero(eltype(C))
237+
else
238+
view(C, 1:A.ind[1]-1, :) .*= beta
239+
view(C, A.ind[1]+1:size(C,1), :) .*= beta
240+
end
241+
ABα = A * B * alpha
242+
nzrow, nzcol = A.ind
243+
if iszero(beta)
244+
C[A.ind...] = ABα[A.ind...]
245+
else
246+
y = view(C, nzrow, :)
247+
y .= view(ABα, nzrow, :) .+ y .* beta
248+
end
249+
C
250+
end
251+
252+
function _mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractVector, alpha, beta)
253+
check_matmul_sizes(C, A, B)
254+
if iszero(getindex_value(A))
255+
mul!(C, Zeros{eltype(A)}(axes(A)), B, alpha, beta)
256+
return C
257+
end
258+
nzrow, nzcol = A.ind
259+
if iszero(beta)
260+
C .= zero(eltype(C))
261+
else
262+
view(C, 1:nzrow-1) .*= beta
263+
view(C, nzrow+1:size(C,1)) .*= beta
264+
end
265+
Aval = A.val
266+
if iszero(beta)
267+
C[nzrow] = Aval * B[nzcol] * alpha
268+
else
269+
C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
270+
end
131271
C
132272
end
133273

134274
for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}))
135275
@eval function mul!(y::StridedVector, A::$MT, x::OneElementVector, alpha::Number, beta::Number)
136-
_mulonel!(y, A, x, alpha, beta)
276+
_mul!(y, A, x, alpha, beta)
137277
end
278+
end
279+
for MT in (:StridedMatrix, :(Transpose{<:Any, <:StridedMatrix}), :(Adjoint{<:Any, <:StridedMatrix}),
280+
:Diagonal)
138281
@eval function mul!(C::StridedMatrix, A::$MT, B::OneElementMatrix, alpha::Number, beta::Number)
139-
_mulonel!(C, A, B, alpha, beta)
282+
_mul!(C, A, B, alpha, beta)
140283
end
284+
@eval function mul!(C::StridedMatrix, A::OneElementMatrix, B::$MT, alpha::Number, beta::Number)
285+
_mul!(C, A, B, alpha, beta)
286+
end
287+
end
288+
function mul!(C::StridedVector, A::OneElementMatrix, B::StridedVector, alpha::Number, beta::Number)
289+
_mul!(C, A, B, alpha, beta)
141290
end
142291

143292
function mul!(y::AbstractVector, A::AbstractFillMatrix, x::OneElementVector, alpha::Number, beta::Number)
144-
_mulonel!(y, A, x, alpha, beta)
293+
_mul!(y, A, x, alpha, beta)
145294
end
146295
function mul!(C::AbstractMatrix, A::AbstractFillMatrix, B::OneElementMatrix, alpha::Number, beta::Number)
147-
_mulonel!(C, A, B, alpha, beta)
296+
_mul!(C, A, B, alpha, beta)
297+
end
298+
function mul!(C::AbstractVector, A::OneElementMatrix, B::AbstractFillVector, alpha::Number, beta::Number)
299+
_mul!(C, A, B, alpha, beta)
300+
end
301+
function mul!(C::AbstractMatrix, A::OneElementMatrix, B::AbstractFillMatrix, alpha::Number, beta::Number)
302+
_mul!(C, A, B, alpha, beta)
148303
end
149304

150305
# adjoint/transpose

test/runtests.jl

Lines changed: 82 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2113,14 +2113,16 @@ end
21132113

21142114
@testset "matmul" begin
21152115
A = reshape(Float64[1:9;], 3, 3)
2116+
v = reshape(Float64[1:3;], 3)
21162117
testinds(w::AbstractArray) = testinds(size(w))
21172118
testinds(szw::Tuple{Int}) = (szw .- 1, szw .+ 1)
21182119
function testinds(szA::Tuple{Int,Int})
21192120
(szA .- 1, szA .+ (-1,0), szA .+ (0,-1), szA .+ 1, szA .+ (1,-1), szA .+ (-1,1))
21202121
end
2121-
function test_A_mul_OneElement(A, (w, w2))
2122-
@testset for ind in testinds(w)
2123-
x = OneElement(3, ind, size(w))
2122+
# test matvec if w is a vector, or matmat if w is a matrix
2123+
function test_mat_mul_OneElement(A, (w, w2), sz)
2124+
@testset for ind in testinds(sz)
2125+
x = OneElement(3, ind, sz)
21242126
xarr = Array(x)
21252127
Axarr = A * xarr
21262128
Aadjxarr = A' * xarr
@@ -2143,26 +2145,84 @@ end
21432145
@test mul!(w2, F, x, 1.0, 1.0) Array(F) * xarr .+ 1
21442146
end
21452147
end
2148+
function test_OneElementMatrix_mul_mat(A, (w, w2), sz)
2149+
@testset for ind in testinds(sz)
2150+
O = OneElement(3, ind, sz)
2151+
Oarr = Array(O)
2152+
OarrA = Oarr * A
2153+
OarrAadj = Oarr * A'
2154+
2155+
@test O * A OarrA
2156+
@test O * A' OarrAadj
2157+
@test O * transpose(A) Oarr * transpose(A)
2158+
2159+
@test mul!(w, O, A) OarrA
2160+
# check columnwise to ensure zero columns
2161+
@test all(((c1, c2),) -> c1 c2, zip(eachcol(w), eachcol(OarrA)))
2162+
@test mul!(w, O, A') OarrAadj
2163+
w .= 1
2164+
@test mul!(w, O, A, 1.0, 2.0) OarrA .+ 2
2165+
w .= 1
2166+
@test mul!(w, O, A', 1.0, 2.0) OarrAadj .+ 2
2167+
2168+
F = Fill(3, size(A))
2169+
w2 .= 1
2170+
@test mul!(w2, O, F, 1.0, 1.0) Oarr * Array(F) .+ 1
2171+
end
2172+
end
2173+
function test_OneElementMatrix_mul_vec(v, (w, w2), sz)
2174+
@testset for ind in testinds(sz)
2175+
O = OneElement(3, ind, sz)
2176+
Oarr = Array(O)
2177+
Oarrv = Oarr * v
2178+
2179+
@test O * v == Oarrv
2180+
2181+
@test mul!(w, O, v) == Oarrv
2182+
# check rowwise to ensure zero rows
2183+
@test all(((r1, r2),) -> r1 == r2, zip(eachrow(w), eachrow(Oarrv)))
2184+
w .= 1
2185+
@test mul!(w, O, v, 1.0, 2.0) == Oarrv .+ 2
2186+
2187+
F = Fill(3, size(v))
2188+
w2 .= 1
2189+
@test mul!(w2, O, F, 1.0, 1.0) == Oarr * Array(F) .+ 1
2190+
end
2191+
end
21462192
@testset "Matrix * OneElementVector" begin
21472193
w = zeros(size(A,1))
21482194
w2 = MVector{length(w)}(w)
2149-
test_A_mul_OneElement(A, (w, w2))
2195+
test_mat_mul_OneElement(A, (w, w2), size(w))
21502196
end
21512197
@testset "Matrix * OneElementMatrix" begin
21522198
C = zeros(size(A))
21532199
C2 = MMatrix{size(C)...}(C)
2154-
test_A_mul_OneElement(A, (C, C2))
2200+
test_mat_mul_OneElement(A, (C, C2), size(C))
2201+
end
2202+
@testset "OneElementMatrix * Vector" begin
2203+
w = zeros(size(v))
2204+
w2 = MVector{size(v)...}(v)
2205+
test_OneElementMatrix_mul_vec(v, (w, w2), size(A))
2206+
end
2207+
@testset "OneElementMatrix * Matrix" begin
2208+
C = zeros(size(A))
2209+
C2 = MMatrix{size(C)...}(C)
2210+
test_OneElementMatrix_mul_mat(A, (C, C2), size(A))
21552211
end
21562212
@testset "OneElementMatrix * OneElement" begin
21572213
@testset for ind in testinds(A)
21582214
O = OneElement(3, ind, size(A))
21592215
v = OneElement(4, ind[2], size(A,1))
21602216
@test O * v isa OneElement
21612217
@test O * v == Array(O) * Array(v)
2218+
@test mul!(ones(size(O,1)), O, v) == O * v
2219+
@test mul!(ones(size(O,1)), O, v, 2, 1) == 2 * O * v .+ 1
21622220

21632221
B = OneElement(4, ind, size(A))
21642222
@test O * B isa OneElement
21652223
@test O * B == Array(O) * Array(B)
2224+
@test mul!(ones(size(O,1), size(B,2)), O, B) == O * B
2225+
@test mul!(ones(size(O,1), size(B,2)), O, B, 2, 1) == 2 * O * B .+ 1
21662226
end
21672227

21682228
@test OneElement(3, (2,3), (5,4)) * OneElement(2, 2, 4) == Zeros(5)
@@ -2191,6 +2251,23 @@ end
21912251
B = Zeros(4)
21922252
@test A * B === Zeros(5)
21932253
end
2254+
@testset "Diagonal and OneElementMatrix" begin
2255+
for ind in ((2,3), (2,2), (10,10))
2256+
O = OneElement(3, ind, (4,3))
2257+
Oarr = Array(O)
2258+
C = zeros(size(O))
2259+
D = Diagonal(axes(O,1))
2260+
@test D * O == D * Oarr
2261+
@test mul!(C, D, O) == D * O
2262+
C .= 1
2263+
@test mul!(C, D, O, 2, 2) == 2 * D * O .+ 2
2264+
D = Diagonal(axes(O,2))
2265+
@test O * D == Oarr * D
2266+
@test mul!(C, O, D) == O * D
2267+
C .= 1
2268+
@test mul!(C, O, D, 2, 2) == 2 * O * D .+ 2
2269+
end
2270+
end
21942271
end
21952272

21962273
@testset "multiplication/division by a number" begin

0 commit comments

Comments
 (0)