@@ -38,6 +38,7 @@ def __init__(
38
38
schedule = "fixed" ,
39
39
desired_kl = 0.01 ,
40
40
device = "cpu" ,
41
+ normalize_advantage_per_mini_batch = False ,
41
42
# RND parameters
42
43
rnd_cfg : dict | None = None ,
43
44
# Symmetry parameters
@@ -48,6 +49,7 @@ def __init__(
48
49
self .desired_kl = desired_kl
49
50
self .schedule = schedule
50
51
self .learning_rate = learning_rate
52
+ self .normalize_advantage_per_mini_batch = normalize_advantage_per_mini_batch
51
53
52
54
# RND components
53
55
if rnd_cfg is not None :
@@ -84,8 +86,10 @@ def __init__(
84
86
# PPO components
85
87
self .actor_critic = actor_critic
86
88
self .actor_critic .to (self .device )
87
- self . storage = None # initialized later
89
+ # Create optimizer
88
90
self .optimizer = optim .Adam (self .actor_critic .parameters (), lr = learning_rate )
91
+ # Create rollout storage
92
+ self .storage : RolloutStorage = None # type: ignore
89
93
self .transition = RolloutStorage .Transition ()
90
94
91
95
# PPO parameters
@@ -168,7 +172,9 @@ def process_env_step(self, rewards, dones, infos):
168
172
def compute_returns (self , last_critic_obs ):
169
173
# compute value for the last step
170
174
last_values = self .actor_critic .evaluate (last_critic_obs ).detach ()
171
- self .storage .compute_returns (last_values , self .gamma , self .lam )
175
+ self .storage .compute_returns (
176
+ last_values , self .gamma , self .lam , normalize_advantage = not self .normalize_advantage_per_mini_batch
177
+ )
172
178
173
179
def update (self ): # noqa: C901
174
180
mean_value_loss = 0
@@ -213,6 +219,11 @@ def update(self): # noqa: C901
213
219
# original batch size
214
220
original_batch_size = obs_batch .shape [0 ]
215
221
222
+ # check if we should normalize advantages per mini batch
223
+ if self .normalize_advantage_per_mini_batch :
224
+ with torch .no_grad ():
225
+ advantages_batch = (advantages_batch - advantages_batch .mean ()) / (advantages_batch .std () + 1e-8 )
226
+
216
227
# Perform symmetric augmentation
217
228
if self .symmetry and self .symmetry ["use_data_augmentation" ]:
218
229
# augmentation using symmetry
0 commit comments