@@ -93,7 +93,58 @@ function *(A::OneElementMatrix, B::AbstractFillVector)
9393 OneElement (val, A. ind[1 ], size (A,1 ))
9494end
9595
96- @inline function __mulonel! (y, A, x, alpha, beta)
96+ # Special matrix types
97+
98+ function * (A:: OneElementMatrix , D:: Diagonal )
99+ check_matmul_sizes (A, D)
100+ nzcol = A. ind[2 ]
101+ val = if nzcol in axes (D,1 )
102+ A. val * D[nzcol, nzcol]
103+ else
104+ A. val * zero (eltype (D))
105+ end
106+ OneElement (val, A. ind, size (A))
107+ end
108+ function * (D:: Diagonal , A:: OneElementMatrix )
109+ check_matmul_sizes (D, A)
110+ nzrow = A. ind[1 ]
111+ val = if nzrow in axes (D,2 )
112+ D[nzrow, nzrow] * A. val
113+ else
114+ zero (eltype (D)) * A. val
115+ end
116+ OneElement (val, A. ind, size (A))
117+ end
118+
119+ # Inplace multiplication
120+
121+ # We use this for out overloads for _mul! for OneElement because its more efficient
122+ # due to how efficient 2 arg mul is when one or more of the args are OneElement
123+ function __mulonel! (C, A, B, alpha, beta)
124+ ABα = A * B * alpha
125+ if iszero (beta)
126+ C .= ABα
127+ else
128+ C .= ABα .+ C .* beta
129+ end
130+ return C
131+ end
132+ # These methods remove the ambituity in _mul!. This isn't strictly necessary, but this makes Aqua happy.
133+ function _mul! (C:: AbstractVector , A:: OneElementMatrix , B:: OneElementVector , alpha, beta)
134+ __mulonel! (C, A, B, alpha, beta)
135+ end
136+ function _mul! (C:: AbstractMatrix , A:: OneElementMatrix , B:: OneElementMatrix , alpha, beta)
137+ __mulonel! (C, A, B, alpha, beta)
138+ end
139+
140+ function mul! (C:: AbstractMatrix , A:: OneElementMatrix , B:: OneElementMatrix , alpha:: Number , beta:: Number )
141+ _mul! (C, A, B, alpha, beta)
142+ end
143+ function mul! (C:: AbstractVector , A:: OneElementMatrix , B:: OneElementVector , alpha:: Number , beta:: Number )
144+ _mul! (C, A, B, alpha, beta)
145+ end
146+
147+ @inline function __mul! (y, A:: AbstractMatrix , x:: OneElement , alpha, beta)
97148 αx = alpha * x. val
98149 ind1 = x. ind[1 ]
99150 if iszero (beta)
@@ -104,19 +155,19 @@ end
104155 return y
105156end
106157
107- function _mulonel ! (y, A, x:: OneElementVector , alpha:: Number , beta:: Number )
158+ function _mul ! (y:: AbstractVector , A:: AbstractMatrix , x:: OneElementVector , alpha, beta)
108159 check_matmul_sizes (y, A, x)
109- if x . ind[ 1 ] ∉ axes (x, 1 ) # in this case x is all zeros
160+ if iszero ( getindex_value (x))
110161 mul! (y, A, Zeros {eltype(x)} (axes (x)), alpha, beta)
111162 return y
112163 end
113- __mulonel ! (y, A, x, alpha, beta)
164+ __mul ! (y, A, x, alpha, beta)
114165 y
115166end
116167
117- function _mulonel ! (C, A, B:: OneElementMatrix , alpha:: Number , beta:: Number )
168+ function _mul ! (C:: AbstractMatrix , A:: AbstractMatrix , B:: OneElementMatrix , alpha, beta)
118169 check_matmul_sizes (C, A, B)
119- if B . ind[ 1 ] ∉ axes (B, 1 ) || B . ind[ 2 ] ∉ axes (B, 2 ) # in this case x is all zeros
170+ if iszero ( getindex_value (B))
120171 mul! (C, A, Zeros {eltype(B)} (axes (B)), alpha, beta)
121172 return C
122173 end
@@ -127,24 +178,128 @@ function _mulonel!(C, A, B::OneElementMatrix, alpha::Number, beta::Number)
127178 view (C, :, B. ind[2 ]+ 1 : size (C,2 )) .*= beta
128179 end
129180 y = view (C, :, B. ind[2 ])
130- __mulonel! (y, A, B, alpha, beta)
181+ __mul! (y, A, B, alpha, beta)
182+ C
183+ end
184+ function _mul! (C:: AbstractMatrix , A:: Diagonal , B:: OneElementMatrix , alpha, beta)
185+ check_matmul_sizes (C, A, B)
186+ if iszero (getindex_value (B))
187+ mul! (C, A, Zeros {eltype(B)} (axes (B)), alpha, beta)
188+ return C
189+ end
190+ if iszero (beta)
191+ C .= zero (eltype (C))
192+ else
193+ view (C, :, 1 : B. ind[2 ]- 1 ) .*= beta
194+ view (C, :, B. ind[2 ]+ 1 : size (C,2 )) .*= beta
195+ end
196+ ABα = A * B * alpha
197+ nzrow, nzcol = B. ind
198+ if iszero (beta)
199+ C[B. ind... ] = ABα[B. ind... ]
200+ else
201+ y = view (C, :, nzcol)
202+ y .= view (ABα, :, nzcol) .+ y .* beta
203+ end
204+ C
205+ end
206+
207+ function _mul! (C:: AbstractMatrix , A:: OneElementMatrix , B:: AbstractMatrix , alpha, beta)
208+ check_matmul_sizes (C, A, B)
209+ if iszero (getindex_value (A))
210+ mul! (C, Zeros {eltype(A)} (axes (A)), B, alpha, beta)
211+ return C
212+ end
213+ if iszero (beta)
214+ C .= zero (eltype (C))
215+ else
216+ view (C, 1 : A. ind[1 ]- 1 , :) .*= beta
217+ view (C, A. ind[1 ]+ 1 : size (C,1 ), :) .*= beta
218+ end
219+ y = view (C, A. ind[1 ], :)
220+ ind2 = A. ind[2 ]
221+ Aval = A. val
222+ if iszero (beta)
223+ y .= Aval .* view (B, ind2, :) .* alpha
224+ else
225+ y .= Aval .* view (B, ind2, :) .* alpha .+ y .* beta
226+ end
227+ C
228+ end
229+ function _mul! (C:: AbstractMatrix , A:: OneElementMatrix , B:: Diagonal , alpha, beta)
230+ check_matmul_sizes (C, A, B)
231+ if iszero (getindex_value (A))
232+ mul! (C, Zeros {eltype(A)} (axes (A)), B, alpha, beta)
233+ return C
234+ end
235+ if iszero (beta)
236+ C .= zero (eltype (C))
237+ else
238+ view (C, 1 : A. ind[1 ]- 1 , :) .*= beta
239+ view (C, A. ind[1 ]+ 1 : size (C,1 ), :) .*= beta
240+ end
241+ ABα = A * B * alpha
242+ nzrow, nzcol = A. ind
243+ if iszero (beta)
244+ C[A. ind... ] = ABα[A. ind... ]
245+ else
246+ y = view (C, nzrow, :)
247+ y .= view (ABα, nzrow, :) .+ y .* beta
248+ end
249+ C
250+ end
251+
252+ function _mul! (C:: AbstractVector , A:: OneElementMatrix , B:: AbstractVector , alpha, beta)
253+ check_matmul_sizes (C, A, B)
254+ if iszero (getindex_value (A))
255+ mul! (C, Zeros {eltype(A)} (axes (A)), B, alpha, beta)
256+ return C
257+ end
258+ nzrow, nzcol = A. ind
259+ if iszero (beta)
260+ C .= zero (eltype (C))
261+ else
262+ view (C, 1 : nzrow- 1 ) .*= beta
263+ view (C, nzrow+ 1 : size (C,1 )) .*= beta
264+ end
265+ Aval = A. val
266+ if iszero (beta)
267+ C[nzrow] = Aval * B[nzcol] * alpha
268+ else
269+ C[nzrow] = Aval * B[nzcol] * alpha + C[nzrow] * beta
270+ end
131271 C
132272end
133273
134274for MT in (:StridedMatrix , :(Transpose{<: Any , <: StridedMatrix }), :(Adjoint{<: Any , <: StridedMatrix }))
135275 @eval function mul! (y:: StridedVector , A:: $MT , x:: OneElementVector , alpha:: Number , beta:: Number )
136- _mulonel ! (y, A, x, alpha, beta)
276+ _mul ! (y, A, x, alpha, beta)
137277 end
278+ end
279+ for MT in (:StridedMatrix , :(Transpose{<: Any , <: StridedMatrix }), :(Adjoint{<: Any , <: StridedMatrix }),
280+ :Diagonal )
138281 @eval function mul! (C:: StridedMatrix , A:: $MT , B:: OneElementMatrix , alpha:: Number , beta:: Number )
139- _mulonel ! (C, A, B, alpha, beta)
282+ _mul ! (C, A, B, alpha, beta)
140283 end
284+ @eval function mul! (C:: StridedMatrix , A:: OneElementMatrix , B:: $MT , alpha:: Number , beta:: Number )
285+ _mul! (C, A, B, alpha, beta)
286+ end
287+ end
288+ function mul! (C:: StridedVector , A:: OneElementMatrix , B:: StridedVector , alpha:: Number , beta:: Number )
289+ _mul! (C, A, B, alpha, beta)
141290end
142291
143292function mul! (y:: AbstractVector , A:: AbstractFillMatrix , x:: OneElementVector , alpha:: Number , beta:: Number )
144- _mulonel ! (y, A, x, alpha, beta)
293+ _mul ! (y, A, x, alpha, beta)
145294end
146295function mul! (C:: AbstractMatrix , A:: AbstractFillMatrix , B:: OneElementMatrix , alpha:: Number , beta:: Number )
147- _mulonel! (C, A, B, alpha, beta)
296+ _mul! (C, A, B, alpha, beta)
297+ end
298+ function mul! (C:: AbstractVector , A:: OneElementMatrix , B:: AbstractFillVector , alpha:: Number , beta:: Number )
299+ _mul! (C, A, B, alpha, beta)
300+ end
301+ function mul! (C:: AbstractMatrix , A:: OneElementMatrix , B:: AbstractFillMatrix , alpha:: Number , beta:: Number )
302+ _mul! (C, A, B, alpha, beta)
148303end
149304
150305# adjoint/transpose
0 commit comments