Skip to content

Commit 5d0a527

Browse files
committed
simplify Broadcast object computations
Code should normally preserve values, not the types of values. This ensures the user can define styles with metadata, and requires less type-parameter-based programming, but rather can focus on the values.
1 parent 1512d6f commit 5d0a527

File tree

2 files changed

+42
-30
lines changed

2 files changed

+42
-30
lines changed

base/broadcast.jl

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -167,16 +167,28 @@ BroadcastStyle(a::AbstractArrayStyle{M}, ::DefaultArrayStyle{N}) where {M,N} =
167167
# copyto!(dest::AbstractArray, bc::Broadcasted{MyStyle})
168168

169169
struct Broadcasted{Style<:Union{Nothing,BroadcastStyle}, Axes, F, Args<:Tuple} <: Base.AbstractBroadcasted
170+
style::Style
170171
f::F
171172
args::Args
172173
axes::Axes # the axes of the resulting object (may be bigger than implied by `args` if this is nested inside a larger `Broadcasted`)
173-
end
174174

175-
Broadcasted(f::F, args::Args, axes=nothing) where {F, Args<:Tuple} =
176-
Broadcasted{typeof(combine_styles(args...))}(f, args, axes)
177-
function Broadcasted{Style}(f::F, args::Args, axes=nothing) where {Style, F, Args<:Tuple}
178-
# using Core.Typeof rather than F preserves inferrability when f is a type
179-
Broadcasted{Style, typeof(axes), Core.Typeof(f), Args}(f, args, axes)
175+
Broadcasted(style::Union{Nothing,BroadcastStyle}, f::Tuple, args::Tuple) = error() # disambiguation: tuple is not callable
176+
function Broadcasted(style::Union{Nothing,BroadcastStyle}, f::F, args::Tuple, axes=nothing) where {F}
177+
# using Core.Typeof rather than F preserves inferrability when f is a type
178+
return new{typeof(style), typeof(axes), Core.Typeof(f), typeof(args)}(style, f, args, axes)
179+
end
180+
181+
function Broadcasted(f::F, args::Tuple, axes=nothing) where {F}
182+
Broadcasted(combine_styles(args...)::BroadcastStyle, f, args, axes)
183+
end
184+
185+
function Broadcasted{Style}(f::F, args, axes=nothing) where {Style, F}
186+
return new{Style, typeof(axes), Core.Typeof(f), typeof(args)}(Style()::Style, f, args, axes)
187+
end
188+
189+
function Broadcasted{Style,Axes,F,Args}(f, args, axes) where {Style,Axes,F,Args}
190+
return new{Style, Axes, F, Args}(Style()::Style, f, args, axes)
191+
end
180192
end
181193

182194
struct AndAnd end
@@ -194,16 +206,16 @@ function broadcasted(::OrOr, a, bc::Broadcasted)
194206
broadcasted((a, args...) -> a || bcf.f(args...), a, bcf.args...)
195207
end
196208

197-
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{Style,Axes,F,Args}) where {NewStyle,Style,Axes,F,Args} =
209+
Base.convert(::Type{Broadcasted{NewStyle}}, bc::Broadcasted{<:Any,Axes,F,Args}) where {NewStyle,Axes,F,Args} =
198210
Broadcasted{NewStyle,Axes,F,Args}(bc.f, bc.args, bc.axes)::Broadcasted{NewStyle,Axes,F,Args}
199211

200212
function Base.show(io::IO, bc::Broadcasted{Style}) where {Style}
201213
print(io, Broadcasted)
202214
# Only show the style parameter if we have a set of axes — representing an instantiated
203215
# "outermost" Broadcasted. The styles of nested Broadcasteds represent an intermediate
204216
# computation that is not relevant for dispatch, confusing, and just extra line noise.
205-
bc.axes isa Tuple && print(io, '{', Style, '}')
206-
print(io, '(', bc.f, ", ", bc.args, ')')
217+
bc.axes isa Tuple && print(io, "{", Style, "}")
218+
print(io, "(", bc.f, ", ", bc.args, ")")
207219
nothing
208220
end
209221

@@ -231,7 +243,7 @@ BroadcastStyle(::Type{<:Broadcasted{Style}}) where {Style} = Style()
231243
BroadcastStyle(::Type{<:Broadcasted{S}}) where {S<:Union{Nothing,Unknown}} =
232244
throw(ArgumentError("Broadcasted{Unknown} wrappers do not have a style assigned"))
233245

