Skip to content

Commit e204e20

Browse files
Kenoaviatesk
andauthored
irinterp: Consider cfg information from discovered errors (#49692)
If we infer a call to `Union{}`, we can terminate further abstract interpretation. However, this of course also means that we can make use of that information to refine the types of any phis that may have originated from the basic block containing the call that was refined to `Union{}`. Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent b9b8b38 commit e204e20

File tree

2 files changed

+103
-36
lines changed

2 files changed

+103
-36
lines changed

base/compiler/ssair/irinterp.jl

Lines changed: 70 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,49 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRIn
5858
return RTEffects(rt, effects)
5959
end
6060

61+
function update_phi!(irsv::IRInterpretationState, from::Int, to::Int)
62+
ir = irsv.ir
63+
if length(ir.cfg.blocks[to].preds) == 0
64+
# Kill the entire block
65+
for bidx = ir.cfg.blocks[to].stmts
66+
ir.stmts[bidx][:inst] = nothing
67+
ir.stmts[bidx][:type] = Bottom
68+
ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
69+
end
70+
return
71+
end
72+
for sidx = ir.cfg.blocks[to].stmts
73+
sinst = ir.stmts[sidx][:inst]
74+
isa(sinst, Nothing) && continue # allowed between `PhiNode`s
75+
isa(sinst, PhiNode) || break
76+
for (eidx, edge) in enumerate(sinst.edges)
77+
if edge == from
78+
deleteat!(sinst.edges, eidx)
79+
deleteat!(sinst.values, eidx)
80+
push!(irsv.ssa_refined, sidx)
81+
break
82+
end
83+
end
84+
end
85+
end
86+
update_phi!(irsv::IRInterpretationState) = (from::Int, to::Int)->update_phi!(irsv, from, to)
87+
88+
function kill_terminator_edges!(irsv::IRInterpretationState, term_idx::Int, bb::Int=block_for_inst(irsv.ir, term_idx))
89+
ir = irsv.ir
90+
inst = ir[SSAValue(term_idx)][:inst]
91+
if isa(inst, GotoIfNot)
92+
kill_edge!(ir, bb, inst.dest, update_phi!(irsv))
93+
kill_edge!(ir, bb, bb+1, update_phi!(irsv))
94+
elseif isa(inst, GotoNode)
95+
kill_edge!(ir, bb, inst.label, update_phi!(irsv))
96+
elseif isa(inst, ReturnNode)
97+
# Nothing to do
98+
else
99+
@assert !isexpr(inst, :enter)
100+
kill_edge!(ir, bb, bb+1, update_phi!(irsv))
101+
end
102+
end
103+
61104
function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union{Int,Nothing},
62105
@nospecialize(inst), @nospecialize(typ), irsv::IRInterpretationState,
63106
extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet})
@@ -66,30 +109,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
66109
cond = inst.cond
67110
condval = maybe_extract_const_bool(argextype(cond, ir))
68111
if condval isa Bool
69-
function update_phi!(from::Int, to::Int)
70-
if length(ir.cfg.blocks[to].preds) == 0
71-
# Kill the entire block
72-
for bidx = ir.cfg.blocks[to].stmts
73-
ir.stmts[bidx][:inst] = nothing
74-
ir.stmts[bidx][:type] = Bottom
75-
ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
76-
end
77-
return
78-
end
79-
for sidx = ir.cfg.blocks[to].stmts
80-
sinst = ir.stmts[sidx][:inst]
81-
isa(sinst, Nothing) && continue # allowed between `PhiNode`s
82-
isa(sinst, PhiNode) || break
83-
for (eidx, edge) in enumerate(sinst.edges)
84-
if edge == from
85-
deleteat!(sinst.edges, eidx)
86-
deleteat!(sinst.values, eidx)
87-
push!(irsv.ssa_refined, sidx)
88-
break
89-
end
90-
end
91-
end
92-
end
93112
if isa(cond, SSAValue)
94113
kill_def_use!(irsv.tpdum, cond, idx)
95114
end
@@ -100,10 +119,10 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
100119
if condval
101120
ir.stmts[idx][:inst] = nothing
102121
ir.stmts[idx][:type] = Any
103-
kill_edge!(ir, bb, inst.dest, update_phi!)
122+
kill_edge!(ir, bb, inst.dest, update_phi!(irsv))
104123
else
105124
ir.stmts[idx][:inst] = GotoNode(inst.dest)
106-
kill_edge!(ir, bb, bb+1, update_phi!)
125+
kill_edge!(ir, bb, bb+1, update_phi!(irsv))
107126
end
108127
return true
109128
end
@@ -123,9 +142,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
123142
rt, nothrow = concrete_eval_invoke(interp, inst, inst.args[1]::MethodInstance, irsv)
124143
if nothrow
125144
ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW
126-
if isa(rt, Const) && is_inlineable_constant(rt.val)
127-
ir.stmts[idx][:inst] = quoted(rt.val)
128-
end
129145
end
130146
elseif head === :throw_undef_if_not || # TODO: Terminate interpretation early if known false?
131147
head === :gc_preserve_begin ||
@@ -148,9 +164,17 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union
148164
else
149165
error("reprocess_instruction!: unhandled instruction found")
150166
end
151-
if rt !== nothing && !(typeinf_lattice(interp), typ, rt)
152-
ir.stmts[idx][:type] = rt
153-
return true
167+
if rt !== nothing
168+
if isa(rt, Const)
169+
ir.stmts[idx][:type] = rt
170+
if is_inlineable_constant(rt.val)
171+
ir.stmts[idx][:inst] = quoted(rt.val)
172+
end
173+
return true
174+
elseif !(typeinf_lattice(interp), typ, rt)
175+
ir.stmts[idx][:type] = rt
176+
return true
177+
end
154178
end
155179
return false
156180
end
@@ -227,12 +251,22 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR
227251
any_refined = true
228252
delete!(ssa_refined, idx)
229253
end
230-
if any_refined && reprocess_instruction!(interp,
231-
idx, bb, inst, typ, irsv, extra_reprocess)
232-
push!(ssa_refined, idx)
254+
did_reprocess = false
255+
if any_refined
256+
did_reprocess = reprocess_instruction!(interp,
257+
idx, bb, inst, typ, irsv, extra_reprocess)
258+
if did_reprocess
259+
push!(ssa_refined, idx)
260+
inst = ir.stmts[idx][:inst]
261+
typ = ir.stmts[idx][:type]
262+
end
263+
end
264+
if idx == lstmt
265+
process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan
266+
(isa(inst, GotoNode) || isa(inst, GotoIfNot) || isa(inst, ReturnNode) || isexpr(inst, :enter)) && continue
233267
end
234-
idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan
235268
if typ === Bottom && !isa(inst, PhiNode)
269+
kill_terminator_edges!(irsv, lstmt, bb)
236270
break
237271
end
238272
end

