Skip to content

Commit a4d9436

Browse files
committed
Treat non-indexable types as scalars in broadcast
1 parent 1069aae commit a4d9436

File tree

7 files changed

+45
-14
lines changed

7 files changed

+45
-14
lines changed

base/abstractarray.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,6 @@ immutable IndicesList <: IndicesBehavior end # indices like (:cat, :dog,
151151

152152
indicesbehavior(A::AbstractArray) = indicesbehavior(typeof(A))
153153
indicesbehavior{T<:AbstractArray}(::Type{T}) = IndicesStartAt1()
154-
indicesbehavior(::Number) = IndicesStartAt1()
155154

156155
abstract IndicesPerformance
157156
immutable IndicesFast1D <: IndicesPerformance end # indices(A, d) is fast
@@ -412,8 +411,9 @@ end
412411
promote_indices(a::AbstractArray, b::AbstractArray) = _promote_indices(indicesbehavior(a), indicesbehavior(b), a, b)
413412
_promote_indices(::IndicesStartAt1, ::IndicesStartAt1, a, b) = a
414413
_promote_indices(::IndicesBehavior, ::IndicesBehavior, a, b) = throw(ArgumentError("types $(typeof(a)) and $(typeof(b)) do not have promote_indices defined"))
415-
promote_indices(a::Number, b::AbstractArray) = b
416-
promote_indices(a::AbstractArray, b::Number) = a
414+
promote_indices(a, b::AbstractArray) = b
415+
promote_indices(a::AbstractArray, b) = a
416+
promote_indices(a, b) = a
417417

418418
# Strip off the index-changing container---this assumes that `parent`
419419
# performs such an operation. TODO: since few things in Base need this, it
@@ -1459,10 +1459,20 @@ end
14591459
promote_eltype_op(::Any) = (@_pure_meta; Bottom)
14601460
promote_eltype_op{T}(op, ::AbstractArray{T}) = (@_pure_meta; promote_op(op, T))
14611461
promote_eltype_op{T}(op, ::T ) = (@_pure_meta; promote_op(op, T))
1462+
promote_eltype_op{T}(op, Ts::AbstractArray{DataType}, ::AbstractArray{T}) = typejoin((promote_op(op, S, T) for S in Ts)...)
1463+
promote_eltype_op{T}(op, Ts::AbstractArray{DataType}, ::Type{T} ) = typejoin((promote_op(op, S, T) for S in Ts)...)
1464+
promote_eltype_op{T}(op, Ts::AbstractArray{DataType}, ::T ) = typejoin((promote_op(op, S, T) for S in Ts)...)
1465+
promote_eltype_op{R<:DataType,S}(op, Ts::AbstractArray{R}, ::AbstractArray{S}) = promote_eltype_op(op, Ts, S)
14621466
promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S))
14631467
promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::S) = (@_pure_meta; promote_op(op, R, S))
14641468
promote_eltype_op{R,S}(op, ::R, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S))
1465-
promote_eltype_op(op, A, B, C, D...) = (@_pure_meta; promote_op(op, eltype(A), promote_eltype_op(op, B, C, D...)))
1469+
promote_eltype_op{R,S}(op, ::AbstractArray{R}, ::Type{S}) = (@_pure_meta; promote_op(op, R, S))
1470+
promote_eltype_op{R,S}(op, ::Type{R}, ::AbstractArray{S}) = (@_pure_meta; promote_op(op, R, S))
1471+
promote_eltype_op{R,S}(op, ::Type{R}, ::Type{S}) = (@_pure_meta; promote_op(op, R, S))
1472+
promote_eltype_op{R,S}(op, ::Type{R}, ::S) = (@_pure_meta; promote_op(op, R, S))
1473+
promote_eltype_op{R,S}(op, ::R, ::Type{S}) = (@_pure_meta; promote_op(op, R, S))
1474+
promote_eltype_op{R,S}(op, ::R, ::S) = (@_pure_meta; promote_op(op, R, S))
1475+
promote_eltype_op(op, A, B, C, D...) = promote_eltype_op(op, A, promote_eltype_op(op, B, C, D...))
14661476

14671477
## 1 argument
14681478
map!{F}(f::F, A::AbstractArray) = map!(f, A, A)

