Skip to content

Commit e53134b

Browse files
authored
Adds support for log-std parameter in ActorCritic (#67)
* Fixes gradient propogation through std-dev * adds support for log std * renames to noise_std_type * fixes value error * adds back option for noise std type
1 parent 1e61178 commit e53134b

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

config/dummy_config.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ policy:
7171
actor_hidden_dims: [128, 128, 128]
7272
critic_hidden_dims: [128, 128, 128]
7373
init_noise_std: 1.0
74+
noise_std_type: "scalar" # 'scalar' or 'log'
7475
# only needed for `ActorCriticRecurrent`
7576
# rnn_type: 'lstm'
7677
# rnn_hidden_size: 512

rsl_rl/modules/actor_critic.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def __init__(
2424
critic_hidden_dims=[256, 256, 256],
2525
activation="elu",
2626
init_noise_std=1.0,
27+
noise_std_type: str = "scalar",
2728
**kwargs,
2829
):
2930
if kwargs:
@@ -64,7 +65,15 @@ def __init__(
6465
print(f"Critic MLP: {self.critic}")
6566

6667
# Action noise
67-
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
68+
self.noise_std_type = noise_std_type
69+
if self.noise_std_type == "scalar":
70+
self.std = nn.Parameter(init_noise_std * torch.ones(num_actions))
71+
elif self.noise_std_type == "log":
72+
self.log_std = nn.Parameter(torch.log(init_noise_std * torch.ones(num_actions)))
73+
else:
74+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
75+
76+
# Action distribution (populated in update_distribution)
6877
self.distribution = None
6978
# disable args validation for speedup
7079
Normal.set_default_validate_args(False)
@@ -100,8 +109,16 @@ def entropy(self):
100109
return self.distribution.entropy().sum(dim=-1)
101110

102111
def update_distribution(self, observations):
112+
# compute mean
103113
mean = self.actor(observations)
104-
std = self.std.expand_as(mean)
114+
# compute standard deviation
115+
if self.noise_std_type == "scalar":
116+
std = self.std.expand_as(mean)
117+
elif self.noise_std_type == "log":
118+
std = torch.exp(self.log_std).expand_as(mean)
119+
else:
120+
raise ValueError(f"Unknown standard deviation type: {self.noise_std_type}. Should be 'scalar' or 'log'")
121+
# create distribution
105122
self.distribution = Normal(mean, std)
106123

107124
def act(self, observations, **kwargs):

rsl_rl/runners/on_policy_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def log(self, locs: dict, width: int = 80, pad: int = 35):
261261
else:
262262
self.writer.add_scalar("Episode/" + key, value, locs["it"])
263263
ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n"""
264-
mean_std = self.alg.actor_critic.std.mean()
264+
mean_std = self.alg.actor_critic.action_std.mean()
265265
fps = int(self.num_steps_per_env * self.env.num_envs / (locs["collection_time"] + locs["learn_time"]))
266266

267267
# -- Losses

0 commit comments

Comments
 (0)