Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 17 additions & 37 deletions src/blockbroadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ similar(bc::Broadcasted{PseudoBlockStyle{N}}, ::Type{T}) where {T,N} =
SubBlockIterator(subblock_lasts::Vector{Int}, block_lasts::Vector{Int})
SubBlockIterator(A::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}} where N, dim::Integer)

An iterator for iterating `BlockIndexRange` of the blocks specified by
Return an iterator over the `BlockIndexRange`s of the blocks specified by
`subblock_lasts`. The `Block` index part of `BlockIndexRange` is
determined by `subblock_lasts`. That is to say, the `Block` index first
specifies one of the block represented by `subblock_lasts` and then the
Expand All @@ -63,41 +63,27 @@ be ensured by the caller.
```jldoctest
julia> using BlockArrays

julia> import BlockArrays: SubBlockIterator, BlockIndexRange
julia> import BlockArrays: SubBlockIterator

julia> A = BlockArray(1:6, 1:3);

julia> subblock_lasts = blocklasts(axes(A, 1));

julia> @assert subblock_lasts == [1, 3, 6];
julia> subblock_lasts = blocklasts(axes(A, 1))
3-element ArrayLayouts.RangeCumsum{Int64, UnitRange{Int64}}:
1
3
6

julia> block_lasts = [1, 3, 4, 6];

julia> for idx in SubBlockIterator(subblock_lasts, block_lasts)
B = @show view(A, idx)
@assert !(parent(B) isa BlockArray)
idx :: BlockIndexRange
idx.block :: Block{1}
idx.indices :: Tuple{UnitRange}
end
view(A, idx) = 1:1
view(A, idx) = 2:3
view(A, idx) = 4:4
view(A, idx) = 5:6

julia> [idx.block.n[1] for idx in SubBlockIterator(subblock_lasts, block_lasts)]
4-element Vector{Int64}:
1
2
3
3
julia> itr = SubBlockIterator(subblock_lasts, block_lasts)
SubBlockIterator([1, 3, 6], [1, 3, 4, 6])

julia> [idx.indices[1] for idx in SubBlockIterator(subblock_lasts, block_lasts)]
4-element Vector{UnitRange{Int64}}:
1:1
1:2
1:1
2:3
julia> collect(itr)
4-element Vector{BlockArrays.BlockIndexRange{1, Tuple{UnitRange{Int64}}}}:
Block(1)[1:1]
Block(2)[1:2]
Block(3)[1:1]
Block(3)[2:3]
```
"""
struct SubBlockIterator
Expand All @@ -114,15 +100,9 @@ Base.length(it::SubBlockIterator) = length(it.block_lasts)
SubBlockIterator(arr::AbstractArray, bs::NTuple{N,AbstractUnitRange{Int}}, dim::Integer) where N =
SubBlockIterator(blocklasts(axes(arr, dim)), blocklasts(bs[dim]))

function Base.iterate(it::SubBlockIterator, state=nothing)
if state === nothing
i,j = 1,1
else
i, j = state
end
length(it.block_lasts)+1 == i && return nothing
function Base.iterate(it::SubBlockIterator, (i, j) = (1,1))
i > length(it.block_lasts) && return nothing
idx = i == 1 ? (1:it.block_lasts[i]) : (it.block_lasts[i-1]+1:it.block_lasts[i])

bir = Block(j)[j == 1 ? idx : idx .- it.subblock_lasts[j-1]]
if it.subblock_lasts[j] == it.block_lasts[i]
j += 1
Expand Down