1+ const FillVector{F,A} = Fill{F,1 ,A}
2+ const FillMatrix{F,A} = Fill{F,2 ,A}
3+ const OnesVector{F,A} = Ones{F,1 ,A}
4+ const OnesMatrix{F,A} = Ones{F,2 ,A}
5+ const ZerosVector{F,A} = Zeros{F,1 ,A}
6+ const ZerosMatrix{F,A} = Zeros{F,2 ,A}
7+
18# # vec
29
310vec (a:: Ones{T} ) where T = Ones {T} (length (a))
8794* (a:: Zeros{<:Any,2} , b:: Diagonal ) = mult_zeros (a, b)
8895* (a:: Diagonal , b:: Zeros{<:Any,1} ) = mult_zeros (a, b)
8996* (a:: Diagonal , b:: Zeros{<:Any,2} ) = mult_zeros (a, b)
90- function * (a:: Diagonal , b:: AbstractFill{<:Any,2} )
97+
98+ # Cannot unify following methods for Diagonal
99+ # due to ambiguity with general array mult. with fill
100+ function * (a:: Diagonal , b:: FillMatrix )
101+ size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
102+ a. diag .* b # use special broadcast
103+ end
104+ function * (a:: FillMatrix , b:: Diagonal )
105+ size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
106+ a .* permutedims (b. diag) # use special broadcast
107+ end
108+ function * (a:: Diagonal , b:: OnesMatrix )
91109 size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
92110 a. diag .* b # use special broadcast
93111end
94- function * (a:: AbstractFill{<:Any,2} , b:: Diagonal )
112+ function * (a:: OnesMatrix , b:: Diagonal )
95113 size (a,2 ) == size (b,1 ) || throw (DimensionMismatch (" A has dimensions $(size (a)) but B has dimensions $(size (b)) " ))
96114 a .* permutedims (b. diag) # use special broadcast
97115end
@@ -100,23 +118,61 @@ end
100118* (a:: Transpose{T, <:StridedMatrix{T}} , b:: Fill{T, 1} ) where T = reshape (sum (parent (a); dims= 1 ) .* b. value, size (parent (a), 2 ))
101119* (a:: StridedMatrix{T} , b:: Fill{T, 1} ) where T = reshape (sum (a; dims= 2 ) .* b. value, size (a, 1 ))
102120
103- function * (a:: Adjoint{T, <:StridedMatrix{T}} , b:: Fill{T, 2} ) where T
104- fB = similar (parent (a), size (b, 1 ), size (b, 2 ))
105- fill! (fB, b. value)
106- return a* fB
121+ function * (x:: AbstractMatrix , f:: FillMatrix )
122+ axes (x, 2 ) ≠ axes (f, 1 ) &&
123+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
124+ m = size (f, 2 )
125+ repeat (sum (x, dims= 2 ) * f. value, 1 , m)
126+ end
127+
128+ function * (f:: FillMatrix , x:: AbstractMatrix )
129+ axes (f, 2 ) ≠ axes (x, 1 ) &&
130+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
131+ m = size (f, 1 )
132+ repeat (sum (x, dims= 1 ) * f. value, m, 1 )
107133end
108134
109- function * (a:: Transpose{T, <:StridedMatrix{T}} , b:: Fill{T, 2} ) where T
110- fB = similar (parent (a), size (b, 1 ), size (b, 2 ))
111- fill! (fB, b. value)
112- return a* fB
135+ function * (x:: AbstractMatrix , f:: OnesMatrix )
136+ axes (x, 2 ) ≠ axes (f, 1 ) &&
137+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
138+ m = size (f, 2 )
139+ repeat (sum (x, dims= 2 ) * one (eltype (f)), 1 , m)
113140end
114141
115- function * (a:: StridedMatrix{T} , b:: Fill{T, 2} ) where T
116- fB = similar (a, size (b, 1 ), size (b, 2 ))
117- fill! (fB, b. value)
118- return a* fB
142+ function * (f:: OnesMatrix , x:: AbstractMatrix )
143+ axes (f, 2 ) ≠ axes (x, 1 ) &&
144+ throw (DimensionMismatch (" Incompatible matrix multiplication dimensions" ))
145+ m = size (f, 1 )
146+ repeat (sum (x, dims= 1 ) * one (eltype (f)), m, 1 )
119147end
148+
149+ * (x:: FillMatrix , y:: FillMatrix ) = mult_fill (x, y)
150+ * (x:: FillMatrix , y:: OnesMatrix ) = mult_fill (x, y)
151+ * (x:: OnesMatrix , y:: FillMatrix ) = mult_fill (x, y)
152+ * (x:: OnesMatrix , y:: OnesMatrix ) = mult_fill (x, y)
153+ * (x:: ZerosMatrix , y:: OnesMatrix ) = mult_zeros (x, y)
154+ * (x:: ZerosMatrix , y:: FillMatrix ) = mult_zeros (x, y)
155+ * (x:: FillMatrix , y:: ZerosMatrix ) = mult_zeros (x, y)
156+ * (x:: OnesMatrix , y:: ZerosMatrix ) = mult_zeros (x, y)
157+
158+ # function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
159+ # fB = similar(parent(a), size(b, 1), size(b, 2))
160+ # fill!(fB, b.value)
161+ # return a*fB
162+ # end
163+
164+ # function *(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
165+ # fB = similar(parent(a), size(b, 1), size(b, 2))
166+ # fill!(fB, b.value)
167+ # return a*fB
168+ # end
169+
170+ # function *(a::StridedMatrix{T}, b::Fill{T, 2}) where T
171+ # fB = similar(a, size(b, 1), size(b, 2))
172+ # fill!(fB, b.value)
173+ # return a*fB
174+ # end
175+
120176function _adjvec_mul_zeros (a:: Adjoint{T} , b:: Zeros{S, 1} ) where {T, S}
121177 la, lb = length (a), length (b)
122178 if la ≠ lb
0 commit comments