Skip to content

Commit 4c90ed9

Browse files
authored
Merge pull request #42172 from JuliaLang/jn/42168
fix collect on stateful iterators
2 parents 60423e2 + 68e0813 commit 4c90ed9

File tree

11 files changed

+77
-88
lines changed

11 files changed

+77
-88
lines changed

base/array.jl

Lines changed: 39 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -643,23 +643,38 @@ julia> collect(Float64, 1:2:5)
643643
"""
644644
collect(::Type{T}, itr) where {T} = _collect(T, itr, IteratorSize(itr))
645645

646-
_collect(::Type{T}, itr, isz::HasLength) where {T} = copyto!(Vector{T}(undef, Int(length(itr)::Integer)), itr)
647-
_collect(::Type{T}, itr, isz::HasShape) where {T} = copyto!(similar(Array{T}, axes(itr)), itr)
646+
_collect(::Type{T}, itr, isz::Union{HasLength,HasShape}) where {T} =
647+
copyto!(_array_for(T, isz, _similar_shape(itr, isz)), itr)
648648
function _collect(::Type{T}, itr, isz::SizeUnknown) where T
649649
a = Vector{T}()
650650
for x in itr
651-
push!(a,x)
651+
push!(a, x)
652652
end
653653
return a
654654
end
655655

656656
# make a collection similar to `c` and appropriate for collecting `itr`
657-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown) where {T} = similar(c, T, 0)
658-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength) where {T} =
659-
similar(c, T, Int(length(itr)::Integer))
660-
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape) where {T} =
661-
similar(c, T, axes(itr))
662-
_similar_for(c, ::Type{T}, itr, isz) where {T} = similar(c, T)
657+
_similar_for(c, ::Type{T}, itr, isz, shp) where {T} = similar(c, T)
658+
659+
_similar_shape(itr, ::SizeUnknown) = nothing
660+
_similar_shape(itr, ::HasLength) = length(itr)::Integer
661+
_similar_shape(itr, ::HasShape) = axes(itr)
662+
663+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::SizeUnknown, ::Nothing) where {T} =
664+
similar(c, T, 0)
665+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasLength, len::Integer) where {T} =
666+
similar(c, T, len)
667+
_similar_for(c::AbstractArray, ::Type{T}, itr, ::HasShape, axs) where {T} =
668+
similar(c, T, axs)
669+
670+
# make a collection appropriate for collecting `itr::Generator`
671+
_array_for(::Type{T}, ::SizeUnknown, ::Nothing) where {T} = Vector{T}(undef, 0)
672+
_array_for(::Type{T}, ::HasLength, len::Integer) where {T} = Vector{T}(undef, Int(len))
673+
_array_for(::Type{T}, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
674+
675+
# used by syntax lowering for simple typed comprehensions
676+
_array_for(::Type{T}, itr, isz) where {T} = _array_for(T, isz, _similar_shape(itr, isz))
677+
663678

664679
"""
665680
collect(collection)
@@ -698,10 +713,10 @@ collect(A::AbstractArray) = _collect_indices(axes(A), A)
698713
collect_similar(cont, itr) = _collect(cont, itr, IteratorEltype(itr), IteratorSize(itr))
699714

700715
_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
701-
copyto!(_similar_for(cont, eltype(itr), itr, isz), itr)
716+
copyto!(_similar_for(cont, eltype(itr), itr, isz, _similar_shape(itr, isz)), itr)
702717

703718
function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
704-
a = _similar_for(cont, eltype(itr), itr, isz)
719+
a = _similar_for(cont, eltype(itr), itr, isz, nothing)
705720
for x in itr
706721
push!(a,x)
707722
end
@@ -759,24 +774,19 @@ else
759774
end
760775
end
761776

762-
_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr))
763-
_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr))
764-
_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len)
765-
_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs)
766-
767777
function collect(itr::Generator)
768778
isz = IteratorSize(itr.iter)
769779
et = @default_eltype(itr)
770780
if isa(isz, SizeUnknown)
771781
return grow_to!(Vector{et}(), itr)
772782
else
773-
shape = isz isa HasLength ? length(itr) : axes(itr)
783+
shp = _similar_shape(itr, isz)
774784
y = iterate(itr)
775785
if y === nothing
776-
return _array_for(et, itr.iter, isz)
786+
return _array_for(et, isz, shp)
777787
end
778788
v1, st = y
779-
dest = _array_for(typeof(v1), itr.iter, isz, shape)
789+
dest = _array_for(typeof(v1), isz, shp)
780790
# The typeassert gives inference a helping hand on the element type and dimensionality
781791
# (work-around for #28382)
782792
et′ = et <: Type ? Type : et
@@ -786,15 +796,22 @@ function collect(itr::Generator)
786796
end
787797

