@@ -195,20 +195,18 @@ function broadcasted(::DefaultArrayStyle{1}, ::typeof(*), a::AbstractRange, b::O
195195 return _range_convert (AbstractVector{TT}, a)
196196end
197197
198- # To fix AD issues with `broadcast(T, x)`
199- # Avoids type inference issues with x -> T(x)
200- struct Constructor{T} end
201-
202- function (:: Constructor{T} )(x) where {T}
203- return T (x)
204- end
205-
206198for op in (:+ , :- )
207199 @eval begin
208200 function broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractVector , b:: ZerosVector )
209201 broadcast_shape (axes (a), axes (b)) == axes (a) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $b to a Vector first." ))
210202 TT = typeof ($ op (zero (eltype (a)), zero (eltype (b))))
211- eltype (a) === TT ? a : broadcasted (Constructor {TT} (), a)
203+ # Use `TT ∘ (+)` to fix AD issues with `broadcasted(TT, x)`
204+ eltype (a) === TT ? a : broadcasted (TT ∘ (+ ), a)
205+ end
206+ function broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: ZerosVector , b:: AbstractVector )
207+ broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
208+ TT = typeof ($ op (zero (eltype (a)), zero (eltype (b))))
209+ $ op === (+ ) && eltype (b) === TT ? b : broadcasted (TT ∘ ($ op), b)
212210 end
213211
214212 broadcasted (:: DefaultArrayStyle{1} , :: typeof ($ op), a:: AbstractFillVector , b:: ZerosVector ) =
@@ -219,18 +217,6 @@ for op in (:+, :-)
219217 end
220218end
221219
222- function broadcasted (:: DefaultArrayStyle{1} , :: typeof (+ ), a:: ZerosVector , b:: AbstractVector )
223- broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
224- TT = typeof (zero (eltype (a)) + zero (eltype (b)))
225- eltype (b) === TT ? b : broadcasted (Constructor {TT} (), b)
226- end
227-
228- function broadcasted (:: DefaultArrayStyle{1} , :: typeof (- ), a:: ZerosVector , b:: AbstractVector )
229- broadcast_shape (axes (a), axes (b)) == axes (b) || throw (ArgumentError (" Cannot broadcast $a and $b . Convert $a to a Vector first." ))
230- TT = typeof (zero (eltype (a)) - zero (eltype (b)))
231- broadcasted (TT ∘ (- ), b)
232- end
233-
234220# Need to prevent array-valued fills from broadcasting over entry
235221_broadcast_getindex_value (a:: AbstractFill{<:Number} ) = getindex_value (a)
236222_broadcast_getindex_value (a:: AbstractFill ) = Ref (getindex_value (a))
0 commit comments