-
Notifications
You must be signed in to change notification settings - Fork 6.9k
[RLlib] Attention Net prep PR #3. #12450
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ntion_nets_prep_3
…ntion_nets_prep_2
…ntion_nets_prep_2
…ntion_nets_prep_2
…ntion_nets_prep_3 � Conflicts: � rllib/policy/view_requirement.py
…ntion_nets_prep_3
…ntion_nets_prep_3
ericl
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sven1977 main question is the input dict option. Maybe I'm missing something, but it doesn't seem like a necessary change. Correct me if this is wrong.
rllib/agents/ppo/ppo_tf_policy.py
Outdated
| # (abs_pos=-1). It's only used if the trajectory is not finished yet and we | ||
| # have to rely on the last value function output as a reward estimation. | ||
| return { | ||
| "_value_input_dict": ViewRequirement(is_input_dict=True, abs_pos=-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if there are better names for this
# Can specify either abs_index or shift
abs_index=-1
shift=x
I think the notion of "index" is clearer in Python (-1 index means end). Also, we get to keep shift.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough, I'll rename these two:
abs_pos -> index
data_rel_pos -> shift
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On the input_dict:
It's not necessary for this very PR, but as this is a preparatory PR (to make the attention PR smaller) I decided to already add this here. The attention net PR needs this feature to be able to not have to boiler-plate/hardcode the attention logic inside e.g. PPO's postprocessing fn (this function should not have to worry about the model being an RNN or attention net, it should not need to know).
src/ray/raylet/node_manager.cc
Outdated
| object_manager_.FreeObjects(object_ids, | ||
| /*local_only=*/false); | ||
| }, | ||
| on_objects_spilled), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
? revert change
rllib/policy/view_requirement.py
Outdated
| used_for_training (bool): Whether the data will be used for | ||
| training. If False, the column will not be copied into the | ||
| final train batch. | ||
| is_input_dict (bool): Whether the "view" of this requirement is an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems odd, what is the reason we need it? Are there any cleaner alternatives?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This option will be necessary for attention nets. We shouldn't have attention net or RNN-specific code in the postprocessing fn (e.g. of PPO). Instead, it's like saying: "I need an input_dict here, provide one given the model's requirements".
| return self.model.value_function()[0] | ||
| # Input dict is provided to us automatically via the policy-defined | ||
| # "view". It's a single-timestep (last one in trajectory) | ||
| # input_dict. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the reason for this change? Could we get the previous code to work without adding this if branch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see above: We want the input dict to be determined by what the model needs as inputs.
| use_critic=policy.config["use_critic"]) | ||
| else: | ||
| batch = sample_batch | ||
| sample_batch = postprocess_ppo_gae(policy, sample_batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Re-use PPO's function.
We should probably do the same for A3C and PG.
| sample_batch[SampleBatch.ACTIONS][-1], | ||
| sample_batch[SampleBatch.REWARDS][-1], | ||
| *next_state) | ||
| # Input dict is provided to us automatically via the Model's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ask Model for the input dict (from the given SampleBatc) at index=-1.
- in prep for attention nets which have special requirements for this (different from RNNs, different from non-recursive models).
- removes boilerplate input-dict creating code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment inaccurate now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, this is still valid.
| self.fc2 = nn.Linear(self.rnn_hidden_dim, num_outputs) | ||
| self.n_agents = model_config["n_agents"] | ||
|
|
||
| self.inference_view_requirements.update({ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
make sure this stays backward-compatible even w/o specifying this here.
| }) | ||
| self.buffers[SampleBatch.OBS].append(init_obs) | ||
| self.buffers[SampleBatch.EPS_ID].append(episode_id) | ||
| self.episode_id = episode_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't have to "collect" these. They are always the same for the same agent anyways.
|
|
||
| def build(self, view_requirements: Dict[str, ViewRequirement]) -> \ | ||
| SampleBatch: | ||
| def build(self, view_requirements: Dict[str, ViewRequirement], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe rename this to model_view_requirements for clarity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also: Model.inference_view_requirements -> Model.view_requirements.
| # Python primitive or dict (e.g. INFOs). | ||
| if isinstance(data, (int, float, bool, str, dict)): | ||
| self.buffers[col] = [0 for _ in range(shift)] | ||
| self.buffers[col] = [data for _ in range(shift)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Important for custom initial state values. Cannot assume always 0 here.
| # not view_requirements[view_col].used_for_training: | ||
| # continue | ||
| self.buffers[view_col].extend(data) | ||
| # 1) If col is not in view_requirements, we must have a direct |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this solves it:
- if - after postprocessing - some column is not in the view-reqs, we must deal with a base-Policy child (w/o auto-view-requirement handling) -> leave as is
- if we do have it in the view reqs AND used_for_training is False -> we must have gone through auto-detection, so it's save to remove it here (this column won't be needed for training).
| data_col: view_req.space.sample() | ||
| }) | ||
| data_list.append(buffers[k][data_col][time_indices]) | ||
| if data_col == SampleBatch.EPS_ID: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above: episode_id is always the same within one agent's collector. No need to collect an extra buffer here.
| fake_sampler: bool = False, | ||
| spaces: Optional[Dict[PolicyID, Tuple[gym.spaces.Space, | ||
| gym.spaces.Space]]] = None, | ||
| _use_trajectory_view_api: bool = True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pass this in explicitly now into RolloutWorker (was derived from policy_config before, which is problematic as this could be a partial config dict)
| # inherited from base `Policy` class. At this point here, the Policy | ||
| # must have it's Model (if any) defined and ready to output an initial | ||
| # state. | ||
| for pol in self.policy_map.values(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do the auto internal state -> view req here after policy has been created. This covers direct child Policies of the base Policy class, which don't have an auto-view-req mechanism.
| return self.time_major is True | ||
|
|
||
| # TODO: (sven) Experimental method. | ||
| def get_input_dict(self, sample_batch, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Model is able to create an input_dict for a single-step forward pass from an agent's trajectory batch.
| action_distribution=action_dist, | ||
| timestep=timestep, | ||
| explore=explore) | ||
| if self.config["_use_trajectory_view_api"]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed.
| mo = re.match("state_in_(\d+)", view_col) | ||
| if mo is not None: | ||
| input_dict[view_col] = self._state_inputs[int(mo.group(1))] | ||
| dummy_batch[view_col] = np.zeros_like( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better do all these in one call below.
rllib/policy/dynamic_tf_policy.py
Outdated
| batch_for_postproc = UsageTrackingDict(sb) | ||
| batch_for_postproc.count = sb.count | ||
| logger.info("Testing `postprocess_trajectory` w/ dummy batch.") | ||
| self.exploration.postprocess_trajectory( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have to also call exploration's postprocessing (may access fields in the batch we need to track; e.g. curiosity).
| # Just like torch Policy does. | ||
| self._optimizer = optimizers[0] if optimizers else None | ||
|
|
||
| self._initialize_loss_from_dummy_batch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Moved this here for consistency (same behavior as TorchPolicy). Also fixes a problem with curiosity where we do need the optimizer before loss init.
|
|
||
| def _update_model_inference_view_requirements_from_init_state(self): | ||
| """Uses this Model's initial state to auto-add necessary ViewReqs. | ||
| """Uses Model's (or this Policy's) init state to add needed ViewReqs. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make this more robust against Policies that don't have a model, but do return something from get_initial_state().
| from ray.rllib.policy.policy import Policy, LEARNER_STATS_KEY | ||
| from ray.rllib.policy.sample_batch import SampleBatch | ||
| from ray.rllib.policy.torch_policy import TorchPolicy | ||
| from ray.rllib.policy.view_requirement import ViewRequirement |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not needed anymore in any Policy templates.
| infos={}, | ||
| new_obs=obs_batch[0]) | ||
| batch = builder.build_and_reset(episode=None) | ||
| env_id = episodes[0].env_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changed test case to use new SampleCollector.
| sample_batch[SampleBatch.ACTIONS][-1], | ||
| sample_batch[SampleBatch.REWARDS][-1], | ||
| *next_state) | ||
| # Input dict is provided to us automatically via the Model's |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the comment inaccurate now?
| view-col to data-col in them). | ||
| inference_view_requirements (Dict[str, ViewRequirement]: The view | ||
| requirements dict needed to build an input dict for a ModelV2 | ||
| forward call. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This argument doesn't seem to be used, can we remove it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.
|
Looks good, but please resolve comments before merging. |
|
👍 Will do. |
The current attention net trajectory view PR (#11729) is too large (>1000 lines added).
Therefore, I'm moving smaller preparatory and cleanup changes in ~2 pre-PRs. This is the third one of these. Only review it once this 2nd one here (#12449) has been merged.
Why are these changes needed?
Related issue number
Checks
scripts/format.shto lint the changes in this PR.