Skip to content

Commit bc6da93

Browse files
authored
inference: enable constant propagation for invoked calls, fixes #41024 (#41383)
* inference: enable constant propagation for `invoke`d calls, fixes #41024 Especially useful for defining mixins with typed interface fields, e.g. ```julia abstract type AbstractInterface end # mixin, which expects common field `x::Int` function Base.getproperty(x::AbstractInterface, sym::Symbol) if sym === :x return getfield(x, sym)::Int # inferred field else return getfield(x, sym) # fallback end end abstract type AbstractInterfaceExtended <: AbstractInterface end # extended mixin, which expects additional common field `y::Rational{Int}` function Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol) if sym === :y return getfield(x, sym)::Rational{Int} end return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol) end ``` As a bonus, inliner is able to use `InferenceResult` as a fast inlining pass for constant-prop'ed `invoke`s * improve compile-time latency * Update base/compiler/abstractinterpretation.jl * Update base/compiler/abstractinterpretation.jl
1 parent 7c566b1 commit bc6da93

File tree

4 files changed

+91
-17
lines changed

4 files changed

+91
-17
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1135,7 +1135,8 @@ function abstract_call_unionall(argtypes::Vector{Any})
11351135
end
11361136

11371137
function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
1138-
ft = widenconst(argtype_by_index(argtypes, 2))
1138+
ft′ = argtype_by_index(argtypes, 2)
1139+
ft = widenconst(ft′)
11391140
ft === Bottom && return CallMeta(Bottom, false)
11401141
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
11411142
types === Bottom && return CallMeta(Bottom, false)
@@ -1149,15 +1150,30 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
11491150
nargtype = Tuple{ft, nargtype.parameters...}
11501151
argtype = Tuple{ft, argtype.parameters...}
11511152
result = findsup(types, method_table(interp))
1152-
if result === nothing
1153-
return CallMeta(Any, false)
1154-
end
1153+
result === nothing && return CallMeta(Any, false)
11551154
method, valid_worlds = result
11561155
update_valid_age!(sv, valid_worlds)
11571156
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
1158-
rt, edge = typeinf_edge(interp, method, ti, env, sv)
1157+
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
11591158
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
1160-
return CallMeta(rt, InvokeCallInfo(MethodMatch(ti, env, method, argtype <: method.sig)))
1159+
match = MethodMatch(ti, env, method, argtype <: method.sig)
1160+
# try constant propagation with manual inlinings of some of the heuristics
1161+
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
1162+
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
1163+
argtypes′ = argtypes[4:end]
1164+
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
1165+
pushfirst!(argtypes′, ft)
1166+
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
1167+
# for i in 1:length(argtypes′)
1168+
# t, a = ti.parameters[i], argtypes′[i]
1169+
# argtypes′[i] = t ⊑ a ? t : a
1170+
# end
1171+
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
1172+
if const_rt !== rt && const_rt rt
1173+
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
1174+
else
1175+
return CallMeta(rt, InvokeCallInfo(match, nothing))
1176+
end
11611177
end
11621178

11631179
# call where the function is known exactly
@@ -1291,17 +1307,12 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
12911307
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
12921308
#print("call ", e.args[1], argtypes, "\n\n")
12931309
ft = argtypes[1]
1294-
if isa(ft, Const)
1295-
f = ft.val
1296-
elseif isconstType(ft)
1297-
f = ft.parameters[1]
1298-
elseif isa(ft, DataType) && isdefined(ft, :instance)
1299-
f = ft.instance
1300-
elseif isa(ft, PartialOpaque)
1310+
f = argtype_to_function(ft)
1311+
if isa(ft, PartialOpaque)
13011312
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
13021313
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
13031314
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
1304-
else
1315+
elseif f === nothing
13051316
# non-constant function, but the number of arguments is known
13061317
# and the ft is not a Builtin or IntrinsicFunction
13071318
if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{}
@@ -1313,6 +1324,18 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
13131324
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
13141325
end
13151326

