Skip to content

Commit 66d549b

Browse files
committed
optimizer: supports callsite annotations of inlining, fixes #18773
Enable `@inline`/`@noinline` annotations on function callsites. From #40754. Now `@inline` and `@noinline` can be applied to a code block and then the compiler will try to (not) inline calls within the block: ```julia @inline f(...) # The compiler will try to inline `f` @inline f(...) + g(...) # The compiler will try to inline `f`, `g` and `+` @inline f(args...) = ... # Of course annotations on a definition is still allowed ``` Here are couple of notes on how those callsite annotations will work: - callsite annotation always has the precedence over the annotation applied to the definition of the called function, whichever we use `@inline`/`@noinline`: ```julia @inline function explicit_inline(args...) # body end let @noinline explicit_inline(args...) # this call will not be inlined end ``` - when callsite annotations are nested, the innermost annotations has the precedence ```julia @noinline let a0, b0 = ... a = @inline f(a0) # the compiler will try to inline this call b = notinlined(b0) # the compiler will NOT try to inline this call return a, b end ``` They're both tested and included in documentations.
1 parent 6ea0b78 commit 66d549b

File tree

15 files changed

+312
-67
lines changed

15 files changed

+312
-67
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -590,7 +590,7 @@ function maybe_get_const_prop_profitable(interp::AbstractInterpreter, result::Me
590590
return nothing
591591
end
592592
mi = mi::MethodInstance
593-
if !force && !const_prop_methodinstance_heuristic(interp, method, mi)
593+
if !force && !const_prop_methodinstance_heuristic(interp, match, mi)
594594
add_remark!(interp, sv, "[constprop] Disabled by method instance heuristic")
595595
return nothing
596596
end
@@ -692,7 +692,8 @@ end
692692
# This is a heuristic to avoid trying to const prop through complicated functions
693693
# where we would spend a lot of time, but are probably unlikely to get an improved
694694
# result anyway.
695-
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method::Method, mi::MethodInstance)
695+
function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, match::MethodMatch, mi::MethodInstance)
696+
method = match.method
696697
if method.is_for_opaque_closure
697698
# Not inlining an opaque closure can be very expensive, so be generous
698699
# with the const-prop-ability. It is quite possible that we can't infer
@@ -710,7 +711,8 @@ function const_prop_methodinstance_heuristic(interp::AbstractInterpreter, method
710711
if isdefined(code, :inferred) && !cache_inlineable
711712
cache_inf = code.inferred
712713
if !(cache_inf === nothing)
713-
cache_inlineable = inlining_policy(interp)(cache_inf) !== nothing
714+
# TODO maybe we want to respect callsite `@inline`/`@noinline` annotations here ?
715+
cache_inlineable = inlining_policy(interp)(cache_inf, nothing, match) !== nothing
714716
end
715717
end
716718
if !cache_inlineable
@@ -1889,7 +1891,8 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
18891891
if isa(fname, SlotNumber)
18901892
changes = StateUpdate(fname, VarState(Any, false), changes, false)
18911893
end
1892-
elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect
1894+
elseif hd === :code_coverage_effect ||
1895+
(hd !== :boundscheck && hd !== nothing && is_meta_expr_head(hd)) # :boundscheck can be narrowed to Bool
18931896
# these do not generate code
18941897
else
18951898
t = abstract_eval_statement(interp, stmt, changes, frame)

base/compiler/optimize.jl

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,20 @@ struct InliningState{S <: Union{EdgeTracker, Nothing}, T, P}
2828
policy::P
2929
end
3030

31-
function default_inlining_policy(@nospecialize(src))
31+
function default_inlining_policy(@nospecialize(src), stmt_flag::Union{Nothing,UInt8}, match::Union{MethodMatch,InferenceResult})
3232
if isa(src, CodeInfo) || isa(src, Vector{UInt8})
3333
src_inferred = ccall(:jl_ir_flag_inferred, Bool, (Any,), src)
34-
src_inlineable = ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
34+
src_inlineable = is_stmt_inline(stmt_flag) || ccall(:jl_ir_flag_inlineable, Bool, (Any,), src)
3535
return src_inferred && src_inlineable ? src : nothing
36-
end
37-
if isa(src, OptimizationState) && isdefined(src, :ir)
38-
return src.src.inlineable ? src.ir : nothing
36+
elseif isa(src, OptimizationState) && isdefined(src, :ir)
37+
return (is_stmt_inline(stmt_flag) || src.src.inlineable) ? src.ir : nothing
38+
elseif src === nothing && is_stmt_inline(stmt_flag) && isa(match, MethodMatch)
39+
# when the source isn't available at this moment, try to re-infer and inline it
40+
# HACK in order to avoid cycles here, we disable inlining and makes sure the following inference never comes here
41+
# TODO sort out `AbstractInterpreter` interface to handle this well, and also inference should try to keep the source if the statement will be inlined
42+
interp = NativeInterpreter(; opt_params = OptimizationParams(; inlining = false))
43+
src, rt = typeinf_code(interp, match.method, match.spec_types, match.sparams, true)
44+
return src
3945
end
4046
return nothing
4147
end
@@ -129,6 +135,10 @@ const SLOT_USEDUNDEF = 32 # slot has uses that might raise UndefVarError
129135
# This statement was marked as @inbounds by the user. If replaced by inlining,
130136
# any contained boundschecks may be removed
131137
const IR_FLAG_INBOUNDS = 0x01
138+
# This statement was marked as @inline by the user
139+
const IR_FLAG_INLINE = 0x01 << 1
140+
# This statement was marked as @noinline by the user
141+
const IR_FLAG_NOINLINE = 0x01 << 2
132142
# This statement may be removed if its result is unused. In particular it must
133143
# thus be both pure and effect free.
134144
const IR_FLAG_EFFECT_FREE = 0x01 << 4
@@ -174,6 +184,11 @@ function isinlineable(m::Method, me::OptimizationState, params::OptimizationPara
174184
return inlineable
175185
end
176186

187+
is_stmt_inline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_INLINE != 0
188+
is_stmt_inline(::Nothing) = false
189+
is_stmt_noinline(stmt_flag::UInt8) = stmt_flag & IR_FLAG_NOINLINE != 0
190+
is_stmt_noinline(::Nothing) = false # not used for now
191+
177192
# These affect control flow within the function (so may not be removed
178193
# if there is no usage within the function), but don't affect the purity
179194
# of the function as a whole.
@@ -360,6 +375,7 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, sv::
360375
renumber_ir_elements!(code, changemap, labelmap)
361376

362377
inbounds_depth = 0 # Number of stacked inbounds
378+
inline_flags = BitVector()
363379
meta = Any[]
364380
flags = fill(0x00, length(code))
365381
for i = 1:length(code)
@@ -374,16 +390,38 @@ function convert_to_ircode(ci::CodeInfo, code::Vector{Any}, coverage::Bool, sv::
374390
inbounds_depth -= 1
375391
end
376392
stmt = nothing
393+
elseif isexpr(stmt, :inline)
394+
if stmt.args[1]::Bool
395+
push!(inline_flags, true)
396+
else
397+
pop!(inline_flags)
398+
end
399+
stmt = nothing
400+
elseif isexpr(stmt, :noinline)
401+
if stmt.args[1]::Bool
402+
push!(inline_flags, false)
403+
else
404+
pop!(inline_flags)
405+
end
406+
stmt = nothing
377407
else
378408
stmt = normalize(stmt, meta)
379409
end
380410
code[i] = stmt
381-
if !(stmt === nothing)
411+
if stmt !== nothing
382412
if inbounds_depth > 0
383413
flags[i] |= IR_FLAG_INBOUNDS
384414
end
415+
if !isempty(inline_flags)
416+
if last(inline_flags)
417+
flags[i] |= IR_FLAG_INLINE
418+
else
419+
flags[i] |= IR_FLAG_NOINLINE
420+
end
421+
end
385422
end
386423
end
424+
@assert isempty(inline_flags) "malformed meta flags"
387425
strip_trailing_junk!(ci, code, stmtinfo, flags)
388426
cfg = compute_basic_blocks(code)
389427
types = Any[]

base/compiler/ssair/inlining.jl

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
604604
argexprs::Vector{Any}, atypes::Vector{Any}, arginfos::Vector{Any},
605605
arg_start::Int, istate::InliningState)
606606

607+
flag = ir.stmts[idx][:flag]
607608
new_argexprs = Any[argexprs[arg_start]]
608609
new_atypes = Any[atypes[arg_start]]
609610
# loop over original arguments and flatten any known iterators
@@ -659,8 +660,9 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
659660
info = call.info
660661
handled = false
661662
if isa(info, ConstCallInfo)
662-
if maybe_handle_const_call!(ir, state1.id, new_stmt, info, new_sig,
663-
call.rt, istate, false, todo)
663+
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
664+
ir, state1.id, new_stmt, info, new_sig,call.rt, istate, flag, false, todo)
665+
664666
handled = true
665667
else
666668
info = info.call
@@ -671,7 +673,7 @@ function rewrite_apply_exprargs!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::
671673
MethodMatchInfo[info] : info.matches
672674
# See if we can inline this call to `iterate`
673675
analyze_single_call!(ir, todo, state1.id, new_stmt,
674-
new_sig, call.rt, info, istate)
676+
new_sig, call.rt, info, istate, flag)
675677
end
676678
if i != length(thisarginfo.each)
677679
valT = getfield_tfunc(call.rt, Const(1))
@@ -719,16 +721,16 @@ function compileable_specialization(et::Union{EdgeTracker, Nothing}, (; linfo)::
719721
return mi
720722
end
721723

722-
function resolve_todo(todo::InliningTodo, state::InliningState)
723-
spec = todo.spec::DelayedInliningSpec
724+
function resolve_todo(todo::InliningTodo, state::InliningState, flag::UInt8)
725+
(; match) = todo.spec::DelayedInliningSpec
724726

725727
#XXX: update_valid_age!(min_valid[1], max_valid[1], sv)
726728
isconst, src = false, nothing
727-
if isa(spec.match, InferenceResult)
728-
let inferred_src = spec.match.src
729+
if isa(match, InferenceResult)
730+
let inferred_src = match.src
729731
if isa(inferred_src, Const)
730732
if !is_inlineable_constant(inferred_src.val)
731-
return compileable_specialization(state.et, spec.match)
733+
return compileable_specialization(state.et, match)
732734
end
733735
isconst, src = true, quoted(inferred_src.val)
734736
else
@@ -756,12 +758,10 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
756758
return ConstantCase(src)
757759
end
758760

759-
if src !== nothing
760-
src = state.policy(src)
761-
end
761+
src = state.policy(src, flag, match)
762762

763763
if src === nothing
764-
return compileable_specialization(et, spec.match)
764+
return compileable_specialization(et, match)
765765
end
766766

767767
if isa(src, IRCode)
@@ -772,9 +772,9 @@ function resolve_todo(todo::InliningTodo, state::InliningState)
772772
return InliningTodo(todo.mi, src)
773773
end
774774

775-
function resolve_todo(todo::UnionSplit, state::InliningState)
775+
function resolve_todo(todo::UnionSplit, state::InliningState, flag::UInt8)
776776
UnionSplit(todo.fully_covered, todo.atype,
777-
Pair{Any,Any}[sig=>resolve_todo(item, state) for (sig, item) in todo.cases])
777+
Pair{Any,Any}[sig=>resolve_todo(item, state, flag) for (sig, item) in todo.cases])
778778
end
779779

