Skip to content

Commit a12e1cc

Browse files
committed
Align module base between invalidation and edge tracking
Our implicit edge tracking for bindings does not explicitly store any edges for bindings in the *current* module. The idea behind this is that this is a good time-space tradeoff for validation, because substantially all binding references in a module will be to its defining module, while the total number of methods within a module is limited and substantially smaller than the total number of methods in the entire system. However, we have an issue where the code that stores these edges and the invalidation code disagree on which module is the *current* one. The edge storing code was using the module in which the method was defined, while the invalidation code was using the one in which the MethodTable is defined. With these being misaligned, we can miss necessary invalidations. Both options are in principle possible, but I think the former is better, because the module in which the method is defined is also the module that we are likely to have a lot of references to (since they get referenced implicitly by just writing symbols in the code). However, this presents a problem: We don't actually have a way to iterate all the methods defined in a particular module, without just doing the brute force thing of scanning all methods and filtering. To address this, build on the deferred scanning code added in #57615 to also add any scanned modules to an explicit list in `Module`. This costs some space, but only proportional to the number of defined methods, (and thus proportional to the written source code). Note that we don't actually observe any issues in the test suite on master due to this bug. However, this is because we are grossly over-invalidating, which hides the missing invalidations from this issue (#57617).
1 parent 44975e1 commit a12e1cc

File tree

7 files changed

+52
-13
lines changed

7 files changed

+52
-13
lines changed

base/invalidation.jl

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,10 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
136136

137137
if need_to_invalidate_code
138138
if (b.flags & BINDING_FLAG_ANY_IMPLICIT_EDGES) != 0
139-
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
140-
for method in MethodList(mt)
141-
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
142-
end
143-
return true
139+
nmethods = ccall(:jl_module_scanned_methods_length, Csize_t, (Any,), gr.mod)
140+
for i = 1:nmethods
141+
method = ccall(:jl_module_scanned_methods_getindex, Any, (Any, Csize_t), gr.mod, i)::Method
142+
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
144143
end
145144
end
146145
if isdefined(b, :backedges)
@@ -166,7 +165,7 @@ function invalidate_code_for_globalref!(b::Core.Binding, invalidated_bpart::Core
166165
# have a binding that is affected by this change.
167166
usings_backedges = ccall(:jl_get_module_usings_backedges, Any, (Any,), gr.mod)
168167
if usings_backedges !== nothing
169-
for user in usings_backedges::Vector{Any}
168+
for user::Module in usings_backedges::Vector{Any}
170169
user_binding = ccall(:jl_get_module_binding_or_nothing, Any, (Any, Any), user, gr.name)
171170
user_binding === nothing && continue
172171
isdefined(user_binding, :partitions) || continue

src/gc-stock.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2147,6 +2147,9 @@ STATIC_INLINE void gc_mark_module_binding(jl_ptls_t ptls, jl_module_t *mb_parent
21472147
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->usings_backedges);
21482148
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->usings_backedges, &nptr);
21492149
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->usings_backedges);
2150+
gc_assert_parent_validity((jl_value_t *)mb_parent, (jl_value_t *)mb_parent->scanned_methods);
2151+
gc_try_claim_and_push(mq, (jl_value_t *)mb_parent->scanned_methods, &nptr);
2152+
gc_heap_snapshot_record_binding_partition_edge((jl_value_t*)mb_parent, mb_parent->scanned_methods);
21502153
size_t nusings = module_usings_length(mb_parent);
21512154
if (nusings > 0) {
21522155
// this is only necessary because bindings for "using" modules

src/julia.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,7 @@ typedef struct _jl_module_t {
782782
jl_sym_t *file;
783783
int32_t line;
784784
jl_value_t *usings_backedges;
785+
jl_value_t *scanned_methods;
785786
// hidden fields:
786787
arraylist_t usings; /* arraylist of struct jl_module_using */ // modules with all bindings potentially imported
787788
jl_uuid_t build_id;
@@ -2059,6 +2060,7 @@ JL_DLLEXPORT int jl_get_module_infer(jl_module_t *m);
20592060
JL_DLLEXPORT void jl_set_module_max_methods(jl_module_t *self, int value);
20602061
JL_DLLEXPORT int jl_get_module_max_methods(jl_module_t *m);
20612062
JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m);
2063+
JL_DLLEXPORT jl_value_t *jl_get_module_scanned_methods(jl_module_t *m);
20622064
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s);
20632065

20642066
// get binding for reading

src/julia_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ jl_code_info_t *jl_new_code_info_from_ir(jl_expr_t *ast);
722722
JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void);
723723
JL_DLLEXPORT void jl_resolve_definition_effects_in_ir(jl_array_t *stmts, jl_module_t *m, jl_svec_t *sparam_vals, jl_value_t *binding_edge,
724724
int binding_effects);
725-
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
725+
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge);
726726
JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge);
727727

728728
int get_next_edge(jl_array_t *list, int i, jl_value_t** invokesig, jl_code_instance_t **caller) JL_NOTSAFEPOINT;

src/method.c

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,20 @@ static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
3939
}
4040
}
4141