234-
argtype(::Type{Broadcasted{Style,Axes,F,Args}}) where {Style,Axes,F,Args} = Args
246+
argtype(::Type{BC}) where {BC<:Broadcasted} = fieldtype(BC, :args)
235247
argtype(bc::Broadcasted) = argtype(typeof(bc))
236248

237249
@inline Base.eachindex(bc::Broadcasted) = _eachindex(axes(bc))
@@ -262,7 +274,7 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
262274
end
263275

264276
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
265-
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2))
277+
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, :args))
266278
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N
267279

268280
_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T))))
@@ -289,14 +301,14 @@ Custom [`BroadcastStyle`](@ref)s may override this default in cases where it is
289301
to compute and verify the resulting `axes` on-demand, leaving the `axis` field
290302
of the `Broadcasted` object empty (populated with [`nothing`](@ref)).
291303
"""
292-
@inline function instantiate(bc::Broadcasted{Style}) where {Style}
304+
@inline function instantiate(bc::Broadcasted)
293305
if bc.axes isa Nothing # Not done via dispatch to make it easier to extend instantiate(::Broadcasted{Style})
294306
axes = combine_axes(bc.args...)
295307
else
296308
axes = bc.axes
297309
check_broadcast_axes(axes, bc.args...)
298310
end
299-
return Broadcasted{Style}(bc.f, bc.args, axes)
311+
return Broadcasted(bc.style, bc.f, bc.args, axes)
300312
end
301313
instantiate(bc::Broadcasted{<:AbstractArrayStyle{0}}) = bc
302314
# Tuples don't need axes, but when they have axes (for .= assignment), we need to check them (#33020)
@@ -325,7 +337,7 @@ becomes
325337
This is an optional operation that may make custom implementation of broadcasting easier in
326338
some cases.
327339
"""
328-
function flatten(bc::Broadcasted{Style}) where {Style}
340+
function flatten(bc::Broadcasted)
329341
isflat(bc) && return bc
330342
# concatenate the nested arguments into {a, b, c, d}
331343
args = cat_nested(bc)
@@ -341,7 +353,7 @@ function flatten(bc::Broadcasted{Style}) where {Style}
341353
newf = @inline function(args::Vararg{Any,N}) where N
342354
f(makeargs(args...)...)
343355
end
344-
return Broadcasted{Style}(newf, args, bc.axes)
356+
return Broadcasted(bc.style, newf, args, bc.axes)
345357
end
346358
end
347359

@@ -895,11 +907,11 @@ materialize(x) = x
895907
return materialize!(dest, instantiate(Broadcasted(identity, (x,), axes(dest))))
896908
end
897909

898-
@inline function materialize!(dest, bc::Broadcasted{Style}) where {Style}
910+
@inline function materialize!(dest, bc::Broadcasted{<:Any})
899911
return materialize!(combine_styles(dest, bc), dest, bc)
900912
end
901-
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{Style}) where {Style}
902-
return copyto!(dest, instantiate(Broadcasted{Style}(bc.f, bc.args, axes(dest))))
913+
@inline function materialize!(::BroadcastStyle, dest, bc::Broadcasted{<:Any})
914+
return copyto!(dest, instantiate(Broadcasted(bc.style, bc.f, bc.args, axes(dest))))
903915
end
904916

905917
## general `copy` methods
@@ -909,7 +921,7 @@ copy(bc::Broadcasted{<:Union{Nothing,Unknown}}) =
909921

910922
const NonleafHandlingStyles = Union{DefaultArrayStyle,ArrayConflict}
911923

912-
@inline function copy(bc::Broadcasted{Style}) where {Style}
924+
@inline function copy(bc::Broadcasted)
913925
ElType = combine_eltypes(bc.f, bc.args)
914926
if Base.isconcretetype(ElType)
915927
# We can trust it and defer to the simpler `copyto!`
@@ -968,7 +980,7 @@ broadcast_unalias(::Nothing, src) = src
968980
# Preprocessing a `Broadcasted` does two things:
969981
# * unaliases any arguments from `dest`
970982
# * "extrudes" the arguments where it is advantageous to pre-compute the broadcasted indices
971-
@inline preprocess(dest, bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, preprocess_args(dest, bc.args), bc.axes)
983+
@inline preprocess(dest, bc::Broadcasted) = Broadcasted(bc.style, bc.f, preprocess_args(dest, bc.args), bc.axes)
972984
preprocess(dest, x) = extrude(broadcast_unalias(dest, x))
973985