780780
function validate_sparams(sparams::SimpleVector)
@@ -785,7 +785,7 @@ function validate_sparams(sparams::SimpleVector)
785785
end
786786

787787
function analyze_method!(match::MethodMatch, atypes::Vector{Any},
788-
state::InliningState, @nospecialize(stmttyp))
788+
state::InliningState, @nospecialize(stmttyp), flag::UInt8)
789789
method = match.method
790790
methsig = method.sig
791791

@@ -805,7 +805,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
805805

806806
et = state.et
807807

808-
if !state.params.inlining
808+
if !state.params.inlining || is_stmt_noinline(flag)
809809
return compileable_specialization(et, match)
810810
end
811811

@@ -819,7 +819,7 @@ function analyze_method!(match::MethodMatch, atypes::Vector{Any},
819819
# If we don't have caches here, delay resolving this MethodInstance
820820
# until the batch inlining step (or an external post-processing pass)
821821
state.mi_cache === nothing && return todo
822-
return resolve_todo(todo, state)
822+
return resolve_todo(todo, state, flag)
823823
end
824824

825825
function InliningTodo(mi::MethodInstance, ir::IRCode)
@@ -1044,7 +1044,7 @@ is_builtin(s::Signature) =
10441044
s.ft Builtin
10451045

10461046
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
1047-
state::InliningState, todo::Vector{Pair{Int, Any}})
1047+
state::InliningState, todo::Vector{Pair{Int, Any}}, flag::UInt8)
10481048
stmt = ir.stmts[idx][:inst]
10491049
calltype = ir.stmts[idx][:type]
10501050

