Skip to content

Commit 851d9c0

Browse files
committed
Clean up test case
1 parent 647e06b commit 851d9c0

File tree

1 file changed

+94
-69
lines changed

1 file changed

+94
-69
lines changed

test/compiler/contextual.jl

Lines changed: 94 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,110 @@
1-
using Core.Compiler: method_instances, retrieve_code_info, CodeInfo,
2-
MethodInstance, SSAValue, GotoNode, Slot, SlotNumber
3-
using Base.Meta: isexpr
4-
using Test
1+
module MiniCassette
2+
# A minimal demonstration of the cassette mechanism. Doesn't support all the
3+
# fancy features, but sufficient to exercise this code path in the compiler.
54

6-
struct Ctx; end
5+
using Core.Compiler: method_instances, retrieve_code_info, CodeInfo,
6+
MethodInstance, SSAValue, GotoNode, Slot, SlotNumber, quoted,
7+
signature_type
8+
using Base: _methods_by_ftype
9+
using Base.Meta: isexpr
10+
using Test
711

8-
# A no-op cassette-like transform
9-
function transform_expr(expr, map_slot_number, map_ssa_value)
10-
transform(expr) = transform_expr(expr, map_slot_number, map_ssa_value)
11-
if isexpr(expr, :call)
12-
return Expr(:call, overdub, SlotNumber(1), map(transform, expr.args)...)
13-
elseif isexpr(expr, :gotoifnot)
14-
return Expr(:gotoifnot, transform(expr), map_ssa_value(SSAValue(expr.args[2])).id)
15-
elseif isa(expr, GotoNode)
16-
return GotoNode(map_ssa_value(SSAValue(expr.label)).id)
17-
elseif isa(expr, Slot)
18-
return map_slot_number(expr.id)
19-
elseif isa(expr, SSAValue)
20-
return map_ssa_value(expr)
21-
else
22-
return expr
23-
end
24-
end
12+
export Ctx, overdub
2513

26-
function transform!(ci, nargs)
27-
code = ci.code
28-
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...]
29-
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...]
30-
# Insert one SSAValue for every argument statement
31-
for i = 1:nargs
32-
pushfirst!(code, Expr(:getfield, SlotNumber(3), i))
33-
end
34-
function map_slot_number(slot)
35-
if slot == 1
36-
# self in the original function is now `f`
37-
return SlotNumber(3)
38-
elseif 2 <= slot <= nargs + 1
39-
# Arguments get inserted as ssa values at the top of the function
40-
return SSAValue(slot - 2)
14+
struct Ctx; end
15+
16+
# A no-op cassette-like transform
17+
function transform_expr(expr, map_slot_number, map_ssa_value, sparams)
18+
transform(expr) = transform_expr(expr, map_slot_number, map_ssa_value, sparams)
19+
if isexpr(expr, :call)
20+
return Expr(:call, overdub, SlotNumber(2), map(transform, expr.args)...)
21+
elseif isexpr(expr, :gotoifnot)
22+
return Expr(:gotoifnot, transform(expr.args[1]), map_ssa_value(SSAValue(expr.args[2])).id)
23+
elseif isexpr(expr, :static_parameter)
24+
return quoted(sparams[expr.args[1]])
25+
elseif isa(expr, Expr)
26+
return Expr(expr.head, map(transform, expr.args)...)
27+
elseif isa(expr, GotoNode)
28+
return GotoNode(map_ssa_value(SSAValue(expr.label)).id)
29+
elseif isa(expr, Slot)
30+
return map_slot_number(expr.id)
31+
elseif isa(expr, SSAValue)
32+
return map_ssa_value(expr)
4133
else
42-
# The first non-argument slot will be 5
43-
return SlotNumber(slot - (nargs + 1) + 4)
34+
return expr
4435
end
4536
end
46-
map_ssa_value(ssa::SSAValue) = SSAValue(ssa.id + nargs)
47-
for i = (nargs+1:length(code))
48-
code[i] = transform_expr(code[i], map_slot_number, map_ssa_value)
37+
38+
function transform!(ci, nargs, sparams)
39+
code = ci.code
40+
ci.slotnames = Symbol[Symbol("#self#"), :ctx, :f, :args, ci.slotnames[nargs+1:end]...]
41+
ci.slotflags = UInt8[(0x00 for i = 1:4)..., ci.slotflags[nargs+1:end]...]
42+
# Insert one SSAValue for every argument statement
43+
prepend!(code, [Expr(:call, getfield, SlotNumber(4), i) for i = 1:nargs])
44+
prepend!(ci.codelocs, [0 for i = 1:nargs])
45+
ci.ssavaluetypes += nargs
46+
function map_slot_number(slot)
47+
if slot == 1
48+
# self in the original function is now `f`
49+
return SlotNumber(3)
50+
elseif 2 <= slot <= nargs + 1
51+
# Arguments get inserted as ssa values at the top of the function
52+
return SSAValue(slot - 1)
53+
else
54+
# The first non-argument slot will be 5
55+
return SlotNumber(slot - (nargs + 1) + 4)
56+
end
57+
end
58+
map_ssa_value(ssa::SSAValue) = SSAValue(ssa.id + nargs)
59+
for i = (nargs+1:length(code))
60+
code[i] = transform_expr(code[i], map_slot_number, map_ssa_value, sparams)
61+
end
4962
end
50-
end
5163

