Skip to content

Commit 997e336

Browse files
authored
inference: fixes and improvements for backedge computation (#46741)
This commit consists of the following changes: * inference: setup separate functions for each backedge kind Also changes the argument list so that they are ordered as `(caller, [backedge information])`. * inference: fix backedge computation for const-prop'ed callsite With this commit `abstract_call_method_with_const_args` doesn't add backedge but rather returns the backedge to the caller, letting the callers like `abstract_call_gf_by_type` and `abstract_invoke` take the responsibility to add backedge to current context appropriately. As a result, this fixes the backedge calculation for const-prop'ed `invoke` callsite. For example, for the following call graph, ```julia foo(a::Int) = a > 0 ? :int : println(a) foo(a::Integer) = a > 0 ? "integer" : println(a) bar(a::Int) = @invoke foo(a::Integer) ``` Previously we added the wrong backedge `nothing, bar(Int64) from bar(Int64)`: ```julia julia> last(only(code_typed(()->bar(42)))) String julia> let m = only(methods(foo, (UInt,))) @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges) println(sig, ", ", caller) end end Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64) nothing, bar(Int64) from bar(Int64) ``` but now we only add `invoke`-backedge: ```julia julia> last(only(code_typed(()->bar(42)))) String julia> let m = only(methods(foo, (UInt,))) @eval Core.Compiler for (sig, caller) in BackedgeIterator($m.specializations[1].backedges) println(sig, ", ", caller) end end Tuple{typeof(Main.foo), Integer}, bar(Int64) from bar(Int64) ``` * inference: make `BackedgePair` struct * add invalidation test for `invoke` call * optimizer: fixup inlining backedge calculation Should fix the following backedge calculation: ```julia julia> m = which(unique, Tuple{Any}) unique(itr) @ Base set.jl:170 julia> specs = collect(Iterators.filter(m.specializations) do mi mi === nothing && return false return mi.specTypes.parameters[end] === Vector{Int} # find specialization of `unique(::Any)` for `::Vector{Int}` end) Any[] julia> Base._unique_dims([1,2,3],:) # no existing callers with specialization `Vector{Int}`, let's make one 3-element Vector{Int64}: 1 2 3 julia> mi = only(Iterators.filter(m.specializations) do mi mi === nothing && return false return mi.specTypes.parameters[end] === Vector{Int} # find specialization of `unique(::Any)` for `::Vector{Int}` end) MethodInstance for unique(::Vector{Int64}) julia> mi.def unique(itr) @ Base set.jl:170 julia> mi.backedges 3-element Vector{Any}: Tuple{typeof(unique), Any} MethodInstance for Base._unique_dims(::Vector{Int64}, ::Colon) MethodInstance for Base._unique_dims(::Vector{Int64}, ::Colon) # <= now we don't register this backedge ```
1 parent 94c3a15 commit 997e336

File tree

8 files changed

+171
-106
lines changed

8 files changed

+171
-106
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
126126
for sig_n in splitsigs
127127
result = abstract_call_method(interp, method, sig_n, svec(), multiple_matches, sv)
128128
(; rt, edge, effects) = result
129-
edge === nothing || push!(edges, edge)
130129
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
131130
this_arginfo = ArgInfo(fargs, this_argtypes)
132131
const_call_result = abstract_call_method_with_const_args(interp, result,
@@ -135,12 +134,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
135134
if const_call_result !== nothing
136135
if const_call_result.rt ᵢ rt
137136
rt = const_call_result.rt
138-
(; effects, const_result) = const_call_result
137+
(; effects, const_result, edge) = const_call_result
139138
end
140139
end
141140
all_effects = merge_effects(all_effects, effects)
142141
push!(const_results, const_result)
143142
any_const_result |= const_result !== nothing
143+
edge === nothing || push!(edges, edge)
144144
this_rt = tmerge(this_rt, rt)
145145
if bail_out_call(interp, this_rt, sv)
146146
break
@@ -153,7 +153,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
153153
(; rt, edge, effects) = result
154154
this_conditional = ignorelimited(rt)
155155
this_rt = widenwrappedconditional(rt)
156-
edge === nothing || push!(edges, edge)
157156
# try constant propagation with argtypes for this match
158157
# this is in preparation for inlining, or improving the return result
159158
this_argtypes = isa(matches, MethodMatches) ? argtypes : matches.applicable_argtypes[i]
@@ -169,12 +168,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
169168
if this_const_rt ᵢ this_rt
170169
this_conditional = this_const_conditional
171170
this_rt = this_const_rt
172-
(; effects, const_result) = const_call_result
171+
(; effects, const_result, edge) = const_call_result
173172
end
174173
end
175174
all_effects = merge_effects(all_effects, effects)
176175
push!(const_results, const_result)
177176
any_const_result |= const_result !== nothing
177+
edge === nothing || push!(edges, edge)
178178
end
179179
@assert !(this_conditional isa Conditional) "invalid lattice element returned from inter-procedural context"
180180
seen += 1
@@ -483,15 +483,15 @@ function add_call_backedges!(interp::AbstractInterpreter,
483483
end
484484
end
485485
for edge in edges
486-
add_backedge!(edge, sv)
486+
add_backedge!(sv, edge)
487487
end
488488
# also need an edge to the method table in case something gets
489489
# added that did not intersect with any existing method
490490
if isa(matches, MethodMatches)
491-
matches.fullmatch || add_mt_backedge!(matches.mt, atype, sv)
491+
matches.fullmatch || add_mt_backedge!(sv, matches.mt, atype)
492492
else
493493
for (thisfullmatch, mt) in zip(matches.fullmatches, matches.mts)
494-
thisfullmatch || add_mt_backedge!(mt, atype, sv)
494+
thisfullmatch || add_mt_backedge!(sv, mt, atype)
495495
end
496496
end
497497
end
@@ -838,17 +838,18 @@ function concrete_eval_call(interp::AbstractInterpreter,
838838
f = invoke
839839
end
840840
world = get_world_counter(interp)
841+
edge = result.edge::MethodInstance
841842
value = try
842843
Core._call_in_world_total(world, f, args...)
843844
catch
844845
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime
845-
return ConstCallResults(Union{}, ConcreteResult(result.edge::MethodInstance, result.effects), result.effects)
846+
return ConstCallResults(Union{}, ConcreteResult(edge, result.effects), result.effects, edge)
846847
end
847848
if is_inlineable_constant(value) || call_result_unused(sv)
848849
# If the constant is not inlineable, still do the const-prop, since the
849850
# code that led to the creation of the Const may be inlineable in the same
850851
# circumstance and may be optimizable.
851-
return ConstCallResults(Const(value), ConcreteResult(result.edge::MethodInstance, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
852+
return ConstCallResults(Const(value), ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
852853
end
853854
return false
854855
else # eligible for semi-concrete evaluation
@@ -875,10 +876,12 @@ struct ConstCallResults
875876
rt::Any
876877
const_result::ConstResult
877878
effects::Effects
879+
edge::MethodInstance
878880
ConstCallResults(@nospecialize(rt),
879881
const_result::ConstResult,
880-
effects::Effects) =
881-
new(rt, const_result, effects)
882+
effects::Effects,
883+
edge::MethodInstance) =
884+
new(rt, const_result, effects, edge)
882885
end
883886

884887
function abstract_call_method_with_const_args(interp::AbstractInterpreter,
@@ -888,10 +891,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
888891
return nothing
889892
end
890893
res = concrete_eval_call(interp, f, result, arginfo, sv, invokecall)
891-
if isa(res, ConstCallResults)
892-
add_backedge!(res.const_result.mi, sv, invokecall === nothing ? nothing : invokecall.lookupsig)
893-
return res
894-
end
894+
isa(res, ConstCallResults) && return res
895895
mi = maybe_get_const_prop_profitable(interp, result, f, arginfo, match, sv)
896896
mi === nothing && return nothing
897897
# try semi-concrete evaluation
@@ -903,7 +903,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
903903
if isa(ir, IRCode)
904904
T = ir_abstract_constant_propagation(interp, mi_cache, sv, mi, ir, arginfo.argtypes)
905905
if !isa(T, Type) || typeintersect(T, Bool) === Union{}
906-
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects)
906+
return ConstCallResults(T, SemiConcreteResult(mi, ir, result.effects), result.effects, mi)
907907
end
908908
end
909909
end
@@ -936,8 +936,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter,
936936
result = inf_result.result
937937
# if constant inference hits a cycle, just bail out
938938
isa(result, InferenceState) && return nothing
939-
add_backedge!(mi, sv)
940-
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects)
939+
return ConstCallResults(result, ConstPropResult(inf_result), inf_result.ipo_effects, mi)
941940
end
942941

943942
# if there's a possibility we could get a better result with these constant arguments
@@ -1692,7 +1691,6 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
16921691
ti = tienv[1]; env = tienv[2]::SimpleVector
16931692
result = abstract_call_method(interp, method, ti, env, false, sv)
16941693
(; rt, edge, effects) = result
1695-
edge !== nothing && add_backedge!(edge::MethodInstance, sv, lookupsig)
16961694
match = MethodMatch(ti, env, method, argtype <: method.sig)
16971695
res = nothing
16981696
sig = match.spec_types
@@ -1711,10 +1709,11 @@ function abstract_invoke(interp::AbstractInterpreter, (; fargs, argtypes)::ArgIn
17111709
const_result = nothing
17121710
if const_call_result !== nothing
17131711
if (typeinf_lattice(interp), const_call_result.rt, rt)
1714-
(; rt, effects, const_result) = const_call_result
1712+
(; rt, effects, const_result, edge) = const_call_result
17151713
end
17161714
end
17171715
effects = Effects(effects; nonoverlayed=!overlayed)
1716+
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
17181717
return CallMeta(from_interprocedural!(ipo_lattice(interp), rt, sv, arginfo, sig), effects, InvokeCallInfo(match, const_result))
17191718
end
17201719

@@ -1848,7 +1847,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
18481847
sig = argtypes_to_type(arginfo.argtypes)
18491848
result = abstract_call_method(interp, closure.source, sig, Core.svec(), false, sv)
18501849
(; rt, edge, effects) = result
1851-
edge !== nothing && add_backedge!(edge, sv)
18521850
tt = closure.typ
18531851
sigT = (unwrap_unionall(tt)::DataType).parameters[1]
18541852
match = MethodMatch(sig, Core.svec(), closure.source, sig <: rewrap_unionall(sigT, tt))
@@ -1858,7 +1856,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
18581856
nothing, arginfo, match, sv)
18591857
if const_call_result !== nothing
18601858
if const_call_result.rt rt
1861-
(; rt, effects, const_result) = const_call_result
1859+
(; rt, effects, const_result, edge) = const_call_result
18621860
end
18631861
end
18641862
end
@@ -1874,6 +1872,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
18741872
end
18751873
end
18761874
rt = from_interprocedural!(ipo, rt, sv, arginfo, match.spec_types)
1875+
edge !== nothing && add_backedge!(sv, edge)
18771876
return CallMeta(rt, effects, info)
18781877
end
18791878

base/compiler/inferencestate.jl

Lines changed: 26 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -478,38 +478,49 @@ function record_ssa_assign!(ssa_id::Int, @nospecialize(new), frame::InferenceSta
478478
return nothing
479479
end
480480

481-
function add_cycle_backedge!(frame::InferenceState, caller::InferenceState, currpc::Int)
481+
function add_cycle_backedge!(caller::InferenceState, frame::InferenceState, currpc::Int)
482482
update_valid_age!(frame, caller)
483483
backedge = (caller, currpc)
484484
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
485-
add_backedge!(frame.linfo, caller)
485+
add_backedge!(caller, frame.linfo)
486486
return frame
487487
end
488488

489489
# temporarily accumulate our edges to later add as backedges in the callee
490-
function add_backedge!(li::MethodInstance, caller::InferenceState, @nospecialize(invokesig=nothing))
491-
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
492-
edges = caller.stmt_edges[caller.currpc]
493-
if edges === nothing
494-
edges = caller.stmt_edges[caller.currpc] = []
490+
function add_backedge!(caller::InferenceState, li::MethodInstance)
491+
edges = get_stmt_edges!(caller)
492+
if edges !== nothing
493+
push!(edges, li)
495494
end
496-
if invokesig !== nothing
497-
push!(edges, invokesig)
495+
return nothing
496+
end
497+
498+
function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), li::MethodInstance)
499+
edges = get_stmt_edges!(caller)
500+
if edges !== nothing
501+
push!(edges, invokesig, li)
498502
end
499-
push!(edges, li)
500503
return nothing
501504
end
502505

503506
# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
504-
function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::InferenceState)
505-
isa(caller.linfo.def, Method) || return # don't add backedges to toplevel exprs
507+
function add_mt_backedge!(caller::InferenceState, mt::Core.MethodTable, @nospecialize(typ))
508+
edges = get_stmt_edges!(caller)
509+
if edges !== nothing
510+
push!(edges, mt, typ)
511+
end
512+
return nothing
513+
end
514+
515+
function get_stmt_edges!(caller::InferenceState)
516+
if !isa(caller.linfo.def, Method)
517+
return nothing # don't add backedges to toplevel exprs
518+
end
506519
edges = caller.stmt_edges[caller.currpc]
507520
if edges === nothing
508521
edges = caller.stmt_edges[caller.currpc] = []
509522
end
510-
push!(edges, mt)
511-
push!(edges, typ)
512-
return nothing
523+
return edges
513524
end
514525

515526
function empty_backedges!(frame::InferenceState, currpc::Int = frame.currpc)

base/compiler/optimize.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,14 +64,13 @@ EdgeTracker() = EdgeTracker(Any[], 0:typemax(UInt))
6464
intersect!(et::EdgeTracker, range::WorldRange) =
6565
et.valid_worlds[] = intersect(et.valid_worlds[], range)
6666

67-
push!(et::EdgeTracker, mi::MethodInstance) = push!(et.edges, mi)
68-
function add_edge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
69-
invokesig === nothing && return push!(et.edges, mi)
70-
push!(et.edges, invokesig, mi)
67+
function add_backedge!(et::EdgeTracker, mi::MethodInstance)
68+
push!(et.edges, mi)
69+
return nothing
7170
end
72-
function push!(et::EdgeTracker, ci::CodeInstance)
73-
intersect!(et, WorldRange(min_world(li), max_world(li)))
74-
push!(et, ci.def)
71+
function add_invoke_backedge!(et::EdgeTracker, @nospecialize(invokesig), mi::MethodInstance)
72+
push!(et.edges, invokesig, mi)
73+
return nothing
7574
end
7675

7776
struct InliningState{S <: Union{EdgeTracker, Nothing}, MICache, I<:AbstractInterpreter}

0 commit comments

Comments
 (0)