diff --git a/rsl_rl/modules/student_teacher_recurrent.py b/rsl_rl/modules/student_teacher_recurrent.py index 964a2dcd..bba4bd34 100644 --- a/rsl_rl/modules/student_teacher_recurrent.py +++ b/rsl_rl/modules/student_teacher_recurrent.py @@ -82,8 +82,7 @@ def __init__( self.memory_t = Memory( num_teacher_obs, type=rnn_type, num_layers=rnn_num_layers, hidden_size=rnn_hidden_dim ) - num_teacher_obs = rnn_hidden_dim - self.teacher = MLP(num_teacher_obs, num_actions, teacher_hidden_dims, activation) + self.teacher = MLP(rnn_hidden_dim, num_actions, teacher_hidden_dims, activation) # teacher observation normalization self.teacher_obs_normalization = teacher_obs_normalization