diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index fc3c3a60115e6..2efeb66293f0b 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -241,7 +241,7 @@ function cfg_inline_unionsplit!(ir::IRCode, idx::Int, push!(from_bbs, length(state.new_cfg_blocks)) # TODO: Right now we unconditionally generate a fallback block # in case of subtyping errors - This is probably unnecessary. - if i != length(cases) || (!fully_covered || !params.trust_inference) + if i != length(cases) || (!fully_covered || (!params.trust_inference && isdispatchtuple(cases[i].sig))) # This block will have the next condition or the final else case push!(state.new_cfg_blocks, BasicBlock(StmtRange(idx, idx))) push!(state.new_cfg_blocks[cond_bb].succs, length(state.new_cfg_blocks)) @@ -481,7 +481,8 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, cond = true aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector @assert length(aparams) == length(mparams) - if i != length(cases) || !fully_covered || !params.trust_inference + if i != length(cases) || !fully_covered || + (!params.trust_inference && isdispatchtuple(cases[i].sig)) for i in 1:length(aparams) a, m = aparams[i], mparams[i] # If this is always true, we don't need to check for it @@ -538,7 +539,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, bb += 1 # We're now in the fall through block, decide what to do if fully_covered - if !params.trust_inference + if !params.trust_inference && isdispatchtuple(cases[end].sig) e = Expr(:call, GlobalRef(Core, :throw), FATAL_TYPE_BOUND_ERROR) insert_node_here!(compact, NewInstruction(e, Union{}, line)) insert_node_here!(compact, NewInstruction(ReturnNode(), Union{}, line)) @@ -1170,7 +1171,10 @@ function analyze_single_call!( cases = InliningCase[] local only_method = nothing # keep track of whether there is one matching method local meth::MethodLookupResult - local fully_covered = true + local handled_all_cases = true + local any_covers_full = false + local revisit_idx = nothing + for i in 1:length(infos) meth = infos[i].results if meth.ambig @@ -1179,7 +1183,7 @@ function analyze_single_call!( return nothing elseif length(meth) == 0 # No applicable methods; try next union split - fully_covered = false + handled_all_cases = false continue else if length(meth) == 1 && only_method !== false @@ -1192,16 +1196,43 @@ function analyze_single_call!( only_method = false end end - for match in meth - fully_covered &= handle_match!(match, argtypes, flag, state, cases) - fully_covered &= match.fully_covers + for (j, match) in enumerate(meth) + any_covers_full |= match.fully_covers + if !isdispatchtuple(match.spec_types) + if !match.fully_covers + handled_all_cases = false + continue + end + if revisit_idx === nothing + revisit_idx = (i, j) + else + handled_all_cases = false + revisit_idx = nothing + end + else + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + end end end - # if the signature is fully covered and there is only one applicable method, - # we can try to inline it even if the signature is not a dispatch tuple + atype = argtypes_to_type(argtypes) - if length(cases) == 0 && only_method isa Method + if handled_all_cases && revisit_idx !== nothing + # If there's only one case that's not a dispatchtuple, we can + # still unionsplit by visiting all the other cases first. + # This is useful for code like: + # foo(x::Int) = 1 + # foo(@nospecialize(x::Any)) = 2 + # where we where only a small number of specific dispatchable + # cases are split off from an ::Any typed fallback. + (i, j) = revisit_idx + match = infos[i].results[j] + handled_all_cases &= handle_match!(match, argtypes, flag, state, cases) + elseif length(cases) == 0 && only_method isa Method + # if the signature is fully covered and there is only one applicable method, + # we can try to inline it even if the signature is not a dispatch tuple. + # -- But don't try it if we already tried to handle the match in the revisit_idx + # case, because that'll (necessarily) be the same method. if length(infos) > 1 (metharg, methsp) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), atype, only_method.sig)::SimpleVector @@ -1213,10 +1244,10 @@ function analyze_single_call!( item = analyze_method!(match, argtypes, flag, state) item === nothing && return nothing push!(cases, InliningCase(match.spec_types, item)) - fully_covered = match.fully_covers + any_covers_full = handled_all_cases = match.fully_covers end - handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params) + handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) end # similar to `analyze_single_call!`, but with constant results @@ -1227,7 +1258,8 @@ function handle_const_call!( (; call, results) = cinfo infos = isa(call, MethodMatchInfo) ? MethodMatchInfo[call] : call.matches cases = InliningCase[] - local fully_covered = true + local handled_all_cases = true + local any_covers_full = false local j = 0 for i in 1:length(infos) meth = infos[i].results @@ -1237,22 +1269,22 @@ function handle_const_call!( return nothing elseif length(meth) == 0 # No applicable methods; try next union split - fully_covered = false + handled_all_cases = false continue end for match in meth j += 1 result = results[j] + any_covers_full |= match.fully_covers if isa(result, ConstResult) case = const_result_item(result, state) push!(cases, InliningCase(result.mi.specTypes, case)) elseif isa(result, InferenceResult) - fully_covered &= handle_inf_result!(result, argtypes, flag, state, cases) + handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases) else @assert result === nothing - fully_covered &= handle_match!(match, argtypes, flag, state, cases) + handled_all_cases &= isdispatchtuple(match.spec_types) && handle_match!(match, argtypes, flag, state, cases) end - fully_covered &= match.fully_covers end end @@ -1265,17 +1297,16 @@ function handle_const_call!( validate_sparams(mi.sparam_vals) || return nothing item === nothing && return nothing push!(cases, InliningCase(mi.specTypes, item)) - fully_covered = atype <: mi.specTypes + any_covers_full = handled_all_cases = atype <: mi.specTypes end - handle_cases!(ir, idx, stmt, atype, cases, fully_covered, todo, state.params) + handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params) end function handle_match!( match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState, cases::Vector{InliningCase}) spec_types = match.spec_types - isdispatchtuple(spec_types) || return false item = analyze_method!(match, argtypes, flag, state) item === nothing && return false _any(case->case.sig === spec_types, cases) && return true diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 1f55ceb94a062..bba9f41bf64d3 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -314,15 +314,17 @@ end @inline tchanged(@nospecialize(n), @nospecialize(o)) = o === NOT_FOUND || (n !== NOT_FOUND && !(n ⊑ o)) @inline schanged(@nospecialize(n), @nospecialize(o)) = (n !== o) && (o === NOT_FOUND || (n !== NOT_FOUND && !issubstate(n::VarState, o::VarState))) -widenconditional(@nospecialize typ) = typ -function widenconditional(typ::AnyConditional) - if typ.vtype === Union{} - return Const(false) - elseif typ.elsetype === Union{} - return Const(true) - else - return Bool +function widenconditional(@nospecialize typ) + if isa(typ, AnyConditional) + if typ.vtype === Union{} + return Const(false) + elseif typ.elsetype === Union{} + return Const(true) + else + return Bool + end end + return typ end widenconditional(t::LimitedAccuracy) = error("unhandled LimitedAccuracy") diff --git a/test/compiler/inline.jl b/test/compiler/inline.jl index 3259c752d9aa0..7619d4e8a0308 100644 --- a/test/compiler/inline.jl +++ b/test/compiler/inline.jl @@ -1099,3 +1099,12 @@ end let src = code_typed1(f44200) @test count(x -> isa(x, Core.PiNode), src.code) == 0 end + +# Test that peeling off one case from (::Any) doesn't introduce +# a dynamic dispatch. +@noinline f_peel(x::Int) = Base.inferencebarrier(1) +@noinline f_peel(@nospecialize(x::Any)) = Base.inferencebarrier(2) +g_call_peel(x) = f_peel(x) +let src = code_typed1(g_call_peel, Tuple{Any}) + @test count(isinvoke(:f_peel), src.code) == 2 +end diff --git a/test/worlds.jl b/test/worlds.jl index a6cbed9560a8d..015ff470a56dd 100644 --- a/test/worlds.jl +++ b/test/worlds.jl @@ -191,7 +191,7 @@ f_gen265(x::Type{Int}) = 3 # intermediate worlds by later additions to the method table that # would have capped those specializations if they were still valid f26506(@nospecialize(x)) = 1 -g26506(x) = f26506(x[1]) +g26506(x) = Base.inferencebarrier(f26506)(x[1]) z = Any["ABC"] f26506(x::Int) = 2 g26506(z) # Places an entry for f26506(::String) in mt.name.cache