Skip to content

Commit e3d366f

Browse files
authored
generators: expose caller world to GeneratedFunctionStub (#48611)
Expose the demanded world to the GeneratedFunctionStub caller, for users such as Cassette. If this argument is used, the user must return a CodeInfo with the min/max world field set correctly. Make the internal representation a tiny bit more compact also, removing a little bit of unnecessary metadata. Remove support for returning `body isa CodeInfo` via this wrapper, since it is impossible to return a correct object via the GeneratedFunctionStub since it strips off the world argument, which is required for it to do so. This also removes support for not inferring these fully (expand_early=false). Also answer method lookup queries about the future correctly, by refusing to answer them. This helps keeps execution correct as methods get added to the system asynchronously. This reverts "fix #25678: return matters for generated functions (#40778)" (commit 92c84bf), since this is no longer sensible to return here anyways, so it is no longer permitted or supported by this macro. Fixes various issues where we failed to specify the correct world.
1 parent 6412a56 commit e3d366f

31 files changed

+203
-203
lines changed

base/Base.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ in_sysimage(pkgid::PkgId) = pkgid in _sysimage_modules
478478
for match = _methods(+, (Int, Int), -1, get_world_counter())
479479
m = match.method
480480
delete!(push!(Set{Method}(), m), m)
481-
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match)))
481+
copy(Core.Compiler.retrieve_code_info(Core.Compiler.specialize_method(match), typemax(UInt)))
482482

483483
empty!(Set())
484484
push!(push!(Set{Union{GlobalRef,Symbol}}(), :two), GlobalRef(Base, :two))

base/boot.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -590,28 +590,25 @@ println(@nospecialize a...) = println(stdout, a...)
590590

591591
struct GeneratedFunctionStub
592592
gen
593-
argnames::Array{Any,1}
594-
spnames::Union{Nothing, Array{Any,1}}
595-
line::Int
596-
file::Symbol
597-
expand_early::Bool
593+
argnames::SimpleVector
594+
spnames::SimpleVector
598595
end
599596

