Skip to content

Commit 68e0813

Browse files
committed
fix collect on stateful iterators
Generalization of #41919 Fixes #42168
1 parent 4a048d3 commit 68e0813

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

base/array.jl

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -643,23 +643,38 @@ julia> collect(Float64, 1:2:5)
643643
"""
644644
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))
645645

646-
_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
647-
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
646+
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
647+
copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
648648
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
649649
a = Vector{T}()
650650
for x in itr
651-
push!(a,x)
651+
push!(a, x)
652652
end
653653
return a
654654
end
655655

656656
# make a collection similar to `c` and appropriate for collecting `itr`
657-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
658-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
659-
similar(c, T, Int(length(itr)::Integer))
660-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
661-
similar(c, T, axes(itr))
662-
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
657+
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)
658+
659+
_similar_shape(itr, ::SizeUnknown) = nothing
660+
_similar_shape(itr, ::HasLength) = length(itr)::Integer
661+
_similar_shape(itr, ::HasShape) = axes(itr)
662+
663+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
664+
similar(c, T, 0)
665+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
666+
similar(c, T, len)
667+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
668+
similar(c, T, axs)
669+
670+
# make a collection appropriate for collecting `itr::Generator`
671+
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
672+
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
673+
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
674+
675+
# used by syntax lowering for simple typed comprehensions
676+
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))
677+
663678

664679
"""
665680
collect(collection)
@@ -698,10 +713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
698713
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))
699714

700715
_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
701-
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
716+
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)
702717

703718
function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
704-
a = _similar_for(cont, eltype(itr), itr, isz)
719+
a = _similar_for(cont, eltype(itr), itr, isz, nothing)
705720
for x in itr
706721
push!(a,x)
707722
end
@@ -759,24 +774,19 @@ else
759774
end
760775
end
761776

762-
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
763-
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
764-
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
765-
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
766-
767777
function collect(itr::Generator)
768778
isz = IteratorSize(itr.iter)
769779
et = @default_eltype(itr)
770780
if isa(isz, SizeUnknown)
771781
return grow_to!(Vector{et}(), itr)
772782
else
773-
shape = isz isa HasLength ? length(itr) : axes(itr)
783+
shp = _similar_shape(itr, isz)
774784
y = iterate(itr)
775785
if y === nothing
776-
return _array_for(et, itr.iter, isz)
786+
return _array_for(et, isz, shp)
777787
end
778788
v1, st = y
779-
dest = _array_for(typeof(v1), itr.iter, isz, shape)
789+
dest = _array_for(typeof(v1), isz, shp)
780790
# The typeassert gives inference a helping hand on the element type and dimensionality
781791
# (work-around for #28382)
782792
et′ = et <: Type ? Type : et
@@ -786,15 +796,22 @@ function collect(itr::Generator)
786796
end
787797

788798
_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
789-
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
799+
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)
790800

791801
function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
802+
et = @default_eltype(itr)
803+
shp = _similar_shape(itr, isz)
792804
y = iterate(itr)
793805
if y === nothing
794-
return _similar_for(c, @default_eltype(itr), itr, isz)
806+
return _similar_for(c, et, itr, isz, shp)
795807
end
796808
v1, st = y
797-
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
809+
dest = _similar_for(c, typeof(v1), itr, isz, shp)
810+
# The typeassert gives inference a helping hand on the element type and dimensionality
811+
# (work-around for #28382)
812+
et′ = et <: Type ? Type : et
813+
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
814+
collect_to_with_first!(dest, v1, itr, st)::RT
798815
end
799816

800817
function collect_to_with_first!(dest::AbstractArray, v1, itr, st)

base/dict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
826826
isempty(t::ImmutableDict) = !isdefined(t, :parent)
827827
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()
828828

829-
_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
830-
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
829+
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
830+
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
831831
throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead"))

base/set.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
4444
# by default, a Set is returned
4545
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
4646

47-
_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
47+
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)
4848

4949
function show(io::IO, s::Set)
5050
if isempty(s)

src/julia-syntax.scm

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@
27342734
(check-no-return expr)
27352735
(if (has-break-or-continue? expr)
27362736
(error "break or continue outside loop"))
2737-
(let ((result (gensy))
2737+
(let ((result (make-ssavalue))
27382738
(idx (gensy))
27392739
(oneresult (make-ssavalue))
27402740
(prod (make-ssavalue))
@@ -2758,16 +2758,14 @@
27582758
(let ((overall-itr (if (length= itrs 1) (car iv) prod)))
27592759
`(scope-block
27602760
(block
2761-
(local ,result) (local ,idx)
2761+
(local ,idx)
27622762
,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
27632763
,.(if (length= itrs 1)
27642764
'()
27652765
`((= ,prod (call (top product) ,@iv))))
27662766
(= ,isz (call (top IteratorSize) ,overall-itr))
27672767
(= ,szunk (call (core isa) ,isz (top SizeUnknown)))
2768-
(if ,szunk
2769-
(= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
2770-
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
2768+
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
27712769
(= ,idx (call (top first) (call (top LinearIndices) ,result)))
27722770
,(construct-loops (reverse itrs) (reverse iv))
27732771
,result)))))

test/iterators.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -293,11 +293,14 @@ let (a, b) = (1:3, [4 6;
293293
end
294294

295295
# collect stateful iterator
296-
let
297-
itr = (i+1 for i in Base.Stateful([1,2,3]))
296+
let itr
297+
itr = Iterators.Stateful(Iterators.map(identity, 1:5))
298+
@test collect(itr) == 1:5
299+
@test collect(itr) == Int[] # Stateful do not preserve shape
300+
itr = (i+1 for i in Base.Stateful([1, 2, 3]))
298301
@test collect(itr) == [2, 3, 4]
299-
A = zeros(Int, 0, 0)
300-
itr = (i-1 for i in Base.Stateful(A))
302+
@test collect(itr) == Int[] # Stateful do not preserve shape
303+
itr = (i-1 for i in Base.Stateful(zeros(Int, 0, 0)))
301304
@test collect(itr) == Int[] # Stateful do not preserve shape
302305
end
303306

0 commit comments

Comments
 (0)