Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
212 commits
Select commit Hold shift + click to select a range
5fa938e
various changes
HenriDeh Feb 17, 2022
fbe4bc7
improve SAC docstring
HenriDeh Feb 18, 2022
f30ca9d
SAC: add target network and start policy defaults
HenriDeh Feb 18, 2022
aab6b5d
make GaussianNetwork normalizer customizable
HenriDeh Feb 18, 2022
98f1633
create MPOPolicy struct
HenriDeh Feb 21, 2022
d4d4b9c
Merge branch 'master' of https://github.com/HenriDeh/ReinforcementLea…
HenriDeh Feb 22, 2022
fab1b83
continue mpo
HenriDeh Feb 25, 2022
dae8af0
make GaussianNetwork normalizer customizable
HenriDeh Feb 18, 2022
967a7cc
add multiple action samples per state
HenriDeh Feb 25, 2022
1d50b26
Merge branch 'GaussianNets'
HenriDeh Feb 25, 2022
391cd25
Add MPO algorithm
HenriDeh Feb 25, 2022
99d5c89
custom normalizer and multi action sampling
HenriDeh Feb 25, 2022
78ac2b8
Complete docs on gaussian normalizer
HenriDeh Feb 25, 2022
857987d
Upgrade manifest format and update dependecies
HenriDeh Mar 3, 2022
1cc3c35
add default initializer
HenriDeh Mar 3, 2022
027d0b3
Fix logp_pi
HenriDeh Mar 3, 2022
075b1a8
add convenience function
HenriDeh Mar 3, 2022
fefd141
add unit tests for GaussianNetwork
HenriDeh Mar 3, 2022
4d488b8
use isapprox in tests
HenriDeh Mar 4, 2022
e24854f
Update src/ReinforcementLearningCore/src/policies/q_based_policies/le…
findmyway Mar 4, 2022
eb2008b
add unknown words
findmyway Mar 4, 2022
f7348da
Rand directly on the device for all NN approx
HenriDeh Mar 4, 2022
76fb288
fix CUDA functional
HenriDeh Mar 4, 2022
6d627b6
Fix CUDA rand test
HenriDeh Mar 4, 2022
ef06ba4
add CURAND rng
HenriDeh Mar 4, 2022
a8c5fa1
Merge branch 'master' into GaussianNets
findmyway Mar 4, 2022
6c919a7
logdet of matrix from L decomposition
HenriDeh Mar 10, 2022
24ee6de
computing logpdf given cholesky of Covariance
HenriDeh Mar 10, 2022
4fe1f61
add CovGaussianNetwork
HenriDeh Mar 10, 2022
5d163aa
add unit tests
HenriDeh Mar 10, 2022
0cc6d2e
disallow scalar indexing in tests
HenriDeh Mar 10, 2022
76cb606
Merge branch 'master' into CovGaussianNet
HenriDeh Mar 10, 2022
2cea000
Merge branch 'CovGaussianNet'
HenriDeh Mar 10, 2022
2ec382d
Merge branch 'master' into mpo
HenriDeh Mar 10, 2022
05e0659
add 2D input compatibility
HenriDeh Mar 10, 2022
c86869e
Fix and add tests
HenriDeh Mar 10, 2022
2a94983
add unknown words
HenriDeh Mar 10, 2022
99c29b6
stabilize test
HenriDeh Mar 10, 2022
25a3c50
Add failing tests
HenriDeh Mar 11, 2022
936b1c3
Fix gradient of GaussianNetwork
HenriDeh Mar 11, 2022
88b9749
add more tests
HenriDeh Mar 11, 2022
4dd7dc8
remove a problamatic newline
HenriDeh Mar 11, 2022
eb77b64
Merge branch 'GN_gradient'
HenriDeh Mar 11, 2022
6da3523
Merge branch 'master' into mpo
HenriDeh Mar 11, 2022
e0f5028
Update src/ReinforcementLearningCore/src/policies/q_based_policies/le…
HenriDeh Mar 13, 2022
1e129ae
Update src/ReinforcementLearningCore/src/policies/q_based_policies/le…
HenriDeh Mar 13, 2022
5daeee2
rename vec_to_tril
HenriDeh Mar 14, 2022
09e62c5
use similar zero constructor
HenriDeh Mar 14, 2022
5dfe9fe
remove CUDA specific method
HenriDeh Mar 14, 2022
5a4881c
Merge branch 'CovGaussianNet' of https://github.com/HenriDeh/Reinforc…
HenriDeh Mar 14, 2022
72d2040
spacing
HenriDeh Mar 14, 2022
cc3452a
rename all to vec_to_tril
HenriDeh Mar 14, 2022
f0ad02d
Fix NaN problem with vec_to_tril
HenriDeh Mar 14, 2022
8b7569d
Merge branch 'master' into CovGaussianNet
HenriDeh Mar 14, 2022
523498c
Merge branch 'CovGaussianNet'
HenriDeh Mar 14, 2022
043dea5
Merge branch 'master' into mpo
HenriDeh Mar 14, 2022
b9df23f
typo
HenriDeh Mar 14, 2022
ed8eec9
Return L instead of Sigma
HenriDeh Mar 14, 2022
743c0a8
Change tests accordingly
HenriDeh Mar 14, 2022
40b2660
Merge branch 'CovGaussianNet'
HenriDeh Mar 15, 2022
95f708c
Merge branch 'master' into mpo
HenriDeh Mar 15, 2022
c1f79d8
change comment
HenriDeh Mar 16, 2022
89d144c
revise struct
HenriDeh Mar 17, 2022
08137c5
add constructor
HenriDeh Mar 17, 2022
00cc3f1
simplify call
HenriDeh Mar 17, 2022
b58efea
update in two functions. full et diag gaussians
HenriDeh Mar 17, 2022
5c51bd8
Merge branch 'master' into mpo
HenriDeh Mar 17, 2022
8e3e5aa
Revert "SAC: add target network and start policy defaults"
HenriDeh Mar 18, 2022
c1f50f3
Merge branch 'mpo' of https://github.com/HenriDeh/ReinforcementLearni…
HenriDeh Mar 18, 2022
d2d0132
Revert "Revert "SAC: add target network and start policy defaults""
HenriDeh Mar 18, 2022
b302a76
move update step to policy call
HenriDeh Mar 18, 2022
132c225
Revert "remove a problamatic newline"
HenriDeh Mar 18, 2022
de32919
Revert "remove a problamatic newline"
HenriDeh Mar 18, 2022
1f791b6
Merge branch 'mpo' of https://github.com/HenriDeh/ReinforcementLearni…
HenriDeh Mar 18, 2022
7a0742e
remove assignement in struct
HenriDeh Mar 18, 2022
288f385
Move Optim to Zoo
HenriDeh Mar 18, 2022
463f659
include mpo
HenriDeh Mar 18, 2022
81ceb7a
remove comas
HenriDeh Mar 18, 2022
e371da3
updgrade manifest
HenriDeh Mar 18, 2022
22cf087
fix constructor
HenriDeh Mar 18, 2022
102656b
make CovGaussian compatible with vec state
HenriDeh Mar 18, 2022
ccb62f5
remove max entropy
HenriDeh Mar 18, 2022
ce557df
more changes
HenriDeh Mar 21, 2022
9bf34a3
add functor and gpu utils
HenriDeh Mar 21, 2022
486a74b
new manifest
HenriDeh Mar 22, 2022
b351a75
switch to batch sampler, loop inside update calls
HenriDeh Mar 22, 2022
1f4d8d2
Use a smarter ldiv solve
HenriDeh Mar 22, 2022
3aa003f
use a mean instead of sum
HenriDeh Mar 22, 2022
43d8289
return mu and Sigma in identical shapes
HenriDeh Mar 22, 2022
c3431ee
make kldiv gpu friendly
HenriDeh Mar 23, 2022
84117b8
fix a typo
HenriDeh Mar 23, 2022
03a4468
Move kldivergences
HenriDeh Mar 23, 2022
c619eb8
remove scalar indexing
HenriDeh Mar 24, 2022
f232b81
export kldiv
HenriDeh Mar 24, 2022
e189127
fix losses
HenriDeh Mar 24, 2022
bbf1fcd
add grad checks
HenriDeh Mar 24, 2022
c8023b3
fix parenthesis
HenriDeh Mar 24, 2022
1db38c7
add a testmode
HenriDeh Mar 25, 2022
3b6e780
add logger
HenriDeh Mar 25, 2022
87e5b6c
Change dual to univariate
HenriDeh Mar 25, 2022
4a7e42a
change CUDA logdet
HenriDeh Mar 29, 2022
7e770fb
improve mvnormlogpdf
HenriDeh Mar 29, 2022
6d45dcd
add is_return_log_prob
HenriDeh Mar 29, 2022
979d5ea
log_prob is useless
HenriDeh Mar 29, 2022
f06b179
some fixes for diag norm
HenriDeh Mar 29, 2022
c108d69
return logprob fix
HenriDeh Mar 29, 2022
dcc90c8
Create rewardnormalizer.jl
HenriDeh Mar 29, 2022
b668608
inlcude and export
HenriDeh Mar 29, 2022
b41c40e
Fix NaN
HenriDeh Mar 29, 2022
6b762cd
comment
HenriDeh Mar 29, 2022
256cc9e
typo
HenriDeh Mar 29, 2022
4bbd01a
Merge branch 'reward_normalizer' into mpo
HenriDeh Mar 29, 2022
55743ad
rename file
HenriDeh Mar 30, 2022
9d670a1
add an exponential MA
HenriDeh Mar 30, 2022
24afb4b
actually move to Zoo
HenriDeh Mar 30, 2022
db5b3cf
Merge branch 'reward_normalizer' into mpo
HenriDeh Mar 30, 2022
b51c948
add reward normalizer
HenriDeh Mar 30, 2022
82f7383
fix identity normalizer
HenriDeh Mar 30, 2022
79a7623
use logexpfunctions
HenriDeh Mar 31, 2022
b244f78
Add a categorical Network
HenriDeh Apr 26, 2022
7f42d55
include
HenriDeh Apr 28, 2022
b9eeb85
Merge branch 'master' into mpo
HenriDeh Jun 23, 2022
48513a7
Merge branch 'master' into mpo
HenriDeh Jun 23, 2022
173ddb8
implement optimise! and remove normalizer
HenriDeh Jun 23, 2022
23c91a8
single batch update_policy
HenriDeh Jun 23, 2022
e839c70
remove batch_sampler
HenriDeh Jun 24, 2022
52e81b5
remove update args and batch sizes
HenriDeh Jun 24, 2022
684856b
remove reward normalizer
HenriDeh Jun 24, 2022
1565fd6
delete reward normalizer
HenriDeh Jun 24, 2022
b426e9b
Merge branch 'master' into discretenetwork
HenriDeh Jun 24, 2022
056ce8a
add test
HenriDeh Jun 24, 2022
56591f0
remove rn include
HenriDeh Jun 27, 2022
97fcdf0
add logdetLorU back
HenriDeh Jun 28, 2022
48c43a5
add dependencies
HenriDeh Jun 28, 2022
1938b74
use Approximator
HenriDeh Jun 28, 2022
c2ab6cb
incldue mpo
HenriDeh Jun 28, 2022
df34538
fix doc underscores
HenriDeh Jun 28, 2022
ff5741a
move to networks.jl
HenriDeh Jun 30, 2022
c9baf73
add back tests for other networks
HenriDeh Jun 30, 2022
400dd5e
Merge branch 'master' into discretenetwork
HenriDeh Jun 30, 2022
f4574f6
Merge branch 'master' into discretenetwork
HenriDeh Jul 1, 2022
a009d46
fixes and add tests
HenriDeh Jul 1, 2022
99a3196
fix typo
HenriDeh Jul 1, 2022
25ee90d
fix ci ?
HenriDeh Jul 4, 2022
83c196c
add action_masking
HenriDeh Jul 4, 2022
166e158
Merge branch 'JuliaReinforcementLearning:master' into mpo
HenriDeh Jul 5, 2022
38b3866
Merge branch 'master' into mpo
HenriDeh Jul 5, 2022
6ff1d8f
restore CovGaussianNetwork tests
HenriDeh Jul 5, 2022
28a7043
rename logits kwarg to log_prob
HenriDeh Jul 5, 2022
dbfb8ae
updating flux api
HenriDeh Jul 6, 2022
1ee8012
add diagnormlogpdf
HenriDeh Jul 18, 2022
c56405a
kldiv
HenriDeh Jul 19, 2022
352e7c0
finalize losses
HenriDeh Jul 19, 2022
4fdb93a
manifest update
HenriDeh Jul 22, 2022
b2d26ed
add kwarg propagation to approximator
HenriDeh Jul 22, 2022
6e16447
fix a few things
HenriDeh Jul 22, 2022
d3f16e1
fit to paper style
HenriDeh Jul 27, 2022
1784f99
Merge branch 'mpo' of https://github.com/HenriDeh/ReinforcementLearni…
HenriDeh Jul 27, 2022
63ad5e8
fix networks normalizer
HenriDeh Jul 28, 2022
2041aa7
use paper looping logic
HenriDeh Jul 28, 2022
dd90147
add experiment
HenriDeh Jul 28, 2022
221d915
Merge branch 'mpo' of https://github.com/HenriDeh/ReinforcementLearni…
HenriDeh Jul 28, 2022
2ff354f
update dependency
findmyway Jul 28, 2022
d6c85f6
move normalizer in gaussians
HenriDeh Jul 28, 2022
a65058c
add eta caching
HenriDeh Jul 28, 2022
3444b1d
update exp
HenriDeh Jul 28, 2022
a9e35be
Merge branch 'mpo' of https://github.com/HenriDeh/ReinforcementLearni…
HenriDeh Jul 28, 2022
6ae59b3
Merge pull request #65 from findmyway/mpo
HenriDeh Jul 28, 2022
36a887d
Revert "add eta caching"
HenriDeh Jul 28, 2022
513ff66
environment
HenriDeh Jul 28, 2022
156a72e
finishing stuff
HenriDeh Dec 12, 2022
8159580
Merge branch 'gumbelsoftmax' into discretenetwork
HenriDeh Dec 12, 2022
45faa27
Merge branch 'discretenetwork' into mpo
HenriDeh Dec 12, 2022
5887cab
Merge branch 'master' into mpo
HenriDeh Dec 12, 2022
dde1692
adding docstrings
HenriDeh Dec 12, 2022
6fe1e85
Merge branch 'master' into mpo
HenriDeh Dec 16, 2022
bbc986f
Merge branch 'master' into mpo
HenriDeh Dec 16, 2022
69ff077
Merge branch 'master' into mpo
HenriDeh Dec 19, 2022
e06554e
add missing Nothing argument
HenriDeh Dec 19, 2022
80a0ee2
add MPO experiments
HenriDeh Dec 19, 2022
b4ddc8a
add new deps for the 36th time
HenriDeh Dec 19, 2022
f977e05
import Random
HenriDeh Dec 19, 2022
07ff000
remove normalizers again
HenriDeh Dec 19, 2022
6d349cc
train qnetworks with different batches
HenriDeh Dec 19, 2022
f20b39c
fix CUDA deps
HenriDeh Dec 19, 2022
5dbf6c9
Make a doc page
HenriDeh Dec 19, 2022
e99b78a
fix spelling
HenriDeh Dec 19, 2022
8deaee8
remove normalizer in tests
HenriDeh Dec 19, 2022
708f040
fix typo
HenriDeh Dec 19, 2022
c585fc3
"fix" spelling
HenriDeh Dec 19, 2022
b726371
unfix DQN
HenriDeh Dec 19, 2022
aab75c0
Attempting to solve CI
HenriDeh Dec 20, 2022
aca37df
add missing dep
HenriDeh Dec 20, 2022
d25e0e9
add my name to cspell
HenriDeh Dec 20, 2022
138f847
trying again
HenriDeh Dec 20, 2022
33a074e
up compat for trajectories
HenriDeh Dec 20, 2022
0e3e0af
exp
HenriDeh Dec 20, 2022
7be5c74
Merge branch 'master' into mpo
HenriDeh Dec 20, 2022
2de4b3c
add missing pkg dep
HenriDeh Dec 20, 2022
ca6e3d5
fix doc
HenriDeh Dec 20, 2022
1c48bac
update tutorial
HenriDeh Dec 21, 2022
616bc04
update compat for trajectories
HenriDeh Dec 21, 2022
0e8944e
use rng
HenriDeh Dec 21, 2022
a5c7474
change default HPs
HenriDeh Dec 21, 2022
2c78e86
ci fix maybe
HenriDeh Dec 21, 2022
d7382d7
add plots
HenriDeh Dec 21, 2022
f4d44b2
use ignore_derivatives()
HenriDeh Dec 21, 2022
134060e
fix runtests and tangle
HenriDeh Dec 21, 2022
913feb7
fix runtests
HenriDeh Dec 21, 2022
ebc4aa8
fix devmode
HenriDeh Dec 21, 2022
5147182
fix ci
HenriDeh Dec 21, 2022
970128d
correct a mistake in doc
HenriDeh Dec 21, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .cspell/cspell.json
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,8 @@
"rsold",
"rsnew",
"unnormalized",
"baedan"
"baedan",
"Dehaybe"
],
"ignoreWords": [],
"minWordLength": 5,
Expand Down
10 changes: 9 additions & 1 deletion .cspell/julia_words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5284,4 +5284,12 @@ inworld
Posteriori
normalised
kldivergence
devmode
qnetworks
mpodual
lagrangeμ
mvnormkldivergence
diagnormkldivergence
normkldivergence
sqmahal
logdpf
devmode
4 changes: 2 additions & 2 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.6'
- '1.8'
- '1'
os:
- ubuntu-latest
Expand Down Expand Up @@ -157,7 +157,7 @@ jobs:
# - run: python -m pip install --user matplotlib
# - uses: julia-actions/setup-julia@v1
# with:
# version: '1.6'
# version: '1.8'
# - name: Build homepage
# run: |
# cd docs/homepage
Expand Down
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Johanni Brea <[email protected]>", "Jun Tian <tianjun.c
version = "0.11.0"

