Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RuntimeGeneratedFunctions"
uuid = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
authors = ["Chris Rackauckas <[email protected]> and contributors"]
version = "0.3.2"
version = "0.4.0"

[deps]
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expand Down
39 changes: 36 additions & 3 deletions src/RuntimeGeneratedFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,13 @@ then calling the resulting function. The differences are:
* The result is not a named generic function, and doesn't participate in
generic function dispatch; it's more like a callable method.

You need to use `RuntimeGeneratedFunctions.init(your_module)` a single time at
the top level of `your_module` before any other uses of the macro.

# Examples
```
RuntimeGeneratedFunctions.init(@__MODULE__) # Required at module top-level

function foo()
expression = :((x,y)->x+y+1) # May be generated dynamically
f = @RuntimeGeneratedFunction(expression)
Expand All @@ -42,8 +47,11 @@ end
```
"""
macro RuntimeGeneratedFunction(ex)
_ensure_cache_exists!(__module__)
quote
if !($(esc(:(@isdefined($_tagname)))))
error("""You must use `RuntimeGeneratedFunctions.init(@__MODULE__)` at module
top level before using runtime generated functions""")
end
RuntimeGeneratedFunction(
$(esc(_tagname)),
$(esc(ex))
Expand All @@ -59,7 +67,11 @@ end

(f::RuntimeGeneratedFunction)(args::Vararg{Any,N}) where N = generated_callfunc(f, args...)

@inline @generated function generated_callfunc(f::RuntimeGeneratedFunction{moduletag, id, argnames}, __args...) where {moduletag,id,argnames}
# We'll generate a method of this function in every module which wants to use
# @RuntimeGeneratedFunction
function generated_callfunc end

function generated_callfunc_body(moduletag, id, argnames, __args)
setup = (:($(argnames[i]) = @inbounds __args[$i]) for i in 1:length(argnames))
body = _lookup_body(moduletag, id)
@assert body !== nothing
Expand Down Expand Up @@ -122,13 +134,34 @@ function _lookup_body(moduletag, id)
end
end

function _ensure_cache_exists!(mod)
"""
RuntimeGeneratedFunctions.init(mod)

Use this at top level to set up your module `mod` before using
`@RuntimeGeneratedFunction`.
"""
function init(mod)
lock(_cache_lock) do
if !isdefined(mod, _cachename)
mod.eval(quote
const $_cachename = Dict()
struct $_tagname
end

# We create method of `generated_callfunc` in the user's module
# so that any global symbols within the body will be looked up
# in the user's module scope.
#
# This is straightforward but clunky. A neater solution should
# be to explicitly expand in the user's module and return a
# CodeInfo from `generated_callfunc`, but it seems we'd need
# `jl_expand_and_resolve` which doesn't exist until Julia 1.3
# or so. See:
# https://github.com/JuliaLang/julia/pull/32902
# https://github.com/NHDaly/StagedFunctions.jl/blob/master/src/StagedFunctions.jl#L30
@inline @generated function $RuntimeGeneratedFunctions.generated_callfunc(f::$RuntimeGeneratedFunctions.RuntimeGeneratedFunction{$_tagname, id, argnames}, __args...) where {id,argnames}
$RuntimeGeneratedFunctions.generated_callfunc_body($_tagname, id, argnames, __args)
end
end)
end
end
Expand Down
1 change: 1 addition & 0 deletions test/precomp/RGFPrecompTest.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module RGFPrecompTest
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

f = @RuntimeGeneratedFunction(:((x,y)->x+y))
end
20 changes: 20 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
using RuntimeGeneratedFunctions, BenchmarkTools
using Test

RuntimeGeneratedFunctions.init(@__MODULE__)

function f(_du,_u,_p,_t)
@inbounds _du[1] = _u[1]
@inbounds _du[2] = _u[2]
Expand Down Expand Up @@ -107,3 +109,21 @@ for k=1:4
end
@test all(all.(fetch.(tasks)))


# Test that globals are resolved within the correct scope

module GlobalsTest
using RuntimeGeneratedFunctions
RuntimeGeneratedFunctions.init(@__MODULE__)

y = 40
f = @RuntimeGeneratedFunction(:(x->x+y))
end

@test GlobalsTest.f(2) == 42

@test_throws ErrorException @eval(module NotInitTest
using RuntimeGeneratedFunctions
# RuntimeGeneratedFunctions.init(@__MODULE__) # <-- missing
f = @RuntimeGeneratedFunction(:(x->x+y))
end)