@@ -198,19 +198,37 @@ Base.literal_pow(::typeof(^), D::Diagonal, valp::Val) =
198198 Diagonal (Base. literal_pow .(^ , D. diag, valp)) # for speed
199199Base. literal_pow (:: typeof (^ ), D:: Diagonal , :: Val{-1} ) = inv (D) # for disambiguation
200200
201+ function _muldiag_size_check (A, B)
202+ nA = size (A, 2 )
203+ mB = size (B, 1 )
204+ @noinline throw_dimerr (:: AbstractMatrix , nA, mB) = throw (DimensionMismatch (" second dimension of A, $nA , does not match first dimension of B, $mB " ))
205+ @noinline throw_dimerr (:: AbstractVector , nA, mB) = throw (DimensionMismatch (" second dimension of D, $nA , does not match length of V, $mB " ))
206+ nA == mB || throw_dimerr (B, nA, mB)
207+ return nothing
208+ end
209+ # the output matrix should have the same size as the non-diagonal input matrix or vector
210+ @noinline throw_dimerr (szC, szA) = throw (DimensionMismatch (" output matrix has size: $szC , but should have size $szA " ))
211+ _size_check_out (C, :: Diagonal , A) = _size_check_out (C, A)
212+ _size_check_out (C, A, :: Diagonal ) = _size_check_out (C, A)
213+ _size_check_out (C, A:: Diagonal , :: Diagonal ) = _size_check_out (C, A)
214+ function _size_check_out (C, A)
215+ szA = size (A)
216+ szC = size (C)
217+ szA == szC || throw_dimerr (szC, szA)
218+ return nothing
219+ end
220+ function _muldiag_size_check (C, A, B)
221+ _muldiag_size_check (A, B)
222+ _size_check_out (C, A, B)
223+ end
224+
201225function (* )(Da:: Diagonal , Db:: Diagonal )
202- nDa, mDb = size (Da, 2 ), size (Db, 1 )
203- if nDa != mDb
204- throw (DimensionMismatch (" second dimension of Da, $nDa , does not match first dimension of Db, $mDb " ))
205- end
226+ _muldiag_size_check (Da, Db)
206227 return Diagonal (Da. diag .* Db. diag)
207228end
208229
209230function (* )(D:: Diagonal , V:: AbstractVector )
210- nD = size (D, 2 )
211- if nD != length (V)
212- throw (DimensionMismatch (" second dimension of D, $nD , does not match length of V, $(length (V)) " ))
213- end
231+ _muldiag_size_check (D, V)
214232 return D. diag .* V
215233end
216234
@@ -220,29 +238,12 @@ end
220238 lmul! (D, copy_oftype (B, promote_op (* , eltype (B), eltype (D. diag))))
221239
222240(* )(A:: AbstractMatrix , D:: Diagonal ) =
223- rmul! ( copy_similar (A, promote_op (* , eltype (A), eltype (D. diag))) , D)
241+ mul! ( similar (A, promote_op (* , eltype (A), eltype (D. diag)), size (A)), A , D)
224242(* )(D:: Diagonal , A:: AbstractMatrix ) =
225- lmul! (D, copy_similar (A, promote_op (* , eltype (A), eltype (D. diag))) )
243+ mul! ( similar (A, promote_op (* , eltype (A), eltype (D. diag)), size (A)), D, A )
226244
227- function rmul! (A:: AbstractMatrix , D:: Diagonal )
228- require_one_based_indexing (A)
229- nA, nD = size (A, 2 ), length (D. diag)
230- if nA != nD
231- throw (DimensionMismatch (" second dimension of A, $nA , does not match the first of D, $nD " ))
232- end
233- A .= A .* permutedims (D. diag)
234- return A
235- end
236-
237- function lmul! (D:: Diagonal , B:: AbstractVecOrMat )
238- require_one_based_indexing (B)
239- nB, nD = size (B, 1 ), length (D. diag)
240- if nB != nD
241- throw (DimensionMismatch (" second dimension of D, $nD , does not match the first of B, $nB " ))
242- end
243- B .= D. diag .* B
244- return B
245- end
245+ rmul! (A:: AbstractMatrix , D:: Diagonal ) = mul! (A, A, D)
246+ lmul! (D:: Diagonal , B:: AbstractVecOrMat ) = mul! (B, D, B)
246247
247248rmul! (A:: Union{LowerTriangular,UpperTriangular} , D:: Diagonal ) = typeof (A)(rmul! (A. data, D))
248249function rmul! (A:: UnitLowerTriangular , D:: Diagonal )
@@ -306,37 +307,66 @@ function *(D::Diagonal, transA::Transpose{<:Any,<:AbstractMatrix})
306307 lmul! (D, At)
307308end
308309
309- rmul! (A:: Diagonal , B:: Diagonal ) = Diagonal (A. diag .*= B. diag)
310- lmul! (A:: Diagonal , B:: Diagonal ) = Diagonal (B. diag .= A. diag .* B. diag)
310+ @inline function __muldiag! (out, D:: Diagonal , B, alpha, beta)
311+ if iszero (beta)
312+ out .= (D. diag .* B) .* ₛ alpha
313+ else
314+ out .= (D. diag .* B) .* ₛ alpha .+ out .* beta
315+ end
316+ return out
317+ end
318+
319+ @inline function __muldiag! (out, A, D:: Diagonal , alpha, beta)
320+ if iszero (beta)
321+ out .= (A .* permutedims (D. diag)) .* ₛ alpha
322+ else
323+ out .= (A .* permutedims (D. diag)) .* ₛ alpha .+ out .* beta
324+ end
325+ return out
326+ end
327+
328+ @inline function __muldiag! (out:: Diagonal , D1:: Diagonal , D2:: Diagonal , alpha, beta)
329+ if iszero (beta)
330+ out. diag .= (D1. diag .* D2. diag) .* ₛ alpha
331+ else
332+ out. diag .= (D1. diag .* D2. diag) .* ₛ alpha .+ out. diag .* beta
333+ end
334+ return out
335+ end
336+
337+ # only needed for ambiguity resolution, as mul! is explicitly defined for these arguments
338+ @inline __muldiag! (out, D1:: Diagonal , D2:: Diagonal , alpha, beta) =
339+ mul! (out, D1, D2, alpha, beta)
340+
341+ @inline function _muldiag! (out, A, B, alpha, beta)
342+ _muldiag_size_check (out, A, B)
343+ __muldiag! (out, A, B, alpha, beta)
344+ return out
345+ end
311346
312347# Get ambiguous method if try to unify AbstractVector/AbstractMatrix here using AbstractVecOrMat
313- @inline mul! (out:: AbstractVector , A:: Diagonal , in:: AbstractVector , alpha:: Number , beta:: Number ) =
314- out .= (A. diag .* in) .* ₛ alpha .+ out .* ₛ beta
315- @inline mul! (out:: AbstractMatrix , A:: Diagonal , in:: AbstractMatrix , alpha:: Number , beta:: Number ) =
316- out .= (A. diag .* in) .* ₛ alpha .+ out .* ₛ beta
317- @inline mul! (out:: AbstractMatrix , A:: Diagonal , in:: Adjoint{<:Any,<:AbstractVecOrMat} ,
318- alpha:: Number , beta:: Number ) =
319- out .= (A. diag .* in) .* ₛ alpha .+ out .* ₛ beta
320- @inline mul! (out:: AbstractMatrix , A:: Diagonal , in:: Transpose{<:Any,<:AbstractVecOrMat} ,
321- alpha:: Number , beta:: Number ) =
322- out .= (A. diag .* in) .* ₛ alpha .+ out .* ₛ beta
323-
324- @inline mul! (out:: AbstractMatrix , in:: AbstractMatrix , A:: Diagonal , alpha:: Number , beta:: Number ) =
325- out .= (in .* permutedims (A. diag)) .* ₛ alpha .+ out .* ₛ beta
326- @inline mul! (out:: AbstractMatrix , in:: Adjoint{<:Any,<:AbstractVecOrMat} , A:: Diagonal ,
327- alpha:: Number , beta:: Number ) =
328- out .= (in .* permutedims (A. diag)) .* ₛ alpha .+ out .* ₛ beta
329- @inline mul! (out:: AbstractMatrix , in:: Transpose{<:Any,<:AbstractVecOrMat} , A:: Diagonal ,
330- alpha:: Number , beta:: Number ) =
331- out .= (in .* permutedims (A. diag)) .* ₛ alpha .+ out .* ₛ beta
348+ @inline mul! (out:: AbstractVector , D:: Diagonal , V:: AbstractVector , alpha:: Number , beta:: Number ) =
349+ _muldiag! (out, D, V, alpha, beta)
350+ @inline mul! (out:: AbstractMatrix , D:: Diagonal , B:: AbstractMatrix , alpha:: Number , beta:: Number ) =
351+ _muldiag! (out, D, B, alpha, beta)
352+ @inline mul! (out:: AbstractMatrix , D:: Diagonal , B:: Adjoint{<:Any,<:AbstractVecOrMat} ,
353+ alpha:: Number , beta:: Number ) = _muldiag! (out, D, B, alpha, beta)
354+ @inline mul! (out:: AbstractMatrix , D:: Diagonal , B:: Transpose{<:Any,<:AbstractVecOrMat} ,
355+ alpha:: Number , beta:: Number ) = _muldiag! (out, D, B, alpha, beta)
356+
357+ @inline mul! (out:: AbstractMatrix , A:: AbstractMatrix , D:: Diagonal , alpha:: Number , beta:: Number ) =
358+ _muldiag! (out, A, D, alpha, beta)
359+ @inline mul! (out:: AbstractMatrix , A:: Adjoint{<:Any,<:AbstractVecOrMat} , D:: Diagonal ,
360+ alpha:: Number , beta:: Number ) = _muldiag! (out, A, D, alpha, beta)
361+ @inline mul! (out:: AbstractMatrix , A:: Transpose{<:Any,<:AbstractVecOrMat} , D:: Diagonal ,
362+ alpha:: Number , beta:: Number ) = _muldiag! (out, A, D, alpha, beta)
363+ @inline mul! (C:: Diagonal , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number ) =
364+ _muldiag! (C, Da, Db, alpha, beta)
332365
333366function mul! (C:: AbstractMatrix , Da:: Diagonal , Db:: Diagonal , alpha:: Number , beta:: Number )
334- mA = size (Da, 1 )
335- mB = size (Db, 1 )
336- mA == mB || throw (DimensionMismatch (" A has dimensions ($mA ,$mA ) but B has dimensions ($mB ,$mB )" ))
337- mC, nC = size (C)
338- mC == nC == mA || throw (DimensionMismatch (" output matrix has size: ($mC ,$nC ), but should have size ($mA ,$mA )" ))
367+ _muldiag_size_check (C, Da, Db)
339368 require_one_based_indexing (C)
369+ mA = size (Da, 1 )
340370 da = Da. diag
341371 db = Db. diag
342372 _rmul_or_fill! (C, beta)
0 commit comments