Skip to content

Commit 02b7b04

Browse files
authored
simplify Broadcast object computations (#49395)
Code should normally preserve values, not the types of values. This ensures the user can define styles with metadata, and requires less type-parameter-based programming, but rather can focus on the values.
1 parent c237c0a commit 02b7b04

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

base/broadcast.jl

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,28 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
167167
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})
168168

169169
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
170+
style::Style
170171
f::F
171172
args::Args
172173
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
173-
end
174174

175-
Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} =
176-
Broadcasted{typeof(combine_styles(args...))}(f, args, axes)
177-
function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple}
178-
# using Core.Typeof rather than F preserves inferrability when f is a type
179-
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
175+
Broadcasted(style::Union{Nothing,BroadcastStyle}, f::Tuple, args::Tuple) = error() # disambiguation: tuple is not callable
176+
function Broadcasted(style::Union{Nothing,BroadcastStyle}, f::F, args::Tuple, axes=nothing) where {F}
177+
# using Core.Typeof rather than F preserves inferrability when f is a type
178+
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(style, f, args, axes)
179+
end
180+
181+
function Broadcasted(f::F, args::Tuple, axes=nothing) where {F}
182+
Broadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
183+
end
184+
185+
function Broadcasted{Style}(f::F, args, axes=nothing) where {Style, F}
186+
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(Style()::Style, f, args, axes)
187+
end
188+
189+
function Broadcasted{Style,Axes,F,Args}(f, args, axes) where {Style,Axes,F,Args}
190+
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
191+
end
180192
end
181193

182194
struct AndAnd end
@@ -194,16 +206,16 @@ function broadcasted(::OrOr, a, bc::Broadcasted)
194206
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
195207
end
196208

197-
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
209+
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{<:Any,Axes,F,Args}) where {NewStyle,Axes,F,Args} =
198210
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)::Broadcasted{NewStyle,Axes,F,Args}
199211

200212
function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
201213
print(io, Broadcasted)
202214
# Only show the style parameter if we have a set of axes — representing an instantiated
203215
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
204216
# computation that is not relevant for dispatch, confusing, and just extra line noise.
205-
bc.axes isa Tuple && print(io, '{', Style, '}')
206-
print(io, '(', bc.f, ", ", bc.args, ')')
217+
bc.axes isa Tuple && print(io, "{", Style, "}")
218+
print(io, "(", bc.f, ", ", bc.args, ")")
207219
nothing
208220
end
209221

@@ -231,7 +243,7 @@ BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
231243
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
232244
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))
233245

234-
argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args
246+
argtype(::Type{BC}) where {BC<:Broadcasted} = fieldtype(BC, :args)
235247
argtype(bc::Broadcasted) = argtype(typeof(bc))
236248

237249
@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
@@ -262,7 +274,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
262274
end
263275

264276
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
265-
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
277+
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
266278
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
267279

268280
_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
@@ -289,14 +301,14 @@ Custom [`BroadcastStyle`](@ref)s may override this default in cases where it is
289301
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
290302
of the `Broadcasted` object empty (populated with [`nothing`](@ref)).
291303
"""
292-
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
304+
@inline function instantiate(bc::Broadcasted)
293305
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
294306
axes = combine_axes(bc.args...)
295307
else
296308
axes = bc.axes
297309
check_broadcast_axes(axes, bc.args...)
298310
end
299-
return Broadcasted{Style}(bc.f, bc.args, axes)
311+
return Broadcasted(bc.style, bc.f, bc.args, axes)
300312
end
301313
instantiate(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc
302314
# Tuples don't need axes, but when they have axes (for .= assignment), we need to check them (#33020)
@@ -325,7 +337,7 @@ becomes
325337
This is an optional operation that may make custom implementation of broadcasting easier in
326338
some cases.
327339
"""
328-
function flatten(bc::Broadcasted{Style}) where {Style}
340+
function flatten(bc::Broadcasted)
329341
isflat(bc) && return bc
330342
# concatenate the nested arguments into {a, b, c, d}
331343
args = cat_nested(bc)
@@ -341,7 +353,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
341353
newf = @inline function(args::Vararg{Any,N}) where N
342354
f(makeargs(args...)...)
343355
end
344-
return Broadcasted{Style}(newf, args, bc.axes)
356+
return Broadcasted(bc.style, newf, args, bc.axes)
345357
end
346358
end
347359

