From f6badbd0b9944471ed43cd6fd9ad173fbb405839 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sat, 24 Sep 2022 08:02:13 -0500 Subject: [PATCH 1/3] Record invoke-appropriate targets This fixes backedge-based invalidation when a precompiled `invoke` is followed by loading a package that adds new specializations for the `invoke`d method. An example is LowRankApprox.jl, where FillArrays adds a specialization to `unique`. --- src/dump.c | 87 ++++++++++++++++++++++++++++---------------- src/julia_internal.h | 1 + test/precompile.jl | 20 ++++++++++ 3 files changed, 77 insertions(+), 31 deletions(-) diff --git a/src/dump.c b/src/dump.c index c8fe2b731920a..995a5c4868003 100644 --- a/src/dump.c +++ b/src/dump.c @@ -1366,6 +1366,7 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets) jl_value_t *invokeTypes; jl_method_instance_t *c; size_t i; + size_t world = jl_get_world_counter(); void **table = edges_map.table; // edges is caller => callees size_t table_size = edges_map.size; for (i = 0; i < table_size; i += 2) { @@ -1408,15 +1409,28 @@ static void jl_collect_backedges(jl_array_t *edges, jl_array_t *ext_targets) size_t min_valid = 0; size_t max_valid = ~(size_t)0; int ambig = 0; - jl_value_t *matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig); - if (matches == jl_false) { - valid = 0; - break; - } - size_t k; - for (k = 0; k < jl_array_len(matches); k++) { - jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k); - jl_array_ptr_set(matches, k, match->method); + jl_value_t *matches; + if (mode == 2 && callee && jl_is_method_instance(callee) && jl_is_type(sig)) { + // invoke, use subtyping + jl_methtable_t *mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method); + size_t min_world, max_world; + matches = jl_gf_invoke_lookup_worlds(sig, (jl_value_t*)mt, world, &min_world, &max_world); + if (matches == jl_nothing) { + valid = 0; + break; + } + matches = (jl_value_t*)((jl_method_match_t*)matches)->method; + } else { + matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig); + if (matches == jl_false) { + valid = 0; + break; + } + size_t k; + for (k = 0; k < jl_array_len(matches); k++) { + jl_method_match_t *match = (jl_method_match_t *)jl_array_ptr_ref(matches, k); + jl_array_ptr_set(matches, k, match->method); + } } jl_array_ptr_1d_push(ext_targets, mode == 1 ? NULL : sig); jl_array_ptr_1d_push(ext_targets, callee); @@ -2544,6 +2558,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) jl_value_t *loctag = NULL, *matches = NULL; JL_GC_PUSH2(&loctag, &matches); *pvalids = valids; + size_t world = jl_get_world_counter(); for (i = 0; i < l; i++) { jl_value_t *invokesig = jl_array_ptr_ref(targets, i * 3); jl_value_t *callee = jl_array_ptr_ref(targets, i * 3 + 1); @@ -2555,33 +2570,43 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) else { sig = callee == NULL ? invokesig : callee; } - jl_array_t *expected = (jl_array_t*)jl_array_ptr_ref(targets, i * 3 + 2); - assert(jl_is_array(expected)); + jl_value_t *expected = jl_array_ptr_ref(targets, i * 3 + 2); int valid = 1; size_t min_valid = 0; size_t max_valid = ~(size_t)0; int ambig = 0; - // TODO: possibly need to included ambiguities too (for the optimizer correctness)? - matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig); - if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) { - valid = 0; - } - else { - size_t j, k, l = jl_array_len(expected); - for (k = 0; k < jl_array_len(matches); k++) { - jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k); - jl_method_t *m = match->method; - for (j = 0; j < l; j++) { - if (m == (jl_method_t*)jl_array_ptr_ref(expected, j)) + int use_invoke = invokesig == NULL || callee == NULL ? 0 : 1; + if (!use_invoke) { + // TODO: possibly need to included ambiguities too (for the optimizer correctness)? + matches = jl_matching_methods((jl_tupletype_t*)sig, jl_nothing, -1, 0, jl_atomic_load_acquire(&jl_world_counter), &min_valid, &max_valid, &ambig); + if (matches == jl_false || jl_array_len(matches) != jl_array_len(expected)) { + valid = 0; + } + else { + assert(jl_is_array(expected)); + size_t j, k, l = jl_array_len(expected); + for (k = 0; k < jl_array_len(matches); k++) { + jl_method_match_t *match = (jl_method_match_t*)jl_array_ptr_ref(matches, k); + jl_method_t *m = match->method; + for (j = 0; j < l; j++) { + if (m == (jl_method_t*)jl_array_ptr_ref(expected, j)) + break; + } + if (j == l) { + // intersection has a new method or a method was + // deleted--this is now probably no good, just invalidate + // everything about it now + valid = 0; break; + } } - if (j == l) { - // intersection has a new method or a method was - // deleted--this is now probably no good, just invalidate - // everything about it now - valid = 0; - break; - } + } + } else { + jl_methtable_t *mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method); + size_t min_world, max_world; + matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, world, &min_world, &max_world); + if (matches == jl_nothing || expected != (jl_value_t*)((jl_method_match_t*)matches)->method) { + valid = 0; } } jl_array_uint8_set(valids, i, valid); @@ -2593,7 +2618,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag); loctag = jl_box_uint64(jl_worklist_key(serializer_worklist)); jl_array_ptr_1d_push(_jl_debug_method_invalidation, loctag); - if (matches != jl_false) { + if (!use_invoke && matches != jl_false) { // setdiff!(matches, expected) size_t j, k, ins = 0; for (j = 0; j < jl_array_len(matches); j++) { diff --git a/src/julia_internal.h b/src/julia_internal.h index 635a14b6a2f26..91315b13f9d08 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -716,6 +716,7 @@ jl_value_t *jl_gf_invoke_by_method(jl_method_t *method, jl_value_t *gf, jl_value jl_value_t *jl_gf_invoke(jl_value_t *types, jl_value_t *f, jl_value_t **args, size_t nargs); JL_DLLEXPORT jl_value_t *jl_matching_methods(jl_tupletype_t *types, jl_value_t *mt, int lim, int include_ambiguous, size_t world, size_t *min_valid, size_t *max_valid, int *ambig); +JL_DLLEXPORT jl_value_t *jl_gf_invoke_lookup_worlds(jl_value_t *types, jl_value_t *mt, size_t world, size_t *min_world, size_t *max_world); JL_DLLEXPORT jl_datatype_t *jl_first_argument_datatype(jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_value_t *jl_argument_datatype(jl_value_t *argt JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; diff --git a/test/precompile.jl b/test/precompile.jl index f6936197917a8..b0ef7593cf739 100644 --- a/test/precompile.jl +++ b/test/precompile.jl @@ -931,6 +931,7 @@ precompile_test_harness("invoke") do dir module $InvokeModule export f, g, h, q, fnc, gnc, hnc, qnc # nc variants do not infer to a Const export f44320, g44320 + export getlast # f is for testing invoke that occurs within a dependency f(x::Real) = 0 f(x::Int) = x < 5 ? 1 : invoke(f, Tuple{Real}, x) @@ -954,6 +955,16 @@ precompile_test_harness("invoke") do dir f44320(::Any) = 2 g44320() = invoke(f44320, Tuple{Any}, 0) g44320() + + # Adding new specializations should not invalidate `invoke`s + function getlast(itr) + x = nothing + for y in itr + x = y + end + return x + end + getlast(a::AbstractArray) = invoke(getlast, Tuple{Any}, a) end """) write(joinpath(dir, "$CallerModule.jl"), @@ -981,6 +992,8 @@ precompile_test_harness("invoke") do dir # Issue #44320 f44320(::Real) = 3 + call_getlast(x) = getlast(x) + # force precompilation begin Base.Experimental.@force_compile @@ -996,6 +1009,7 @@ precompile_test_harness("invoke") do dir callqnci(3) internal(3) internalnc(3) + call_getlast([1,2,3]) end # Now that we've precompiled, invalidate with a new method that overrides the `invoke` dispatch @@ -1007,6 +1021,9 @@ precompile_test_harness("invoke") do dir end """) Base.compilecache(Base.PkgId(string(CallerModule))) + @eval using $InvokeModule: $InvokeModule + MI = getfield(@__MODULE__, InvokeModule) + @eval $MI.getlast(a::UnitRange) = a.stop @eval using $CallerModule M = getfield(@__MODULE__, CallerModule) @@ -1060,6 +1077,9 @@ precompile_test_harness("invoke") do dir m = only(methods(M.g44320)) @test m.specializations[1].cache.max_world == typemax(UInt) + m = which(MI.getlast, (Any,)) + @test m.specializations[1].cache.max_world == typemax(UInt) + # Precompile specific methods for arbitrary arg types invokeme(x) = 1 invokeme(::Int) = 2 From 7deb9a2737d1e2622c64029063e17b99b856fa6e Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Mon, 26 Sep 2022 15:15:18 -0500 Subject: [PATCH 2/3] GC-protect mt --- src/dump.c | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/dump.c b/src/dump.c index 995a5c4868003..354ba4c704312 100644 --- a/src/dump.c +++ b/src/dump.c @@ -2556,7 +2556,8 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) jl_array_t *valids = jl_alloc_array_1d(jl_array_uint8_type, l); memset(jl_array_data(valids), 1, l); jl_value_t *loctag = NULL, *matches = NULL; - JL_GC_PUSH2(&loctag, &matches); + jl_methtable_t *mt = NULL; + JL_GC_PUSH3(&loctag, &matches, &mt); *pvalids = valids; size_t world = jl_get_world_counter(); for (i = 0; i < l; i++) { @@ -2602,7 +2603,7 @@ static void jl_verify_edges(jl_array_t *targets, jl_array_t **pvalids) } } } else { - jl_methtable_t *mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method); + mt = jl_method_get_table(((jl_method_instance_t*)callee)->def.method); size_t min_world, max_world; matches = jl_gf_invoke_lookup_worlds(invokesig, (jl_value_t*)mt, world, &min_world, &max_world); if (matches == jl_nothing || expected != (jl_value_t*)((jl_method_match_t*)matches)->method) { From 439ca3ee5802a25c269adcb44777eab0cb1bacc8 Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Tue, 27 Sep 2022 04:23:10 -0500 Subject: [PATCH 3/3] Add ASAN annotation to jl_method_get_table --- src/julia_internal.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/julia_internal.h b/src/julia_internal.h index 91315b13f9d08..0b505e4234f5c 100644 --- a/src/julia_internal.h +++ b/src/julia_internal.h @@ -723,7 +723,7 @@ JL_DLLEXPORT jl_value_t *jl_argument_datatype(jl_value_t *argt JL_PROPAGATES_ROO JL_DLLEXPORT jl_methtable_t *jl_method_table_for( jl_value_t *argtypes JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; JL_DLLEXPORT jl_methtable_t *jl_method_get_table( - jl_method_t *method) JL_NOTSAFEPOINT; + jl_method_t *method JL_PROPAGATES_ROOT) JL_NOTSAFEPOINT; jl_methtable_t *jl_argument_method_table(jl_value_t *argt JL_PROPAGATES_ROOT); JL_DLLEXPORT int jl_pointer_egal(jl_value_t *t);