788798
_collect(c, itr, ::EltypeUnknown, isz::SizeUnknown) =
789-
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz), itr)
799+
grow_to!(_similar_for(c, @default_eltype(itr), itr, isz, nothing), itr)
790800

791801
function _collect(c, itr, ::EltypeUnknown, isz::Union{HasLength,HasShape})
802+
et = @default_eltype(itr)
803+
shp = _similar_shape(itr, isz)
792804
y = iterate(itr)
793805
if y === nothing
794-
return _similar_for(c, @default_eltype(itr), itr, isz)
806+
return _similar_for(c, et, itr, isz, shp)
795807
end
796808
v1, st = y
797-
collect_to_with_first!(_similar_for(c, typeof(v1), itr, isz), v1, itr, st)
809+
dest = _similar_for(c, typeof(v1), itr, isz, shp)
810+
# The typeassert gives inference a helping hand on the element type and dimensionality
811+
# (work-around for #28382)
812+
et′ = et <: Type ? Type : et
813+
RT = dest isa AbstractArray ? AbstractArray{<:et′, ndims(dest)} : Any
814+
collect_to_with_first!(dest, v1, itr, st)::RT
798815
end
799816

800817
function collect_to_with_first!(dest::AbstractArray, v1, itr, st)

base/compiler/ssair/inlining.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ function inline_into_block!(state::CFGInliningState, block::Int)
109109
new_range = state.first_bb+1:block
110110
l = length(state.new_cfg_blocks)
111111
state.bb_rename[new_range] = (l+1:l+length(new_range))
112-
append!(state.new_cfg_blocks, map(copy, state.cfg.blocks[new_range]))
112+
append!(state.new_cfg_blocks, (copy(block) for block in state.cfg.blocks[new_range]))
113113
push!(state.merged_orig_blocks, last(new_range))
114114
end
115115
state.first_bb = block

base/compiler/ssair/passes.jl

Lines changed: 2 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,31 +27,6 @@ function try_compute_fieldidx_args(typ::DataType, args::Vector{Any})
2727
return try_compute_fieldidx(typ, field)
2828
end
2929

30-
function lift_defuse(cfg::CFG, ssa::SSADefUse)
31-
# We remove from `uses` any block where all uses are dominated
32-
# by a def. This prevents insertion of dead phi nodes at the top
33-
# of such a block if that block happens to be in a loop
34-
ordered = Tuple{Int, Int, Bool}[(x, block_for_inst(cfg, x), true) for x in ssa.uses]
35-
for x in ssa.defs
36-
push!(ordered, (x, block_for_inst(cfg, x), false))
37-
end
38-
ordered = sort(ordered, by=x->x[1])
39-
bb_defs = Int[]
40-
bb_uses = Int[]
41-
last_bb = last_def_bb = 0
42-
for (_, bb, is_use) in ordered
43-
if bb != last_bb && is_use
44-
push!(bb_uses, bb)
45-
end
46-
last_bb = bb
47-
if last_def_bb != bb && !is_use
48-
push!(bb_defs, bb)
49-
last_def_bb = bb
50-
end
51-
end
52-
SSADefUse(bb_uses, bb_defs, Int[])
53-
end
54-
5530
function find_curblock(domtree::DomTree, allblocks::Vector{Int}, curblock::Int)
5631
# TODO: This can be much faster by looking at current level and only
5732
# searching for those blocks in a sorted order
@@ -1209,12 +1184,12 @@ function cfg_simplify!(ir::IRCode)
12091184
# Compute (renamed) successors and predecessors given (renamed) block
12101185
function compute_succs(i)
12111186
orig_bb = follow_merged_succ(result_bbs[i])
1212-
return map(i -> bb_rename_succ[i], bbs[orig_bb].succs)
1187+
return Int[bb_rename_succ[i] for i in bbs[orig_bb].succs]
12131188
end
12141189
function compute_preds(i)
12151190
orig_bb = result_bbs[i]
12161191
preds = bbs[orig_bb].preds
1217-
return map(pred -> bb_rename_pred[pred], preds)
1192+
return Int[bb_rename_pred[pred] for pred in preds]
12181193
end
12191194

