-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[RLlib] Attention Net prep PR #3. #12450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b5a4bc1
0437680
86911d9
222f0a9
5e269c4
787810d
7113c32
2040c93
f79faf7
9320694
b5a31b3
045a6f2
1d8fb50
6769a2b
7176066
ab077f6
bc084a2
7241d82
26839ba
9178a8c
8e4565a
8225693
6825f4e
02aad5a
e469f5e
d757cdd
09b2071
76c7461
2483dc3
98d8b9f
b905302
01deb21
d514de4
324b058
753f467
8fd8099
950163b
ab84da6
2d9fcf6
ff9d66f
fcd7223
eebd212
f07f19b
0599c52
d0de124
a27d08a
a790ab6
eed8ef7
86d8ab2
18d8fbf
8b0e483
8753231
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -193,13 +193,22 @@ def postprocess_ppo_gae( | |
| last_r = 0.0 | ||
| # Trajectory has been truncated -> last r=VF estimate of last obs. | ||
| else: | ||
| next_state = [] | ||
| for i in range(policy.num_state_tensors()): | ||
| next_state.append(sample_batch["state_out_{}".format(i)][-1]) | ||
| last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], | ||
| sample_batch[SampleBatch.ACTIONS][-1], | ||
| sample_batch[SampleBatch.REWARDS][-1], | ||
| *next_state) | ||
| # Input dict is provided to us automatically via the Model's | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ask Model for the input dict (from the given SampleBatc) at index=-1.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the comment inaccurate now?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this is still valid. |
||
| # requirements. It's a single-timestep (last one in trajectory) | ||
| # input_dict. | ||
| if policy.config["_use_trajectory_view_api"]: | ||
| # Create an input dict according to the Model's requirements. | ||
| input_dict = policy.model.get_input_dict(sample_batch, index=-1) | ||
| last_r = policy._value(**input_dict) | ||
| # TODO: (sven) Remove once trajectory view API is all-algo default. | ||
| else: | ||
| next_state = [] | ||
| for i in range(policy.num_state_tensors()): | ||
| next_state.append(sample_batch["state_out_{}".format(i)][-1]) | ||
| last_r = policy._value(sample_batch[SampleBatch.NEXT_OBS][-1], | ||
| sample_batch[SampleBatch.ACTIONS][-1], | ||
| sample_batch[SampleBatch.REWARDS][-1], | ||
| *next_state) | ||
|
|
||
| # Adds the policy logits, VF preds, and advantages to the batch, | ||
| # using GAE ("generalized advantage estimation") or not. | ||
|
|
@@ -208,7 +217,9 @@ def postprocess_ppo_gae( | |
| last_r, | ||
| policy.config["gamma"], | ||
| policy.config["lambda"], | ||
| use_gae=policy.config["use_gae"]) | ||
| use_gae=policy.config["use_gae"], | ||
| use_critic=policy.config.get("use_critic", True)) | ||
|
|
||
| return batch | ||
|
|
||
|
|
||
|
|
@@ -292,25 +303,40 @@ def __init__(self, obs_space, action_space, config): | |
| # observation. | ||
| if config["use_gae"]: | ||
|
|
||
| @make_tf_callable(self.get_session()) | ||
| def value(ob, prev_action, prev_reward, *state): | ||
| model_out, _ = self.model({ | ||
| SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]), | ||
| SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( | ||
| [prev_action]), | ||
| SampleBatch.PREV_REWARDS: tf.convert_to_tensor( | ||
| [prev_reward]), | ||
| "is_training": tf.convert_to_tensor([False]), | ||
| }, [tf.convert_to_tensor([s]) for s in state], | ||
| tf.convert_to_tensor([1])) | ||
| # [0] = remove the batch dim. | ||
| return self.model.value_function()[0] | ||
| # Input dict is provided to us automatically via the Model's | ||
| # requirements. It's a single-timestep (last one in trajectory) | ||
| # input_dict. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the reason for this change? Could we get the previous code to work without adding this if branch?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above: We want the input dict to be determined by what the model needs as inputs. |
||
| if config["_use_trajectory_view_api"]: | ||
|
|
||
| @make_tf_callable(self.get_session()) | ||
| def value(**input_dict): | ||
| model_out, _ = self.model.from_batch( | ||
| input_dict, is_training=False) | ||
| # [0] = remove the batch dim. | ||
| return self.model.value_function()[0] | ||
|
|
||
| # TODO: (sven) Remove once trajectory view API is all-algo default. | ||
| else: | ||
|
|
||
| @make_tf_callable(self.get_session()) | ||
| def value(ob, prev_action, prev_reward, *state): | ||
| model_out, _ = self.model({ | ||
| SampleBatch.CUR_OBS: tf.convert_to_tensor([ob]), | ||
| SampleBatch.PREV_ACTIONS: tf.convert_to_tensor( | ||
| [prev_action]), | ||
| SampleBatch.PREV_REWARDS: tf.convert_to_tensor( | ||
| [prev_reward]), | ||
| "is_training": tf.convert_to_tensor([False]), | ||
| }, [tf.convert_to_tensor([s]) for s in state], | ||
| tf.convert_to_tensor([1])) | ||
| # [0] = remove the batch dim. | ||
| return self.model.value_function()[0] | ||
|
|
||
| # When not doing GAE, we do not require the value function's output. | ||
| else: | ||
|
|
||
| @make_tf_callable(self.get_session()) | ||
| def value(ob, prev_action, prev_reward, *state): | ||
| def value(*args, **kwargs): | ||
| return tf.constant(0.0) | ||
|
|
||
| self._value = value | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,9 +1,6 @@ | ||
| from gym.spaces import Box | ||
|
|
||
| from ray.rllib.models.modelv2 import ModelV2 | ||
| from ray.rllib.models.preprocessors import get_preprocessor | ||
| from ray.rllib.models.torch.torch_modelv2 import TorchModelV2 | ||
| from ray.rllib.policy.view_requirement import ViewRequirement | ||
| from ray.rllib.utils.annotations import override | ||
| from ray.rllib.utils.framework import try_import_torch | ||
|
|
||
|
|
@@ -25,17 +22,13 @@ def __init__(self, obs_space, action_space, num_outputs, model_config, | |
| self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) | ||
| self.n_agents = model_config["n_agents"] | ||
|
|
||
| self.inference_view_requirements.update({ | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure this stays backward-compatible even w/o specifying this here. |
||
| "state_in_0": ViewRequirement( | ||
| "state_out_0", | ||
| data_rel_pos=-1, | ||
| space=Box(-1.0, 1.0, (self.n_agents, self.rnn_hidden_dim))) | ||
| }) | ||
|
|
||
| @override(ModelV2) | ||
| def get_initial_state(self): | ||
| # Place hidden states on same device as model. | ||
| return [self.fc1.weight.new(1, self.rnn_hidden_dim).zero_().squeeze(0)] | ||
| return [ | ||
| self.fc1.weight.new(self.n_agents, | ||
| self.rnn_hidden_dim).zero_().squeeze(0) | ||
| ] | ||
|
|
||
| @override(ModelV2) | ||
| def forward(self, input_dict, hidden_state, seq_lens): | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-use PPO's function.
We should probably do the same for A3C and PG.