-
Notifications
You must be signed in to change notification settings - Fork 228
Description
We're duplicating a lot of code and a lot of effort by having a bunch of sampler (or rather, InferenceAlgorithm
) implementations in Turing.jl itself.
There are a few reasons for this is / was the case:
- The old approach of doing Gibbs sampling took an approach that required hooking into the
assume
andobserve
statements for samplers and to mutate the varinfo in a particular, even if the functionality of the sampler itself (when used outside of Gibbs) didn't require it. - The samplers in Turing.jl would often offer more convenient constructors while the sampler packages themselves, e.g. AdvancedHMC.jl, would offer a more flexible but also more complicated interfaces.
InferenceAlgorithm
allows us to overload thesample
call explicitly to do some "non-standard" things, e.g. usechain_type=MCMCChains.Chains
as the default, instead ofchain_type=Vector
as is default in AbstractMCMC.jl.
Everything but (3) is "easily" addressable (i.e. only requires dev-time, not necessarily any discussion on how to do it):
- (1) is being addressed in Replace old Gibbs sampler with the experimental one. #2328 (issue ref: Remove old Gibbs sampler, make the experimental one the default #2318). This should therefore be addressed very soon.
- (2) should be addressed by simply moving any convenience constructors from Turing.jl itself into the respective package. There's no reason why we should keep convenient constructors in a different package (Turing.jl in this case) than the package implementing the samplers. Effort has been made towards this, e.g. Convinience constructors AdvancedHMC.jl#325, but we need to through all the samplers and check which have missing "convenience" constructors. Related issues that should be addressed in downstream packages: Improve documentation AdvancedMH.jl#107 Can we remove
DensityModel
? AdvancedMH.jl#108 - (3) is somewhat tricky. There are a few aspects to this that we need to handle: a) how to default to
chain_type=Chains
for Turing.jl models (ref: Remove overly specialized bundle_samples AbstractMCMC.jl#120, Defaultbundle_samples
is quite annoying AbstractMCMC.jl#118), b) how to allow extraction of other interesting information than just the realizations for the variables fromsample
calls, and c) extraction of parameter names used to construct the chain. See the section below for more extensive discussion of this issue. Relevant issues: Removehmc.jl
andmh.jl
in light of upstreamed "getparams" into AbstractMCMC #2367
Removing the InferenceAlgorithm
type (3)
Problem
Currently, all the samplers in Turing.jl have most of their code living outside of Turing.jl + inside Turing.jl we define a "duplicate" which is not an AbstractMCMC.AbstractSampler
(as typically expected by AbstractMCMC.sample
), but instead a subtype of Turing.Infernece.InferenceAlgorithm
:
Turing.jl/src/mcmc/Inference.jl
Lines 91 to 95 in c0a4ee9
abstract type InferenceAlgorithm end | |
abstract type ParticleInference <: InferenceAlgorithm end | |
abstract type Hamiltonian <: InferenceAlgorithm end | |
abstract type StaticHamiltonian <: Hamiltonian end | |
abstract type AdaptiveHamiltonian <: Hamiltonian end |
But exactly because these are not AbstractMCMC.AbstractSampler
, we can overload sample
calls to do more than what sample
does for a given AbstractSampler
.
One of the things we do is to make chain_type=Chains
rather than chain_type=Vector
(as is the default in AbstractMCMC.jl):
Turing.jl/src/mcmc/Inference.jl
Lines 337 to 359 in c0a4ee9
function AbstractMCMC.sample( | |
rng::AbstractRNG, | |
model::AbstractModel, | |
sampler::Sampler{<:InferenceAlgorithm}, | |
ensemble::AbstractMCMC.AbstractMCMCEnsemble, | |
N::Integer, | |
n_chains::Integer; | |
chain_type=MCMCChains.Chains, | |
progress=PROGRESS[], | |
kwargs..., | |
) | |
return AbstractMCMC.mcmcsample( | |
rng, | |
model, | |
sampler, | |
ensemble, | |
N, | |
n_chains; | |
chain_type=chain_type, | |
progress=progress, | |
kwargs..., | |
) | |
end |
Another is to perform some simple model checks to stop the user from doing things they shouldn't, e.g. accidentally using a model twice (this is done using DynamicPPL.check_model
):
Turing.jl/src/mcmc/Inference.jl
Lines 296 to 306 in c0a4ee9
function AbstractMCMC.sample( | |
rng::AbstractRNG, | |
model::AbstractModel, | |
alg::InferenceAlgorithm, | |
N::Integer; | |
check_model::Bool=true, | |
kwargs..., | |
) | |
check_model && _check_model(model, alg) | |
return AbstractMCMC.sample(rng, model, Sampler(alg, model), N; kwargs...) | |
end |
However, as mentioned before, having to repeat all these sampler constructors just to go from working with a AbstractSampler
to InferenceAlgorithm
so we can do these things is a) very annoying to maintain, and b) makes it all very confusing for newcomers to contribute.
Now, the problem is that cannot simple start overloading sample(model::DynamicPPL.Model, sampler::AbstractMCMC.AbstractSampler, ...)
calls since sampler packages might define something like sample(model::AbstractMCMC.AbstractModel, sampler::MySampler, ...)
(we have DynamicPPL.Model <: AbstractMCMC.AbstractModel
btw) which would give rise to a host of method ambiguities.
Someone might say "oh, but nobody is going to impelment sample(model::AbstractMCMC.AbstractModel, sampler::MySampler, ...)
; they're always going to implement a sampler for a specific model type, e.g. AbstractMCMC.LogDensityModel
", but this is not great for two reasons: a) "meta" samplers, i.e. samplers that use other samplers as components, might want to be agnostic to what the underlying model is as this "meta" sampler doesn't interact directly with the model itself, and b) if we do so, we're claiming that DynamicPPL.Model
is, in some way, a special and more important model type than all other subtypes of AbstractModel
, which is the exact opposite of what we wanted to do with AbstractMCMC.jl (we wanted it to be a "sampler package for all, not just Turing.jl").
externalsampler
introduced in #2008 is a step towards this, but in the end we don't want to require externalsampler
to wrap every sampler
passed to Turing.jl; we really only want this to have to wrap samplers which do not support all the additional niceties that Turing.jl's current sample
provides.
Solution 1: rename or duplicate sample
The only true solution I see, which is very, very annoying, is to either
- Not export
AbstractMCMC.sample
from Turing.jl, and instead define and export a separateTuring.sample
which is a fancy wrapper aroundAbstractMCMC.sample
. - Define a new entry-point for
sample
from Turing.jl with a different name, e.g.infer
ormcmc
(or even use the internalmcmcsample
from AbstractMCMC.jl naming but making it public).
None of these are ideal tbh.
(1) sucks because so many of the packages are using StatsBase.sample
(as we are in AbstractMCMC.jl) for this very reasonable interface, and so diverging from this is confusing + we'll easily end up with naming collisions in the namespace of the user, e.g. using Turing, AbstractMCMC
would immediately cause two sample
methods to be imported.
(2) is also a bit annoying as this would be a highly breaking change. It's also a bit annoying because, well, sample
is a much better name 🤷
IMHO, I think (2) is best here though. If we define a method called mcmc
or mcmcsample
(ideally we'd do something with AbstractMCMC.mcmcsample
) which is exported from Turing.jl, we could do away with all of InferenceAlgorithm
and its implementations in favour of a single (or a few) overloads of this method.