Skip to content

Commit 9c5639d

Browse files
authored
Fix nested broadcasting of BlockedArray (#485)
Fixes #295. That issue seemed to be caused by code in the generic block broadcasting logic such as https://github.com/JuliaArrays/BlockArrays.jl/blob/90ddfa96fb0ade3da4f67ac285fcb340350c19c6/src/blockbroadcast.jl#L121-L129 assuming the broadcasting expression is already flat. This PR adds a call to `Broadcast.flatten` to explicitly flatten the broadcast expression in the generic block broadcast code. It appears that at some point `BlockedStyle` broadcasting expressions were being flattened in this `Broadcast.instantiate` definition: https://github.com/JuliaArrays/BlockArrays.jl/blob/90ddfa96fb0ade3da4f67ac285fcb340350c19c6/src/blockbroadcast.jl#L199-L202 but that was changed in #193. Reverting that change broke the tests introduced in #193 so I guess it is not safe to do that. I have to admit I'm not sure if this is the best place to call `Broadcast.flatten`, a lot of this broadcasting code is hard to follow. The place I put it in this PR fixes #295, doesn't break other tests, and is hopefully least likely to interfere with custom implementations of block broadcasting in downstream packages.
1 parent 3edffcb commit 9c5639d

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/blockbroadcast.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,9 +138,10 @@ end
138138
end
139139

140140
function _generic_blockbroadcast_copyto!(dest::AbstractArray,
141-
bc::Broadcasted{<:AbstractBlockStyle{NDims}, <:Any, <:Any, Args}) where {NDims, Args <: Tuple}
141+
bc1::Broadcasted{<:AbstractBlockStyle{NDims}}) where {NDims}
142142

143-
NArgs = fieldcount(Args)
143+
bc = Broadcast.flatten(bc1)
144+
NArgs = length(bc.args)
144145

145146
bs = axes(bc)
146147
if !blockisequal(axes(dest), bs)

test/test_blockbroadcast.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,26 @@ using StaticArrays
6767
@test A .+ 1 .+ B == Vector(A) .+ 1 .+ B == Vector(A) .+ 1 .+ Matrix(B)
6868

6969
@testset "preserve structure" begin
70-
x = BlockedArray(1:6, Fill(3,2))
71-
@test x + x isa BlockedVector{Int,<:AbstractRange}
72-
@test 2x + x isa BlockedVector{Int,<:AbstractRange}
73-
@test 2 .* (x .+ 1) isa BlockedVector{Int,<:AbstractRange}
70+
x = BlockedArray(1:6, Fill(3,2))
71+
@test x + x isa BlockedVector{Int,<:AbstractRange}
72+
@test 2x + x isa BlockedVector{Int,<:AbstractRange}
73+
@test 2 .* (x .+ 1) isa BlockedVector{Int,<:AbstractRange}
74+
end
75+
76+
@testset "nested in-place broadcast" begin
77+
x = BlockedVector(randn(4), [2, 2])
78+
y = BlockedVector(randn(4), [2, 2])
79+
dest = copy(x)
80+
dest .+= 2 .* y
81+
@test dest x + 2y
82+
end
83+
84+
@testset "0-dim nested in-place broadcast" begin
85+
x = BlockedArray(randn(()))
86+
y = BlockedArray(randn(()))
87+
dest = copy(x)
88+
dest .+= 2 .* y
89+
@test dest x + 2y
7490
end
7591
end
7692

0 commit comments

Comments
 (0)