base/broadcast.jl

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,9 @@ export broadcast_getindex, broadcast_setindex!
1313
## Calculate the broadcast shape of the arguments, or error if incompatible
1414
# array inputs
1515
broadcast_shape() = ()
16-
broadcast_shape(A) = shape(A)
17-
@inline broadcast_shape(A, B...) = broadcast_shape((), shape(A), map(shape, B)...)
16+
broadcast_shape(A) = ()
17+
broadcast_shape(A::AbstractArray) = shape(A)
18+
@inline broadcast_shape(A, B...) = broadcast_shape((), broadcast_shape(A), map(broadcast_shape, B)...)
1819
# shape inputs
1920
broadcast_shape(shape::Tuple) = shape
2021
@inline broadcast_shape(shape::Tuple, shape1::Tuple, shapes::Tuple...) = broadcast_shape(_bcs((), shape, shape1), shapes...)
@@ -40,7 +41,7 @@ _bcsm(a::Number, b::Number) = a == b || b == 1
4041
## Check that all arguments are broadcast compatible with shape
4142
# comparing one input against a shape
4243
check_broadcast_shape(shp) = nothing
43-
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, shape(A))
44+
check_broadcast_shape(shp, A) = check_broadcast_shape(shp, broadcast_shape(A))
4445
check_broadcast_shape(::Tuple{}, ::Tuple{}) = nothing
4546
check_broadcast_shape(shp, ::Tuple{}) = nothing
4647
check_broadcast_shape(::Tuple{}, Ashp::Tuple) = throw(DimensionMismatch("cannot broadcast array to have fewer dimensions"))
@@ -63,8 +64,8 @@ end
6364
@inline _newindex(out, I) = out # can truncate if indexmap is shorter than I
6465
@inline _newindex(out, I, keep::Bool, indexmap...) = _newindex((out..., ifelse(keep, I[1], 1)), tail(I), indexmap...)
6566

66-
newindexer(sz, x::Number) = ()
67-
@inline newindexer(sz, A) = _newindexer(sz, size(A))
67+
newindexer(sz, x) = ()
68+
@inline newindexer(sz, A::AbstractArray) = _newindexer(sz, size(A))
6869
@inline _newindexer(sz, szA::Tuple{}) = ()
6970
@inline _newindexer(sz, szA) = (sz[1] == szA[1], _newindexer(tail(sz), tail(szA))...)
7071

@@ -79,6 +80,10 @@ const bitcache_size = 64 * bitcache_chunks # do not change this
7980
dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
8081
Base.copy_to_bitarray_chunks!(Bc, ((bind - 1) << 6) + 1, C, 1, min(bitcache_size, (length(Bc)-bind+1) << 6))
8182

83+
# Since we can't make T[1] return T, use this inside `_broadcast!`
84+
@inline _broadcast_getvals(A, I) = A
85+
@inline _broadcast_getvals(A::AbstractArray, I) = A[I]
86+
8287
## Broadcasting core
8388
# nargs encodes the number of As arguments (which matches the number
8489
# of indexmaps). The first two type parameters are to ensure specialization.
@@ -92,7 +97,7 @@ dumpbitcache(Bc::Vector{UInt64}, bind::Int, C::Vector{Bool}) =
9297
# reverse-broadcast the indices
9398
@nexprs $nargs i->(I_i = newindex(I, imap_i))
9499
# extract array values
95-
@nexprs $nargs i->(@inbounds val_i = A_i[I_i])
100+
@nexprs $nargs i->(@inbounds val_i = _broadcast_getvals(A_i, I_i))
96101
# call the function and store the result
97102
@inbounds B[I] = @ncall $nargs f val
98103
end
@@ -140,7 +145,10 @@ end
140145
B
141146
end
142147

143-
@inline broadcast(f, As...) = broadcast!(f, allocate_for(Array{promote_eltype_op(f, As...)}, As, broadcast_shape(As...)), As...)
148+
@inline _broadcast(::Type{Val{false}}, f, a...) = f(a...)
149+
@inline _broadcast(::Type{Val{true}}, f, As...) = broadcast!(f, allocate_for(Array{promote_eltype_op(f, As...)}, As, broadcast_shape(As...)), As...)
150+
151+
@inline broadcast(f, As...) = (b = any(isa(T,AbstractArray) for T in As); _broadcast(Val{b}, f, As...))
144152

