diff --git a/base/compiler/ssair/irinterp.jl b/base/compiler/ssair/irinterp.jl index d58d18b188757..ad6077fa48859 100644 --- a/base/compiler/ssair/irinterp.jl +++ b/base/compiler/ssair/irinterp.jl @@ -58,6 +58,49 @@ function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, irsv::IRIn return RTEffects(rt, effects) end +function update_phi!(irsv::IRInterpretationState, from::Int, to::Int) + ir = irsv.ir + if length(ir.cfg.blocks[to].preds) == 0 + # Kill the entire block + for bidx = ir.cfg.blocks[to].stmts + ir.stmts[bidx][:inst] = nothing + ir.stmts[bidx][:type] = Bottom + ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW + end + return + end + for sidx = ir.cfg.blocks[to].stmts + sinst = ir.stmts[sidx][:inst] + isa(sinst, Nothing) && continue # allowed between `PhiNode`s + isa(sinst, PhiNode) || break + for (eidx, edge) in enumerate(sinst.edges) + if edge == from + deleteat!(sinst.edges, eidx) + deleteat!(sinst.values, eidx) + push!(irsv.ssa_refined, sidx) + break + end + end + end +end +update_phi!(irsv::IRInterpretationState) = (from::Int, to::Int)->update_phi!(irsv, from, to) + +function kill_terminator_edges!(irsv::IRInterpretationState, term_idx::Int, bb::Int=block_for_inst(irsv.ir, term_idx)) + ir = irsv.ir + inst = ir[SSAValue(term_idx)][:inst] + if isa(inst, GotoIfNot) + kill_edge!(ir, bb, inst.dest, update_phi!(irsv)) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) + elseif isa(inst, GotoNode) + kill_edge!(ir, bb, inst.label, update_phi!(irsv)) + elseif isa(inst, ReturnNode) + # Nothing to do + else + @assert !isexpr(inst, :enter) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) + end +end + function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union{Int,Nothing}, @nospecialize(inst), @nospecialize(typ), irsv::IRInterpretationState, extra_reprocess::Union{Nothing,BitSet,BitSetBoundedMinPrioritySet}) @@ -66,30 +109,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union cond = inst.cond condval = maybe_extract_const_bool(argextype(cond, ir)) if condval isa Bool - function update_phi!(from::Int, to::Int) - if length(ir.cfg.blocks[to].preds) == 0 - # Kill the entire block - for bidx = ir.cfg.blocks[to].stmts - ir.stmts[bidx][:inst] = nothing - ir.stmts[bidx][:type] = Bottom - ir.stmts[bidx][:flag] = IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW - end - return - end - for sidx = ir.cfg.blocks[to].stmts - sinst = ir.stmts[sidx][:inst] - isa(sinst, Nothing) && continue # allowed between `PhiNode`s - isa(sinst, PhiNode) || break - for (eidx, edge) in enumerate(sinst.edges) - if edge == from - deleteat!(sinst.edges, eidx) - deleteat!(sinst.values, eidx) - push!(irsv.ssa_refined, sidx) - break - end - end - end - end if isa(cond, SSAValue) kill_def_use!(irsv.tpdum, cond, idx) end @@ -100,10 +119,10 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union if condval ir.stmts[idx][:inst] = nothing ir.stmts[idx][:type] = Any - kill_edge!(ir, bb, inst.dest, update_phi!) + kill_edge!(ir, bb, inst.dest, update_phi!(irsv)) else ir.stmts[idx][:inst] = GotoNode(inst.dest) - kill_edge!(ir, bb, bb+1, update_phi!) + kill_edge!(ir, bb, bb+1, update_phi!(irsv)) end return true end @@ -123,9 +142,6 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union rt, nothrow = concrete_eval_invoke(interp, inst, inst.args[1]::MethodInstance, irsv) if nothrow ir.stmts[idx][:flag] |= IR_FLAG_NOTHROW - if isa(rt, Const) && is_inlineable_constant(rt.val) - ir.stmts[idx][:inst] = quoted(rt.val) - end end elseif head === :throw_undef_if_not || # TODO: Terminate interpretation early if known false? head === :gc_preserve_begin || @@ -148,9 +164,17 @@ function reprocess_instruction!(interp::AbstractInterpreter, idx::Int, bb::Union else error("reprocess_instruction!: unhandled instruction found") end - if rt !== nothing && !⊑(typeinf_lattice(interp), typ, rt) - ir.stmts[idx][:type] = rt - return true + if rt !== nothing + if isa(rt, Const) + ir.stmts[idx][:type] = rt + if is_inlineable_constant(rt.val) + ir.stmts[idx][:inst] = quoted(rt.val) + end + return true + elseif !⊑(typeinf_lattice(interp), typ, rt) + ir.stmts[idx][:type] = rt + return true + end end return false end @@ -227,12 +251,22 @@ function _ir_abstract_constant_propagation(interp::AbstractInterpreter, irsv::IR any_refined = true delete!(ssa_refined, idx) end - if any_refined && reprocess_instruction!(interp, - idx, bb, inst, typ, irsv, extra_reprocess) - push!(ssa_refined, idx) + did_reprocess = false + if any_refined + did_reprocess = reprocess_instruction!(interp, + idx, bb, inst, typ, irsv, extra_reprocess) + if did_reprocess + push!(ssa_refined, idx) + inst = ir.stmts[idx][:inst] + typ = ir.stmts[idx][:type] + end + end + if idx == lstmt + process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan + (isa(inst, GotoNode) || isa(inst, GotoIfNot) || isa(inst, ReturnNode) || isexpr(inst, :enter)) && continue end - idx == lstmt && process_terminator!(ir, inst, idx, bb, all_rets, bb_ip) && @goto residual_scan if typ === Bottom && !isa(inst, PhiNode) + kill_terminator_edges!(irsv, lstmt, bb) break end end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 1b137d1d8f661..5987e10401bc8 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -4871,3 +4871,36 @@ function nt_splat_partial(x::Int) Val{tuple(nt...)[2]}() end @test @inferred(nt_splat_partial(42)) == Val{2}() + +# Test that irinterp refines based on discovered errors +Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, x) + y = b1 ? "julia" : xs[] + if b2 + a = length(y) + else + a = sin(y) + end + a + x +end + +Base.@assume_effects :foldable Base.@constprop :aggressive function kill_error_edge(b1, b2, xs, ys, x) + y = b1 ? xs[] : ys[] + if b2 + a = length(y) + else + a = sin(y) + end + a + x +end + +let src = code_typed1((Bool,Base.RefValue{Any},Int,)) do b2, xs, x + kill_error_edge(true, b2, xs, x) + end + @test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0 +end + +let src = code_typed1((Bool,Base.RefValue{String}, Base.RefValue{Any},Int,)) do b2, xs, ys, x + kill_error_edge(true, b2, xs, ys, x) + end + @test count(@nospecialize(x)->isa(x, Core.PhiNode), src.code) == 0 +end