Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 50 additions & 0 deletions HISTORY.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,55 @@
# 0.42.0

## Breaking Changes

**AdvancedVI 0.5**

Turing.jl v0.42 updates `AdvancedVI.jl` compatibility to 0.5.
Most of the changes introduced in `[email protected]` are structural, with some changes spilling out into the interface.
The summary of the changes below are the things that affect the end-users of Turing.
For a more comprehensive list of changes, please refer to the [changelogs](https://github.com/TuringLang/AdvancedVI.jl/blob/main/HISTORY.md) in `AdvancedVI`.

A new level of interface for defining different variational algorithms has been introduced in `AdvancedVI` v0.5. As a result, the function `Turing.vi` now receives a keyword argument `algorithm`. The object `algorithm <: AdvancedVI.AbstractVariationalAlgorithm` should now contain all the algorithm-specific configurations. Therefore, keyword arguments of `vi` that were algorithm-specific such as `objective`, `operator`, `averager` and so on, have been moved as fields of the relevant `<: AdvancedVI.AbstractVariationalAlgorithm` structs.
For example,

```julia
vi(model, q, n_iters; objective=RepGradELBO(10), operator=AdvancedVI.ClipScale())
```

is now

```julia
vi(
model,
q,
n_iters;
algorithm=KLMinRepGradDescent(adtype; n_samples=10, operator=AdvancedVI.ClipScale()),
)
```

Similarly,

```julia
vi(
model,
q,
n_iters;
objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
)
```

is now

```julia
vi(model, q, n_iters; algorithm=KLMinRepGradProxDescent(adtype; n_samples=10))
```

Additionally,

- The default hyperparameters of `DoG`and `DoWG` have been altered.
- The deprecated `[email protected]`-era interface is now removed.

# 0.41.0

## DynamicPPL 0.38
Expand Down
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Accessors = "0.1"
AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6, 0.7, 0.8"
AdvancedMH = "0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
BangBang = "0.4.2"
Bijectors = "0.14, 0.15"
Compat = "4.15.0"
Expand Down
4 changes: 3 additions & 1 deletion src/Turing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,12 @@ export
externalsampler,
# Variational inference - AdvancedVI
vi,
ADVI,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent,
# ADTypes
AutoForwardDiff,
AutoReverseDiff,
Expand Down
81 changes: 37 additions & 44 deletions src/variational/VariationalInference.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@

module Variational

using DynamicPPL
using AdvancedVI:
AdvancedVI, KLMinRepGradDescent, KLMinRepGradProxDescent, KLMinScoreGradDescent
using ADTypes
using Bijectors: Bijectors
using Distributions
using DynamicPPL
using LinearAlgebra
using LogDensityProblems
using Random
using ..Turing: DEFAULT_ADTYPE, PROGRESS

import ..Turing: DEFAULT_ADTYPE, PROGRESS

import AdvancedVI
import Bijectors

export vi, q_locationscale, q_meanfield_gaussian, q_fullrank_gaussian

include("deprecated.jl")
export vi,
q_locationscale,
q_meanfield_gaussian,
q_fullrank_gaussian,
KLMinRepGradProxDescent,
KLMinRepGradDescent,
KLMinScoreGradDescent

"""
q_initialize_scale(
Expand Down Expand Up @@ -248,76 +251,66 @@ end
"""
vi(
[rng::Random.AbstractRNG,]
model::DynamicPPL.Model;
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective::AdvancedVI.AbstractVariationalObjective = AdvancedVI.RepGradELBO(
10; entropy = AdvancedVI.ClosedFormEntropyZeroGradient()
max_iter::Int;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm::AdvancedVI.AbstractVariationalAlgorithm = KLMinRepGradProxDescent(
adtype; n_samples=10
),
show_progress::Bool = Turing.PROGRESS[],
optimizer::Optimisers.AbstractRule = AdvancedVI.DoWG(),
averager::AdvancedVI.AbstractAverager = AdvancedVI.PolynomialAveraging(),
operator::AdvancedVI.AbstractOperator = AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType = Turing.DEFAULT_ADTYPE,
kwargs...
)

Approximating the target `model` via variational inference by optimizing `objective` with the initialization `q`.
Approximate the target `model` via the variational inference algorithm `algorithm` by starting from the initial variational approximation `q`.
This is a thin wrapper around `AdvancedVI.optimize`.
The default `algorithm`, `KLMinRepGradProxDescent` ([relevant docs](https://turinglang.org/AdvancedVI.jl/dev/klminrepgradproxdescent/)), assumes `q` uses `AdvancedVI.MvLocationScale`, which can be constructed by invoking `q_fullrank_gaussian` or `q_meanfield_gaussian`.
For other variational families, refer to `AdvancedVI` to determine the best algorithm and options.

# Arguments
- `model`: The target `DynamicPPL.Model`.
- `q`: The initial variational approximation.
- `n_iterations`: Number of optimization steps.
- `max_iter`: Maximum number of steps.

# Keyword Arguments
- `objective`: Variational objective to be optimized.
- `adtype`: Automatic differentiation backend to be applied to the log-density. The default value for `algorithm` also uses this backend for differentiation the variational objective.
- `algorithm`: Variational inference algorithm.
- `show_progress`: Whether to show the progress bar.
- `optimizer`: Optimization algorithm.
- `averager`: Parameter averaging strategy.
- `operator`: Operator applied after each optimization step.
- `adtype`: Automatic differentiation backend.

See the docs of `AdvancedVI.optimize` for additional keyword arguments.

# Returns
- `q`: Variational distribution formed by the last iterate of the optimization run.
- `q_avg`: Variational distribution formed by the averaged iterates according to `averager`.
- `state`: Collection of states used for optimization. This can be used to resume from a past call to `vi`.
- `info`: Information generated during the optimization run.
- `q`: Output variational distribution of `algorithm`.
- `state`: Collection of states used by `algorithm`. This can be used to resume from a past call to `vi`.
- `info`: Information generated while executing `algorithm`.
"""
function vi(
rng::Random.AbstractRNG,
model::DynamicPPL.Model,
q,
n_iterations::Int;
objective=AdvancedVI.RepGradELBO(
10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()
max_iter::Int,
args...;
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
algorithm::AdvancedVI.AbstractVariationalAlgorithm=KLMinRepGradProxDescent(
adtype; n_samples=10
),
show_progress::Bool=PROGRESS[],
optimizer=AdvancedVI.DoWG(),
averager=AdvancedVI.PolynomialAveraging(),
operator=AdvancedVI.ProximalLocationScaleEntropy(),
adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE,
kwargs...,
)
return AdvancedVI.optimize(
rng,
LogDensityFunction(model),
objective,
algorithm,
max_iter,
LogDensityFunction(model; adtype),
q,
n_iterations;
args...;
show_progress=show_progress,
adtype,
optimizer,
averager,
operator,
kwargs...,
)
end

function vi(model::DynamicPPL.Model, q, n_iterations::Int; kwargs...)
return vi(Random.default_rng(), model, q, n_iterations; kwargs...)
function vi(model::DynamicPPL.Model, q, max_iter::Int; kwargs...)
return vi(Random.default_rng(), model, q, max_iter; kwargs...)
end

end
61 changes: 0 additions & 61 deletions src/variational/deprecated.jl

This file was deleted.

3 changes: 2 additions & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ AbstractMCMC = "5"
AbstractPPL = "0.11, 0.12, 0.13"
AdvancedMH = "0.6, 0.7, 0.8"
AdvancedPS = "0.7"
AdvancedVI = "0.4"
AdvancedVI = "0.5"
Aqua = "0.8"
BangBang = "0.4"
Bijectors = "0.14, 0.15"
Expand All @@ -53,6 +53,7 @@ Combinatorics = "1"
Distributions = "0.25"
DistributionsAD = "0.6.3"
DynamicHMC = "2.1.6, 3.0"
DynamicPPL = "0.38"
FiniteDifferences = "0.10.8, 0.11, 0.12"
ForwardDiff = "0.10.12 - 0.10.32, 0.10, 1"
HypothesisTests = "0.11"
Expand Down
Loading
Loading