12201195
BasicBlock[

base/compiler/ssair/show.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,15 @@ show_unquoted(io::IO, val::Argument, indent::Int, prec::Int) = show_unquoted(io,
7979

8080
show_unquoted(io::IO, stmt::PhiNode, indent::Int, ::Int) = show_unquoted_phinode(io, stmt, indent, "%")
8181
function show_unquoted_phinode(io::IO, stmt::PhiNode, indent::Int, prefix::String)
82-
args = map(1:length(stmt.edges)) do i
82+
args = String[let
8383
e = stmt.edges[i]
8484
v = !isassigned(stmt.values, i) ? "#undef" :
8585
sprint() do io′
8686
show_unquoted(io′, stmt.values[i], indent)
8787
end
88-
return "$prefix$e => $v"
89-
end
88+
"$prefix$e => $v"
89+
end for i in 1:length(stmt.edges)
90+
]
9091
print(io, "φ ", '(')
9192
join(io, args, ", ")
9293
print(io, ')')

base/compiler/ssair/slot2ssa.jl

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,6 @@ function scan_entry!(result::Vector{SlotInfo}, idx::Int, @nospecialize(stmt))
3333
end
3434

3535

36-
function lift_defuse(cfg::CFG, defuse)
37-
map(defuse) do slot
38-
SlotInfo(
39-
Int[block_for_inst(cfg, x) for x in slot.defs],
40-
Int[block_for_inst(cfg, x) for x in slot.uses],
41-
slot.any_newvar
42-
)
43-
end
44-
end
45-
4636
function scan_slot_def_use(nargs::Int, ci::CodeInfo, code::Vector{Any})
4737
nslots = length(ci.slotflags)
4838
result = SlotInfo[SlotInfo() for i = 1:nslots]
@@ -524,7 +514,7 @@ function domsort_ssa!(ir::IRCode, domtree::DomTree)
524514
return new_ir
525515
end
526516

527-
function compute_live_ins(cfg::CFG, defuse)
517+
function compute_live_ins(cfg::CFG, defuse #=::Union{SlotInfo,SSADefUse}=#)
528518
# We remove from `uses` any block where all uses are dominated
529519
# by a def. This prevents insertion of dead phi nodes at the top
530520
# of such a block if that block happens to be in a loop
@@ -586,8 +576,8 @@ function recompute_type(node::Union{PhiNode, PhiCNode}, ci::CodeInfo, ir::IRCode
586576
return new_typ
587577
end
588578

589-
function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
590-
slottypes::Vector{Any})
579+
function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree,
580+
defuses::Vector{SlotInfo}, slottypes::Vector{Any})
591581
code = ir.stmts.inst
592582
cfg = ir.cfg
593583
left = Int[]
@@ -616,7 +606,7 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
616606
for (_, exc) in catch_entry_blocks
617607
phicnodes[exc] = Vector{Tuple{SlotNumber, NewSSAValue, PhiCNode}}()
618608
end
619-
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuse)
609+
@timeit "idf" for (idx, slot) in Iterators.enumerate(defuses)
620610
# No uses => no need for phi nodes
621611
isempty(slot.uses) && continue
622612
# TODO: Restore this optimization
@@ -671,9 +661,9 @@ function construct_ssa!(ci::CodeInfo, ir::IRCode, domtree::DomTree, defuse,
671661
end
672662
# Perform SSA renaming
673663
initial_incoming_vals = Any[
674-
if 0 in defuse[x].defs
664+
if 0 in defuses[x].defs
675665
Argument(x)
676-
elseif !defuse[x].any_newvar
666+
elseif !defuses[x].any_newvar
677667
undef_token
678668
else
679669
SSAValue(-2)

base/dict.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -826,6 +826,6 @@ length(t::ImmutableDict) = count(Returns(true), t)
826826
isempty(t::ImmutableDict) = !isdefined(t, :parent)
827827
empty(::ImmutableDict, ::Type{K}, ::Type{V}) where {K, V} = ImmutableDict{K,V}()
828828

829-
_similar_for(c::Dict, ::Type{Pair{K,V}}, itr, isz) where {K, V} = empty(c, K, V)
830-
_similar_for(c::AbstractDict, ::Type{T}, itr, isz) where {T} =
829+
_similar_for(c::AbstractDict, ::Type{Pair{K,V}}, itr, isz, len) where {K, V} = empty(c, K, V)
830+
_similar_for(c::AbstractDict, ::Type{T}, itr, isz, len) where {T} =
831831
throw(ArgumentError("for AbstractDicts, similar requires an element type of Pair;\n if calling map, consider a comprehension instead"))

base/set.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ empty(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
4444
# by default, a Set is returned
4545
emptymutable(s::AbstractSet{T}, ::Type{U}=T) where {T,U} = Set{U}()
4646

47-
_similar_for(c::AbstractSet, ::Type{T}, itr, isz) where {T} = empty(c, T)
47+
_similar_for(c::AbstractSet, ::Type{T}, itr, isz, len) where {T} = empty(c, T)
4848

4949
function show(io::IO, s::Set)
5050
if isempty(s)

src/gf.c

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,6 +477,7 @@ void jl_foreach_reachable_mtable(void (*visit)(jl_methtable_t *mt, void *env), v
477477
}
478478
else {
479479
foreach_mtable_in_module(jl_main_module, visit, env, &visited);
480+
foreach_mtable_in_module(jl_core_module, visit, env, &visited);
480481
}
481482
JL_GC_POP();
482483
}
@@ -493,14 +494,15 @@ static void reset_mt_caches(jl_methtable_t *mt, void *env)
493494

494495

495496
jl_function_t *jl_typeinf_func = NULL;
496-
size_t jl_typeinf_world = 0;
497+
size_t jl_typeinf_world = 1;
497498

498499
JL_DLLEXPORT void jl_set_typeinf_func(jl_value_t *f)
499500
{
501+
size_t newfunc = jl_typeinf_world == 1 && jl_typeinf_func == NULL;
500502
jl_typeinf_func = (jl_function_t*)f;
501503
jl_typeinf_world = jl_get_tls_world_age();
502504
++jl_world_counter; // make type-inference the only thing in this world
503-
if (jl_typeinf_world == 0) {
505+
if (newfunc) {
504506
// give type inference a chance to see all of these
505507
// TODO: also reinfer if max_world != ~(size_t)0
506508
jl_array_t *unspec = jl_alloc_vec_any(0);

src/julia-syntax.scm

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2734,7 +2734,7 @@
27342734
(check-no-return expr)
27352735
(if (has-break-or-continue? expr)
27362736
(error "break or continue outside loop"))
2737-
(let ((result (gensy))
2737+
(let ((result (make-ssavalue))
27382738
(idx (gensy))
27392739
(oneresult (make-ssavalue))
27402740
(prod (make-ssavalue))
@@ -2758,16 +2758,14 @@
27582758
(let ((overall-itr (if (length= itrs 1) (car iv) prod)))
27592759
`(scope-block
27602760
(block
2761-
(local ,result) (local ,idx)
2761+
(local ,idx)
27622762
,.(map (lambda (v r) `(= ,v ,(caddr r))) iv itrs)
27632763
,.(if (length= itrs 1)
27642764
'()
27652765
`((= ,prod (call (top product) ,@iv))))
27662766
(= ,isz (call (top IteratorSize) ,overall-itr))
27672767
(= ,szunk (call (core isa) ,isz (top SizeUnknown)))
2768-
(if ,szunk
2769-
(= ,result (call (curly (core Array) ,ty 1) (core undef) 0))
2770-
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz)))
2768+
(= ,result (call (top _array_for) ,ty ,overall-itr ,isz))
27712769
(= ,idx (call (top first) (call (top LinearIndices) ,result)))
27722770
,(construct-loops (reverse itrs) (reverse iv))
27732771
,result)))))

test/errorshow.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -728,7 +728,7 @@ end
728728

729729
# Test that implementation detail of include() is hidden from the user by default
730730
let bt = try
731-
include("testhelpers/include_error.jl")
731+
@noinline include("testhelpers/include_error.jl")
732732
catch
733733
catch_backtrace()
734734
end
@@ -740,7 +740,7 @@ end
740740
# Test backtrace printing
741741
module B
742742
module C
743-
f(x; y=2.0) = error()
743+
@noinline f(x; y=2.0) = error()
744744
end
745745
module D
746746
import ..C: f
@@ -749,7 +749,8 @@ module B
749749
end
750750

751751
@testset "backtrace" begin
752-
bt = try B.D.g()
752+
bt = try
753+
B.D.g()
753754
catch
754755
catch_backtrace()
755756
end
@@ -777,15 +778,17 @@ if Sys.isapple() || (Sys.islinux() && Sys.ARCH === :x86_64)
777778
pair_repeater_b() = pair_repeater_a()
778779

779780
@testset "repeated stack frames" begin
780-
let bt = try single_repeater()
781+
let bt = try
782+
single_repeater()
781783
catch
782784
catch_backtrace()
783785
end
784786
bt_str = sprint(Base.show_backtrace, bt)
785787
@test occursin(r"repeats \d+ times", bt_str)
786788
end
787789

788-
let bt = try pair_repeater_a()
790+
let bt = try
791+
pair_repeater_a()
789792
catch
790793
catch_backtrace()
791794
end

0 commit comments

Comments
 (0)