52-
function overdub_generator(self, c, f, args)
53-
mis = method_instances(f.instance, args)
54-
@assert length(mis) == 1
55-
mi = mis[1]
56-
# Unsupported in this mini-cassette
57-
@assert !mi.def.isva
58-
code_info = retrieve_code_info(mi)
59-
@assert isa(code_info, CodeInfo)
60-
code_info = copy(code_info)
61-
if isdefined(code_info, :edges)
62-
code_info.edges = MethodInstance[mi]
64+
function overdub_generator(self, c, f, args)
65+
if f <: Core.Builtin || !isdefined(f, :instance)
66+
return :(return f(args...))
67+
end
68+
69+
tt = Tuple{f, args...}
70+
mthds = _methods_by_ftype(tt, -1, typemax(UInt))
71+
@assert length(mthds) == 1
72+
mtypes, msp, m = mthds[1]
73+
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance}, (Any, Any, Any), m, mtypes, msp)
74+
# Unsupported in this mini-cassette
75+
@assert !mi.def.isva
76+
code_info = retrieve_code_info(mi)
77+
@assert isa(code_info, CodeInfo)
78+
code_info = copy(code_info)
79+
if isdefined(code_info, :edges)
80+
code_info.edges = MethodInstance[mi]
81+
end
82+
transform!(code_info, length(args), msp)
83+
code_info
6384
end
64-
transform!(code_info, length(args))
65-
code_info
66-
end
6785

68-
@eval function overdub(c::Ctx, f, args...)
69-
$(Expr(:meta, :generated_only))
70-
$(Expr(:meta,
71-
:generated,
72-
Expr(:new,
73-
Core.GeneratedFunctionStub,
74-
:overdub_generator,
75-
Any[:overdub, :ctx, :f, :args],
76-
Any[],
77-
@__LINE__,
78-
QuoteNode(Symbol(@__FILE__)),
79-
true)))
86+
@eval function overdub(c::Ctx, f, args...)
87+
$(Expr(:meta, :generated_only))
88+
$(Expr(:meta,
89+
:generated,
90+
Expr(:new,
91+
Core.GeneratedFunctionStub,
92+
:overdub_generator,
93+
Any[:overdub, :ctx, :f, :args],
94+
Any[],
95+
@__LINE__,
96+
QuoteNode(Symbol(@__FILE__)),
97+
true)))
98+
end
8099
end
81100

101+
using .MiniCassette
102+
103+
# Test #265 for Cassette
82104
f() = 1
83105
@test overdub(Ctx(), f) === 1
84106
f() = 2
85107
@test overdub(Ctx(), f) === 2
108+
109+
# Test that MiniCassette is at least somewhat capable by overdubbing gcd
110+
@test overdub(Ctx(), gcd, 10, 20) === gcd(10, 20)

0 commit comments

Comments
 (0)