Skip to content

Commit 9340251

Browse files
committed
aot: move jl_insert_backedges to Julia side
With #56447, the dependency between `jl_insert_backedges` and method insertion has been eliminated, allowing `jl_insert_backedges` to be performed after loading. As a result, it is now possible to move `jl_insert_backedges` to the Julia side. Currently this commit simply moves the implementation without adding any new features.
1 parent 5040e48 commit 9340251

File tree

7 files changed

+330
-379
lines changed

7 files changed

+330
-379
lines changed

base/Base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,7 @@ include("uuid.jl")
262262
include("pkgid.jl")
263263
include("toml_parser.jl")
264264
include("linking.jl")
265+
include("staticdata.jl")
265266
include("loading.jl")
266267

267268
# misc useful functions & macros

base/loading.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1280,10 +1280,12 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
12801280
sv = try
12811281
if ocachepath !== nothing
12821282
@debug "Loading object cache file $ocachepath for $(repr("text/plain", pkg))"
1283-
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint), ocachepath, depmods, false, pkg.name, ignore_native)
1283+
ccall(:jl_restore_package_image_from_file, Any, (Cstring, Any, Cint, Cstring, Cint),
1284+
ocachepath, depmods, #=completeinfo=#false, pkg.name, ignore_native)
12841285
else
12851286
@debug "Loading cache file $path for $(repr("text/plain", pkg))"
1286-
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring), path, depmods, false, pkg.name)
1287+
ccall(:jl_restore_incremental, Any, (Cstring, Any, Cint, Cstring),
1288+
path, depmods, #=completeinfo=#false, pkg.name)
12871289
end
12881290
finally
12891291
lock(require_lock)
@@ -1292,6 +1294,10 @@ function _include_from_serialized(pkg::PkgId, path::String, ocachepath::Union{No
12921294
return sv
12931295
end
12941296

1297+
edges = sv[3]::Vector{Any}
1298+
ext_edges = sv[4]::Union{Nothing,Vector{Any}}
1299+
StaticData.insert_backedges(edges, ext_edges)
1300+
12951301
restored = register_restored_modules(sv, pkg, path)
12961302

12971303
for M in restored
@@ -4198,7 +4204,7 @@ function precompile(@nospecialize(argt::Type))
41984204
end
41994205

42004206
# Variants that work for `invoke`d calls for which the signature may not be sufficient
4201-
precompile(mi::Core.MethodInstance, world::UInt=get_world_counter()) =
4207+
precompile(mi::MethodInstance, world::UInt=get_world_counter()) =
42024208
(ccall(:jl_compile_method_instance, Cvoid, (Any, Ptr{Cvoid}, UInt), mi, C_NULL, world); return true)
42034209

42044210
"""
@@ -4214,7 +4220,7 @@ end
42144220

42154221
function precompile(@nospecialize(argt::Type), m::Method)
42164222
atype, sparams = ccall(:jl_type_intersection_with_env, Any, (Any, Any), argt, m.sig)::SimpleVector
4217-
mi = Core.Compiler.specialize_method(m, atype, sparams)
4223+
mi = Base.Compiler.specialize_method(m, atype, sparams)
42184224
return precompile(mi)
42194225
end
42204226

base/pcre.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ THREAD_MATCH_CONTEXTS::Vector{Ptr{Cvoid}} = [C_NULL]
2929
PCRE_COMPILE_LOCK = nothing
3030

3131
_tid() = Int(ccall(:jl_threadid, Int16, ())) + 1
32-
_mth() = Int(Core.Intrinsics.atomic_pointerref(cglobal(:jl_n_threads, Cint), :acquire))
32+
_mth() = Base.Threads.maxthreadid()
3333

3434
function get_local_match_context()
3535
tid = _tid()

base/staticdata.jl

Lines changed: 305 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,305 @@
1+
# This file is a part of Julia. License is MIT: https://julialang.org/license
2+
3+
module StaticData
4+
5+
using Core: CodeInstance, MethodInstance
6+
using Base: get_world_counter
7+
8+
const WORLD_AGE_REVALIDATION_SENTINEL::UInt = 1 # needs to sync with staticdata.c
9+
const _jl_debug_method_invalidation = Ref{Union{Nothing,Vector{Any}}}(nothing)
10+
debug_method_invalidation(onoff::Bool) =
11+
_jl_debug_method_invalidation[] = onoff ? Any[] : nothing
12+
13+
function get_ci_mi(codeinst::CodeInstance)
14+
def = codeinst.def
15+
if def isa Core.ABIOverride
16+
return def.def
17+
else
18+
return def::MethodInstance
19+
end
20+
end
21+
22+
# Restore backedges to external targets
23+
# `edges` = [caller1, ...], the list of worklist-owned code instances internally
24+
# `ext_ci_list` = [caller1, ...], the list of worklist-owned code instances externally
25+
function insert_backedges(edges::Vector{Any}, ext_ci_list::Union{Nothing,Vector{Any}})
26+
# determine which CodeInstance objects are still valid in our image
27+
# to enable any applicable new codes
28+
stack = CodeInstance[]
29+
visiting = IdDict{CodeInstance,Int}()
30+
_insert_backedges(edges, stack, visiting)
31+
if ext_ci_list !== nothing
32+
_insert_backedges(ext_ci_list, stack, visiting, #=external=#true)
33+
end
34+
end
35+
36+
function _insert_backedges(edges::Vector{Any}, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int}, external::Bool=false)
37+
for i = 1:length(edges)
38+
codeinst = edges[i]::CodeInstance
39+
verify_method_graph(codeinst, stack, visiting)
40+
minvalid = codeinst.min_world
41+
maxvalid = codeinst.max_world
42+
if maxvalid minvalid
43+
if get_world_counter() == maxvalid
44+
# if this callee is still valid, add all the backedges
45+
Base.Compiler.store_backedges(codeinst, codeinst.edges)
46+
end
47+
if get_world_counter() == maxvalid
48+
maxvalid = typemax(UInt)
49+
@atomic codeinst.max_world = maxvalid
50+
end
51+
if external
52+
caller = get_ci_mi(codeinst)
53+
@assert isdefined(codeinst, :inferred) # See #53586, #53109
54+
inferred = @ccall jl_rettype_inferred(
55+
codeinst.owner::Any, caller::Any, minvalid::UInt, maxvalid::UInt)::Any
56+
if inferred !== nothing
57+
# We already got a code instance for this world age range from
58+
# somewhere else - we don't need this one.
59+
else
60+
@ccall jl_mi_cache_insert(caller::Any, codeinst::Any)::Cvoid
61+
end
62+
end
63+
end
64+
end
65+
end
66+
67+
function verify_method_graph(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
68+
@assert isempty(stack)
69+
@assert isempty(visiting)
70+
child_cycle, minworld, maxworld = verify_method(codeinst, stack, visiting)
71+
@assert child_cycle == 0
72+
@assert isempty(stack)
73+
empty!(visiting)
74+
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
75+
@assert maxworld == 0 || codeinst.min_world == minworld
76+
@assert codeinst.max_world == maxworld
77+
end
78+
end
79+
80+
# Test all edges relevant to a method:
81+
# - Visit the entire call graph, starting from edges[idx] to determine if that method is valid
82+
# - Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
83+
# and slightly modified with an early termination option once the computation reaches its minimum
84+
function verify_method(codeinst::CodeInstance, stack::Vector{CodeInstance}, visiting::IdDict{CodeInstance,Int})
85+
world = codeinst.min_world
86+
let max_valid2 = codeinst.max_world
87+
if max_valid2 WORLD_AGE_REVALIDATION_SENTINEL
88+
return 0, world, max_valid2
89+
end
90+
end
91+
current_world = get_world_counter()
92+
local minworld::UInt, maxworld::UInt = 1, current_world
93+
@assert get_ci_mi(codeinst).def isa Method
94+
if haskey(visiting, codeinst)
95+
return visiting[codeinst], minworld, maxworld
96+
end
97+
push!(stack, codeinst)
98+
depth = length(stack)
99+
visiting[codeinst] = depth
100+
# TODO JL_TIMING(VERIFY_IMAGE, VERIFY_Methods)
101+
callees = codeinst.edges
102+
# verify current edges
103+
if isempty(callees)
104+
# quick return: no edges to verify (though we probably shouldn't have gotten here from WORLD_AGE_REVALIDATION_SENTINEL)
105+
elseif maxworld == unsafe_load(cglobal(:jl_require_world, UInt))
106+
# if no new worlds were allocated since serializing the base module, then no new validation is worth doing right now either
107+
minworld = maxworld
108+
else
109+
j = 1
110+
while j length(callees)
111+
local min_valid2::UInt, max_valid2::UInt
112+
edge = callees[j]
113+
@assert !(edge isa Method) # `Method`-edge isn't allowed for the optimized one-edge format
114+
if edge isa Core.BindingPartition
115+
j += 1
116+
continue
117+
end
118+
if edge isa CodeInstance
119+
edge = get_ci_mi(edge)
120+
end
121+
if edge isa MethodInstance
122+
sig = typeintersect((edge.def::Method).sig, edge.specTypes) # TODO??
123+
min_valid2, max_valid2, matches = verify_call(sig, callees, j, 1, world)
124+
j += 1
125+
elseif edge isa Int
126+
sig = callees[j+1]
127+
min_valid2, max_valid2, matches = verify_call(sig, callees, j+2, edge, world)
128+
j += 2 + edge
129+
edge = sig
130+
else
131+
callee = callees[j+1]
132+
if callee isa Core.MethodTable # skip the legacy edge (missing backedge)
133+
j += 2
134+
continue
135+
end
136+
if callee isa CodeInstance
137+
callee = get_ci_mi(callee)
138+
end
139+
if callee isa MethodInstance
140+
meth = callee.def::Method
141+
else
142+
meth = callee::Method
143+
end
144+
min_valid2, max_valid2 = verify_invokesig(edge, meth, world)
145+
matches = nothing
146+
j += 2
147+
end
148+
if minworld < min_valid2
149+
minworld = min_valid2
150+
end
151+
if maxworld > max_valid2
152+
maxworld = max_valid2
153+
end
154+
invalidations = _jl_debug_method_invalidation[]
155+
if max_valid2 typemax(UInt) && invalidations !== nothing
156+
push!(invalidations, edge, "insert_backedges_callee", codeinst, matches)
157+
end
158+
if max_valid2 == 0 && invalidations === nothing
159+
break
160+
end
161+
end
162+
end
163+
# verify recursive edges (if valid, or debugging)
164+
cycle = depth
165+
cause = codeinst
166+
if maxworld 0 || _jl_debug_method_invalidation[] !== nothing
167+
for j = 1:length(callees)
168+
edge = callees[j]
169+
if !(edge isa CodeInstance)
170+
continue
171+
end
172+
callee = edge
173+
local min_valid2::UInt, max_valid2::UInt
174+
child_cycle, min_valid2, max_valid2 = verify_method(callee, stack, visiting)
175+
if minworld < min_valid2
176+
minworld = min_valid2
177+
end
178+
if minworld > max_valid2
179+
max_valid2 = 0
180+
end
181+
if maxworld > max_valid2
182+
cause = callee
183+
maxworld = max_valid2
184+
end
185+
if max_valid2 == 0
186+
# found what we were looking for, so terminate early
187+
break
188+
elseif child_cycle 0 && child_cycle < cycle
189+
# record the cycle will resolve at depth "cycle"
190+
cycle = child_cycle
191+
end
192+
end
193+
end
194+
if maxworld 0 && cycle depth
195+
return cycle, minworld, maxworld
196+
end
197+
# If we are the top of the current cycle, now mark all other parts of
198+
# our cycle with what we found.
199+
# Or if we found a failed edge, also mark all of the other parts of the
200+
# cycle as also having a failed edge.
201+
while length(stack) depth
202+
child = pop!(stack)
203+
if Threads.maxthreadid() == 1 # a different thread might simultaneously come to a different, but equally valid, alternative result
204+
@assert child.max_world == WORLD_AGE_REVALIDATION_SENTINEL
205+
@assert minworld child.min_world
206+
end
207+
if maxworld 0
208+
@atomic child.min_world = minworld
209+
end
210+
@atomic child.max_world = maxworld
211+
@assert visiting[child] == length(stack) + 1
212+
delete!(visiting, child)
213+
invalidations = _jl_debug_method_invalidation[]
214+
if invalidations !== nothing && maxworld < current_world
215+
push!(invalidations, child, "verify_methods", cause)
216+
end
217+
end
218+
return 0, minworld, maxworld
219+
end
220+
221+
function verify_call(@nospecialize(sig), expecteds::Core.SimpleVector, i::Int, n::Int, world::UInt)
222+
# verify that these edges intersect with the same methods as before
223+
lim = _jl_debug_method_invalidation[] !== nothing ? Int(typemax(Int32)) : n
224+
minworld = Ref{UInt}(1)
225+
maxworld = Ref{UInt}(typemax(UInt))
226+
has_ambig = Ref{Int32}(0)
227+
result = Base._methods_by_ftype(sig, nothing, lim, world, #=ambig=#false, minworld, maxworld, has_ambig)
228+
if result === nothing
229+
maxworld[] = 0
230+
else
231+
# setdiff!(result, expected)
232+
if length(result) n
233+
maxworld[] = 0
234+
end
235+
ins = 0
236+
for k = 1:length(result)
237+
match = result[k]::Core.MethodMatch
238+
local found = false
239+
for j = 1:n
240+
t = expecteds[i+j-1]
241+
if t isa Method
242+
meth = t
243+
else
244+
if t isa CodeInstance
245+
t = get_ci_mi(t)
246+
else
247+
t = t::MethodInstance
248+
end
249+
meth = t.def::Method
250+
end
251+
if match.method == meth
252+
found = true
253+
break
254+
end
255+
end
256+
if !found
257+
# intersection has a new method or a method was
258+
# deleted--this is now probably no good, just invalidate
259+
# everything about it now
260+
maxworld[] = 0
261+
if _jl_debug_method_invalidation[] === nothing
262+
break
263+
end
264+
ins += 1
265+
result[ins] = match.method
266+
end
267+
end
268+
if maxworld[] typemax(UInt) && _jl_debug_method_invalidation[] !== nothing
269+
resize!(result, ins)
270+
end
271+
end
272+
return minworld[], maxworld[], result
273+
end
274+
275+
function verify_invokesig(@nospecialize(invokesig), expected::Method, world::UInt)
276+
@assert invokesig isa Type
277+
local minworld::UInt, maxworld::UInt
278+
if invokesig === expected.sig
279+
# the invoke match is `expected` for `expected->sig`, unless `expected` is invalid
280+
minworld = expected.primary_world
281+
maxworld = expected.deleted_world
282+
@assert minworld world
283+
if maxworld < world
284+
maxworld = 0
285+
end
286+
else
287+
minworld = 1
288+
maxworld = typemax(UInt)
289+
mt = Base.get_methodtable(expected)
290+
if mt === nothing
291+
maxworld = 0
292+
else
293+
matched, valid_worlds = Base.Compiler._findsup(invokesig, mt, world)
294+
minworld, maxworld = valid_worlds.min_world, valid_worlds.max_world
295+
if matched === nothing
296+
maxworld = 0
297+
elseif matched.method != expected
298+
maxworld = 0
299+
end
300+
end
301+
end
302+
return minworld, maxworld
303+
end
304+
305+
end # module StaticData

0 commit comments

Comments
 (0)