Skip to content

Commit e485be8

Browse files
authored
bpart: Start tracking backedges for bindings (#57213)
This PR adds limited backedge support for Bindings. There are two classes of bindings that get backedges: 1. Cross-module `GlobalRef` bindings (new in this PR) 2. Any globals accesses through intrinsics (i.e. those with forward edges from #57009) This is a time/space trade-off for invalidation. As a result of the first category, invalidating a binding now only needs to scan all the methods defined in the same module as the binding. At the same time, it is anticipated that most binding references are to bindings in the same module, keeping the list of bindings that need explicit (back)edges small.
1 parent dbc6681 commit e485be8

23 files changed

+311
-112
lines changed

Compiler/src/Compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ using Base: @_foldable_meta, @_gc_preserve_begin, @_gc_preserve_end, @nospeciali
6767
partition_restriction, quoted, rename_unionall, rewrap_unionall, specialize_method,
6868
structdiff, tls_world_age, unconstrain_vararg_length, unionlen, uniontype_layout,
6969
uniontypes, unsafe_convert, unwrap_unionall, unwrapva, vect, widen_diagonal,
70-
_uncompressed_ir
70+
_uncompressed_ir, maybe_add_binding_backedge!
7171
using Base.Order
7272

7373
import Base: ==, _topmod, append!, convert, copy, copy!, findall, first, get, get!,

Compiler/src/abstractinterpretation.jl

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(fun
208208
end
209209
if const_edge !== nothing
210210
edge = const_edge
211+
update_valid_age!(sv, world_range(const_edge))
211212
end
212213
end
213214

@@ -2330,6 +2331,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
23302331
end
23312332
if const_edge !== nothing
23322333
edge = const_edge
2334+
update_valid_age!(sv, world_range(const_edge))
23332335
end
23342336
end
23352337
rt = from_interprocedural!(interp, rt, sv, arginfo′, sig)
@@ -2396,8 +2398,9 @@ function abstract_eval_getglobal(interp::AbstractInterpreter, sv::AbsIntState, s
23962398
if M isa Const && s isa Const
23972399
M, s = M.val, s.val
23982400
if M isa Module && s isa Symbol
2399-
(ret, bpart) = abstract_eval_globalref(interp, GlobalRef(M, s), saw_latestworld, sv)
2400-
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(bpart))
2401+
gr = GlobalRef(M, s)
2402+
(ret, bpart) = abstract_eval_globalref(interp, gr, saw_latestworld, sv)
2403+
return CallMeta(ret, bpart === nothing ? NoCallInfo() : GlobalAccessInfo(convert(Core.Binding, gr), bpart))
24012404
end
24022405
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
24032406
elseif !hasintersect(widenconst(M), Module) || !hasintersect(widenconst(s), Symbol)
@@ -2475,8 +2478,9 @@ function abstract_eval_setglobal!(interp::AbstractInterpreter, sv::AbsIntState,
24752478
if isa(M, Const) && isa(s, Const)
24762479
M, s = M.val, s.val
24772480
if M isa Module && s isa Symbol
2478-
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, GlobalRef(M, s), v)
2479-
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(partition))
2481+
gr = GlobalRef(M, s)
2482+
(rt, exct), partition = global_assignment_rt_exct(interp, sv, saw_latestworld, gr, v)
2483+
return CallMeta(rt, exct, Effects(setglobal!_effects, nothrow=exct===Bottom), GlobalAccessInfo(convert(Core.Binding, gr), partition))
24802484
end
24812485
return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
24822486
end
@@ -2564,14 +2568,15 @@ function abstract_eval_replaceglobal!(interp::AbstractInterpreter, sv::AbsIntSta
25642568
M, s = M.val, s.val
25652569
M isa Module || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
25662570
s isa Symbol || return CallMeta(Union{}, TypeError, EFFECTS_THROWS, NoCallInfo())
2567-
partition = abstract_eval_binding_partition!(interp, GlobalRef(M, s), sv)
2571+
gr = GlobalRef(M, s)
2572+
partition = abstract_eval_binding_partition!(interp, gr, sv)
25682573
rte = abstract_eval_partition_load(interp, partition)
25692574
if binding_kind(partition) == BINDING_KIND_GLOBAL
25702575
T = partition_restriction(partition)
25712576
end
25722577
exct = Union{rte.exct, global_assignment_binding_rt_exct(interp, partition, v)[2]}
25732578
effects = merge_effects(rte.effects, Effects(setglobal!_effects, nothrow=exct===Bottom))
2574-
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(partition))
2579+
sg = CallMeta(Any, exct, effects, GlobalAccessInfo(convert(Core.Binding, gr), partition))
25752580
else
25762581
sg = abstract_eval_setglobal!(interp, sv, saw_latestworld, M, s, v)
25772582
end
@@ -2791,6 +2796,7 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
27912796
end
27922797
if const_edge !== nothing
27932798
edge = const_edge
2799+
update_valid_age!(sv, world_range(const_edge))
27942800
end
27952801
end
27962802
end
@@ -3225,7 +3231,8 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
32253231
end
32263232

32273233
effects = EFFECTS_TOTAL
3228-
partition = lookup_binding_partition!(interp, GlobalRef(mod, sym), sv)
3234+
gr = GlobalRef(mod, sym)
3235+
partition = lookup_binding_partition!(interp, gr, sv)
32293236
if allow_import !== true && is_some_imported(binding_kind(partition))
32303237
if allow_import === false
32313238
rt = Const(false)
@@ -3243,7 +3250,7 @@ function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, mod::Module,
32433250
effects = Effects(generic_isdefinedglobal_effects, nothrow=true)
32443251
end
32453252
end
3246-
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(partition))
3253+
return CallMeta(RTEffects(rt, Union{}, effects), GlobalAccessInfo(convert(Core.Binding, gr), partition))
32473254
end
32483255

32493256
function abstract_eval_isdefinedglobal(interp::AbstractInterpreter, @nospecialize(M), @nospecialize(s), @nospecialize(allow_import_arg), @nospecialize(order_arg), saw_latestworld::Bool, sv::AbsIntState)
@@ -3454,6 +3461,7 @@ end
34543461

34553462
world_range(ir::IRCode) = ir.valid_worlds
34563463
world_range(ci::CodeInfo) = WorldRange(ci.min_world, ci.max_world)
3464+
world_range(ci::CodeInstance) = WorldRange(ci.min_world, ci.max_world)
34573465
world_range(compact::IncrementalCompact) = world_range(compact.ir)
34583466

34593467
function force_binding_resolution!(g::GlobalRef, world::UInt)

Compiler/src/bootstrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,6 @@ function activate!(; reflection=true, codegen=false)
9191
Base.REFLECTION_COMPILER[] = Compiler
9292
end
9393
if codegen
94-
activate_codegen!()
94+
bootstrap!()
9595
end
9696
end

Compiler/src/ssair/verify.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ function check_op(ir::IRCode, domtree::DomTree, @nospecialize(op), use_bb::Int,
6767
imported_binding = partition_restriction(bpart)::Core.Binding
6868
bpart = lookup_binding_partition(min_world(ir.valid_worlds), imported_binding)
6969
end
70-
if !is_defined_const_binding(binding_kind(bpart)) || (bpart.max_world < max_world(ir.valid_worlds))
70+
if (!is_defined_const_binding(binding_kind(bpart)) || (bpart.max_world < max_world(ir.valid_worlds))) &&
71+
(op.mod !== Core) && (op.mod !== Base)
72+
# Core and Base are excluded because the frontend uses them for intrinsics, etc.
73+
# TODO: Decide which way to go with these.
7174
@verify_error "Unbound or partitioned GlobalRef not allowed in value position"
7275
raise_error()
7376
end

Compiler/src/stmtinfo.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,12 @@ Represents access to a global through runtime reflection, rather than as a manif
489489
perform such accesses.
490490
"""
491491
struct GlobalAccessInfo <: CallInfo
492+
b::Core.Binding
492493
bpart::Core.BindingPartition
493494
end
494-
GlobalAccessInfo(::Nothing) = NoCallInfo()
495-
add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo) =
496-
push!(edges, info.bpart)
495+
GlobalAccessInfo(::Core.Binding, ::Nothing) = NoCallInfo()
496+
function add_edges_impl(edges::Vector{Any}, info::GlobalAccessInfo)
497+
push!(edges, info.b)
498+
end
497499

498500
@specialize

Compiler/src/typeinfer.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,9 @@ function store_backedges(caller::CodeInstance, edges::SimpleVector)
544544
# ignore `Method`-edges (from e.g. failed `abstract_call_method`)
545545
i += 1
546546
continue
547-
elseif isa(item, Core.BindingPartition)
547+
elseif isa(item, Core.Binding)
548548
i += 1
549+
maybe_add_binding_backedge!(item, caller)
549550
continue
550551
end
551552
if isa(item, CodeInstance)

Compiler/test/ssair.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ let code = Any[
134134
Expr(:boundscheck),
135135
Compiler.GotoIfNot(SSAValue(1), 6),
136136
# block 2
137-
Expr(:call, GlobalRef(Base, :size), Compiler.Argument(3)),
137+
Expr(:call, size, Compiler.Argument(3)),
138138
Compiler.ReturnNode(),
139139
# block 3
140140
Core.PhiNode(),

base/Base.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ let os = ccall(:jl_get_UNAME, Any, ())
3131
end
3232
end
3333

34+
# subarrays
35+
include("subarray.jl")
36+
include("views.jl")
37+
3438
# numeric operations
3539
include("hashing.jl")
3640
include("rounding.jl")

base/Base_compiler.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,8 +231,6 @@ include("indices.jl")
231231
include("genericmemory.jl")
232232
include("array.jl")
233233
include("abstractarray.jl")
234-
include("subarray.jl")
235-
include("views.jl")
236234
include("baseext.jl")
237235

238236
include("c.jl")

base/invalidation.jl

Lines changed: 97 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
3535
visit(mt) || return false
3636
end
3737
end
38-
elseif isa(v, Module) && v !== m && parentmodule(v) === m && _nameof(v) === name
39-
# this is the original/primary binding for the submodule
40-
foreach_module_mtable(visit, v, world) || return false
4138
elseif isa(v, Core.MethodTable) && v.module === m && v.name === name
4239
# this is probably an external method table here, so let's
4340
# assume so as there is no way to precisely distinguish them
@@ -48,83 +45,138 @@ function foreach_module_mtable(visit, m::Module, world::UInt)
4845
return true
4946
end
5047

51-
function foreach_reachable_mtable(visit, world::UInt)
52-
visit(TYPE_TYPE_MT) || return
53-
visit(NONFUNCTION_MT) || return
54-
for mod in loaded_modules_array()
55-
foreach_module_mtable(visit, mod, world)
48+
function foreachgr(visit, src::CodeInfo)
49+
stmts = src.code
50+
for i = 1:length(stmts)
51+
stmt = stmts[i]
52+
isa(stmt, GlobalRef) && visit(stmt)
53+
for ur in Compiler.userefs(stmt)
54+
arg = ur[]
55+
isa(arg, GlobalRef) && visit(arg)
56+
end
5657
end
5758
end
5859

59-
function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
60-
found_any = false
61-
labelchangemap = nothing
60+
function anygr(visit, src::CodeInfo)
6261
stmts = src.code
63-
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
64-
isgr(g) = false
6562
for i = 1:length(stmts)
6663
stmt = stmts[i]
67-
if isgr(stmt)
68-
found_any = true
64+
if isa(stmt, GlobalRef)
65+
visit(stmt) && return true
6966
continue
7067
end
7168
for ur in Compiler.userefs(stmt)
7269
arg = ur[]
73-
# If any of the GlobalRefs in this stmt match the one that
74-
# we are about, we need to move out all GlobalRefs to preserve
75-
# effect order, in case we later invalidate a different GR
76-
if isa(arg, GlobalRef)
77-
if isgr(arg)
78-
@assert !isa(stmt, PhiNode)
79-
found_any = true
80-
break
81-
end
82-
end
70+
isa(arg, GlobalRef) && visit(arg) && return true
8371
end
8472
end
85-
return found_any
73+
return false
74+
end
75+
76+
function should_invalidate_code_for_globalref(gr::GlobalRef, src::CodeInfo)
77+
isgr(g::GlobalRef) = gr.mod == g.mod && gr.name === g.name
78+
isgr(g) = false
79+
return anygr(isgr, src)
8680
end
8781

88-
function scan_edge_list(ci::Core.CodeInstance, bpart::Core.BindingPartition)
82+
function scan_edge_list(ci::Core.CodeInstance, binding::Core.Binding)
8983
isdefined(ci, :edges) || return false
9084
edges = ci.edges
9185
i = 1
9286
while i <= length(edges)
93-
if isassigned(edges, i) && edges[i] === bpart
87+
if isassigned(edges, i) && edges[i] === binding
9488
return true
9589
end
9690
i += 1
9791
end
9892
return false
9993
end
10094

95+
function invalidate_method_for_globalref!(gr::GlobalRef, method::Method, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
96+
if isdefined(method, :source)
97+
src = _uncompressed_ir(method)
98+
binding = convert(Core.Binding, gr)
99+
old_stmts = src.code
100+
invalidate_all = should_invalidate_code_for_globalref(gr, src)
101+
for mi in specializations(method)
102+
isdefined(mi, :cache) || continue
103+
ci = mi.cache
104+
while true
105+
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, binding))
106+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
107+
end
108+
isdefined(ci, :next) || break
109+
ci = ci.next
110+
end
111+
end
112+
end
113+
end
114+
101115
function invalidate_code_for_globalref!(gr::GlobalRef, invalidated_bpart::Core.BindingPartition, new_max_world::UInt)
102116
try
103117
valid_in_valuepos = false
104-
foreach_reachable_mtable(new_max_world) do mt::Core.MethodTable
118+
foreach_module_mtable(gr.mod, new_max_world) do mt::Core.MethodTable
105119
for method in MethodList(mt)
106-
if isdefined(method, :source)
107-
src = _uncompressed_ir(method)
108-
old_stmts = src.code
109-
invalidate_all = should_invalidate_code_for_globalref(gr, src)
110-
for mi in specializations(method)
111-
isdefined(mi, :cache) || continue
112-
ci = mi.cache
113-
while true
114-
if ci.max_world > new_max_world && (invalidate_all || scan_edge_list(ci, invalidated_bpart))
115-
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), ci, new_max_world)
116-
end
117-
isdefined(ci, :next) || break
118-
ci = ci.next
119-
end
120-
end
121-
end
120+
invalidate_method_for_globalref!(gr, method, invalidated_bpart, new_max_world)
122121
end
123122
return true
124123
end
124+
b = convert(Core.Binding, gr)
125+
if isdefined(b, :backedges)
126+
for edge in b.backedges
127+
if isa(edge, CodeInstance)
128+
ccall(:jl_invalidate_code_instance, Cvoid, (Any, UInt), edge, new_max_world)
129+
else
130+
invalidate_method_for_globalref!(gr, edge::Method, invalidated_bpart, new_max_world)
131+
end
132+
end
133+
end
125134
catch err
126135
bt = catch_backtrace()
127136
invokelatest(Base.println, "Internal Error during invalidation:")
128137
invokelatest(Base.display_error, err, bt)
129138
end
130139
end
140+
141+
gr_needs_backedge_in_module(gr::GlobalRef, mod::Module) = gr.mod !== mod
142+
143+
# N.B.: This needs to match jl_maybe_add_binding_backedge
144+
function maybe_add_binding_backedge!(b::Core.Binding, edge::Union{Method, CodeInstance})
145+
method = isa(edge, Method) ? edge : edge.def.def::Method
146+
gr_needs_backedge_in_module(b.globalref, method.module) || return
147+
if !isdefined(b, :backedges)
148+
b.backedges = Any[]
149+
end
150+
!isempty(b.backedges) && b.backedges[end] === edge && return
151+
push!(b.backedges, edge)
152+
end
153+
154+
function binding_was_invalidated(b::Core.Binding)
155+
# At least one partition is required for invalidation
156+
!isdefined(b, :partitions) && return false
157+
b.partitions.min_world > unsafe_load(cglobal(:jl_require_world, UInt))
158+
end
159+
160+
function scan_new_method!(methods_with_invalidated_source::IdSet{Method}, method::Method)
161+
isdefined(method, :source) || return
162+
src = _uncompressed_ir(method)
163+
mod = method.module
164+
foreachgr(src) do gr::GlobalRef
165+
b = convert(Core.Binding, gr)
166+
binding_was_invalidated(b) && push!(methods_with_invalidated_source, method)
167+
maybe_add_binding_backedge!(b, method)
168+
end
169+
end
170+
171+
function scan_new_methods(extext_methods::Vector{Any}, internal_methods::Vector{Any})
172+
methods_with_invalidated_source = IdSet{Method}()
173+
for method in internal_methods
174+
if isa(method, Method)
175+
scan_new_method!(methods_with_invalidated_source, method)
176+
end
177+
end
178+
for tme::Core.TypeMapEntry in extext_methods
179+
scan_new_method!(methods_with_invalidated_source, tme.func::Method)
180+
end
181+
return methods_with_invalidated_source
182+
end

0 commit comments

Comments
 (0)