145153
@inline bitbroadcast(f, As...) = broadcast!(f, allocate_for(BitArray, As, broadcast_shape(As...)), As...)
146154

base/float.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,13 @@ promote_rule(::Type{Float64}, ::Type{Float32}) = Float64
199199
widen(::Type{Float16}) = Float32
200200
widen(::Type{Float32}) = Float64
201201

202+
promote_op{Op<:typeof(trunc),T<:Union{Float32,Float64}}(::Op, ::Type{Signed}, ::Type{T}) = Int
203+
promote_op{Op<:typeof(trunc),T<:Union{Float32,Float64}}(::Op, ::Type{Unsigned}, ::Type{T}) = UInt
204+
promote_op{Op<:typeof(trunc),R,S}(::Op, ::Type{R}, ::Type{S}) = R
205+
for f in (ceil, floor, round)
206+
@eval promote_op{Op<:$(typeof(f)),R,S}(::Op, ::Type{R}, ::Type{S}) = promote_op($trunc, R, S)
207+
end
208+
202209
## floating point arithmetic ##
203210
-(x::Float32) = box(Float32,neg_float(unbox(Float32,x)))
204211
-(x::Float64) = box(Float64,neg_float(unbox(Float64,x)))

base/number.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,6 @@ one(x::Number) = oftype(x,1)
6464
one{T<:Number}(::Type{T}) = convert(T,1)
6565

6666
factorial(x::Number) = gamma(x + 1) # fallback for x not Integer
67+
68+
promote_op{T<:Number}(op, ::Type{T}) = typeof(op(one(T)))
69+
promote_op{R,S<:Number}(op::Type{R}, ::Type{S}) = R # to handle ambiguities

base/parse.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,3 +194,5 @@ function parse(str::AbstractString; raise::Bool=true)
194194
end
195195
return ex
196196
end
197+
198+
promote_op{Op<:typeof(parse),R,S}(::Op, ::Type{R}, ::Type{S}) = R

base/promotion.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,9 +222,10 @@ minmax(x::Real, y::Real) = minmax(promote(x, y)...)
222222
# for the multiplication of two types,
223223
# promote_op{R<:MyType,S<:MyType}(::typeof(*), ::Type{R}, ::Type{S}) = MyType{multype(R,S)}
224224
promote_op(::Any) = (@_pure_meta; Bottom)
225-
promote_op(::Any, T) = (@_pure_meta; T)
225+
promote_op(::Any, T) = (@_pure_meta; Any)
226226
promote_op{T}(::Type{T}, ::Any) = (@_pure_meta; T)
227-
promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = (@_pure_meta; promote_type(R, S))
227+
promote_op{R,S}(::Any, ::Type{R}, ::Type{S}) = (@_pure_meta; Any)
228+
promote_op{Op<:typeof(convert),R,S}(::Op, ::Type{R}, ::Type{S}) = (@_pure_meta; R)
228229
promote_op(op, T, S, U, V...) = (@_pure_meta; promote_op(op, T, promote_op(op, S, U, V...)))
229230

230231
## catch-alls to prevent infinite recursion when definitions are missing ##

test/broadcast.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ m = [1:2;]'
165165
@test @inferred([0,1.2].+reshape([0,-2],1,1,2)) == reshape([0 -2; 1.2 -0.8],2,1,2)
166166
rt = Base.return_types(.+, Tuple{Array{Float64, 3}, Array{Int, 1}})
167167
@test length(rt) == 1 && rt[1] == Array{Float64, 3}
168-
rt = Base.return_types(broadcast, Tuple{Function, Array{Float64, 3}, Array{Int, 1}})
168+
rt = Base.return_types(broadcast, Tuple{typeof(+), Array{Float64, 3}, Array{Int, 1}})
169169
@test length(rt) == 1 && rt[1] == Array{Float64, 3}
170170
rt = Base.return_types(broadcast!, Tuple{Function, Array{Float64, 3}, Array{Float64, 3}, Array{Int, 1}})
171171
@test length(rt) == 1 && rt[1] == Array{Float64, 3}

0 commit comments

Comments
 (0)