42+
static void jl_add_scanned_method(jl_module_t *m, jl_method_t *meth)
43+
{
44+
JL_LOCK(&m->lock);
45+
if (m->scanned_methods == jl_nothing) {
46+
m->scanned_methods = (jl_value_t*)jl_alloc_vec_any(0);
47+
jl_gc_wb(m, m->scanned_methods);
48+
}
49+
jl_array_ptr_1d_push((jl_array_t*)m->scanned_methods, (jl_value_t*)meth);
50+
JL_UNLOCK(&m->lock);
51+
}
52+
4253
JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
4354
{
44-
if (!jl_atomic_load_relaxed(&m->did_scan_source)) {
55+
if (!jl_atomic_fetch_or(&m->did_scan_source, 1)) {
4556
jl_code_info_t *code = NULL;
4657
JL_GC_PUSH1(&code);
4758
if (!jl_is_code_info(src))
@@ -50,13 +61,15 @@ JL_DLLEXPORT void jl_scan_method_source_now(jl_method_t *m, jl_value_t *src)
5061
code = (jl_code_info_t*)src;
5162
jl_array_t *stmts = code->code;
5263
size_t i, l = jl_array_nrows(stmts);
64+
int any_implicit = 0;
5365
for (i = 0; i < l; i++) {
5466
jl_value_t *stmt = jl_array_ptr_ref(stmts, i);
5567
if (jl_is_globalref(stmt)) {
56-
jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
68+
any_implicit |= jl_maybe_add_binding_backedge((jl_globalref_t*)stmt, m->module, (jl_value_t*)m);
5769
}
5870
}
59-
jl_atomic_store_relaxed(&m->did_scan_source, 1);
71+
if (any_implicit)
72+
jl_add_scanned_method(m->module, m);
6073
JL_GC_POP();
6174
}
6275
}

src/module.c

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,7 @@ JL_DLLEXPORT jl_module_t *jl_new_module__(jl_sym_t *name, jl_module_t *parent)
319319
m->build_id.hi = ~(uint64_t)0;
320320
jl_atomic_store_relaxed(&m->counter, 1);
321321
m->usings_backedges = jl_nothing;
322+
m->scanned_methods = jl_nothing;
322323
m->nospecialize = 0;
323324
m->optlevel = -1;
324325
m->compile = -1;
@@ -1163,6 +1164,22 @@ JL_DLLEXPORT jl_value_t *jl_get_module_usings_backedges(jl_module_t *m)
11631164
return m->usings_backedges;
11641165
}
11651166

1167+
JL_DLLEXPORT size_t jl_module_scanned_methods_length(jl_module_t *m)
1168+
{
1169+
JL_LOCK(&m->lock);
1170+
size_t len = jl_array_len(m->scanned_methods);
1171+
JL_UNLOCK(&m->lock);
1172+
return len;
1173+
}
1174+
1175+
JL_DLLEXPORT jl_value_t *jl_module_scanned_methods_getindex(jl_module_t *m, size_t i)
1176+
{
1177+
JL_LOCK(&m->lock);
1178+
jl_value_t *ret = jl_array_ptr_ref(m->scanned_methods, i-1);
1179+
JL_UNLOCK(&m->lock);
1180+
return ret;
1181+
}
1182+
11661183
JL_DLLEXPORT jl_value_t *jl_get_module_binding_or_nothing(jl_module_t *m, jl_sym_t *s)
11671184
{
11681185
jl_binding_t *b = jl_get_module_binding(m, s, 0);
@@ -1369,10 +1386,10 @@ JL_DLLEXPORT void jl_add_binding_backedge(jl_binding_t *b, jl_value_t *edge)
13691386

13701387
// Called for all GlobalRefs found in lowered code. Adds backedges for cross-module
13711388
// GlobalRefs.
1372-
JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
1389+
JL_DLLEXPORT int jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t *defining_module, jl_value_t *edge)
13731390
{
13741391
if (!edge)
1375-
return;
1392+
return 0;
13761393
jl_binding_t *b = gr->binding;
13771394
if (!b)
13781395
b = jl_get_module_binding(gr->mod, gr->name, 1);
@@ -1381,9 +1398,10 @@ JL_DLLEXPORT void jl_maybe_add_binding_backedge(jl_globalref_t *gr, jl_module_t
13811398
if (gr->mod == defining_module) {
13821399
// No backedge required - invalidation will forward scan
13831400
jl_atomic_fetch_or(&b->flags, BINDING_FLAG_ANY_IMPLICIT_EDGES);
1384-
return;
1401+
return 1;
13851402
}
13861403
jl_add_binding_backedge(b, edge);
1404+
return 0;
13871405
}
13881406

13891407
JL_DLLEXPORT jl_binding_partition_t *jl_replace_binding_locked(jl_binding_t *b,

src/staticdata.c

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,7 @@ static void jl_queue_module_for_serialization(jl_serializer_state *s, jl_module_
812812
}
813813

814814
jl_queue_for_serialization(s, m->usings_backedges);
815+
jl_queue_for_serialization(s, m->scanned_methods);
815816
}
816817

817818
// Anything that requires uniquing or fixing during deserialization needs to be "toplevel"
@@ -1324,6 +1325,9 @@ static void jl_write_module(jl_serializer_state *s, uintptr_t item, jl_module_t
13241325
newm->usings_backedges = NULL;
13251326
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, usings_backedges)));
13261327
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->usings_backedges, s->link_ids_relocs));
1328+
newm->scanned_methods = NULL;
1329+
arraylist_push(&s->relocs_list, (void*)(reloc_offset + offsetof(jl_module_t, scanned_methods)));
1330+
arraylist_push(&s->relocs_list, (void*)backref_id(s, m->scanned_methods, s->link_ids_relocs));
13271331

13281332
// After reload, everything that has happened in this process happened semantically at
13291333
// (for .incremental) or before jl_require_world, so reset this flag.

0 commit comments

Comments
 (0)