Skip to content

Commit 0aa14c1

Browse files
authored
Add @might_produce macro (#198)
* Add `@might_produce_kwargs` macro * Add docs entry * Format * Expand to cover positional args as well * rename doctest * Format * Use function reference instead of another module * Fix Turing performance tests In v0.40 (or earlier?), TracedModel was moved from essential to src/mcmc/particle_mcmc.jl. * Improve macro docstring * Improve warning
1 parent f154425 commit 0aa14c1

File tree

6 files changed

+122
-7
lines changed

6 files changed

+122
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.9.4"
6+
version = "0.9.5"
77

88
[deps]
99
MistyClosures = "dbe65cb8-6be2-42dd-bbc5-4196aaced4f4"

docs/src/index.md

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ They divide neatly into two kinds of functions: those which are used to manipula
1111
[`TapedTask`](@ref)s, and those which are intended to be used _inside_ a
1212
[`TapedTask`](@ref).
1313

14-
## Manipulation of [`TapedTask`](@ref)s:
14+
## Manipulation of [`TapedTask`](@ref)s
1515
```@docs; canonical=true
1616
Libtask.consume
1717
Base.copy(::Libtask.TapedTask)
1818
Libtask.set_taped_globals!
1919
```
2020

21-
## Functions for use inside a [`TapedTask`](@ref)s:
21+
## Functions for use inside a [`TapedTask`](@ref)s
2222
```@docs; canonical=true
2323
Libtask.produce
2424
Libtask.get_taped_globals
@@ -28,4 +28,5 @@ An opt-in mechanism marks functions that might contain `Libtask.produce` stateme
2828

2929
```@docs; canonical=true
3030
Libtask.might_produce(::Type{<:Tuple})
31+
Libtask.@might_produce
3132
```

perf/p0.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ end
1414

1515
# Case 1: Sample from the prior.
1616
rng = MersenneTwister()
17-
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
17+
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), SampleFromPrior(), VarInfo(), rng)
1818
f = m.evaluator[1];
1919
args = m.evaluator[2:end];
2020

@@ -27,7 +27,7 @@ println("Run a tape...")
2727
@btime t.tf(args...)
2828

2929
# Case 2: SMC sampler
30-
m = Turing.Core.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
30+
m = Turing.Inference.TracedModel(gdemo(1.5, 2.0), Sampler(SMC(50)), VarInfo(), rng)
3131
f = m.evaluator[1];
3232
args = m.evaluator[2:end];
3333

perf/p2.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ Random.seed!(rng, 2)
5252
iterations = 500
5353
model_fun = infiniteGMM(data)
5454

55-
m = Turing.Core.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
55+
m = Turing.Inference.TracedModel(model_fun, Sampler(SMC(50)), VarInfo(), rng)
5656
f = m.evaluator[1]
5757
args = m.evaluator[2:end]
5858

src/copyable_task.jl

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,11 +354,76 @@ end
354354
`true` if a call to method with signature `sig` is permitted to contain
355355
`Libtask.produce` statements.
356356
357-
This is an opt-in mechanism. the fallback method of this function returns `false` indicating
357+
This is an opt-in mechanism. The fallback method of this function returns `false` indicating
358358
that, by default, we assume that calls do not contain `Libtask.produce` statements.
359359
"""
360360
might_produce(::Type{<:Tuple}) = false
361361

362+
"""
363+
@might_produce(f)
364+
365+
If `f` is a function that may call `Libtask.produce` inside it, then `@might_produce(f)`
366+
will generate the appropriate methods needed to ensure that `Libtask.might_produce` returns
367+
`true` for **all** relevant signatures of `f`. This works even if `f` has methods with
368+
keyword arguments.
369+
370+
!!! note
371+
Because `@might_produce f` is applied to all possible signatures, there are performance
372+
penalties associated with marking all methods of `f` as produceable if only one method
373+
can actually call `produce`. If performance is critical, please use the
374+
[`might_produce`](@ref) function directly.
375+
376+
```jldoctest might_produce_macro
377+
julia> # For this demonstration we need to mark `g` as not being inlineable.
378+
@noinline function g(x; y, z=0)
379+
produce(x + y + z)
380+
end
381+
g (generic function with 1 method)
382+
383+
julia> function f()
384+
g(1; y=2, z=3)
385+
end
386+
f (generic function with 1 method)
387+
388+
julia> # This returns nothing because `g` isn't yet marked as being able to `produce`.
389+
consume(Libtask.TapedTask(nothing, f))
390+
391+
julia> Libtask.@might_produce(g)
392+
393+
julia> # Now it works!
394+
consume(Libtask.TapedTask(nothing, f))
395+
6
396+
"""
397+
macro might_produce(f)
398+
# See https://github.com/TuringLang/Libtask.jl/issues/197 for discussion of this macro.
399+
quote
400+
function $(Libtask).might_produce(::Type{<:Tuple{typeof($(esc(f))),Vararg}})
401+
return true
402+
end
403+
possible_n_kwargs = unique(map(length Base.kwarg_decl, methods($(esc(f)))))
404+
if possible_n_kwargs != [0]
405+
# Oddly we need to interpolate the module and not the function: either
406+
# `$(might_produce)` or $(Libtask.might_produce) seem more natural but both of
407+
# those cause the entire `Libtask.might_produce` to be treated as a single
408+
# symbol. See https://discourse.julialang.org/t/128613
409+
function $(Libtask).might_produce(
410+
::Type{<:Tuple{typeof(Core.kwcall),<:NamedTuple,typeof($(esc(f))),Vararg}}
411+
)
412+
return true
413+
end
414+
for n in possible_n_kwargs
415+
# We only need `Any` and not `<:Any` because tuples are covariant.
416+
kwarg_types = fill(Any, n)
417+
function $(Libtask).might_produce(
418+
::Type{<:Tuple{<:Function,kwarg_types...,typeof($(esc(f))),Vararg}}
419+
)
420+
return true
421+
end
422+
end
423+
end
424+
end
425+
end
426+
362427
# Helper struct used in `derive_copyable_task_ir`.
363428
struct TupleRef
364429
n::Int

test/copyable_task.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,4 +251,53 @@
251251
@test Libtask.consume(tt) === :a
252252
@test Libtask.consume(tt) === nothing
253253
end
254+
255+
@testset "@might_produce macro" begin
256+
# Positional arguments only
257+
@noinline g1(x) = produce(x)
258+
f1(x) = g1(x)
259+
# Without marking it as might_produce
260+
tt = Libtask.TapedTask(nothing, f1, 0)
261+
@test Libtask.consume(tt) === nothing
262+
# Now marking it
263+
Libtask.@might_produce(g1)
264+
tt = Libtask.TapedTask(nothing, f1, 0)
265+
@test Libtask.consume(tt) === 0
266+
@test Libtask.consume(tt) === nothing
267+
268+
# Keyword arguments only
269+
@noinline g2(x; y=1, z=2) = produce(x + y + z)
270+
f2(x) = g2(x)
271+
# Without marking it as might_produce
272+
tt = Libtask.TapedTask(nothing, f2, 0)
273+
@test Libtask.consume(tt) === nothing
274+
# Now marking it
275+
Libtask.@might_produce(g2)
276+
tt = Libtask.TapedTask(nothing, f2, 0)
277+
@test Libtask.consume(tt) === 3
278+
@test Libtask.consume(tt) === nothing
279+
280+
# A function with multiple methods.
281+
# The function reference is used to ensure that it really doesn't get inlined
282+
# (otherwise, for reasons that are yet unknown, these functions do get inlined when
283+
# inside a testset)
284+
@noinline g3(x) = produce(x)
285+
@noinline g3(x, y; z) = produce(x + y + z)
286+
@noinline g3(x, y, z; p, q) = produce(x + y + z + p + q)
287+
function f3(x, fref)
288+
fref[](x)
289+
fref[](x, 1; z=2)
290+
fref[](x, 1, 2; p=3, q=4)
291+
return nothing
292+
end
293+
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
294+
@test Libtask.consume(tt) === nothing
295+
# Now marking it
296+
Libtask.@might_produce(g3)
297+
tt = Libtask.TapedTask(nothing, f3, 0, Ref(g3))
298+
@test Libtask.consume(tt) === 0
299+
@test Libtask.consume(tt) === 3
300+
@test Libtask.consume(tt) === 10
301+
@test Libtask.consume(tt) === nothing
302+
end
254303
end

0 commit comments

Comments
 (0)