974986
@inline preprocess_args(dest, args::Tuple) = (preprocess(dest, args[1]), preprocess_args(dest, tail(args))...)
@@ -1038,11 +1050,11 @@ ischunkedbroadcast(R, args::Tuple{<:BroadcastedChunkableOp,Vararg{Any}}) = ischu
10381050
ischunkedbroadcast(R, args::Tuple{}) = true
10391051

10401052
# Convert compatible functions to chunkable ones. They must also be green-lighted as ChunkableOps
1041-
liftfuncs(bc::Broadcasted{Style}) where {Style} = Broadcasted{Style}(bc.f, map(liftfuncs, bc.args), bc.axes)
1042-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(sign)}) where {Style} = Broadcasted{Style}(identity, map(liftfuncs, bc.args), bc.axes)
1043-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(!)}) where {Style} = Broadcasted{Style}(~, map(liftfuncs, bc.args), bc.axes)
1044-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(*)}) where {Style} = Broadcasted{Style}(&, map(liftfuncs, bc.args), bc.axes)
1045-
liftfuncs(bc::Broadcasted{Style,<:Any,typeof(==)}) where {Style} = Broadcasted{Style}((~)(xor), map(liftfuncs, bc.args), bc.axes)
1053+
liftfuncs(bc::Broadcasted{<:Any,<:Any,<:Any}) = Broadcasted(bc.style, bc.f, map(liftfuncs, bc.args), bc.axes)
1054+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(sign)}) = Broadcasted(bc.style, identity, map(liftfuncs, bc.args), bc.axes)
1055+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(!)}) = Broadcasted(bc.style, ~, map(liftfuncs, bc.args), bc.axes)
1056+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(*)}) = Broadcasted(bc.style, &, map(liftfuncs, bc.args), bc.axes)
1057+
liftfuncs(bc::Broadcasted{<:Any,<:Any,typeof(==)}) = Broadcasted(bc.style, (~)(xor), map(liftfuncs, bc.args), bc.axes)
10461058
liftfuncs(x) = x
10471059

10481060
liftchunks(::Tuple{}) = ()
@@ -1315,26 +1327,26 @@ end
13151327
return broadcasted((args...) -> f(args...; kwargs...), args...)
13161328
end
13171329
end
1318-
@inline function broadcasted(f, args...)
1330+
@inline function broadcasted(f::F, args...) where {F}
13191331
args′ = map(broadcastable, args)
13201332
broadcasted(combine_styles(args′...), f, args′...)
13211333
end
13221334
# Due to the current Type{T}/DataType specialization heuristics within Tuples,
13231335
# the totally generic varargs broadcasted(f, args...) method above loses Type{T}s in
13241336
# mapping broadcastable across the args. These additional methods with explicit
13251337
# arguments ensure we preserve Type{T}s in the first or second argument position.
1326-
@inline function broadcasted(f, arg1, args...)
1338+
@inline function broadcasted(f::F, arg1, args...) where {F}
13271339
arg1′ = broadcastable(arg1)
13281340
args′ = map(broadcastable, args)
13291341
broadcasted(combine_styles(arg1′, args′...), f, arg1′, args′...)
13301342
end
1331-
@inline function broadcasted(f, arg1, arg2, args...)
1343+
@inline function broadcasted(f::F, arg1, arg2, args...) where {F}
13321344
arg1′ = broadcastable(arg1)
13331345
arg2′ = broadcastable(arg2)
13341346
args′ = map(broadcastable, args)
13351347
broadcasted(combine_styles(arg1′, arg2′, args′...), f, arg1′, arg2′, args′...)
13361348
end
1337-
@inline broadcasted(::S, f, args...) where S<:BroadcastStyle = Broadcasted{S}(f, args)
1349+
@inline broadcasted(style::BroadcastStyle, f::F, args...) where {F} = Broadcasted(style, f, args)
13381350

13391351
"""
13401352
BroadcastFunction{F} <: Function

test/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ let
880880
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{Broadcast.ArrayConflict}
881881
@test Broadcast.broadcasted(+, AD1(rand(3)), AD2(rand(3))) isa Broadcast.Broadcasted{<:Broadcast.AbstractArrayStyle{Any}}
882882

883-
@test @inferred(Base.IteratorSize(Broadcast.broadcasted((1,2,3),a1,zeros(3,3,3)))) === Base.HasShape{3}()
883+
@test @inferred(Base.IteratorSize(Broadcast.broadcasted(+, (1,2,3), a1, zeros(3,3,3)))) === Base.HasShape{3}()
884884

885885
# inference on nested
886886
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))

0 commit comments

Comments
 (0)