@@ -1058,17 +1058,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result):
10581058
atypes = atypes[4:end]
10591059
pushfirst!(atypes, atype0)
10601060

1061-
if isa(result, InferenceResult)
1061+
if isa(result, InferenceResult) && !is_stmt_noinline(flag)
10621062
(; mi) = item = InliningTodo(result, atypes, calltype)
10631063
validate_sparams(mi.sparam_vals) || return nothing
10641064
if argtypes_to_type(atypes) <: mi.def.sig
1065-
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1065+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
10661066
handle_single_case!(ir, stmt, idx, item, true, todo)
10671067
return nothing
10681068
end
10691069
end
10701070

1071-
result = analyze_method!(match, atypes, state, calltype)
1071+
result = analyze_method!(match, atypes, state, calltype, flag)
10721072
handle_single_case!(ir, stmt, idx, result, true, todo)
10731073
return nothing
10741074
end
@@ -1163,7 +1163,7 @@ end
11631163

11641164
function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int, @nospecialize(stmt),
11651165
sig::Signature, @nospecialize(calltype), infos::Vector{MethodMatchInfo},
1166-
state::InliningState)
1166+
state::InliningState, flag::UInt8)
11671167
cases = Pair{Any, Any}[]
11681168
signature_union = Union{}
11691169
only_method = nothing # keep track of whether there is one matching method
@@ -1197,7 +1197,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
11971197
fully_covered = false
11981198
continue
11991199
end
1200-
case = analyze_method!(match, sig.atypes, state, calltype)
1200+
case = analyze_method!(match, sig.atypes, state, calltype, flag)
12011201
if case === nothing
12021202
fully_covered = false
12031203
continue
@@ -1224,7 +1224,7 @@ function analyze_single_call!(ir::IRCode, todo::Vector{Pair{Int, Any}}, idx::Int
12241224
match = meth[1]
12251225
end
12261226
fully_covered = true
1227-
case = analyze_method!(match, sig.atypes, state, calltype)
1227+
case = analyze_method!(match, sig.atypes, state, calltype, flag)
12281228
case === nothing && return
12291229
push!(cases, Pair{Any,Any}(match.spec_types, case))
12301230
end
@@ -1246,7 +1246,7 @@ end
12461246

12471247
function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
12481248
info::ConstCallInfo, sig::Signature, @nospecialize(calltype),
1249-
state::InliningState,
1249+
state::InliningState, flag::UInt8,
12501250
isinvoke::Bool, todo::Vector{Pair{Int, Any}})
12511251
# when multiple matches are found, bail out and later inliner will union-split this signature
12521252
# TODO effectively use multiple constant analysis results here
@@ -1258,7 +1258,7 @@ function maybe_handle_const_call!(ir::IRCode, idx::Int, stmt::Expr,
12581258
validate_sparams(mi.sparam_vals) || return true
12591259
mthd_sig = mi.def.sig
12601260
mistypes = mi.specTypes
1261-
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1261+
state.mi_cache !== nothing && (item = resolve_todo(item, state, flag))
12621262
if sig.atype <: mthd_sig
12631263
handle_single_case!(ir, stmt, idx, item, isinvoke, todo)
12641264
return true
@@ -1296,6 +1296,8 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
12961296
info = info.info
12971297
end
12981298

1299+
flag = ir.stmts[idx][:flag]
1300+
12991301
# Inference determined this couldn't be analyzed. Don't question it.
13001302
if info === false
13011303
continue
@@ -1305,23 +1307,24 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13051307
# it'll have performed a specialized analysis for just this case. Use its
13061308
# result.
13071309
if isa(info, ConstCallInfo)
1308-
if maybe_handle_const_call!(ir, idx, stmt, info, sig, calltype, state, sig.f === Core.invoke, todo)
1310+
if !is_stmt_noinline(flag) && maybe_handle_const_call!(
1311+
ir, idx, stmt, info, sig, calltype, state, flag, sig.f === Core.invoke, todo)
13091312
continue
13101313
else
13111314
info = info.call
13121315
end
13131316
end
13141317

13151318
if isa(info, OpaqueClosureCallInfo)
1316-
result = analyze_method!(info.match, sig.atypes, state, calltype)
1319+
result = analyze_method!(info.match, sig.atypes, state, calltype, flag)
13171320
handle_single_case!(ir, stmt, idx, result, false, todo)
13181321
continue
13191322
end
13201323

13211324
# Handle invoke
13221325
if sig.f === Core.invoke
13231326
if isa(info, InvokeCallInfo)
1324-
inline_invoke!(ir, idx, sig, info, state, todo)
1327+
inline_invoke!(ir, idx, sig, info, state, todo, flag)
13251328
end
13261329
continue
13271330
end
@@ -1335,7 +1338,7 @@ function assemble_inline_todo!(ir::IRCode, state::InliningState)
13351338
continue
13361339
end
13371340

1338-
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state)
1341+
analyze_single_call!(ir, todo, idx, stmt, sig, calltype, infos, state, flag)
13391342
end
13401343
todo
13411344
end

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ function maybe_compress_codeinfo(interp::AbstractInterpreter, linfo::MethodInsta
343343
nslots = length(ci.slotflags)
344344
resize!(ci.slottypes::Vector{Any}, nslots)
345345
resize!(ci.slotnames, nslots)
346-
return ccall(:jl_compress_ir, Any, (Any, Any), def, ci)
346+
return ccall(:jl_compress_ir, Vector{UInt8}, (Any, Any), def, ci)
347347
else
348348
return ci
349349
end

base/compiler/utilities.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ end
5959

6060
# Meta expression head, these generally can't be deleted even when they are
6161
# in a dead branch but can be ignored when analyzing uses/liveness.
62-
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta || head === :loopinfo)
62+
is_meta_expr_head(head::Symbol) = (head === :inbounds || head === :boundscheck || head === :meta ||
63+
head === :loopinfo || head === :inline || head === :noinline)
6364

6465
sym_isless(a::Symbol, b::Symbol) = ccall(:strcmp, Int32, (Ptr{UInt8}, Ptr{UInt8}), a, b) < 0
6566

@@ -188,7 +189,7 @@ function specialize_method(method::Method, @nospecialize(atypes), sparams::Simpl
188189
if preexisting
189190
# check cached specializations
190191
# for an existing result stored there
191-
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)
192+
return ccall(:jl_specializations_lookup, Any, (Any, Any), method, atypes)::Union{Nothing,MethodInstance}
192193
end
193194
return ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), method, atypes, sparams)
194195
end

0 commit comments

Comments
 (0)