Skip to content

Commit 2fc1f78

Browse files
Removes hardcoded policy obs group for symmetry (#111)
1 parent 830fa98 commit 2fc1f78

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "rsl-rl-lib"
7-
version = "3.0.0"
7+
version = "3.0.1"
88
keywords = ["reinforcement-learning", "isaac", "leggedrobotics", "rl-pytorch"]
99
maintainers = [
1010
{ name="Clemens Schwarke", email="[email protected]" },

rsl_rl/algorithms/ppo.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,7 @@ def update(self): # noqa: C901
215215
num_aug = 1
216216
# original batch size
217217
# we assume policy group is always there and needs augmentation
218-
original_batch_size = obs_batch["policy"].shape[0]
218+
original_batch_size = obs_batch.batch_size[0]
219219

220220
# check if we should normalize advantages per mini batch
221221
if self.normalize_advantage_per_mini_batch:
@@ -227,14 +227,14 @@ def update(self): # noqa: C901
227227
# augmentation using symmetry
228228
data_augmentation_func = self.symmetry["data_augmentation_func"]
229229
# returned shape: [batch_size * num_aug, ...]
230-
obs_batch, actions_batch = data_augmentation_func( # TODO: needs changes on the isaac lab side
230+
obs_batch, actions_batch = data_augmentation_func(
231231
obs=obs_batch,
232232
actions=actions_batch,
233233
env=self.symmetry["_env"],
234234
)
235235
# compute number of augmentations per sample
236236
# we assume policy group is always there and needs augmentation
237-
num_aug = int(obs_batch["policy"].shape[0] / original_batch_size)
237+
num_aug = int(obs_batch.batch_size[0] / original_batch_size)
238238
# repeat the rest of the batch
239239
# -- actor
240240
old_actions_log_prob_batch = old_actions_log_prob_batch.repeat(num_aug, 1)

0 commit comments

Comments
 (0)