Skip to content

Commit 1d3cdec

Browse files
committed
Add backedges for Cassette
Adds an extra field in `CodeInfo` that allows users (Cassette, et al.) to specify dependencies (as forward edges to `MethodInstance`s) that should be turned into backedges once the CodeInfo is passed over by inference. The test includes a minimal implementation of the Cassette mechansim to exercise this code path.
1 parent ad42d5d commit 1d3cdec

File tree

8 files changed

+139
-8
lines changed

8 files changed

+139
-8
lines changed

base/compiler/inferencestate.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,9 @@ mutable struct InferenceState
8787
inmodule = linfo.def::Module
8888
end
8989

90-
min_valid = UInt(1)
91-
max_valid = get_world_counter()
90+
min_valid = src.min_world
91+
max_valid = src.max_world == typemax(UInt) ?
92+
get_world_counter() : src.max_world
9293
frame = new(
9394
params, result, linfo,
9495
sp, slottypes, inmodule, 0,

base/compiler/typeinfer.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,15 @@ function store_backedges(frame::InferenceState)
191191
end
192192
end
193193
end
194+
edges = frame.src.edges
195+
if edges !== nothing
196+
edges = edges::Vector{MethodInstance}
197+
for edge in edges
198+
@assert isa(edge, MethodInstance)
199+
ccall(:jl_method_instance_add_backedge, Cvoid, (Any, Any), edge, caller)
200+
end
201+
frame.src.edges = nothing
202+
end
194203
end
195204
end
196205

src/jltypes.c

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2069,7 +2069,7 @@ void jl_init_types(void) JL_GC_DISABLED
20692069
jl_code_info_type =
20702070
jl_new_datatype(jl_symbol("CodeInfo"), core,
20712071
jl_any_type, jl_emptysvec,
2072-
jl_perm_symsvec(17,
2072+
jl_perm_symsvec(18,
20732073
"code",
20742074
"codelocs",
20752075
"ssavaluetypes",
@@ -2081,13 +2081,14 @@ void jl_init_types(void) JL_GC_DISABLED
20812081
"slottypes",
20822082
"rettype",
20832083
"parent",
2084+
"edges",
20842085
"min_world",
20852086
"max_world",
20862087
"inferred",
20872088
"inlineable",
20882089
"propagate_inbounds",
20892090
"pure"),
2090-
jl_svec(17,
2091+
jl_svec(18,
20912092
jl_array_any_type,
20922093
jl_any_type,
20932094
jl_any_type,
@@ -2099,13 +2100,14 @@ void jl_init_types(void) JL_GC_DISABLED
20992100
jl_any_type,
21002101
jl_any_type,
21012102
jl_any_type,
2103+
jl_any_type,
21022104
jl_ulong_type,
21032105
jl_ulong_type,
21042106
jl_bool_type,
21052107
jl_bool_type,
21062108
jl_bool_type,
21072109
jl_bool_type),
2108-
0, 1, 17);
2110+
0, 1, 18);
21092111

21102112
jl_method_type =
21112113
jl_new_datatype(jl_symbol("Method"), core,

src/julia.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ typedef struct _jl_code_info_t {
254254
jl_value_t *slottypes; // inferred types of slots
255255
jl_value_t *rettype;
256256
jl_method_instance_t *parent; // context (optionally, if available, otherwise nothing)
257+
jl_value_t *edges; // forward edges to method instances that must be invalidated
257258
size_t min_world;
258259
size_t max_world;
259260
// various boolean properties:

src/method.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ JL_DLLEXPORT jl_code_info_t *jl_new_code_info_uninit(void)
320320
src->inlineable = 0;
321321
src->propagate_inbounds = 0;
322322
src->pure = 0;
323+
src->edges = jl_nothing;
323324
return src;
324325
}
325326

stdlib/Serialization/src/Serialization.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1015,8 +1015,15 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
10151015
ci.slottypes = deserialize(s)
10161016
ci.rettype = deserialize(s)
10171017
ci.parent = deserialize(s)
1018-
ci.min_world = reinterpret(UInt, deserialize(s))
1019-
ci.max_world = reinterpret(UInt, deserialize(s))
1018+
world_or_edges = deserialize(s)
1019+
pre_13 = isa(world_or_edges, Integer)
1020+
if pre_13
1021+
ci.min_world = world_or_edges
1022+
else
1023+
ci.edges = world_or_edges
1024+
ci.min_world = reinterpret(UInt, deserialize(s))
1025+
ci.max_world = reinterpret(UInt, deserialize(s))
1026+
end
10201027
end
10211028
ci.inferred = deserialize(s)
10221029
ci.inlineable = deserialize(s)

test/choosetests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ function choosetests(choices = [])
108108
end
109109

110110
compilertests = ["compiler/inference", "compiler/validation", "compiler/ssair", "compiler/irpasses",
111-
"compiler/codegen", "compiler/inline"]
111+
"compiler/codegen", "compiler/inline", "compiler/contextual"]
112112

113113
if "compiler" in skip_tests
114114
filter!(x -> (x != "compiler" && !(x in compilertests)), tests)

test/compiler/contextual.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
module MiniCassette
2+
# A minimal demonstration of the cassette mechanism. Doesn't support all the
3+
# fancy features, but sufficient to exercise this code path in the compiler.
4+
5+
using Core.Compiler: method_instances, retrieve_code_info, CodeInfo,
6+
MethodInstance, SSAValue, GotoNode, Slot, SlotNumber, quoted,
7+
signature_type
8+
using Base: _methods_by_ftype
9+
using Base.Meta: isexpr
10+
using Test
11+
12+
export Ctx, overdub
13+
14+
struct Ctx; end
15+
16+
# A no-op cassette-like transform
17+
function transform_expr(expr, map_slot_number, map_ssa_value, sparams)
18+
transform(expr) = transform_expr(expr, map_slot_number, map_ssa_value, sparams)
19+
if isexpr(expr, :call)
20+
return Expr(:call, overdub, SlotNumber(2), map(transform, expr.args)...)
21+
elseif isexpr(expr, :gotoifnot)
22+
return Expr(:gotoifnot, transform(expr.args[1]), map_ssa_value(SSAValue(expr.args[2])).id)
23+
elseif isexpr(expr, :static_parameter)
24+
return quoted(sparams[expr.args[1]])
25+
elseif isa(expr, Expr)
26+
return Expr(expr.head, map(transform, expr.args)...)
27+
elseif isa(expr, GotoNode)
28+
return GotoNode(map_ssa_value(SSAValue(expr.label)).id)
29+
elseif isa(expr, Slot)
30+
return map_slot_number(expr.id)
31+
elseif isa(expr, SSAValue)
32+
return map_ssa_value(expr)
33+
else
34+
return expr
35+
end
36+
end
37+
38+
function transform!(ci, nargs, sparams)
39+
code = ci.code
40+
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...]
41+
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...]
42+
# Insert one SSAValue for every argument statement
43+
prepend!(code, [Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs])
44+
prepend!(ci.codelocs, [0 for i = 1:nargs])
45+
ci.ssavaluetypes += nargs
46+
function map_slot_number(slot)
47+
if slot == 1
48+
# self in the original function is now `f`
49+
return SlotNumber(3)
50+
elseif 2 <= slot <= nargs + 1
51+
# Arguments get inserted as ssa values at the top of the function
52+
return SSAValue(slot - 1)
53+
else
54+
# The first non-argument slot will be 5
55+
return SlotNumber(slot - (nargs + 1) + 4)
56+
end
57+
end
58+
map_ssa_value(ssa::SSAValue) = SSAValue(ssa.id + nargs)
59+
for i = (nargs+1:length(code))
60+
code[i] = transform_expr(code[i], map_slot_number, map_ssa_value, sparams)
61+
end
62+
end
63+
64+
function overdub_generator(self, c, f, args)
65+
if f <: Core.Builtin || !isdefined(f, :instance)
66+
return :(return f(args...))
67+
end
68+
69+
tt = Tuple{f, args...}
70+
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
71+
@assert length(mthds) == 1
72+
mtypes, msp, m = mthds[1]
73+
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
74+
# Unsupported in this mini-cassette
75+
@assert !mi.def.isva
76+
code_info = retrieve_code_info(mi)
77+
@assert isa(code_info, CodeInfo)
78+
code_info = copy(code_info)
79+
if isdefined(code_info, :edges)
80+
code_info.edges = MethodInstance[mi]
81+
end
82+
transform!(code_info, length(args), msp)
83+
code_info
84+
end
85+
86+
@eval function overdub(c::Ctx, f, args...)
87+
$(Expr(:meta, :generated_only))
88+
$(Expr(:meta,
89+
:generated,
90+
Expr(:new,
91+
Core.GeneratedFunctionStub,
92+
:overdub_generator,
93+
Any[:overdub, :ctx, :f, :args],
94+
Any[],
95+
@__LINE__,
96+
QuoteNode(Symbol(@__FILE__)),
97+
true)))
98+
end
99+
end
100+
101+
using .MiniCassette
102+
103+
# Test #265 for Cassette
104+
f() = 1
105+
@test overdub(Ctx(), f) === 1
106+
f() = 2
107+
@test overdub(Ctx(), f) === 2
108+
109+
# Test that MiniCassette is at least somewhat capable by overdubbing gcd
110+
@test overdub(Ctx(), gcd, 10, 20) === gcd(10, 20)

0 commit comments

Comments
 (0)