-
Notifications
You must be signed in to change notification settings - Fork 388
Open
Description
The teacher normalizer is initialized with rnn_hidden_dim
, which I believe is incorrect according to evaluate()
rsl_rl/rsl_rl/modules/student_teacher_recurrent.py
Lines 80 to 93 in cf71aa6
# teacher | |
if self.teacher_recurrent: | |
self.memory_t = Memory( | |
num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim | |
) | |
num_teacher_obs = rnn_hidden_dim | |
self.teacher = MLP(num_teacher_obs, num_actions, teacher_hidden_dims, activation) | |
# teacher observation normalization | |
self.teacher_obs_normalization = teacher_obs_normalization | |
if teacher_obs_normalization: | |
self.teacher_obs_normalizer = EmpiricalNormalization(num_teacher_obs) | |
else: | |
self.teacher_obs_normalizer = torch.nn.Identity() |
rsl_rl/rsl_rl/modules/student_teacher_recurrent.py
Lines 161 to 168 in cf71aa6
def evaluate(self, obs): | |
obs = self.get_teacher_obs(obs) | |
obs = self.teacher_obs_normalizer(obs) | |
with torch.no_grad(): | |
if self.teacher_recurrent: | |
self.memory_t.eval() | |
obs = self.memory_t(obs).squeeze(0) | |
return self.teacher(obs) |
Metadata
Metadata
Assignees
Labels
No labels