Skip to content

Commit 8d2e0f3

Browse files
timholyvtjnash
authored andcommitted
Precompile correct invoke-targets (#46907)
This fixes backedge-based invalidation when a precompiled `invoke` is followed by loading a package that adds new specializations for the `invoke`d method. An example is LowRankApprox.jl, where FillArrays adds a specialization to `unique`. (cherry picked from commit 698beed)
1 parent 10288cc commit 8d2e0f3

File tree

2 files changed

+60
-33
lines changed

2 files changed

+60
-33
lines changed

src/dump.c

Lines changed: 58 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets)
13271327
jl_value_t *invokeTypes;
13281328
jl_method_instance_t *c;
13291329
size_t i;
1330+
size_t world = jl_get_world_counter();
13301331
void **table = edges_map.table; // edges is caller => callees
13311332
size_t table_size = edges_map.size;
13321333
for (i = 0; i < table_size; i += 2) {
@@ -1369,15 +1370,28 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets)
13691370
size_t min_valid = 0;
13701371
size_t max_valid = ~(size_t)0;
13711372
int ambig = 0;
1372-
jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
1373-
if (matches == jl_false) {
1374-
valid = 0;
1375-
break;
1376-
}
1377-
size_t k;
1378-
for (k = 0; k < jl_array_len(matches); k++) {
1379-
jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k);
1380-
jl_array_ptr_set(matches, k, match->method);
1373+
jl_value_t *matches;
1374+
if (mode == 2 && callee && jl_is_method_instance(callee) && jl_is_type(sig)) {
1375+
// invoke, use subtyping
1376+
jl_methtable_t *mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method);
1377+
size_t min_world, max_world;
1378+
matches = jl_gf_invoke_lookup_worlds(sig, (jl_value_t*)mt, world, &min_world, &max_world);
1379+
if (matches == jl_nothing) {
1380+
valid = 0;
1381+
break;
1382+
}
1383+
matches = (jl_value_t*)((jl_method_match_t*)matches)->method;
1384+
} else {
1385+
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
1386+
if (matches == jl_false) {
1387+
valid = 0;
1388+
break;
1389+
}
1390+
size_t k;
1391+
for (k = 0; k < jl_array_len(matches); k++) {
1392+
jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k);
1393+
jl_array_ptr_set(matches, k, match->method);
1394+
}
13811395
}
13821396
jl_array_ptr_1d_push(ext_targets, mode == 1 ? NULL : sig);
13831397
jl_array_ptr_1d_push(ext_targets, callee);
@@ -2495,8 +2509,10 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
24952509
jl_array_t *valids = jl_alloc_array_1d(jl_array_uint8_type, l);
24962510
memset(jl_array_data(valids), 1, l);
24972511
jl_value_t *loctag = NULL, *matches = NULL;
2498-
JL_GC_PUSH2(&loctag, &matches);
2512+
jl_methtable_t *mt = NULL;
2513+
JL_GC_PUSH3(&loctag, &matches, &mt);
24992514
*pvalids = valids;
2515+
size_t world = jl_get_world_counter();
25002516
for (i = 0; i < l; i++) {
25012517
jl_value_t *invokesig = jl_array_ptr_ref(targets, i * 3);
25022518
jl_value_t *callee = jl_array_ptr_ref(targets, i * 3 + 1);
@@ -2508,33 +2524,43 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
25082524
else {
25092525
sig = callee == NULL ? invokesig : callee;
25102526
}
2511-
jl_array_t *expected = (jl_array_t*)jl_array_ptr_ref(targets, i * 3 + 2);
2512-
assert(jl_is_array(expected));
2527+
jl_value_t *expected = jl_array_ptr_ref(targets, i * 3 + 2);
25132528
int valid = 1;
25142529
size_t min_valid = 0;
25152530
size_t max_valid = ~(size_t)0;
25162531
int ambig = 0;
2517-
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
2518-
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
2519-
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
2520-
valid = 0;
2521-
}
2522-
else {
2523-
size_t j, k, l = jl_array_len(expected);
2524-
for (k = 0; k < jl_array_len(matches); k++) {
2525-
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k);
2526-
jl_method_t *m = match->method;
2527-
for (j = 0; j < l; j++) {
2528-
if (m == (jl_method_t*)jl_array_ptr_ref(expected, j))
2532+
int use_invoke = invokesig == NULL || callee == NULL ? 0 : 1;
2533+
if (!use_invoke) {
2534+
// TODO: possibly need to included ambiguities too (for the optimizer correctness)?
2535+
matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig);
2536+
if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) {
2537+
valid = 0;
2538+
}
2539+
else {
2540+
assert(jl_is_array(expected));
2541+
size_t j, k, l = jl_array_len(expected);
2542+
for (k = 0; k < jl_array_len(matches); k++) {
2543+
jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k);
2544+
jl_method_t *m = match->method;
2545+
for (j = 0; j < l; j++) {
2546+
if (m == (jl_method_t*)jl_array_ptr_ref(expected, j))
2547+
break;
2548+
}
2549+
if (j == l) {
2550+
// intersection has a new method or a method was
2551+
// deleted--this is now probably no good, just invalidate
2552+
// everything about it now
2553+
valid = 0;
25292554
break;
2555+
}
25302556
}
2531-
if (j == l) {
2532-
// intersection has a new method or a method was
2533-
// deleted--this is now probably no good, just invalidate
2534-
// everything about it now
2535-
valid = 0;
2536-
break;
2537-
}
2557+
}
2558+
} else {
2559+
mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method);
2560+
size_t min_world, max_world;
2561+
matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, world, &min_world, &max_world);
2562+
if (matches == jl_nothing || expected != (jl_value_t*)((jl_method_match_t*)matches)->method) {
2563+
valid = 0;
25382564
}
25392565
}
25402566
jl_array_uint8_set(valids, i, valid);
@@ -2546,7 +2572,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids)
25462572
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
25472573
loctag = jl_box_uint64(jl_worklist_key(serializer_worklist));
25482574
jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag);
2549-
if (matches != jl_false) {
2575+
if (!use_invoke && matches != jl_false) {
25502576
// setdiff!(matches, expected)
25512577
size_t j, k, ins = 0;
25522578
for (j = 0; j < jl_array_len(matches); j++) {

src/julia_internal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -644,13 +644,14 @@ jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value
644644
jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t *f, jl_value_t **args, size_t nargs);
645645
JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *mt, int lim, int include_ambiguous,
646646
size_t world, size_t *min_valid, size_t *max_valid, int *ambig);
647+
JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world);
647648

648649
JL_DLLEXPORT jl_datatype_t *jl_first_argument_datatype(jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
649650
JL_DLLEXPORT jl_value_t *jl_argument_datatype(jl_value_t *argt JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
650651
JL_DLLEXPORT jl_methtable_t *jl_method_table_for(
651652
jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
652653
JL_DLLEXPORT jl_methtable_t *jl_method_get_table(
653-
jl_method_t *method) JL_NOTSAFEPOINT;
654+
jl_method_t *method JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT;
654655
jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAGATES_ROOT);
655656

656657
JL_DLLEXPORT int jl_pointer_egal(jl_value_t *t);

0 commit comments

Comments
 (0)