@@ -895,11 +907,11 @@ materialize(x) = x
895907
return materialize!(dest, instantiate(Broadcasted(identity, (x,), axes(dest))))
896908
end
897909

898-
@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style}
910+
@inline function materialize!(dest, bc::Broadcasted{<:Any})
899911
return materialize!(combine_styles(dest, bc), dest, bc)
900912
end
901-
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{Style}) where {Style}
902-
return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
913+
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{<:Any})
914+
return copyto!(dest, instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest))))
903915
end
904916

905917
## general `copy` methods
@@ -909,7 +921,7 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) =
909921

910922
const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
911923

912-
@inline function copy(bc::Broadcasted{Style}) where {Style}
924+
@inline function copy(bc::Broadcasted)
913925
ElType = combine_eltypes(bc.f, bc.args)
914926
if Base.isconcretetype(ElType)
915927
# We can trust it and defer to the simpler `copyto!`
@@ -968,7 +980,7 @@ broadcast_unalias(::Nothing, src) = src
968980
# Preprocessing a `Broadcasted` does two things:
969981
# * unaliases any arguments from `dest`
970982
# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices
971-
@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes)
983+
@inline preprocess(dest, bc::Broadcasted) = Broadcasted(bc.style, bc.f, preprocess_args(dest, bc.args), bc.axes)
972984
preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
973985

974986
@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...)
@@ -1038,11 +1050,11 @@ ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischu
10381050
ischunkedbroadcast(R, args::Tuple{}) = true
10391051

10401052
# Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps
1041-
liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes)
1042-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes)
1043-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes)
1044-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes)
1045-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)(xor), map(liftfuncs, bc.args), bc.axes)
1053+
liftfuncs(bc::Broadcasted{<:Any,<:Any,<:Any}) = Broadcasted(bc.style, bc.f, map(liftfuncs, bc.args), bc.axes)
1054+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(sign)}) = Broadcasted(bc.style, identity, map(liftfuncs, bc.args), bc.axes)
1055+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(!)}) = Broadcasted(bc.style, ~, map(liftfuncs, bc.args), bc.axes)
1056+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(*)}) = Broadcasted(bc.style, &, map(liftfuncs, bc.args), bc.axes)
1057+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(==)}) = Broadcasted(bc.style, (~)(xor), map(liftfuncs, bc.args), bc.axes)
10461058
liftfuncs(x) = x
10471059

10481060
liftchunks(::Tuple{}) = ()
@@ -1315,26 +1327,26 @@ end
13151327
return broadcasted((args...) -> f(args...; kwargs...), args...)
13161328
end
13171329
end
1318-
@inline function broadcasted(f, args...)
1330+
@inline function broadcasted(f::F, args...) where {F}
13191331
args′ = map(broadcastable, args)
13201332
broadcasted(combine_styles(args′...), f, args′...)
13211333
end
13221334
# Due to the current Type{T}/DataType specialization heuristics within Tuples,
13231335
# the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in
13241336
# mapping broadcastable across the args. These additional methods with explicit
13251337
# arguments ensure we preserve Type{T}s in the first or second argument position.
1326-
@inline function broadcasted(f, arg1, args...)
1338+
@inline function broadcasted(f::F, arg1, args...) where {F}
13271339
arg1′ = broadcastable(arg1)
13281340
args′ = map(broadcastable, args)
13291341
broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...)
13301342
end
1331-
@inline function broadcasted(f, arg1, arg2, args...)
1343+
@inline function broadcasted(f::F, arg1, arg2, args...) where {F}
13321344
arg1′ = broadcastable(arg1)
13331345
arg2′ = broadcastable(arg2)
13341346
args′ = map(broadcastable, args)
13351347
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
13361348
end
1337-
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)
1349+
@inline broadcasted(style::BroadcastStyle, f::F, args...) where {F} = Broadcasted(style, f, args)
13381350

13391351
"""
13401352
BroadcastFunction{F} <: Function

test/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ let
880880
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
881881
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
882882

883-
@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()
883+
@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()
884884

885885
# inference on nested
886886
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))

0 commit comments

Comments
 (0)