3232_ , tf , _  =  try_import_tf ()
3333
3434
35+ # TODO (sven): Use SingleAgentEnvRunner instead of this as soon as we have the new 
36+ #  ConnectorV2 example classes to make Atari work properly with these (w/o requiring the 
37+ #  classes at the bottom of this file here, e.g. `ActionClip`). 
3538class  DreamerV3EnvRunner (EnvRunner ):
3639    """An environment runner to collect data from vectorized gymnasium environments.""" 
3740
@@ -144,6 +147,7 @@ def __init__(
144147
145148        self ._needs_initial_reset  =  True 
146149        self ._episodes  =  [None  for  _  in  range (self .num_envs )]
150+         self ._states  =  [None  for  _  in  range (self .num_envs )]
147151
148152        # TODO (sven): Move metrics temp storage and collection out of EnvRunner 
149153        #  and RolloutWorkers. These classes should not continue tracking some data 
@@ -254,10 +258,8 @@ def _sample_timesteps(
254258
255259            # Set initial obs and states in the episodes. 
256260            for  i  in  range (self .num_envs ):
257-                 self ._episodes [i ].add_initial_observation (
258-                     initial_observation = obs [i ],
259-                     initial_state = {k : s [i ] for  k , s  in  states .items ()},
260-                 )
261+                 self ._episodes [i ].add_env_reset (observation = obs [i ])
262+                 self ._states [i ] =  {k : s [i ] for  k , s  in  states .items ()}
261263        # Don't reset existing envs; continue in already started episodes. 
262264        else :
263265            # Pick up stored observations and states from previous timesteps. 
@@ -268,7 +270,9 @@ def _sample_timesteps(
268270            states  =  {
269271                k : np .stack (
270272                    [
271-                         initial_states [k ][i ] if  eps .states  is  None  else  eps .states [k ]
273+                         initial_states [k ][i ]
274+                         if  self ._states [i ] is  None 
275+                         else  self ._states [i ][k ]
272276                        for  i , eps  in  enumerate (self ._episodes )
273277                    ]
274278                )
@@ -278,7 +282,7 @@ def _sample_timesteps(
278282            # to 1.0, otherwise 0.0. 
279283            is_first  =  np .zeros ((self .num_envs ,))
280284            for  i , eps  in  enumerate (self ._episodes ):
281-                 if  eps . states   is   None :
285+                 if  len ( eps )  ==   0 :
282286                    is_first [i ] =  1.0 
283287
284288        # Loop through env for n timesteps. 
@@ -319,37 +323,39 @@ def _sample_timesteps(
319323                if  terminateds [i ] or  truncateds [i ]:
320324                    # Finish the episode with the actual terminal observation stored in 
321325                    # the info dict. 
322-                     self ._episodes [i ].add_timestep (
323-                         infos ["final_observation" ][i ],
324-                         actions [i ],
325-                         rewards [i ],
326-                         state = s ,
327-                         is_terminated = terminateds [i ],
328-                         is_truncated = truncateds [i ],
326+                     self ._episodes [i ].add_env_step (
327+                         observation = infos ["final_observation" ][i ],
328+                         action = actions [i ],
329+                         reward = rewards [i ],
330+                         terminated = terminateds [i ],
331+                         truncated = truncateds [i ],
329332                    )
333+                     self ._states [i ] =  s 
330334                    # Reset h-states to the model's initial ones b/c we are starting a 
331335                    # new episode. 
332336                    for  k , v  in  self .module .get_initial_state ().items ():
333337                        states [k ][i ] =  v .numpy ()
334338                    is_first [i ] =  True 
335339                    done_episodes_to_return .append (self ._episodes [i ])
336340                    # Create a new episode object. 
337-                     self ._episodes [i ] =  SingleAgentEpisode (
338-                         observations = [obs [i ]], states = s 
339-                     )
341+                     self ._episodes [i ] =  SingleAgentEpisode (observations = [obs [i ]])
340342                else :
341-                     self ._episodes [i ].add_timestep (
342-                         obs [i ], actions [i ], rewards [i ], state = s 
343+                     self ._episodes [i ].add_env_step (
344+                         observation = obs [i ],
345+                         action = actions [i ],
346+                         reward = rewards [i ],
343347                    )
344348                    is_first [i ] =  False 
345349
350+                 self ._states [i ] =  s 
351+ 
346352        # Return done episodes ... 
347353        self ._done_episodes_for_metrics .extend (done_episodes_to_return )
348354        # ... and all ongoing episode chunks. Also, make sure, we return 
349355        # a copy and start new chunks so that callers of this function 
350356        # don't alter our ongoing and returned Episode objects. 
351357        ongoing_episodes  =  self ._episodes 
352-         self ._episodes  =  [eps .create_successor () for  eps  in  self ._episodes ]
358+         self ._episodes  =  [eps .cut () for  eps  in  self ._episodes ]
353359        for  eps  in  ongoing_episodes :
354360            self ._ongoing_episodes_for_metrics [eps .id_ ].append (eps )
355361
@@ -385,10 +391,9 @@ def _sample_episodes(
385391            render_images  =  [e .render () for  e  in  self .env .envs ]
386392
387393        for  i  in  range (self .num_envs ):
388-             episodes [i ].add_initial_observation (
389-                 initial_observation = obs [i ],
390-                 initial_state = {k : s [i ] for  k , s  in  states .items ()},
391-                 initial_render_image = render_images [i ],
394+             episodes [i ].add_env_reset (
395+                 observation = obs [i ],
396+                 render_image = render_images [i ],
392397            )
393398
394399        eps  =  0 
@@ -419,19 +424,17 @@ def _sample_episodes(
419424                render_images  =  [e .render () for  e  in  self .env .envs ]
420425
421426            for  i  in  range (self .num_envs ):
422-                 s  =  {k : s [i ] for  k , s  in  states .items ()}
423427                # The last entry in self.observations[i] is already the reset 
424428                # obs of the new episode. 
425429                if  terminateds [i ] or  truncateds [i ]:
426430                    eps  +=  1 
427431
428-                     episodes [i ].add_timestep (
429-                         infos ["final_observation" ][i ],
430-                         actions [i ],
431-                         rewards [i ],
432-                         state = s ,
433-                         is_terminated = terminateds [i ],
434-                         is_truncated = truncateds [i ],
432+                     episodes [i ].add_env_step (
433+                         observation = infos ["final_observation" ][i ],
434+                         action = actions [i ],
435+                         reward = rewards [i ],
436+                         terminated = terminateds [i ],
437+                         truncated = truncateds [i ],
435438                    )
436439                    done_episodes_to_return .append (episodes [i ])
437440
@@ -448,15 +451,15 @@ def _sample_episodes(
448451
449452                    episodes [i ] =  SingleAgentEpisode (
450453                        observations = [obs [i ]],
451-                         states = s ,
452-                         render_images = [render_images [i ]],
454+                         render_images = (
455+                             [render_images [i ]] if  with_render_data  else  None 
456+                         ),
453457                    )
454458                else :
455-                     episodes [i ].add_timestep (
456-                         obs [i ],
457-                         actions [i ],
458-                         rewards [i ],
459-                         state = s ,
459+                     episodes [i ].add_env_step (
460+                         observation = obs [i ],
461+                         action = actions [i ],
462+                         reward = rewards [i ],
460463                        render_image = render_images [i ],
461464                    )
462465                    is_first [i ] =  False 
0 commit comments