Skip to content

Commit 2049baa

Browse files
add entry point to construct an OpaqueClosure from pre-optimized IRCode (#44197)
* add entry point to construct an OpaqueClosure from pre-optimized IRCode * update `jl_new_codeinst` signature * fixes to OpaqueClosure argument count handling and MethodError display * more test coverage Co-authored-by: Shuhei Kadowaki <[email protected]>
1 parent 9320fba commit 2049baa

File tree

5 files changed

+134
-24
lines changed

5 files changed

+134
-24
lines changed

base/errorshow.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,11 @@ function show_method_candidates(io::IO, ex::MethodError, @nospecialize kwargs=()
409409
buf = IOBuffer()
410410
iob0 = iob = IOContext(buf, io)
411411
tv = Any[]
412-
sig0 = method.sig
412+
if func isa Core.OpaqueClosure
413+
sig0 = signature_type(func, typeof(func).parameters[1])
414+
else
415+
sig0 = method.sig
416+
end
413417
while isa(sig0, UnionAll)
414418
push!(tv, sig0.var)
415419
iob = IOContext(iob, :unionall_env => sig0.var)

base/methodshow.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ end
7979

8080
# NOTE: second argument is deprecated and is no longer used
8181
function kwarg_decl(m::Method, kwtype = nothing)
82+
if m.sig === Tuple # OpaqueClosure
83+
return Symbol[]
84+
end
8285
mt = get_methodtable(m)
8386
if isdefined(mt, :kwsorter)
8487
kwtype = typeof(mt.kwsorter)

src/method.c

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ extern jl_value_t *jl_builtin_getfield;
1919
extern jl_value_t *jl_builtin_tuple;
2020

2121
jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
22-
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
22+
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
2323

2424
static void check_c_types(const char *where, jl_value_t *rt, jl_value_t *at)
2525
{
@@ -51,11 +51,14 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
5151
return jl_module_globalref(module, (jl_sym_t*)expr);
5252
}
5353
else if (jl_is_returnnode(expr)) {
54-
jl_value_t *val = resolve_globals(jl_returnnode_value(expr), module, sparam_vals, binding_effects, eager_resolve);
55-
if (val != jl_returnnode_value(expr)) {
56-
JL_GC_PUSH1(&val);
57-
expr = jl_new_struct(jl_returnnode_type, val);
58-
JL_GC_POP();
54+
jl_value_t *retval = jl_returnnode_value(expr);
55+
if (retval) {
56+
jl_value_t *val = resolve_globals(retval, module, sparam_vals, binding_effects, eager_resolve);
57+
if (val != retval) {
58+
JL_GC_PUSH1(&val);
59+
expr = jl_new_struct(jl_returnnode_type, val);
60+
JL_GC_POP();
61+
}
5962
}
6063
return expr;
6164
}
@@ -102,7 +105,7 @@ static jl_value_t *resolve_globals(jl_value_t *expr, jl_module_t *module, jl_sve
102105
if (!jl_is_code_info(ci)) {
103106
jl_error("opaque_closure_method: lambda should be a CodeInfo");
104107
}
105-
jl_method_t *m = jl_make_opaque_closure_method(module, name, nargs, functionloc, (jl_code_info_t*)ci, isva);
108+
jl_method_t *m = jl_make_opaque_closure_method(module, name, jl_unbox_long(nargs), functionloc, (jl_code_info_t*)ci, isva);
106109
return (jl_value_t*)m;
107110
}
108111
if (e->head == jl_cfunction_sym) {
@@ -782,7 +785,7 @@ JL_DLLEXPORT jl_method_t *jl_new_method_uninit(jl_module_t *module)
782785
// method definition ----------------------------------------------------------
783786

784787
jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
785-
jl_value_t *nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
788+
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva)
786789
{
787790
jl_method_t *m = jl_new_method_uninit(module);
788791
JL_GC_PUSH1(&m);
@@ -796,7 +799,7 @@ jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name
796799
assert(jl_is_symbol(name));
797800
m->name = (jl_sym_t*)name;
798801
}
799-
m->nargs = jl_unbox_long(nargs) + 1;
802+
m->nargs = nargs + 1;
800803
assert(jl_is_linenode(functionloc));
801804
jl_value_t *file = jl_linenode_file(functionloc);
802805
m->file = jl_is_symbol(file) ? (jl_sym_t*)file : jl_empty_sym;

src/opaque_closure.c

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,23 @@ JL_DLLEXPORT int jl_is_valid_oc_argtype(jl_tupletype_t *argt, jl_method_t *sourc
2222
return 1;
2323
}
2424

25-
jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
26-
jl_value_t *source_, jl_value_t **env, size_t nenv)
25+
static jl_value_t *prepend_type(jl_value_t *t0, jl_tupletype_t *t)
26+
{
27+
jl_svec_t *sig_args = NULL;
28+
JL_GC_PUSH1(&sig_args);
29+
size_t nsig = 1 + jl_svec_len(t->parameters);
30+
sig_args = jl_alloc_svec_uninit(nsig);
31+
jl_svecset(sig_args, 0, t0);
32+
for (size_t i = 0; i < nsig-1; ++i) {
33+
jl_svecset(sig_args, 1+i, jl_tparam(t, i));
34+
}
35+
jl_value_t *sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
36+
JL_GC_POP();
37+
return sigtype;
38+
}
39+
40+
static jl_opaque_closure_t *new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
41+
jl_value_t *source_, jl_value_t *captures)
2742
{
2843
if (!jl_is_tuple_type((jl_value_t*)argt)) {
2944
jl_error("OpaqueClosure argument tuple must be a tuple type");
@@ -40,26 +55,19 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
4055
}
4156
if (jl_nparams(argt) + 1 - jl_is_va_tuple(argt) < source->nargs - source->isva)
4257
jl_error("Argument type tuple has too few required arguments for method");
43-
jl_task_t *ct = jl_current_task;
58+
jl_value_t *sigtype = NULL;
59+
JL_GC_PUSH1(&sigtype);
60+
sigtype = prepend_type(jl_typeof(captures), argt);
61+
4462
jl_value_t *oc_type JL_ALWAYS_LEAFTYPE;
4563
oc_type = jl_apply_type2((jl_value_t*)jl_opaque_closure_type, (jl_value_t*)argt, rt_ub);
4664
JL_GC_PROMISE_ROOTED(oc_type);
47-
jl_value_t *captures = NULL, *sigtype = NULL;
48-
jl_svec_t *sig_args = NULL;
49-
JL_GC_PUSH3(&captures, &sigtype, &sig_args);
50-
captures = jl_f_tuple(NULL, env, nenv);
5165

52-
size_t nsig = 1 + jl_svec_len(argt->parameters);
53-
sig_args = jl_alloc_svec_uninit(nsig);
54-
jl_svecset(sig_args, 0, jl_typeof(captures));
55-
for (size_t i = 0; i < nsig-1; ++i) {
56-
jl_svecset(sig_args, 1+i, jl_tparam(argt, i));
57-
}
58-
sigtype = (jl_value_t*)jl_apply_tuple_type_v(jl_svec_data(sig_args), nsig);
5966
jl_method_instance_t *mi = jl_specializations_get_linfo(source, sigtype, jl_emptysvec);
6067
size_t world = jl_atomic_load_acquire(&jl_world_counter);
6168
jl_code_instance_t *ci = jl_compile_method_internal(mi, world);
6269

70+
jl_task_t *ct = jl_current_task;
6371
jl_opaque_closure_t *oc = (jl_opaque_closure_t*)jl_gc_alloc(ct->ptls, sizeof(jl_opaque_closure_t), oc_type);
6472
JL_GC_POP();
6573
oc->source = source;
@@ -82,6 +90,52 @@ jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_
8290
return oc;
8391
}
8492

93+
jl_opaque_closure_t *jl_new_opaque_closure(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
94+
jl_value_t *source_, jl_value_t **env, size_t nenv)
95+
{
96+
jl_value_t *captures = jl_f_tuple(NULL, env, nenv);
97+
JL_GC_PUSH1(&captures);
98+
jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, source_, captures);
99+
JL_GC_POP();
100+
return oc;
101+
}
102+
103+
jl_method_t *jl_make_opaque_closure_method(jl_module_t *module, jl_value_t *name,
104+
int nargs, jl_value_t *functionloc, jl_code_info_t *ci, int isva);
105+
106+
JL_DLLEXPORT jl_code_instance_t* jl_new_codeinst(
107+
jl_method_instance_t *mi, jl_value_t *rettype,
108+
jl_value_t *inferred_const, jl_value_t *inferred,
109+
int32_t const_flags, size_t min_world, size_t max_world,
110+
uint32_t ipo_effects, uint32_t effects, jl_value_t *argescapes,
111+
uint8_t relocatability);
112+
113+
JL_DLLEXPORT void jl_mi_cache_insert(jl_method_instance_t *mi JL_ROOTING_ARGUMENT,
114+
jl_code_instance_t *ci JL_ROOTED_ARGUMENT JL_MAYBE_UNROOTED);
115+
116+
JL_DLLEXPORT jl_opaque_closure_t *jl_new_opaque_closure_from_code_info(jl_tupletype_t *argt, jl_value_t *rt_lb, jl_value_t *rt_ub,
117+
jl_module_t *mod, jl_code_info_t *ci, int lineno, jl_value_t *file, int nargs, int isva, jl_value_t *env)
118+
{
119+
if (!ci->inferred)
120+
jl_error("CodeInfo must already be inferred");
121+
jl_value_t *root = NULL, *sigtype = NULL;
122+
jl_code_instance_t *inst = NULL;
123+
JL_GC_PUSH3(&root, &sigtype, &inst);
124+
root = jl_box_long(lineno);
125+
root = jl_new_struct(jl_linenumbernode_type, root, file);
126+
root = (jl_value_t*)jl_make_opaque_closure_method(mod, jl_nothing, nargs, root, ci, isva);
127+
128+
sigtype = prepend_type(jl_typeof(env), argt);
129+
jl_method_instance_t *mi = jl_specializations_get_linfo((jl_method_t*)root, sigtype, jl_emptysvec);
130+
inst = jl_new_codeinst(mi, rt_ub, NULL, (jl_value_t*)ci,
131+
0, ((jl_method_t*)root)->primary_world, -1, 0, 0, jl_nothing, 0);
132+
jl_mi_cache_insert(mi, inst);
133+
134+
jl_opaque_closure_t *oc = new_opaque_closure(argt, rt_lb, rt_ub, root, env);
135+
JL_GC_POP();
136+
return oc;
137+
}
138+
85139
JL_CALLABLE(jl_new_opaque_closure_jlcall)
86140
{
87141
if (nargs < 4)

test/opaque_closure.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,3 +239,49 @@ end
239239
let oc = @opaque a->sin(a)
240240
@test length(code_typed(oc, (Int,))) == 1
241241
end
242+
243+
# constructing an opaque closure from IRCode
244+
using Core.Compiler: IRCode
245+
using Core: CodeInfo
246+
247+
function OC(ir::IRCode, nargs::Int, isva::Bool, env...)
248+
if (isva && nargs > length(ir.argtypes)) || (!isva && nargs != length(ir.argtypes)-1)
249+
throw(ArgumentError("invalid argument count"))
250+
end
251+
src = ccall(:jl_new_code_info_uninit, Ref{CodeInfo}, ())
252+
src.slotflags = UInt8[]
253+
src.slotnames = fill(:none, nargs+1)
254+
Core.Compiler.replace_code_newstyle!(src, ir, nargs+1)
255+
Core.Compiler.widen_all_consts!(src)
256+
src.inferred = true
257+
# NOTE: we need ir.argtypes[1] == typeof(env)
258+
259+
ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
260+
Tuple{ir.argtypes[2:end]...}, Union{}, Any, @__MODULE__, src, 0, nothing, nargs, isva, env)
261+
end
262+
263+
function OC(src::CodeInfo, env...)
264+
M = src.parent.def
265+
sig = Base.tuple_type_tail(src.parent.specTypes)
266+
267+
ccall(:jl_new_opaque_closure_from_code_info, Any, (Any, Any, Any, Any, Any, Cint, Any, Cint, Cint, Any),
268+
sig, Union{}, Any, @__MODULE__, src, 0, nothing, M.nargs - 1, M.isva, env)
269+
end
270+
271+
let ci = code_typed(+, (Int, Int))[1][1]
272+
ir = Core.Compiler.inflate_ir(ci)
273+
@test OC(ir, 2, false)(40, 2) == 42
274+
@test OC(ci)(40, 2) == 42
275+
end
276+
277+
let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
278+
ir = Core.Compiler.inflate_ir(ci)
279+
@test OC(ir, 2, true)(40, 2) === (40, (2,))
280+
@test OC(ci)(40, 2) === (40, (2,))
281+
end
282+
283+
let ci = code_typed((x, y...)->(x, y), (Int, Int))[1][1]
284+
ir = Core.Compiler.inflate_ir(ci)
285+
@test_throws MethodError OC(ir, 2, true)(1, 2, 3)
286+
@test_throws MethodError OC(ci)(1, 2, 3)
287+
end

0 commit comments

Comments
 (0)