Skip to content

Commit 1e61178

Browse files
authored
Adds flag for per-batch advantage normalization (#68)
* Adds flag for per-batch advantage normalization * make default false for backwards
1 parent a2b498c commit 1e61178

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

config/dummy_config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
algorithm:
22
class_name: PPO
33
# training parameters
4+
# -- advantage normalization
5+
normalize_advantage_per_mini_batch: false
46
# -- value function
57
value_loss_coef: 1.0
68
clip_param: 0.2

rsl_rl/algorithms/ppo.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
schedule="fixed",
3939
desired_kl=0.01,
4040
device="cpu",
41+
normalize_advantage_per_mini_batch=False,
4142
# RND parameters
4243
rnd_cfg: dict | None = None,
4344
# Symmetry parameters
@@ -48,6 +49,7 @@ def __init__(
4849
self.desired_kl = desired_kl
4950
self.schedule = schedule
5051
self.learning_rate = learning_rate
52+
self.normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
5153

5254
# RND components
5355
if rnd_cfg is not None:
@@ -84,8 +86,10 @@ def __init__(
8486
# PPO components
8587
self.actor_critic = actor_critic
8688
self.actor_critic.to(self.device)
87-
self.storage = None # initialized later
89+
# Create optimizer
8890
self.optimizer = optim.Adam(self.actor_critic.parameters(), lr=learning_rate)
91+
# Create rollout storage
92+
self.storage: RolloutStorage = None # type: ignore
8993
self.transition = RolloutStorage.Transition()
9094

9195
# PPO parameters
@@ -168,7 +172,9 @@ def process_env_step(self, rewards, dones, infos):
168172
def compute_returns(self, last_critic_obs):
169173
# compute value for the last step
170174
last_values = self.actor_critic.evaluate(last_critic_obs).detach()
171-
self.storage.compute_returns(last_values, self.gamma, self.lam)
175+
self.storage.compute_returns(
176+
last_values, self.gamma, self.lam, normalize_advantage=not self.normalize_advantage_per_mini_batch
177+
)
172178

173179
def update(self): # noqa: C901
174180
mean_value_loss = 0
@@ -213,6 +219,11 @@ def update(self): # noqa: C901
213219
# original batch size
214220
original_batch_size = obs_batch.shape[0]
215221

222+
# check if we should normalize advantages per mini batch
223+
if self.normalize_advantage_per_mini_batch:
224+
with torch.no_grad():
225+
advantages_batch = (advantages_batch - advantages_batch.mean()) / (advantages_batch.std() + 1e-8)
226+
216227
# Perform symmetric augmentation
217228
if self.symmetry and self.symmetry["use_data_augmentation"]:
218229
# augmentation using symmetry

rsl_rl/storage/rollout_storage.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def _save_hidden_states(self, hidden_states):
129129
def clear(self):
130130
self.step = 0
131131

132-
def compute_returns(self, last_values, gamma, lam):
132+
def compute_returns(self, last_values, gamma, lam, normalize_advantage: bool = True):
133133
advantage = 0
134134
for step in reversed(range(self.num_transitions_per_env)):
135135
# if we are at the last step, bootstrap the return value
@@ -146,9 +146,12 @@ def compute_returns(self, last_values, gamma, lam):
146146
# Return: R_t = A(s_t, a_t) + V(s_t)
147147
self.returns[step] = advantage + self.values[step]
148148

149-
# Compute and normalize the advantages
149+
# Compute the advantages
150150
self.advantages = self.returns - self.values
151-
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
151+
# Normalize the advantages if flag is set
152+
# This is to prevent double normalization (i.e. if per minibatch normalization is used)
153+
if normalize_advantage:
154+
self.advantages = (self.advantages - self.advantages.mean()) / (self.advantages.std() + 1e-8)
152155

153156
def get_statistics(self):
154157
done = self.dones

0 commit comments

Comments
 (0)