Skip to content

Commit d6d2dee

Browse files
authored
[RLlib] New ConnectorV2 API #2: SingleAgentEpisode enhancements. (#41075)
1 parent 42c8e0b commit d6d2dee

19 files changed

+2613
-1300
lines changed

rllib/BUILD

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -683,6 +683,16 @@ py_test(
683683
args = ["--dir=tuned_examples/ppo"]
684684
)
685685

686+
py_test(
687+
name = "test_memory_leak_ppo_new_stack",
688+
tags = ["team:rllib", "memory_leak_tests"],
689+
main = "utils/tests/run_memory_leak_tests.py",
690+
size = "large",
691+
srcs = ["utils/tests/run_memory_leak_tests.py"],
692+
data = ["tuned_examples/ppo/memory-leak-test-ppo-new-stack.py"],
693+
args = ["--dir=tuned_examples/ppo", "--to-check=rollout_worker"]
694+
)
695+
686696
py_test(
687697
name = "test_memory_leak_sac",
688698
tags = ["team:rllib", "memory_leak_tests"],
@@ -772,12 +782,12 @@ py_test(
772782
srcs = ["env/tests/test_multi_agent_env.py"]
773783
)
774784

775-
py_test(
776-
name = "env/tests/test_multi_agent_episode",
777-
tags = ["team:rllib", "env"],
778-
size = "medium",
779-
srcs = ["env/tests/test_multi_agent_episode.py"]
780-
)
785+
# py_test(
786+
# name = "env/tests/test_multi_agent_episode",
787+
# tags = ["team:rllib", "env"],
788+
# size = "medium",
789+
# srcs = ["env/tests/test_multi_agent_episode.py"]
790+
# )
781791

782792
sh_test(
783793
name = "env/tests/test_remote_inference_cartpole",
@@ -818,19 +828,26 @@ sh_test(
818828
# )
819829

820830
py_test(
821-
name = "env/tests/test_single_agent_gym_env_runner",
831+
name = "env/tests/test_single_agent_env_runner",
822832
tags = ["team:rllib", "env"],
823833
size = "medium",
824-
srcs = ["env/tests/test_single_agent_gym_env_runner.py"]
834+
srcs = ["env/tests/test_single_agent_env_runner.py"]
825835
)
826836

827837
py_test(
828838
name = "env/tests/test_single_agent_episode",
829839
tags = ["team:rllib", "env"],
830-
size = "medium",
840+
size = "small",
831841
srcs = ["env/tests/test_single_agent_episode.py"]
832842
)
833843

844+
py_test(
845+
name = "env/tests/test_lookback_buffer",
846+
tags = ["team:rllib", "env"],
847+
size = "small",
848+
srcs = ["env/tests/test_lookback_buffer.py"]
849+
)
850+
834851
py_test(
835852
name = "env/wrappers/tests/test_exception_wrapper",
836853
tags = ["team:rllib", "env"],
@@ -1332,7 +1349,6 @@ py_test(
13321349
# Tag: utils
13331350
# --------------------------------------------------------------------
13341351

1335-
# Checkpoint Utils
13361352
py_test(
13371353
name = "test_checkpoint_utils",
13381354
tags = ["team:rllib", "utils"],
@@ -2947,6 +2963,7 @@ py_test(
29472963
py_test_module_list(
29482964
files = [
29492965
"env/wrappers/tests/test_kaggle_wrapper.py",
2966+
"env/tests/test_multi_agent_episode.py",
29502967
"examples/env/tests/test_cliff_walking_wall_env.py",
29512968
"examples/env/tests/test_coin_game_non_vectorized_env.py",
29522969
"examples/env/tests/test_coin_game_vectorized_env.py",

rllib/algorithms/dreamerv3/utils/env_runner.py

Lines changed: 41 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@
3232
_, tf, _ = try_import_tf()
3333

3434

35+
# TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new
36+
# ConnectorV2 example classes to make Atari work properly with these (w/o requiring the
37+
# classes at the bottom of this file here, e.g. `ActionClip`).
3538
class DreamerV3EnvRunner(EnvRunner):
3639
"""An environment runner to collect data from vectorized gymnasium environments."""
3740

@@ -144,6 +147,7 @@ def __init__(
144147

145148
self._needs_initial_reset = True
146149
self._episodes = [None for _ in range(self.num_envs)]
150+
self._states = [None for _ in range(self.num_envs)]
147151

148152
# TODO (sven): Move metrics temp storage and collection out of EnvRunner
149153
# and RolloutWorkers. These classes should not continue tracking some data
@@ -254,10 +258,8 @@ def _sample_timesteps(
254258

255259
# Set initial obs and states in the episodes.
256260
for i in range(self.num_envs):
257-
self._episodes[i].add_initial_observation(
258-
initial_observation=obs[i],
259-
initial_state={k: s[i] for k, s in states.items()},
260-
)
261+
self._episodes[i].add_env_reset(observation=obs[i])
262+
self._states[i] = {k: s[i] for k, s in states.items()}
261263
# Don't reset existing envs; continue in already started episodes.
262264
else:
263265
# Pick up stored observations and states from previous timesteps.
@@ -268,7 +270,9 @@ def _sample_timesteps(
268270
states = {
269271
k: np.stack(
270272
[
271-
initial_states[k][i] if eps.states is None else eps.states[k]
273+
initial_states[k][i]
274+
if self._states[i] is None
275+
else self._states[i][k]
272276
for i, eps in enumerate(self._episodes)
273277
]
274278
)
@@ -278,7 +282,7 @@ def _sample_timesteps(
278282
# to 1.0, otherwise 0.0.
279283
is_first = np.zeros((self.num_envs,))
280284
for i, eps in enumerate(self._episodes):
281-
if eps.states is None:
285+
if len(eps) == 0:
282286
is_first[i] = 1.0
283287

284288
# Loop through env for n timesteps.
@@ -319,37 +323,39 @@ def _sample_timesteps(
319323
if terminateds[i] or truncateds[i]:
320324
# Finish the episode with the actual terminal observation stored in
321325
# the info dict.
322-
self._episodes[i].add_timestep(
323-
infos["final_observation"][i],
324-
actions[i],
325-
rewards[i],
326-
state=s,
327-
is_terminated=terminateds[i],
328-
is_truncated=truncateds[i],
326+
self._episodes[i].add_env_step(
327+
observation=infos["final_observation"][i],
328+
action=actions[i],
329+
reward=rewards[i],
330+
terminated=terminateds[i],
331+
truncated=truncateds[i],
329332
)
333+
self._states[i] = s
330334
# Reset h-states to the model's initial ones b/c we are starting a
331335
# new episode.
332336
for k, v in self.module.get_initial_state().items():
333337
states[k][i] = v.numpy()
334338
is_first[i] = True
335339
done_episodes_to_return.append(self._episodes[i])
336340
# Create a new episode object.
337-
self._episodes[i] = SingleAgentEpisode(
338-
observations=[obs[i]], states=s
339-
)
341+
self._episodes[i] = SingleAgentEpisode(observations=[obs[i]])
340342
else:
341-
self._episodes[i].add_timestep(
342-
obs[i], actions[i], rewards[i], state=s
343+
self._episodes[i].add_env_step(
344+
observation=obs[i],
345+
action=actions[i],
346+
reward=rewards[i],
343347
)
344348
is_first[i] = False
345349

350+
self._states[i] = s
351+
346352
# Return done episodes ...
347353
self._done_episodes_for_metrics.extend(done_episodes_to_return)
348354
# ... and all ongoing episode chunks. Also, make sure, we return
349355
# a copy and start new chunks so that callers of this function
350356
# don't alter our ongoing and returned Episode objects.
351357
ongoing_episodes = self._episodes
352-
self._episodes = [eps.create_successor() for eps in self._episodes]
358+
self._episodes = [eps.cut() for eps in self._episodes]
353359
for eps in ongoing_episodes:
354360
self._ongoing_episodes_for_metrics[eps.id_].append(eps)
355361

@@ -385,10 +391,9 @@ def _sample_episodes(
385391
render_images = [e.render() for e in self.env.envs]
386392

387393
for i in range(self.num_envs):
388-
episodes[i].add_initial_observation(
389-
initial_observation=obs[i],
390-
initial_state={k: s[i] for k, s in states.items()},
391-
initial_render_image=render_images[i],
394+
episodes[i].add_env_reset(
395+
observation=obs[i],
396+
render_image=render_images[i],
392397
)
393398

394399
eps = 0
@@ -419,19 +424,17 @@ def _sample_episodes(
419424
render_images = [e.render() for e in self.env.envs]
420425

421426
for i in range(self.num_envs):
422-
s = {k: s[i] for k, s in states.items()}
423427
# The last entry in self.observations[i] is already the reset
424428
# obs of the new episode.
425429
if terminateds[i] or truncateds[i]:
426430
eps += 1
427431

428-
episodes[i].add_timestep(
429-
infos["final_observation"][i],
430-
actions[i],
431-
rewards[i],
432-
state=s,
433-
is_terminated=terminateds[i],
434-
is_truncated=truncateds[i],
432+
episodes[i].add_env_step(
433+
observation=infos["final_observation"][i],
434+
action=actions[i],
435+
reward=rewards[i],
436+
terminated=terminateds[i],
437+
truncated=truncateds[i],
435438
)
436439
done_episodes_to_return.append(episodes[i])
437440

@@ -448,15 +451,15 @@ def _sample_episodes(
448451

449452
episodes[i] = SingleAgentEpisode(
450453
observations=[obs[i]],
451-
states=s,
452-
render_images=[render_images[i]],
454+
render_images=(
455+
[render_images[i]] if with_render_data else None
456+
),
453457
)
454458
else:
455-
episodes[i].add_timestep(
456-
obs[i],
457-
actions[i],
458-
rewards[i],
459-
state=s,
459+
episodes[i].add_env_step(
460+
observation=obs[i],
461+
action=actions[i],
462+
reward=rewards[i],
460463
render_image=render_images[i],
461464
)
462465
is_first[i] = False

0 commit comments

Comments
 (0)