test/compiler/inference.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4871,3 +4871,36 @@ function nt_splat_partial(x::Int)
48714871
Val{tuple(nt...)[2]}()
48724872
end
48734873
@test @inferred(nt_splat_partial(42)) == Val{2}()
4874+
4875+
# Test that irinterp refines based on discovered errors
4876+
Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, x)
4877+
y = b1 ? "julia" : xs[]
4878+
if b2
4879+
a = length(y)
4880+
else
4881+
a = sin(y)
4882+
end
4883+
a + x
4884+
end
4885+
4886+
Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, ys, x)
4887+
y = b1 ? xs[] : ys[]
4888+
if b2
4889+
a = length(y)
4890+
else
4891+
a = sin(y)
4892+
end
4893+
a + x
4894+
end
4895+
4896+
let src = code_typed1((Bool,Base.RefValue{Any},Int,)) do b2, xs, x
4897+
kill_error_edge(true, b2, xs, x)
4898+
end
4899+
@test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0
4900+
end
4901+
4902+
let src = code_typed1((Bool,Base.RefValue{String}, Base.RefValue{Any},Int,)) do b2, xs, ys, x
4903+
kill_error_edge(true, b2, xs, ys, x)
4904+
end
4905+
@test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0
4906+
end

0 commit comments

Comments
 (0)