[deps]
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ReinforcementLearningBase = "e575027e-6cd6-5018-9292-cdc6200d2b44"
ReinforcementLearningCore = "de1b191a-4ae0-4afa-a27b-92d07f46b2d6"
Expand Down
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ makedocs(
"Which algorithm should I use?" => "Which_algorithm_should_I_use.md",
"Episodic vs. Non-episodic environments" => "non_episodic.md",
],
"Zoo Algorithms" => [
"MPO" => "src/Zoo Algorithms/MPO.md"
],
"FAQ" => "FAQ.md",
experiments,
"Tips for Developers" => "tips.md",
Expand Down
121 changes: 121 additions & 0 deletions docs/src/Zoo Algorithms/MPO.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
# Maximum a Posterio Policy Optimization

ReinforcementLearningZoo proposes an implementation of the Maximum a Posterio Policy Optimization (MPO) algorithm. This algorithm was initially proposed by [Abdolmaleki et al. (2018)](https://arxiv.org/abs/1806.06920) and is further detailled in a [subsequent paper](https://arxiv.org/abs/1812.02256). This implementation is not identical to that of the paper for several reasons that we will detail later. The purpose of this page is to guide a RLZoo user through the creation of an experiment that uses the MPO algorithm. We will recreate [one of the three experiments](../../../src/ReinforcementLearningExperiments/deps/experiments/experiments/Policy%20Gradient/JuliaRL_MPO_CartPole.jl) available in RLExperiments.jl.

The implementation of MPO is declined in three forms (one for each cartpole experiments):

- With a Categorical Actor (for discrete action spaces)
- With a Diagonal Gaussian (the standard actor for continuous action spaces in RL)
- With a Full Gaussian (which can learn a covariance between the different action dimensions)

The latter is the approach used in the paper for continuous actions. It is implemented but is very slow on a GPU at the moment. Although more expressive, it may not be worth the extra computation time.

## Learning a continuous Cartpole policy
First, we instantiate the environment from the package `ReinforcementLearningEnvironments`. We wrap it into an `ActionTransformedEnv` with a `tanh` to constrain the action in [-1, 1].

```julia
using ReinforcementLearning, Flux

env = ActionTransformedEnv(CartPoleEnv(continuous = true), action_mapping = x->tanh(only(x)))
```

Because we want our experiment to be reproducible, we also use a seed.

```julia
using Random
Random.set_global_seed!(123)
```

Then we instantiate a `MPOPolicy`
```julia
policy = MPOPolicy(
actor = Approximator(GaussianNetwork(
Chain(Dense(4, 64, tanh), Dense(64,64,tanh)),
Dense(64, 1),
Dense(64, 1)), ADAM(3f-4)),
qnetwork1 = Approximator(Chain(Dense(5, 64, gelu), Dense(64,64,gelu), Dense(64,1)), ADAM(3f-4)),
qnetwork2 = Approximator(Chain(Dense(5, 64, gelu), Dense(64,64,gelu), Dense(64,1)), ADAM(3f-4)),
action_sample_size = 32,
ϵμ = 0.1f0,
ϵΣ = 1f-2,
ϵ = 0.1f0)
```
`MPOPolicy` needs an Actor that is an `Approximator`, we use a Deep Neural Network and the `Adam` Optimiser from the `Flux.jl` package. Notice that the NN is a `GaussianNetwork` made of three parts. The first is a common body with an input size equal to the length of the state of the environment (4 in this case). Then we have two "heads", one for the mean of the Gaussian policy, and one for the standard deviation. Both heads must have the same output size (the size of the action vectors, 1 in this case) with a `GaussianNetwork` and no activation at the output layers. In

Then we have `qnetwork1` and 2. This implementation of MPO uses twin QNetworks with targets. Both must be `Approximator`s, but must not necessarily have the same architecture. The input size should be the size of the state + the size of the action (5). The output size must be 1. The original MPO paper uses the Retrace algorithm instead of 1-step TD to train the critics. This currently not implemented in RL.jl.

`MPOPolicy` has several keyword arguments in its constructor. We omit the least important ones here (that are not specific to MPO). You can see them using `?MPOPolicy` in the REPL.

- `action_sample_size` is the number of actions sampled for each state during the E-step of the algorithm ($K$ in the second paper).
- `ϵ` is the maximum KL divergence between the E-step variational distribution and the current policy.
- `ϵμ` is the maximum KL divergence between the updated policy at the M-step and the current policy, with respect to the mean of the Gaussian.
- `ϵΣ` is the maximum KL divergence between the updated policy at the M-step and the current policy, with respect to the standard deviation of the Gaussian. It should typically be lower than `ϵμ` to ensure it does not shrink to 0 before the mean settles around its optimum.
- `α_scale = 1f0` and `αΣ_scale = 100f0`, are the gradient descent learning rate for the lagrange penalty for the mean and covariance. We leave it to the default values here.

The next step is to wrap this policy into an `Agent`. An agent is a combination of a policy and a `Trajectory`. We will use the following trajectory.

```julia
trajectory = Trajectory(
CircularArraySARTTraces(capacity = 1000, state = Float32 => (4,),action = Float32 => (1,)),
MetaSampler(
actor = MultiBatchSampler(BatchSampler{(:state,)}(32), 10),
critic = MultiBatchSampler(BatchSampler{SS′ART}(32), 1000)
),
InsertSampleRatioController(ratio = 1/1000, threshold = 1000)
)
```

MPO needs to store `SART` Traces, i.e. State-Action-Reward-Terminal-NextState. Here we use a fixed sized buffer with a capacity of 1000 steps. Then we specify the `Sampler`. MPO needs a specific type of sampler called a `MetaSampler`. A MetaSampler contains several named samplers, here one named `:actor` and the other `critic`. As you may have guessed, one samples to update the actor and the other for the critic (the QNetworks). You must use these exact names. Each Sampler must be a `MultiBatchSampler`, that will sample multiple batch to update the networks for several iterations. Here we update the critic 1000 times but only 10 times the policy. The actor sampler must sample only `(:state,)` traces, it does not need any other trace, the critic needs the `SS′ART` traces to perform the 1-step TD update on the `qnetwork`s. Here we sample batches of 32 transitions, of course this is a hyperparameter that you can tune to your liking.
Finally, we decide on the `InsertSampleRatioController`. We decide to start sampling to update the networks once we have inserted `threshold = 1000` transitions in the buffer (that is, when the buffer is full). You can chose another value but it does not make sense to pick one that is larger than the capacity of the buffer. Ratio defines how many steps are to be done between each sample call. In this case, we do 1000 steps to collect data before sampling and updating the networks.

To summarize, with this setup, the algorithm will perform the following:
1. Interact 1000 times with the environment to fill the buffer.
2. Sample 1000 batches of 32 state-action-reward-terminal-next_state.
3. Update each qnetworks 500 times, once with each batch.
4. Sample 10 batches of 32 states.
5. Update the actor 10 times.
6. Perform 1000 new steps with the new policy and replace the old ones in the buffer.
7. Unless the stopping criterion is true, go back to 2.

We can now create the agent, and run the experiment for 50,000 steps:
```julia
agent = Agent(policy = policy, trajectory = trajectory)
stop_condition = StopAfterStep(50_000, is_show_progress=true)
hook = TotalRewardPerEpisode()
run(agent, env, stop_condition, hook)
```

This should take a couple of minutes on a recent CPU. You can plot the result, for example with UnicodePlots:
```julia
using UnicodePlots
lineplot(hook.episodes, hook.mean_rewards, xlabel="episode", ylabel="mean episode reward", title = "Cartpole Continuous Action Space")
```

### Learning on a GPU

If you have a CUDA compatible GPU, you can accelerate your experiments by transfering the neural networks on the card. `MPOPolicy` comes with a method for the `gpu` function from the `Flux` package.

```julia
using CUDA

policy = gpu(policy) #Recreate a new policy if you already trained it.
agent = Agent(policy = policy, trajectory = trajectory)
stop_condition = StopAfterStep(50_000, is_show_progress=true)
hook = TotalRewardPerEpisode()
run(agent, env, stop_condition, hook) #Using the GPU is slower in this case because the NN and the batch size are small.
```

## Learning a discrete Cartpole policy

To use MPO with a discrete action space only requires simple changes.
1. Instantiate the environment with `continuous = false`
2. Instead of using a `GaussianNetwork`, you should use the `CategoricalNetwork`.
3. The action is now a one-hot vector of length two, because the action_size is 2.

## How to use the CovGaussianNetwork

`CovGaussianNetowrk` allows the approximation of a policy with a correlation between action dimensions, unlike the `GaussianNetwork` that only models a standard deviation for each dimension independently. In practice, this only requires two changes to the above example with `GaussianNetwork`:
1. Use a `CovGaussianNetowrk` instead of a `GaussianNetwork`.
2. The output size of the second head ($\Sigma$) should not be the action size ($|A|$), but $\frac{|A|*(|A|+1)}{2}$. For the Cartpole environment, the remains 1 since the action is of length 1.


5 changes: 3 additions & 2 deletions src/ReinforcementLearningCore/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Parsers = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -26,18 +27,18 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
[compat]
AbstractTrees = "0.3, 0.4"
Adapt = "3"
Crayons = "4"
CUDA = "3.5"
ChainRulesCore = "1"
CircularArrayBuffers = "0.1"
Crayons = "4"
Distributions = "0.25"
FillArrays = "0.8, 0.9, 0.10, 0.11, 0.12, 0.13"
Flux = "0.13"
Functors = "0.1, 0.2, 0.3"
Parsers = "2"
ProgressMeter = "1.2"
ReinforcementLearningBase = "0.10, 0.11"
ReinforcementLearningTrajectories = "0.1.5"
ReinforcementLearningTrajectories = "0.1.8"
StatsBase = "0.32, 0.33"
UnicodePlots = "1.3, 2, 3"
julia = "1.6"
Expand Down
2 changes: 1 addition & 1 deletion src/ReinforcementLearningCore/src/policies/learners.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ Base.show(io::IO, m::MIME"text/plain", A::Approximator) = show(io, m, convert(An

@functor Approximator (model,)

(A::Approximator)(args...) = A.model(args...)
(A::Approximator)(args...; kwargs...) = A.model(args...; kwargs...)

RLBase.optimise!(A::Approximator, gs) =
Flux.Optimise.update!(A.optimiser, Flux.params(A), gs)
84 changes: 82 additions & 2 deletions src/ReinforcementLearningCore/src/utils/distributions.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
export normlogpdf, mvnormlogpdf
export normlogpdf, mvnormlogpdf, diagnormlogpdf, mvnormkldivergence, diagnormkldivergence, normkldivergence

using Flux: unsqueeze, stack
using LinearAlgebra

# watch https://github.com/JuliaStats/Distributions.jl/issues/1183
const log2π = log(2.0f0π)

"""
normlogpdf(μ, σ, x; ϵ = 1.0f-8)

GPU automatic differentiable version for the logpdf function of normal distributions.
GPU automatic differentiable version for the logpdf function of a univariate normal distribution.
Adding an epsilon value to guarantee numeric stability if sigma is exactly zero
(e.g. if relu is used in output layer).
"""
Expand All @@ -17,6 +18,24 @@ function normlogpdf(μ, σ, x; ϵ=1.0f-8)
-(z .^ 2 .+ log2π) / 2.0f0 .- log.(σ .+ ϵ)
end

"""
diagnormlogpdf(μ, σ, x; ϵ = 1.0f-8)

GPU automatic differentiable version for the logpdf function of normal distributions with
diagonal covariance. Adding an epsilon value to guarantee numeric stability if sigma is
exactly zero (e.g. if relu is used in output layer).
"""
function diagnormlogpdf(μ, σ, x; ϵ = 1.0f-8)
v = (σ .+ ϵ) .^2
-0.5f0*(log(prod(v)) .+ inv.(v)'*((x .- μ).^2) .+ length(μ)*log2π)
end

#3D tensor version
function diagnormlogpdf(μ::AbstractArray{<:Any,3}, σ::AbstractArray{<:Any,3}, x::AbstractArray{<:Any,3}; ϵ = 1.0f-8)
logp = [diagnormlogpdf(μ[:, :, k], σ[:, :, k], x[:, :, k]) for k in 1:size(x, 3)]
return reduce((x,y)->cat(x,y,dims=3), logp) #returns a 3D vector
end

"""
mvnormlogpdf(μ::AbstractVecOrMat, L::AbstractMatrix, x::AbstractVecOrMat)

Expand Down Expand Up @@ -47,3 +66,64 @@ function mvnormlogpdf(μ::A, LorU::A, x::A; ϵ=1.0f-8) where {A<:AbstractArray}
logp = [mvnormlogpdf(μ[:, :, k], LorU[:, :, k], x[:, :, k]) for k in 1:size(x, 3)]
return unsqueeze(stack(logp, 2), dims=1) #returns a 3D vector
end

#Used for mvnormlogpdf
"""
logdetLorU(LorU::AbstractMatrix)
Log-determinant of the Positive-Semi-Definite matrix A = L*U (cholesky lower and upper triangulars), given L or U.
Has a sign uncertainty for non PSD matrices.
"""
function logdetLorU(LorU::CuArray)
return 2*sum(log.(diag(LorU)))
end

#Cpu fallback
logdetLorU(LorU::AbstractMatrix) = logdet(LorU)*2

"""
mvnormkldivergence(μ1, L1, μ2, L2)

GPU differentiable implementation of the kl_divergence between two MultiVariate Gaussian distributions with mean vectors `μ1, μ2` respectively and
with cholesky decomposition of covariance matrices `L1, L2`.
"""
function mvnormkldivergence(μ1, L1M, μ2, L2M)
L1 = LowerTriangular(L1M)
L2 = LowerTriangular(L2M)
U1 = UpperTriangular(permutedims(L1M))
U2 = UpperTriangular(permutedims(L2M))
d = size(μ1,1)
logdet = logdetLorU(L2M) - logdetLorU(L1M)
M1 = L1*U1
L2i = inv(L2)
U2i = inv(U2)
M2i = U2i*L2i
X = M2i*M1
trace = tr(X) # trace of inv(Σ2) * Σ1
sqmahal = sum(abs2.(L2i*(μ2 .- μ1))) #mahalanobis square distance
return (logdet - d + trace + sqmahal)/2
end

"""
diagnormkldivergence(μ1, σ1, μ2, σ2)

GPU differentiable implementation of the kl_divergence between two MultiVariate Gaussian distributions with mean vectors `μ1, μ2` respectively and
diagonal standard deviations `σ1, σ2`. Arguments must be Vectors or single-column Matrices.
"""
function diagnormkldivergence(μ1, σ1, μ2, σ2)
v1, v2 = σ1.^2, σ2.^2
d = size(μ1,1)
logdet = sum(log.(v2)) - sum(log.(v1))
trace = sum(v1 ./ v2)
sqmahal = sum((μ2 .- μ1) .^2 ./ v2)
return (logdet - d + trace + sqmahal)/2
end

"""
normkldivergence(μ1, σ1, μ2, σ2)

GPU differentiable implementation of the kl_divergence between two univariate Gaussian
distributions with means `μ1, μ2` and standard deviations `σ1, σ2` respectively.
"""
function normkldivergence(μ1, σ1, μ2, σ2)
log(σ2) - log(σ1) + (σ1^2 + (μ1 - μ2)^2)/(2σ2^2) - typeof(μ1)(0.5)
end
Loading