Skip to content

Commit 2349f0a

Browse files
authored
optimizer: improve inlining algorithm robustness (JuliaLang#44445)
Explicitly check the conditions assumed by `ir_inline_item!`/`ir_inline_unionsplit!` within the call analysis phase. This commit also includes a small refactor to use same code for handling both concrete and abstract callsite, and it should slightly improve the handling of abstract, constant-prop'ed callsite.
1 parent c3d7edc commit 2349f0a

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

base/compiler/ssair/inlining.jl

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ struct InvokeCase
5757
end
5858

5959
struct InliningCase
60-
sig # ::Type
60+
sig # Type
6161
item # Union{InliningTodo, MethodInstance, ConstantCase}
6262
function InliningCase(@nospecialize(sig), @nospecialize(item))
6363
@assert isa(item, Union{InliningTodo, InvokeCase, ConstantCase}) "invalid inlining item"
@@ -67,10 +67,10 @@ end
6767

6868
struct UnionSplit
6969
fully_covered::Bool
70-
atype # ::Type
70+
atype::DataType
7171
cases::Vector{InliningCase}
7272
bbs::Vector{Int}
73-
UnionSplit(fully_covered::Bool, atype, cases::Vector{InliningCase}) =
73+
UnionSplit(fully_covered::Bool, atype::DataType, cases::Vector{InliningCase}) =
7474
new(fully_covered, atype, cases, Int[])
7575
end
7676

@@ -474,12 +474,11 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
474474
@assert length(bbs) >= length(cases)
475475
for i in 1:length(cases)
476476
ithcase = cases[i]
477-
metharg = ithcase.sig
477+
metharg = ithcase.sig::DataType # checked within `handle_cases!`
478478
case = ithcase.item
479479
next_cond_bb = bbs[i]
480-
@assert isa(metharg, DataType)
481480
cond = true
482-
aparams, mparams = atype.parameters::SimpleVector, metharg.parameters::SimpleVector
481+
aparams, mparams = atype.parameters, metharg.parameters
483482
@assert length(aparams) == length(mparams)
484483
if i != length(cases) || !fully_covered ||
485484
(!params.trust_inference && isdispatchtuple(cases[i].sig))
@@ -1222,7 +1221,6 @@ function analyze_single_call!(
12221221
end
12231222
end
12241223

1225-
12261224
atype = argtypes_to_type(argtypes)
12271225
if handled_all_cases && revisit_idx !== nothing
12281226
# If there's only one case that's not a dispatchtuple, we can
@@ -1234,7 +1232,7 @@ function analyze_single_call!(
12341232
# cases are split off from an ::Any typed fallback.
12351233
(i, j) = revisit_idx
12361234
match = infos[i].results[j]
1237-
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
1235+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases, true)
12381236
elseif length(cases) == 0 && only_method isa Method
12391237
# if the signature is fully covered and there is only one applicable method,
12401238
# we can try to inline it even if the signature is not a dispatch tuple.
@@ -1248,9 +1246,7 @@ function analyze_single_call!(
12481246
@assert length(meth) == 1
12491247
match = meth[1]
12501248
end
1251-
item = analyze_method!(match, argtypes, flag, state)
1252-
item === nothing && return nothing
1253-
push!(cases, InliningCase(match.spec_types, item))
1249+
handle_match!(match, argtypes, flag, state, cases, true) || return nothing
12541250
any_covers_full = handled_all_cases = match.fully_covers
12551251
end
12561252

@@ -1290,30 +1286,31 @@ function handle_const_call!(
12901286
handled_all_cases &= handle_inf_result!(result, argtypes, flag, state, cases)
12911287
else
12921288
@assert result === nothing
1293-
handled_all_cases &= isdispatchtuple(match.spec_types) && handle_match!(match, argtypes, flag, state, cases)
1289+
handled_all_cases &= handle_match!(match, argtypes, flag, state, cases)
12941290
end
12951291
end
12961292
end
12971293

12981294
# if the signature is fully covered and there is only one applicable method,
12991295
# we can try to inline it even if the signature is not a dispatch tuple
13001296
atype = argtypes_to_type(argtypes)
1301-
if length(cases) == 0 && length(results) == 1 && isa(results[1], InferenceResult)
1302-
(; mi) = item = InliningTodo(results[1]::InferenceResult, argtypes)
1303-
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
1304-
validate_sparams(mi.sparam_vals) || return nothing
1305-
item === nothing && return nothing
1306-
push!(cases, InliningCase(mi.specTypes, item))
1307-
any_covers_full = handled_all_cases = atype <: mi.specTypes
1297+
if length(cases) == 0
1298+
length(results) == 1 || return nothing
1299+
result = results[1]
1300+
isa(result, InferenceResult) || return nothing
1301+
handle_inf_result!(result, argtypes, flag, state, cases, true) || return nothing
1302+
spec_types = cases[1].sig
1303+
any_covers_full = handled_all_cases = atype <: spec_types
13081304
end
13091305

13101306
handle_cases!(ir, idx, stmt, atype, cases, any_covers_full && handled_all_cases, todo, state.params)
13111307
end
13121308

13131309
function handle_match!(
13141310
match::MethodMatch, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1315-
cases::Vector{InliningCase})
1311+
cases::Vector{InliningCase}, allow_abstract::Bool = false)
13161312
spec_types = match.spec_types
1313+
allow_abstract || isdispatchtuple(spec_types) || return false
13171314
item = analyze_method!(match, argtypes, flag, state)
13181315
item === nothing && return false
13191316
_any(case->case.sig === spec_types, cases) && return true
@@ -1323,10 +1320,10 @@ end
13231320

13241321
function handle_inf_result!(
13251322
result::InferenceResult, argtypes::Vector{Any}, flag::UInt8, state::InliningState,
1326-
cases::Vector{InliningCase})
1323+
cases::Vector{InliningCase}, allow_abstract::Bool = false)
13271324
(; mi) = item = InliningTodo(result, argtypes)
13281325
spec_types = mi.specTypes
1329-
isdispatchtuple(spec_types) || return false
1326+
allow_abstract || isdispatchtuple(spec_types) || return false
13301327
validate_sparams(mi.sparam_vals) || return false
13311328
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
13321329
item === nothing && return false
@@ -1351,6 +1348,8 @@ function handle_cases!(ir::IRCode, idx::Int, stmt::Expr, @nospecialize(atype),
13511348
if fully_covered && length(cases) == 1
13521349
handle_single_case!(ir, idx, stmt, cases[1].item, todo, params)
13531350
elseif length(cases) > 0
1351+
isa(atype, DataType) || return nothing
1352+
all(case::InliningCase->isa(case.sig, DataType), cases) || return nothing
13541353
push!(todo, idx=>UnionSplit(fully_covered, atype, cases))
13551354
end
13561355
return nothing

0 commit comments

Comments
 (0)