600-
# invoke and wrap the results of @generated
601-
function (g::GeneratedFunctionStub)(@nospecialize args...)
597+
# invoke and wrap the results of @generated expression
598+
function (g::GeneratedFunctionStub)(world::UInt, source::LineNumberNode, @nospecialize args...)
599+
# args is (spvals..., argtypes...)
602600
body = g.gen(args...)
603-
if body isa CodeInfo
604-
return body
605-
end
606-
lam = Expr(:lambda, g.argnames,
607-
Expr(Symbol("scope-block"),
601+
file = source.file
602+
file isa Symbol || (file = :none)
603+
lam = Expr(:lambda, Expr(:argnames, g.argnames...).args,
604+
Expr(:var"scope-block",
608605
Expr(:block,
609-
LineNumberNode(g.line, g.file),
610-
Expr(:meta, :push_loc, g.file, Symbol("@generated body")),
606+
source,
607+
Expr(:meta, :push_loc, file, :var"@generated body"),
611608
Expr(:return, body),
612609
Expr(:meta, :pop_loc))))
613610
spnames = g.spnames
614-
if spnames === nothing
611+
if spnames === svec()
615612
return lam
616613
else
617614
return Expr(Symbol("with-static-parameters"), lam, spnames...)

base/compiler/abstractinterpretation.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -569,7 +569,7 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
569569
break
570570
end
571571
topmost === nothing || continue
572-
if edge_matches_sv(infstate, method, sig, sparams, hardlimit, sv)
572+
if edge_matches_sv(interp, infstate, method, sig, sparams, hardlimit, sv)
573573
topmost = infstate
574574
edgecycle = true
575575
end
@@ -677,12 +677,13 @@ function abstract_call_method(interp::AbstractInterpreter, method::Method, @nosp
677677
return MethodCallResult(rt, edgecycle, edgelimited, edge, effects)
678678
end
679679

680-
function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
680+
function edge_matches_sv(interp::AbstractInterpreter, frame::InferenceState, method::Method, @nospecialize(sig), sparams::SimpleVector, hardlimit::Bool, sv::InferenceState)
681681
# The `method_for_inference_heuristics` will expand the given method's generator if
682682
# necessary in order to retrieve this field from the generated `CodeInfo`, if it exists.
683683
# The other `CodeInfo`s we inspect will already have this field inflated, so we just
684684
# access it directly instead (to avoid regeneration).
685-
callee_method2 = method_for_inference_heuristics(method, sig, sparams) # Union{Method, Nothing}
685+
world = get_world_counter(interp)
686+
callee_method2 = method_for_inference_heuristics(method, sig, sparams, world) # Union{Method, Nothing}
686687

687688
inf_method2 = frame.src.method_for_inference_limit_heuristics # limit only if user token match
688689
inf_method2 isa Method || (inf_method2 = nothing)
@@ -719,11 +720,11 @@ function edge_matches_sv(frame::InferenceState, method::Method, @nospecialize(si
719720
end
720721

721722
# This function is used for computing alternate limit heuristics
722-
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector)
723-
if isdefined(method, :generator) && method.generator.expand_early && may_invoke_generator(method, sig, sparams)
723+
function method_for_inference_heuristics(method::Method, @nospecialize(sig), sparams::SimpleVector, world::UInt)
724+
if isdefined(method, :generator) && !(method.generator isa Core.GeneratedFunctionStub) && may_invoke_generator(method, sig, sparams)
724725
method_instance = specialize_method(method, sig, sparams)
725726
if isa(method_instance, MethodInstance)
726-
cinfo = get_staged(method_instance)
727+
cinfo = get_staged(method_instance, world)
727728
if isa(cinfo, CodeInfo)
728729
method2 = cinfo.method_for_inference_limit_heuristics
729730
if method2 isa Method

base/compiler/bootstrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ let interp = NativeInterpreter()
3636
else
3737
tt = Tuple{typeof(f), Vararg{Any}}
3838
end
39-
for m in _methods_by_ftype(tt, 10, typemax(UInt))::Vector
39+
for m in _methods_by_ftype(tt, 10, get_world_counter())::Vector
4040
# remove any TypeVars from the intersection
4141
m = m::MethodMatch
4242
typ = Any[m.spec_types.parameters...]

base/compiler/inferencestate.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,8 @@ end
363363

364364
function InferenceState(result::InferenceResult, cache::Symbol, interp::AbstractInterpreter)
365365
# prepare an InferenceState object for inferring lambda
366-
src = retrieve_code_info(result.linfo)
366+
world = get_world_counter(interp)
367+
src = retrieve_code_info(result.linfo, world)
367368
src === nothing && return nothing
368369
validate_code_in_debug_mode(result.linfo, src, "lowered")
369370
return InferenceState(result, src, cache, interp)

base/compiler/optimize.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,8 @@ function OptimizationState(linfo::MethodInstance, src::CodeInfo, params::Optimiz
183183
return OptimizationState(linfo, src, nothing, stmt_info, mod, sptypes, slottypes, inlining, nothing, false)
184184
end
185185
function OptimizationState(linfo::MethodInstance, params::OptimizationParams, interp::AbstractInterpreter)
186-
src = retrieve_code_info(linfo)
186+
world = get_world_counter(interp)
187+
src = retrieve_code_info(linfo, world)
187188
src === nothing && return nothing
188189
return OptimizationState(linfo, src, params, interp)
189190
end

base/compiler/typeinfer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1035,7 +1035,7 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance)
10351035
end
10361036
end
10371037
if ccall(:jl_get_module_infer, Cint, (Any,), method.module) == 0 && !generating_sysimg()
1038-
return retrieve_code_info(mi)
1038+
return retrieve_code_info(mi, get_world_counter(interp))
10391039
end
10401040
lock_mi_inference(interp, mi)
10411041
result = InferenceResult(mi, typeinf_lattice(interp))

base/compiler/types.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ struct NativeInterpreter <: AbstractInterpreter
330330
cache = Vector{InferenceResult}() # Initially empty cache
331331

332332
# Sometimes the caller is lazy and passes typemax(UInt).
333-
# we cap it to the current world age
333+
# we cap it to the current world age for correctness
334334
if world == typemax(UInt)
335335
world = get_world_counter()
336336
end

base/compiler/utilities.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -114,23 +114,23 @@ end
114114
invoke_api(li::CodeInstance) = ccall(:jl_invoke_api, Cint, (Any,), li)
115115
use_const_api(li::CodeInstance) = invoke_api(li) == 2
116116

117-
function get_staged(mi::MethodInstance)
117+
function get_staged(mi::MethodInstance, world::UInt)
118118
may_invoke_generator(mi) || return nothing
119119
try
120120
# user code might throw errors – ignore them
121-
ci = ccall(:jl_code_for_staged, Any, (Any,), mi)::CodeInfo
121+
ci = ccall(:jl_code_for_staged, Any, (Any, UInt), mi, world)::CodeInfo
122122
return ci
123123
catch
124124
return nothing
125125
end
126126
end
127127

128-
function retrieve_code_info(linfo::MethodInstance)
128+
function retrieve_code_info(linfo::MethodInstance, world::UInt)
129129
m = linfo.def::Method
130130
c = nothing
131131
if isdefined(m, :generator)
132132
# user code might throw errors – ignore them
133-
c = get_staged(linfo)
133+
c = get_staged(linfo, world)
134134
end
135135
if c === nothing && isdefined(m, :source)
136136
src = m.source

base/compiler/validation.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -200,15 +200,14 @@ end
200200

201201
"""
202202
validate_code!(errors::Vector{InvalidCodeError}, mi::MethodInstance,
203-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
203+
c::Union{Nothing,CodeInfo})
204204
205205
Validate `mi`, logging any violation by pushing an `InvalidCodeError` into `errors`.
206206
207207
If `isa(c, CodeInfo)`, also call `validate_code!(errors, c)`. It is assumed that `c` is
208-
the `CodeInfo` instance associated with `mi`.
208+
a `CodeInfo` instance associated with `mi`.
209209
"""
210-
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance,
211-
c::Union{Nothing,CodeInfo} = Core.Compiler.retrieve_code_info(mi))
210+
function validate_code!(errors::Vector{InvalidCodeError}, mi::Core.MethodInstance, c::Union{Nothing,CodeInfo})
212211
is_top_level = mi.def isa Module
213212
if is_top_level
214213
mnargs = 0

0 commit comments

Comments
 (0)