1327+
function argtype_to_function(@nospecialize(ft))
1328+
if isa(ft, Const)
1329+
return ft.val
1330+
elseif isconstType(ft)
1331+
return ft.parameters[1]
1332+
elseif isa(ft, DataType) && isdefined(ft, :instance)
1333+
return ft.instance
1334+
else
1335+
return nothing
1336+
end
1337+
end
1338+
13161339
function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
13171340
isref = false
13181341
if T === Bottom

base/compiler/ssair/inlining.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1049,12 +1049,12 @@ is_builtin(s::Signature) =
10491049
isa(s.f, Builtin) ||
10501050
s.ft Builtin
10511051

1052-
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
1052+
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
10531053
state::InliningState, todo::Vector{Pair{Int, Any}})
10541054
stmt = ir.stmts[idx][:inst]
10551055
calltype = ir.stmts[idx][:type]
10561056

1057-
if !info.match.fully_covers
1057+
if !match.fully_covers
10581058
# TODO: We could union split out the signature check and continue on
10591059
return nothing
10601060
end
@@ -1064,7 +1064,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
10641064
atypes = atypes[4:end]
10651065
pushfirst!(atypes, atype0)
10661066

1067-
result = analyze_method!(info.match, atypes, state, calltype)
1067+
if isa(result, InferenceResult)
1068+
item = InliningTodo(result, atypes, calltype)
1069+
validate_sparams(item.mi.sparam_vals) || return nothing
1070+
if argtypes_to_type(atypes) <: item.mi.def.sig
1071+
state.mi_cache !== nothing && (item = resolve_todo(item, state))
1072+
handle_single_case!(ir, stmt, idx, item, true, todo)
1073+
return nothing
1074+
end
1075+
end
1076+
1077+
result = analyze_method!(match, atypes, state, calltype)
10681078
handle_single_case!(ir, stmt, idx, result, true, todo)
10691079
return nothing
10701080
end

base/compiler/stmtinfo.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ method being processed.
108108
"""
109109
struct InvokeCallInfo
110110
match::MethodMatch
111+
result::Union{Nothing,InferenceResult}
111112
end
112113

113114
struct OpaqueClosureCallInfo

test/compiler/inference.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3355,3 +3355,43 @@ let
33553355
Expr(:opaque_closure_method, nothing, 2, LineNumberNode(0, nothing), ci)))(true, 1.0)
33563356
@test Base.return_types(oc, Tuple{}) == Any[Float64]
33573357
end
3358+
3359+
@testset "constant prop' on `invoke` calls" begin
3360+
m = Module()
3361+
3362+
# simple cases
3363+
@eval m begin
3364+
f(a::Any, sym::Bool) = sym ? Any : :any
3365+
f(a::Number, sym::Bool) = sym ? Number : :number
3366+
end
3367+
@test (@eval m Base.return_types((Any,)) do a
3368+
Base.@invoke f(a::Any, true::Bool)
3369+
end) == Any[Type{Any}]
3370+
@test (@eval m Base.return_types((Any,)) do a
3371+
Base.@invoke f(a::Number, true::Bool)
3372+
end) == Any[Type{Number}]
3373+
@test (@eval m Base.return_types((Any,)) do a
3374+
Base.@invoke f(a::Any, false::Bool)
3375+
end) == Any[Symbol]
3376+
@test (@eval m Base.return_types((Any,)) do a
3377+
Base.@invoke f(a::Number, false::Bool)
3378+
end) == Any[Symbol]
3379+
3380+
# https://github.com/JuliaLang/julia/issues/41024
3381+
@eval m begin
3382+
# mixin, which expects common field `x::Int`
3383+
abstract type AbstractInterface end
3384+
Base.getproperty(x::AbstractInterface, sym::Symbol) =
3385+
sym === :x ? getfield(x, sym)::Int :
3386+
return getfield(x, sym) # fallback
3387+
3388+
# extended mixin, which expects additional field `y::Rational{Int}`
3389+
abstract type AbstractInterfaceExtended <: AbstractInterface end
3390+
Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol) =
3391+
sym === :y ? getfield(x, sym)::Rational{Int} :
3392+
return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol)
3393+
end
3394+
@test (@eval m Base.return_types((AbstractInterfaceExtended,)) do x
3395+
x.x
3396+
end) == Any[Int]
3397+
end

0 commit comments

Comments
 (0)