Skip to content

Commit dffb8cc

Browse files
committed
cfunction macro: extend cfunction capabilities
Provide static support for handling dynamic calls and closures
1 parent f88af9f commit dffb8cc

34 files changed

+1101
-401
lines changed

base/c.jl

Lines changed: 38 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,56 @@ respectively.
1616
"""
1717
cglobal
1818

19+
struct CFunction
20+
ptr::Ptr{Cvoid}
21+
f::Any
22+
_1::Ptr{Cvoid}
23+
_2::Ptr{Cvoid}
24+
let construtor = false end
25+
end
26+
unsafe_convert(::Type{Ptr{Cvoid}}, cf::CFunction) = cf.ptr
27+
1928
"""
20-
cfunction(f::Function, returntype::Type, argtypes::Type) -> Ptr{Cvoid}
29+
@cfunction(callable, ReturnType, (ArgumentTypes...,)) -> Ptr{Cvoid}
30+
@cfunction(\$callable, ReturnType, (ArgumentTypes...,)) -> CFunction
31+
32+
Generate a C-callable function pointer from the Julia function `closure`
33+
for the given type signature.
34+
35+
Note that the argument type tuple must be a literal tuple, and not a tuple-valued variable or expression
36+
(although it can include a splat expression). And that these arguments will be evaluated in global scope
37+
during compile-time (not deferred until runtime).
38+
Adding a `\$` in front of the function argument changes this to instead create a runtime closure
39+
over the local variable `callable`.
2140
22-
Generate C-callable function pointer from the Julia function `f`. Type annotation of the return
23-
value in the callback function is a must for situations where Julia cannot infer the return
24-
type automatically.
41+
See [manual section on ccall and cfunction usage](@ref Calling-C-and-Fortran-Code).
2542
2643
# Examples
2744
```julia-repl
2845
julia> function foo(x::Int, y::Int)
2946
return x + y
3047
end
3148
32-
julia> cfunction(foo, Int, Tuple{Int,Int})
49+
julia> @cfunction(foo, Int, (Int, Int))
3350
Ptr{Cvoid} @0x000000001b82fcd0
3451
```
3552
"""
36-
cfunction(f, r, a) = ccall(:jl_function_ptr, Ptr{Cvoid}, (Any, Any, Any), f, r, a)
53+
macro cfunction(f, at, rt)
54+
if !(isa(rt, Expr) && rt.head === :tuple)
55+
throw(ArgumentError("@cfunction argument types must be a literal tuple"))
56+
end
57+
rt.head = :call
58+
pushfirst!(rt.args, GlobalRef(Core, :svec))
59+
if isa(f, Expr) && f.head === :$
60+
fptr = f.args[1]
61+
typ = CFunction
62+
else
63+
fptr = QuoteNode(f)
64+
typ = Ptr{Cvoid}
65+
end
66+
cfun = Expr(:cfunction, typ, fptr, at, rt, QuoteNode(:ccall))
67+
return esc(cfun)
68+
end
3769

3870
if ccall(:jl_is_char_signed, Ref{Bool}, ())
3971
const Cchar = Int8

base/compiler/abstractinterpretation.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -663,8 +663,8 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt
663663
return abstract_call_gf_by_type(f, argtypes, atype, sv)
664664
end
665665

666-
function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState)
667-
argtypes = Any[abstract_eval(a, vtypes, sv) for a in e.args]
666+
# wrapper around `abstract_call` for first computing if `f` is available
667+
function abstract_eval_call(fargs::Union{Tuple{},Vector{Any}}, argtypes::Vector{Any}, vtypes::VarTable, sv::InferenceState)
668668
#print("call ", e.args[1], argtypes, "\n\n")
669669
for x in argtypes
670670
x === Bottom && return Bottom
@@ -689,7 +689,7 @@ function abstract_eval_call(e::Expr, vtypes::VarTable, sv::InferenceState)
689689
end
690690
return abstract_call_gf_by_type(nothing, argtypes, argtypes_to_type(argtypes), sv)
691691
end
692-
return abstract_call(f, e.args, argtypes, vtypes, sv)
692+
return abstract_call(f, fargs, argtypes, vtypes, sv)
693693
end
694694

695695
function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
@@ -730,6 +730,18 @@ function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
730730
return T
731731
end
732732

733+
function abstract_eval_cfunction(e::Expr, vtypes::VarTable, sv::InferenceState)
734+
f = abstract_eval(e.args[2], vtypes, sv)
735+
# rt = sp_type_rewrap(e.args[3], sv.linfo, true)
736+
at = Any[ sp_type_rewrap(argt, sv.linfo, false) for argt in e.args[4]::SimpleVector ]
737+
pushfirst!(at, f)
738+
# this may be the wrong world for the call,
739+
# but some of the result is likely to be valid anyways
740+
# and that may help generate better codegen
741+
abstract_eval_call((), at, vtypes, sv)
742+
nothing
743+
end
744+
733745
function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
734746
if isa(e, QuoteNode)
735747
return AbstractEvalConstant((e::QuoteNode).value)
@@ -748,7 +760,8 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
748760
end
749761
e = e::Expr
750762
if e.head === :call
751-
t = abstract_eval_call(e, vtypes, sv)
763+
argtypes = Any[ abstract_eval(a, vtypes, sv) for a in e.args ]
764+
t = abstract_eval_call(e.args, argtypes, vtypes, sv)
752765
elseif e.head === :new
753766
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
754767
for i = 2:length(e.args)
@@ -767,6 +780,10 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
767780
t = Bottom
768781
end
769782
end
783+
elseif e.head === :cfunction
784+
t = e.args[1]
785+
isa(t, Type) || (t = Any)
786+
abstract_eval_cfunction(e, vtypes, sv)
770787
elseif e.head === :static_parameter
771788
n = e.args[1]
772789
t = Any

base/compiler/optimize.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -769,17 +769,26 @@ function substitute!(
769769
head = e.head
770770
if head === :static_parameter
771771
return quoted(spvals[e.args[1]])
772+
elseif head === :cfunction
773+
@assert !isa(spsig, UnionAll) || !isempty(spvals)
774+
if !(e.args[2] isa QuoteNode) # very common no-op
775+
e.args[2] = substitute!(e.args[2], na, argexprs, spsig, spvals, offset, boundscheck)
776+
end
777+
e.args[3] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[3], spsig, spvals)
778+
e.args[4] = svec(Any[
779+
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
780+
for argt
781+
in e.args[4] ]...)
772782
elseif head === :foreigncall
773783
@assert !isa(spsig, UnionAll) || !isempty(spvals)
774784
for i = 1:length(e.args)
775785
if i == 2
776786
e.args[2] = ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), e.args[2], spsig, spvals)
777787
elseif i == 3
778-
argtuple = Any[
788+
e.args[3] = svec(Any[
779789
ccall(:jl_instantiate_type_in_env, Any, (Any, Any, Ptr{Any}), argt, spsig, spvals)
780790
for argt
781-
in e.args[3] ]
782-
e.args[3] = svec(argtuple...)
791+
in e.args[3] ]...)
783792
elseif i == 4
784793
@assert isa((e.args[4]::QuoteNode).value, Symbol)
785794
elseif i == 5

base/compiler/validation.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}(
2222
:meta => 0:typemax(Int),
2323
:global => 1:1,
2424
:foreigncall => 3:typemax(Int),
25+
:cfunction => 6:6,
2526
:isdefined => 1:1,
2627
:simdloop => 0:0,
2728
:gc_preserve_begin => 0:typemax(Int),
@@ -139,9 +140,11 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
139140
end
140141
validate_val!(x.args[1])
141142
elseif head === :call || head === :invoke || head == :gc_preserve_end || head === :meta ||
142-
head === :inbounds || head === :foreigncall || head === :const || head === :enter ||
143-
head === :leave || head === :method || head === :global || head === :static_parameter ||
144-
head === :new || head === :thunk || head === :simdloop || head === :throw_undef_if_not || head === :unreachable
143+
head === :inbounds || head === :foreigncall || head === :cfunction ||
144+
head === :const || head === :enter || head === :leave ||
145+
head === :method || head === :global || head === :static_parameter ||
146+
head === :new || head === :thunk || head === :simdloop ||
147+
head === :throw_undef_if_not || head === :unreachable
145148
validate_val!(x)
146149
else
147150
push!(errors, InvalidCodeError("invalid statement", x))
@@ -221,7 +224,7 @@ end
221224

222225
function is_valid_rvalue(lhs, x)
223226
is_valid_argument(x) && return true
224-
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :gc_preserve_begin)
227+
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin)
225228
return true
226229
# TODO: disallow `globalref = call` when .typ field is removed
227230
#return isa(lhs, SSAValue) || isa(lhs, Slot)

base/deprecated.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,12 @@ end
495495

496496
# PR #23066
497497
@deprecate cfunction(f, r, a::Tuple) cfunction(f, r, Tuple{a...})
498+
@noinline function cfunction(f, r, a)
499+
@nospecialize(f, r, a)
500+
depwarn("The function `cfunction` is now written as a macro `@cfunction`.", :cfunction)
501+
return ccall(:jl_function_ptr, Ptr{Cvoid}, (Any, Any, Any), f, r, a)
502+
end
503+
export cfunction
498504

499505
# PR 23341
500506
@eval GMP @deprecate gmp_version() version() false

base/exports.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,7 @@ export
924924
withenv,
925925

926926
# C interface
927-
cfunction,
927+
@cfunction,
928928
cglobal,
929929
disable_sigint,
930930
pointer,

base/libuv.jl

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,13 +85,20 @@ function process_events(block::Bool)
8585
end
8686
end
8787

88+
function uv_alloc_buf end
89+
function uv_readcb end
90+
function uv_writecb_task end
91+
function uv_return_spawn end
92+
function uv_asynccb end
93+
function uv_timercb end
94+
8895
function reinit_stdio()
89-
global uv_jl_alloc_buf = cfunction(uv_alloc_buf, Cvoid, Tuple{Ptr{Cvoid}, Csize_t, Ptr{Cvoid}})
90-
global uv_jl_readcb = cfunction(uv_readcb, Cvoid, Tuple{Ptr{Cvoid}, Cssize_t, Ptr{Cvoid}})
91-
global uv_jl_writecb_task = cfunction(uv_writecb_task, Cvoid, Tuple{Ptr{Cvoid}, Cint})
92-
global uv_jl_return_spawn = cfunction(uv_return_spawn, Cvoid, Tuple{Ptr{Cvoid}, Int64, Int32})
93-
global uv_jl_asynccb = cfunction(uv_asynccb, Cvoid, Tuple{Ptr{Cvoid}})
94-
global uv_jl_timercb = cfunction(uv_timercb, Cvoid, Tuple{Ptr{Cvoid}})
96+
global uv_jl_alloc_buf = @cfunction(uv_alloc_buf, Cvoid, (Ptr{Cvoid}, Csize_t, Ptr{Cvoid}))
97+
global uv_jl_readcb = @cfunction(uv_readcb, Cvoid, (Ptr{Cvoid}, Cssize_t, Ptr{Cvoid}))
98+
global uv_jl_writecb_task = @cfunction(uv_writecb_task, Cvoid, (Ptr{Cvoid}, Cint))
99+
global uv_jl_return_spawn = @cfunction(uv_return_spawn, Cvoid, (Ptr{Cvoid}, Int64, Int32))
100+
global uv_jl_asynccb = @cfunction(uv_asynccb, Cvoid, (Ptr{Cvoid},))
101+
global uv_jl_timercb = @cfunction(uv_timercb, Cvoid, (Ptr{Cvoid},))
95102

96103
global uv_eventloop = ccall(:jl_global_event_loop, Ptr{Cvoid}, ())
97104
global stdin = init_stdio(ccall(:jl_stdin_stream, Ptr{Cvoid}, ()))

base/threadcall.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@ const max_ccall_threads = parse(Int, get(ENV, "UV_THREADPOOL_SIZE", "4"))
44
const thread_notifiers = Union{Condition, Nothing}[nothing for i in 1:max_ccall_threads]
55
const threadcall_restrictor = Semaphore(max_ccall_threads)
66

7-
function notify_fun(idx)
8-
global thread_notifiers
9-
notify(thread_notifiers[idx])
10-
return
11-
end
12-
137
"""
148
@threadcall((cfunc, clib), rettype, (argtypes...), argvals...)
159
@@ -36,62 +30,71 @@ macro threadcall(f, rettype, argtypes, argvals...)
3630
argvals = map(esc, argvals)
3731

