diff --git a/brax/envs/__init__.py b/brax/envs/__init__.py index f3cd0b5c2..baa53e4c8 100644 --- a/brax/envs/__init__.py +++ b/brax/envs/__init__.py @@ -22,6 +22,7 @@ from brax.envs import acrobot from brax.envs import ant from brax.envs import fast +from brax.envs import fast_differentiable from brax.envs import fetch from brax.envs import grasp from brax.envs import half_cheetah @@ -45,6 +46,7 @@ 'acrobot': acrobot.Acrobot, 'ant': functools.partial(ant.Ant, use_contact_forces=True), 'fast': fast.Fast, + 'fast_differentiable': fast_differentiable.FastDifferentiable, 'fetch': fetch.Fetch, 'grasp': grasp.Grasp, 'halfcheetah': half_cheetah.Halfcheetah, diff --git a/brax/envs/fast_differentiable.py b/brax/envs/fast_differentiable.py new file mode 100644 index 000000000..8d6969f25 --- /dev/null +++ b/brax/envs/fast_differentiable.py @@ -0,0 +1,53 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Gotta go fast! This trivial Env is meant for unit testing.""" + +import brax +from brax.envs import env +import jax.numpy as jnp + + +class FastDifferentiable(env.Env): + """Trains an agent to go fast.""" + + def __init__(self, **kwargs): + super().__init__(config='dt: .02', **kwargs) + + def reset(self, rng: jnp.ndarray) -> env.State: + zero = jnp.zeros(1) + qp = brax.QP(pos=zero, vel=zero, rot=zero, ang=zero) + obs = jnp.zeros(2) + reward, done = jnp.zeros(2) + return env.State(qp, obs, reward, done) + + def step(self, state: env.State, action: jnp.ndarray) -> env.State: + vel = state.qp.vel + action * (action > 0) * self.sys.config.dt + pos = state.qp.pos + vel * self.sys.config.dt + + qp = state.qp.replace(pos=pos, vel=vel) + obs = jnp.array([pos[0], vel[0]]) + reward = pos[0] + #reward = 1.0 + #reward = action[0] + + return state.replace(qp=qp, obs=obs, reward=reward) + + @property + def observation_size(self): + return 2 + + @property + def action_size(self): + return 1 diff --git a/brax/training/agents/apg/train_test.py b/brax/training/agents/apg/train_test.py index 94d88c16f..28992c03e 100644 --- a/brax/training/agents/apg/train_test.py +++ b/brax/training/agents/apg/train_test.py @@ -30,15 +30,14 @@ class APGTest(parameterized.TestCase): def testTrain(self): """Test APG with a simple env.""" _, _, metrics = apg.train( - envs.get_environment('fast'), + envs.get_environment('fast_differentiable'), episode_length=128, num_envs=64, num_evals=200, learning_rate=3e-3, normalize_observations=True, ) - # TODO: Can you make this 135? - self.assertGreater(metrics['eval/episode_reward'], 50) + self.assertGreater(metrics['eval/episode_reward'], 135) @parameterized.parameters(True, False) def testNetworkEncoding(self, normalize_observations): diff --git a/brax/training/agents/ppo/networks.py b/brax/training/agents/ppo/networks.py index 4631cb4a0..084dee301 100644 --- a/brax/training/agents/ppo/networks.py +++ b/brax/training/agents/ppo/networks.py @@ -66,7 +66,8 @@ def make_ppo_networks( .identity_observation_preprocessor, policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, value_hidden_layer_sizes: Sequence[int] = (256,) * 5, - activation: networks.ActivationFn = linen.swish) -> PPONetworks: + activation: networks.ActivationFn = linen.swish, + layer_norm: bool = False) -> PPONetworks: """Make PPO networks with preprocessor.""" parametric_action_distribution = distribution.NormalTanhDistribution( event_size=action_size) @@ -75,12 +76,14 @@ def make_ppo_networks( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=policy_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) value_network = networks.make_value_network( observation_size, preprocess_observations_fn=preprocess_observations_fn, hidden_layer_sizes=value_hidden_layer_sizes, - activation=activation) + activation=activation, + layer_norm=layer_norm) return PPONetworks( policy_network=policy_network, diff --git a/brax/training/agents/shac/__init__.py b/brax/training/agents/shac/__init__.py new file mode 100644 index 000000000..6d7c8bbba --- /dev/null +++ b/brax/training/agents/shac/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/brax/training/agents/shac/losses.py b/brax/training/agents/shac/losses.py new file mode 100644 index 000000000..692536c3d --- /dev/null +++ b/brax/training/agents/shac/losses.py @@ -0,0 +1,210 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Short-Horizon Actor Critic. + +See: https://arxiv.org/pdf/2204.07137.pdf +""" + +from typing import Any, Tuple + +from brax.training import types +from brax.training.agents.shac import networks as shac_networks +from brax.training.types import Params +import flax +import jax +import jax.numpy as jnp + + +@flax.struct.dataclass +class SHACNetworkParams: + """Contains training state for the learner.""" + policy: Params + value: Params + + +def compute_shac_policy_loss( + policy_params: Params, + value_params: Params, + normalizer_params: Any, + data: types.Transition, + rng: jnp.ndarray, + shac_network: shac_networks.SHACNetworks, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + reward_scaling: float = 1.0) -> Tuple[jnp.ndarray, types.Metrics]: + """Computes SHAC critic loss. + + This implements Eq. 5 of 2204.07137. + + Args: + policy_params: Policy network parameters + value_params: Value network parameters, + normalizer_params: Parameters of the normalizer. + data: Transition that with leading dimension [B, T]. extra fields required + are ['state_extras']['truncation'] ['policy_extras']['raw_action'] + ['policy_extras']['log_prob'] + rng: Random key + shac_network: SHAC networks. + entropy_cost: entropy cost. + discounting: discounting, + reward_scaling: reward multiplier. + + Returns: + A scalar loss + """ + + parametric_action_distribution = shac_network.parametric_action_distribution + policy_apply = shac_network.policy_network.apply + value_apply = shac_network.value_network.apply + + # Put the time dimension first. + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) + + # this is a redundant computation with the critic loss function + # but there isn't a straighforward way to get these values when + # they are used in that step + values = value_apply(normalizer_params, value_params, data.observation) + terminal_values = value_apply(normalizer_params, value_params, data.next_observation[-1]) + + rewards = data.reward * reward_scaling + truncation = data.extras['state_extras']['truncation'] + termination = (1 - data.discount) * (1 - truncation) + + # Append terminal values to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(terminal_values, 0)], axis=0) + + # jax implementation of https://github.com/NVlabs/DiffRL/blob/a4c0dd1696d3c3b885ce85a3cb64370b580cb913/algorithms/shac.py#L227 + def sum_step(carry, target_t): + gam, rew_acc = carry + reward, termination = target_t + + # clean up gamma and rew_acc for done envs, otherwise update + rew_acc = jnp.where(termination, 0, rew_acc + gam * reward) + gam = jnp.where(termination, 1.0, gam * discounting) + + return (gam, rew_acc), (gam, rew_acc) + + rew_acc = jnp.zeros_like(terminal_values) + gam = jnp.ones_like(terminal_values) + (gam, last_rew_acc), (gam_acc, rew_acc) = jax.lax.scan(sum_step, (gam, rew_acc), + (rewards, termination)) + + policy_loss = jnp.sum(-last_rew_acc - gam * terminal_values) + # for trials that are truncated (i.e. hit the episode length) include reward for + # terminal state. otherwise, the trial was aborted and should receive zero additional + policy_loss = policy_loss + jnp.sum((-rew_acc - gam_acc * jnp.where(truncation, values_t_plus_1, 0)) * termination) + policy_loss = policy_loss / values.shape[0] / values.shape[1] + + + # Entropy reward + policy_logits = policy_apply(normalizer_params, policy_params, + data.observation) + entropy = jnp.mean(parametric_action_distribution.entropy(policy_logits, rng)) + entropy_loss = entropy_cost * -entropy + + total_loss = policy_loss + entropy_loss + + return total_loss, { + 'policy_loss': policy_loss, + 'entropy_loss': entropy_loss + } + + +def compute_shac_critic_loss( + params: Params, + normalizer_params: Any, + data: types.Transition, + shac_network: shac_networks.SHACNetworks, + discounting: float = 0.9, + reward_scaling: float = 1.0, + lambda_: float = 0.95, + td_lambda: bool = True) -> Tuple[jnp.ndarray, types.Metrics]: + """Computes SHAC critic loss. + + This implements Eq. 7 of 2204.07137 + https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py#L349 + + Args: + params: Value network parameters, + normalizer_params: Parameters of the normalizer. + data: Transition that with leading dimension [B, T]. extra fields required + are ['state_extras']['truncation'] ['policy_extras']['raw_action'] + ['policy_extras']['log_prob'] + rng: Random key + shac_network: SHAC networks. + entropy_cost: entropy cost. + discounting: discounting, + reward_scaling: reward multiplier. + lambda_: Lambda for TD value updates + td_lambda: whether to use a TD-Lambda value target + + Returns: + A tuple (loss, metrics) + """ + + value_apply = shac_network.value_network.apply + + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 0, 1), data) + + values = value_apply(normalizer_params, params, data.observation) + terminal_value = value_apply(normalizer_params, params, data.next_observation[-1]) + + rewards = data.reward * reward_scaling + truncation = data.extras['state_extras']['truncation'] + termination = (1 - data.discount) * (1 - truncation) + + # Append terminal values to get [v1, ..., v_t+1] + values_t_plus_1 = jnp.concatenate( + [values[1:], jnp.expand_dims(terminal_value, 0)], axis=0) + + # compute target values + if td_lambda: + + def compute_v_st(carry, target_t): + Ai, Bi, lam = carry + reward, vtp1, termination = target_t + + reward = reward * termination + + lam = lam * lambda_ * (1 - termination) + termination + Ai = (1 - termination) * (lam * discounting * Ai + discounting * vtp1 + (1. - lam) / (1. - lambda_) * reward) + Bi = discounting * (vtp1 * termination + Bi * (1.0 - termination)) + reward + vs = (1.0 - lambda_) * Ai + lam * Bi + + return (Ai, Bi, lam), (vs) + + Ai = jnp.ones_like(terminal_value) + Bi = jnp.zeros_like(terminal_value) + lam = jnp.ones_like(terminal_value) + (_, _, _), (vs) = jax.lax.scan(compute_v_st, (Ai, Bi, lam), + (rewards, values_t_plus_1, termination), + length=int(termination.shape[0]), + reverse=True) + + else: + vs = rewards + discounting * values_t_plus_1 + + target_values = jax.lax.stop_gradient(vs) + + v_loss = jnp.mean((target_values - values) ** 2) + + total_loss = v_loss + return total_loss, { + 'total_loss': total_loss, + 'policy_loss': 0, + 'v_loss': v_loss, + 'entropy_loss': 0 + } diff --git a/brax/training/agents/shac/networks.py b/brax/training/agents/shac/networks.py new file mode 100644 index 000000000..bd30d702d --- /dev/null +++ b/brax/training/agents/shac/networks.py @@ -0,0 +1,92 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SHAC networks.""" + +from typing import Sequence, Tuple + +from brax.training import distribution +from brax.training import networks +from brax.training import types +from brax.training.types import PRNGKey +import flax +from flax import linen + + +@flax.struct.dataclass +class SHACNetworks: + policy_network: networks.FeedForwardNetwork + value_network: networks.FeedForwardNetwork + parametric_action_distribution: distribution.ParametricDistribution + + +def make_inference_fn(shac_networks: SHACNetworks): + """Creates params and inference function for the SHAC agent.""" + + def make_policy(params: types.PolicyParams, + deterministic: bool = False) -> types.Policy: + policy_network = shac_networks.policy_network + parametric_action_distribution = shac_networks.parametric_action_distribution + + def policy(observations: types.Observation, + key_sample: PRNGKey) -> Tuple[types.Action, types.Extra]: + logits = policy_network.apply(*params, observations) + if deterministic: + return shac_networks.parametric_action_distribution.mode(logits), {} + raw_actions = parametric_action_distribution.sample_no_postprocessing( + logits, key_sample) + log_prob = parametric_action_distribution.log_prob(logits, raw_actions) + postprocessed_actions = parametric_action_distribution.postprocess( + raw_actions) + return postprocessed_actions, { + 'log_prob': log_prob, + 'raw_action': raw_actions + } + + + return policy + + return make_policy + + +def make_shac_networks( + observation_size: int, + action_size: int, + preprocess_observations_fn: types.PreprocessObservationFn = types + .identity_observation_preprocessor, + policy_hidden_layer_sizes: Sequence[int] = (32,) * 4, + value_hidden_layer_sizes: Sequence[int] = (256,) * 5, + activation: networks.ActivationFn = linen.elu, + layer_norm: bool = True) -> SHACNetworks: + """Make SHAC networks with preprocessor.""" + parametric_action_distribution = distribution.NormalTanhDistribution( + event_size=action_size) + policy_network = networks.make_policy_network( + parametric_action_distribution.param_size, + observation_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=policy_hidden_layer_sizes, + activation=activation, + layer_norm=layer_norm) + value_network = networks.make_value_network( + observation_size, + preprocess_observations_fn=preprocess_observations_fn, + hidden_layer_sizes=value_hidden_layer_sizes, + activation=activation, + layer_norm=layer_norm) + + return SHACNetworks( + policy_network=policy_network, + value_network=value_network, + parametric_action_distribution=parametric_action_distribution) diff --git a/brax/training/agents/shac/train.py b/brax/training/agents/shac/train.py new file mode 100644 index 000000000..a8bfeafd4 --- /dev/null +++ b/brax/training/agents/shac/train.py @@ -0,0 +1,384 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Short-Horizon Actor Critic. + +See: https://arxiv.org/pdf/2204.07137.pdf +and https://github.com/NVlabs/DiffRL/blob/main/algorithms/shac.py +""" + +import functools +import time +from typing import Callable, Optional, Tuple + +from absl import logging +from brax import envs +from brax.envs import wrappers +from brax.training import acting +from brax.training import gradients +from brax.training import pmap +from brax.training import types +from brax.training.acme import running_statistics +from brax.training.acme import specs +from brax.training.agents.shac import losses as shac_losses +from brax.training.agents.shac import networks as shac_networks +from brax.training.types import Params +from brax.training.types import PRNGKey +import flax +import jax +import jax.numpy as jnp +import optax + +InferenceParams = Tuple[running_statistics.NestedMeanStd, Params] +Metrics = types.Metrics + +_PMAP_AXIS_NAME = 'i' + + +@flax.struct.dataclass +class TrainingState: + """Contains training state for the learner.""" + policy_optimizer_state: optax.OptState + policy_params: Params + value_optimizer_state: optax.OptState + value_params: Params + target_value_params: Params + normalizer_params: running_statistics.RunningStatisticsState + env_steps: jnp.ndarray + + +def _unpmap(v): + return jax.tree_util.tree_map(lambda x: x[0], v) + + +def train(environment: envs.Env, + num_timesteps: int, + episode_length: int, + action_repeat: int = 1, + num_envs: int = 1, + max_devices_per_host: Optional[int] = None, + num_eval_envs: int = 128, + actor_learning_rate: float = 1e-3, + critic_learning_rate: float = 1e-4, + entropy_cost: float = 1e-4, + discounting: float = 0.9, + seed: int = 0, + unroll_length: int = 10, + batch_size: int = 32, + num_minibatches: int = 16, + num_updates_per_batch: int = 2, + num_evals: int = 1, + normalize_observations: bool = False, + reward_scaling: float = 1., + tau: float = 0.005, # this is 1-alpha from the original paper + lambda_: float = .95, + td_lambda: bool = True, + deterministic_eval: bool = False, + network_factory: types.NetworkFactory[ + shac_networks.SHACNetworks] = shac_networks.make_shac_networks, + progress_fn: Callable[[int, Metrics], None] = lambda *args: None, + eval_env: Optional[envs.Env] = None): + """SHAC training.""" + assert batch_size * num_minibatches % num_envs == 0 + xt = time.time() + + process_count = jax.process_count() + process_id = jax.process_index() + local_device_count = jax.local_device_count() + local_devices_to_use = local_device_count + if max_devices_per_host: + local_devices_to_use = min(local_devices_to_use, max_devices_per_host) + logging.info( + 'Device count: %d, process count: %d (id %d), local device count: %d, ' + 'devices to be used count: %d', jax.device_count(), process_count, + process_id, local_device_count, local_devices_to_use) + device_count = local_devices_to_use * process_count + + # The number of environment steps executed for every training step. + env_step_per_training_step = ( + batch_size * unroll_length * num_minibatches * action_repeat) + num_evals_after_init = max(num_evals - 1, 1) + # The number of training_step calls per training_epoch call. + # equals to ceil(num_timesteps / (num_evals * env_step_per_training_step)) + num_training_steps_per_epoch = -( + -num_timesteps // (num_evals_after_init * env_step_per_training_step)) + + assert num_envs % device_count == 0 + env = environment + + env = wrappers.wrap_for_training( + env, episode_length=episode_length, action_repeat=action_repeat) + + reset_fn = jax.jit(jax.vmap(env.reset)) + + normalize = lambda x, y: x + if normalize_observations: + normalize = running_statistics.normalize + shac_network = network_factory( + env.observation_size, + env.action_size, + preprocess_observations_fn=normalize) + make_policy = shac_networks.make_inference_fn(shac_network) + + policy_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=actor_learning_rate, b1=0.7, b2=0.95) + ) + value_optimizer = optax.chain( + optax.clip(1.0), + optax.adam(learning_rate=critic_learning_rate, b1=0.7, b2=0.95) + ) + + value_loss_fn = functools.partial( + shac_losses.compute_shac_critic_loss, + shac_network=shac_network, + discounting=discounting, + reward_scaling=reward_scaling, + lambda_=lambda_, + td_lambda=td_lambda) + + value_gradient_update_fn = gradients.gradient_update_fn( + value_loss_fn, value_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + + policy_loss_fn = functools.partial( + shac_losses.compute_shac_policy_loss, + shac_network=shac_network, + entropy_cost=entropy_cost, + discounting=discounting, + reward_scaling=reward_scaling) + + def rollout_loss_fn(policy_params, value_params, normalizer_params, state, key): + policy = make_policy((normalizer_params, policy_params)) + + key, key_loss = jax.random.split(key) + + def f(carry, unused_t): + current_state, current_key = carry + current_key, next_key = jax.random.split(current_key) + next_state, data = acting.generate_unroll( + env, + current_state, + policy, + current_key, + unroll_length, + extra_fields=('truncation',)) + return (next_state, next_key), data + + (state, _), data = jax.lax.scan( + f, (state, key), (), + length=batch_size * num_minibatches // num_envs) + + # Have leading dimentions (batch_size * num_minibatches, unroll_length) + data = jax.tree_util.tree_map(lambda x: jnp.swapaxes(x, 1, 2), data) + data = jax.tree_util.tree_map(lambda x: jnp.reshape(x, (-1,) + x.shape[2:]), + data) + assert data.discount.shape[1:] == (unroll_length,) + + loss, metrics = policy_loss_fn(policy_params, value_params, + normalizer_params, data, key_loss) + + return loss, (state, data, metrics) + + policy_gradient_update_fn = gradients.gradient_update_fn( + rollout_loss_fn, policy_optimizer, pmap_axis_name=_PMAP_AXIS_NAME, has_aux=True) + policy_gradient_update_fn = jax.jit(policy_gradient_update_fn) + + def minibatch_step( + carry, data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + optimizer_state, params, key = carry + key, key_loss = jax.random.split(key) + (_, metrics), params, optimizer_state = value_gradient_update_fn( + params, + normalizer_params, + data, + optimizer_state=optimizer_state) + + return (optimizer_state, params, key), metrics + + def critic_sgd_step(carry, unused_t, data: types.Transition, + normalizer_params: running_statistics.RunningStatisticsState): + optimizer_state, params, key = carry + key, key_perm, key_grad = jax.random.split(key, 3) + + def convert_data(x: jnp.ndarray): + x = jax.random.permutation(key_perm, x) + x = jnp.reshape(x, (num_minibatches, -1) + x.shape[1:]) + return x + + shuffled_data = jax.tree_util.tree_map(convert_data, data) + (optimizer_state, params, _), metrics = jax.lax.scan( + functools.partial(minibatch_step, normalizer_params=normalizer_params), + (optimizer_state, params, key_grad), + shuffled_data, + length=num_minibatches) + return (optimizer_state, params, key), metrics + + def training_step( + carry: Tuple[TrainingState, envs.State, PRNGKey], + unused_t) -> Tuple[Tuple[TrainingState, envs.State, PRNGKey], Metrics]: + training_state, state, key = carry + key_sgd, key_generate_unroll, new_key = jax.random.split(key, 3) + + (policy_loss, (state, data, policy_metrics)), policy_params, policy_optimizer_state = policy_gradient_update_fn( + training_state.policy_params, training_state.target_value_params, + training_state.normalizer_params, state, key_generate_unroll, + optimizer_state=training_state.policy_optimizer_state) + + # Update normalization params and normalize observations. + normalizer_params = running_statistics.update( + training_state.normalizer_params, + data.observation, + pmap_axis_name=_PMAP_AXIS_NAME) + + (value_optimizer_state, value_params, _), metrics = jax.lax.scan( + functools.partial( + critic_sgd_step, data=data, normalizer_params=normalizer_params), + (training_state.value_optimizer_state, training_state.value_params, key_sgd), (), + length=num_updates_per_batch) + + target_value_params = jax.tree_util.tree_map( + lambda x, y: x * (1 - tau) + y * tau, training_state.target_value_params, + value_params) + + metrics.update(policy_metrics) + + new_training_state = TrainingState( + policy_optimizer_state=policy_optimizer_state, + policy_params=policy_params, + value_optimizer_state=value_optimizer_state, + value_params=value_params, + target_value_params=target_value_params, + normalizer_params=training_state.normalizer_params, + env_steps=training_state.env_steps + env_step_per_training_step) + return (new_training_state, state, new_key), metrics + + def training_epoch(training_state: TrainingState, state: envs.State, + key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + (training_state, state, _), loss_metrics = jax.lax.scan( + training_step, (training_state, state, key), (), + length=num_training_steps_per_epoch) + loss_metrics = jax.tree_util.tree_map(jnp.mean, loss_metrics) + return training_state, state, loss_metrics + + training_epoch = jax.pmap(training_epoch, axis_name=_PMAP_AXIS_NAME) + + # Note that this is NOT a pure jittable method. + def training_epoch_with_timing( + training_state: TrainingState, env_state: envs.State, + key: PRNGKey) -> Tuple[TrainingState, envs.State, Metrics]: + nonlocal training_walltime + t = time.time() + (training_state, env_state, + metrics) = training_epoch(training_state, env_state, key) + metrics = jax.tree_util.tree_map(jnp.mean, metrics) + jax.tree_util.tree_map(lambda x: x.block_until_ready(), metrics) + + epoch_training_time = time.time() - t + training_walltime += epoch_training_time + sps = (num_training_steps_per_epoch * + env_step_per_training_step) / epoch_training_time + metrics = { + 'training/sps': sps, + 'training/walltime': training_walltime, + **{f'training/{name}': value for name, value in metrics.items()} + } + return training_state, env_state, metrics + + key = jax.random.PRNGKey(seed) + global_key, local_key = jax.random.split(key) + del key + local_key = jax.random.fold_in(local_key, process_id) + local_key, key_env, eval_key = jax.random.split(local_key, 3) + # key_networks should be global, so that networks are initialized the same + # way for different processes. + key_policy, key_value = jax.random.split(global_key) + del global_key + + policy_init_params = shac_network.policy_network.init(key_policy) + value_init_params = shac_network.value_network.init(key_value) + training_state = TrainingState( + policy_optimizer_state=policy_optimizer.init(policy_init_params), + policy_params=policy_init_params, + value_optimizer_state=value_optimizer.init(value_init_params), + value_params=value_init_params, + target_value_params=value_init_params, + normalizer_params=running_statistics.init_state( + specs.Array((env.observation_size,), jnp.float32)), + env_steps=0) + training_state = jax.device_put_replicated( + training_state, + jax.local_devices()[:local_devices_to_use]) + + key_envs = jax.random.split(key_env, num_envs // process_count) + key_envs = jnp.reshape(key_envs, + (local_devices_to_use, -1) + key_envs.shape[1:]) + env_state = reset_fn(key_envs) + + if not eval_env: + eval_env = env + else: + eval_env = wrappers.wrap_for_training( + eval_env, episode_length=episode_length, action_repeat=action_repeat) + + evaluator = acting.Evaluator( + eval_env, + functools.partial(make_policy, deterministic=deterministic_eval), + num_eval_envs=num_eval_envs, + episode_length=episode_length, + action_repeat=action_repeat, + key=eval_key) + + # Run initial eval + if process_id == 0 and num_evals > 1: + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.policy_params)), + training_metrics={}) + logging.info(metrics) + progress_fn(0, metrics) + + training_walltime = 0 + current_step = 0 + for it in range(num_evals_after_init): + logging.info('starting iteration %s %s', it, time.time() - xt) + + # optimization + epoch_key, local_key = jax.random.split(local_key) + epoch_keys = jax.random.split(epoch_key, local_devices_to_use) + (training_state, env_state, + training_metrics) = training_epoch_with_timing(training_state, env_state, + epoch_keys) + current_step = int(_unpmap(training_state.env_steps)) + + if process_id == 0: + # Run evals. + metrics = evaluator.run_evaluation( + _unpmap( + (training_state.normalizer_params, training_state.policy_params)), + training_metrics) + logging.info(metrics) + progress_fn(current_step, metrics) + + total_steps = current_step + assert total_steps >= num_timesteps + + # If there was no mistakes the training_state should still be identical on all + # devices. + pmap.assert_is_replicated(training_state) + params = _unpmap( + (training_state.normalizer_params, training_state.policy_params)) + logging.info('total steps: %s', total_steps) + pmap.synchronize_hosts() + return (make_policy, params, metrics) diff --git a/brax/training/agents/shac/train_test.py b/brax/training/agents/shac/train_test.py new file mode 100644 index 000000000..b3d5c79ad --- /dev/null +++ b/brax/training/agents/shac/train_test.py @@ -0,0 +1,79 @@ +# Copyright 2022 The Brax Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""SHAC tests.""" +import pickle + +from absl.testing import absltest +from absl.testing import parameterized +from brax import envs +from brax.training.acme import running_statistics +from brax.training.agents.shac import networks as shac_networks +from brax.training.agents.shac import train as shac +import jax + + +class SHACTest(parameterized.TestCase): + """Tests for SHAC module.""" + + + def testTrain(self): + """Test SHAC with a simple env.""" + _, _, metrics = shac.train( + envs.get_environment('fast_differentiable'), + num_timesteps=2**15, + episode_length=128, + num_envs=64, + actor_learning_rate=1.5e-2, + critic_learning_rate=1e-3, + entropy_cost=1e-2, + discounting=0.95, + unroll_length=10, + batch_size=64, + num_minibatches=8, + num_updates_per_batch=1, + normalize_observations=True, + seed=2, + reward_scaling=10) + self.assertGreater(metrics['eval/episode_reward'], 135) + + @parameterized.parameters(True, False) + def testNetworkEncoding(self, normalize_observations): + env = envs.get_environment('fast') + original_inference, params, _ = shac.train( + env, + num_timesteps=128, + episode_length=128, + num_envs=128, + normalize_observations=normalize_observations) + normalize_fn = lambda x, y: x + if normalize_observations: + normalize_fn = running_statistics.normalize + shac_network = shac_networks.make_shac_networks(env.observation_size, + env.action_size, normalize_fn) + inference = shac_networks.make_inference_fn(shac_network) + byte_encoding = pickle.dumps(params) + decoded_params = pickle.loads(byte_encoding) + + # Compute one action. + state = env.reset(jax.random.PRNGKey(0)) + original_action = original_inference(decoded_params)( + state.obs, jax.random.PRNGKey(0))[0] + action = inference(decoded_params)(state.obs, jax.random.PRNGKey(0))[0] + self.assertSequenceEqual(original_action, action) + env.step(state, action) + + +if __name__ == '__main__': + absltest.main() diff --git a/brax/training/networks.py b/brax/training/networks.py index 5856360a8..404e73a9e 100644 --- a/brax/training/networks.py +++ b/brax/training/networks.py @@ -41,6 +41,7 @@ class MLP(linen.Module): kernel_init: Initializer = jax.nn.initializers.lecun_uniform() activate_final: bool = False bias: bool = True + layer_norm: bool = True @linen.compact def __call__(self, data: jnp.ndarray): @@ -54,6 +55,8 @@ def __call__(self, data: jnp.ndarray): hidden) if i != len(self.layer_sizes) - 1 or self.activate_final: hidden = self.activation(hidden) + if self.layer_norm: + hidden = linen.LayerNorm()(hidden) return hidden @@ -86,11 +89,13 @@ def make_policy_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" policy_module = MLP( layer_sizes=list(hidden_layer_sizes) + [param_size], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs): @@ -107,11 +112,13 @@ def make_value_network( preprocess_observations_fn: types.PreprocessObservationFn = types .identity_observation_preprocessor, hidden_layer_sizes: Sequence[int] = (256, 256), - activation: ActivationFn = linen.relu) -> FeedForwardNetwork: + activation: ActivationFn = linen.relu, + layer_norm: bool = False) -> FeedForwardNetwork: """Creates a policy network.""" value_module = MLP( layer_sizes=list(hidden_layer_sizes) + [1], activation=activation, + layer_norm=layer_norm, kernel_init=jax.nn.initializers.lecun_uniform()) def apply(processor_params, policy_params, obs):