3832
# construct non-allocating wrapper to call C function
39-
wrapper = :(function wrapper(args_ptr::Ptr{Cvoid}, retval_ptr::Ptr{Cvoid})
33+
wrapper = :(function (args_ptr::Ptr{Cvoid}, retval_ptr::Ptr{Cvoid})
4034
p = args_ptr
35+
# the rest of the body is created below
4136
end)
4237
body = wrapper.args[2].args
4338
args = Symbol[]
44-
for (i,T) in enumerate(argtypes)
39+
for (i, T) in enumerate(argtypes)
4540
arg = Symbol("arg", i)
4641
push!(body, :($arg = unsafe_load(convert(Ptr{$T}, p))))
47-
push!(body, :(p += sizeof($T)))
42+
push!(body, :(p += Core.sizeof($T)))
4843
push!(args, arg)
4944
end
5045
push!(body, :(ret = ccall($f, $rettype, ($(argtypes...),), $(args...))))
5146
push!(body, :(unsafe_store!(convert(Ptr{$rettype}, retval_ptr), ret)))
52-
push!(body, :(return sizeof($rettype)))
47+
push!(body, :(return Int(Core.sizeof($rettype))))
5348

5449
# return code to generate wrapper function and send work request thread queue
55-
:(let
56-
$wrapper
57-
do_threadcall(wrapper, $rettype, Any[$(argtypes...)], Any[$(argvals...)])
50+
wrapper = Expr(Symbol("hygienic-scope"), wrapper, @__MODULE__)
51+
return :(let fun_ptr = @cfunction($wrapper, Int, (Ptr{Cvoid}, Ptr{Cvoid}))
52+
do_threadcall(fun_ptr, $rettype, Any[$(argtypes...)], Any[$(argvals...)])
5853
end)
5954
end
6055

61-
function do_threadcall(wrapper::Function, rettype::Type, argtypes::Vector, argvals::Vector)
56+
function do_threadcall(fun_ptr::Ptr{Cvoid}, rettype::Type, argtypes::Vector, argvals::Vector)
6257
# generate function pointer
63-
fun_ptr = cfunction(wrapper, Int, Tuple{Ptr{Cvoid}, Ptr{Cvoid}})
64-
c_notify_fun = cfunction(notify_fun, Cvoid, Tuple{Cint})
58+
c_notify_fun = @cfunction(
59+
function notify_fun(idx)
60+
global thread_notifiers
61+
notify(thread_notifiers[idx])
62+
return
63+
end, Cvoid, (Cint,))
6564

6665
# cconvert, root and unsafe_convert arguments
6766
roots = Any[]
68-
args_size = isempty(argtypes) ? 0 : sum(sizeof, argtypes)
67+
args_size = isempty(argtypes) ? 0 : sum(Core.sizeof, argtypes)
6968
args_arr = Vector{UInt8}(undef, args_size)
7069
ptr = pointer(args_arr)
7170
for (T, x) in zip(argtypes, argvals)
71+
isbits(T) || throw(ArgumentError("threadcall requires isbits argument types"))
7272
y = cconvert(T, x)
7373
push!(roots, y)
74-
unsafe_store!(convert(Ptr{T}, ptr), unsafe_convert(T, y))
75-
ptr += sizeof(T)
74+
unsafe_store!(convert(Ptr{T}, ptr), unsafe_convert(T, y)::T)
75+
ptr += Core.sizeof(T)
7676
end
7777

7878
# create return buffer
79-
ret_arr = Vector{UInt8}(undef, sizeof(rettype))
79+
ret_arr = Vector{UInt8}(undef, Core.sizeof(rettype))
8080

8181
# wait for a worker thread to be available
8282
acquire(threadcall_restrictor)
8383
idx = findfirst(isequal(nothing), thread_notifiers)::Int
8484
thread_notifiers[idx] = Condition()
8585

86-
# queue up the work to be done
87-
ccall(:jl_queue_work, Cvoid,
88-
(Ptr{Cvoid}, Ptr{UInt8}, Ptr{UInt8}, Ptr{Cvoid}, Cint),
89-
fun_ptr, args_arr, ret_arr, c_notify_fun, idx)
86+
GC.@preserve args_arr ret_arr roots begin
87+
# queue up the work to be done
88+
ccall(:jl_queue_work, Cvoid,
89+
(Ptr{Cvoid}, Ptr{UInt8}, Ptr{UInt8}, Ptr{Cvoid}, Cint),
90+
fun_ptr, args_arr, ret_arr, c_notify_fun, idx)
9091

91-
# wait for a result & return it
92-
wait(thread_notifiers[idx])
93-
thread_notifiers[idx] = nothing
94-
release(threadcall_restrictor)
92+
# wait for a result & return it
93+
wait(thread_notifiers[idx])
94+
thread_notifiers[idx] = nothing
95+
release(threadcall_restrictor)
9596

96-
unsafe_load(convert(Ptr{rettype}, pointer(ret_arr)))
97+
r = unsafe_load(convert(Ptr{rettype}, pointer(ret_arr)))
98+
end
99+
return r
97100
end

doc/src/base/c.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
```@docs
44
ccall
55
Core.Intrinsics.cglobal
6-
Base.cfunction
6+
Base.@cfunction
77
Base.unsafe_convert
88
Base.cconvert
99
Base.unsafe_load

0 commit comments

Comments
 (0)