Skip to content

Commit ae0f58e

Browse files
committed
improved DirectMARLEnv APIs
1 parent e571091 commit ae0f58e

File tree

1 file changed

+13
-61
lines changed

1 file changed

+13
-61
lines changed

source/isaaclab/isaaclab/envs/direct_marl_env.py

Lines changed: 13 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,9 @@ def num_envs(self) -> int:
229229

230230
@property
231231
def num_agents(self) -> int:
232-
"""Number of current agents.
233-
234-
The number of current agents may change as the environment progresses (e.g.: agents can be added or removed).
232+
"""Number of agents as configured in config
235233
"""
236-
return len(self.agents)
237-
238-
@property
239-
def max_num_agents(self) -> int:
240-
"""Number of all possible agents the environment can generate.
241-
242-
This value remains constant as the environment progresses.
243-
"""
244-
return len(self.possible_agents)
234+
return len(self.cfg.possible_agents)
245235

246236
@property
247237
def unwrapped(self) -> DirectMARLEnv:
@@ -328,7 +318,6 @@ def reset(
328318

329319
# update observations and the list of current agents (sorted as in possible_agents)
330320
self.obs_dict = self._get_observations()
331-
self.agents = [agent for agent in self.possible_agents if agent in self.obs_dict]
332321

333322
# return observations
334323
return self.obs_dict, self.extras
@@ -410,7 +399,6 @@ def step(self, actions: dict[AgentID, ActionType]) -> EnvStepReturn:
410399

411400
# update observations and the list of current agents (sorted as in possible_agents)
412401
self.obs_dict = self._get_observations()
413-
self.agents = [agent for agent in self.possible_agents if agent in self.obs_dict]
414402

415403
# add observation noise
416404
# note: we apply no noise to the state space (since it is used for centralized training or critic networks)
@@ -422,7 +410,7 @@ def step(self, actions: dict[AgentID, ActionType]) -> EnvStepReturn:
422410
# return observations, rewards, resets and extras
423411
return self.obs_dict, self.reward_dict, self.terminated_dict, self.time_out_dict, self.extras
424412

425-
def state(self) -> StateType | None:
413+
def state(self) -> dict[AgentID, torch.Tensor]:
426414
"""Returns the state for the environment.
427415
428416
The state-space is used for centralized training or asymmetric actor-critic architectures. It is configured
@@ -431,18 +419,7 @@ def state(self) -> StateType | None:
431419
Returns:
432420
The states for the environment, or None if :attr:`DirectMARLEnvCfg.state_space` parameter is zero.
433421
"""
434-
if not self.cfg.state_space:
435-
return None
436-
# concatenate and return the observations as state
437-
# FIXME: This implementation assumes the spaces are fundamental ones. Fix it to support composite spaces
438-
if isinstance(self.cfg.state_space, int) and self.cfg.state_space < 0:
439-
self.state_buf = torch.cat(
440-
[self.obs_dict[agent].reshape(self.num_envs, -1) for agent in self.cfg.possible_agents], dim=-1
441-
)
442-
# compute and return custom environment state
443-
else:
444-
self.state_buf = self._get_states()
445-
return self.state_buf
422+
return self._get_states()
446423

447424
@staticmethod
448425
def seed(seed: int = -1) -> int:
@@ -597,43 +574,18 @@ def set_debug_vis(self, debug_vis: bool) -> bool:
597574

598575
def _configure_env_spaces(self):
599576
"""Configure the spaces for the environment."""
600-
self.agents = self.cfg.possible_agents
601-
self.possible_agents = self.cfg.possible_agents
602-
603-
# show deprecation message and overwrite configuration
604-
if self.cfg.num_actions is not None:
605-
omni.log.warn("DirectMARLEnvCfg.num_actions is deprecated. Use DirectMARLEnvCfg.action_spaces instead.")
606-
if isinstance(self.cfg.action_spaces, type(MISSING)):
607-
self.cfg.action_spaces = self.cfg.num_actions
608-
if self.cfg.num_observations is not None:
609-
omni.log.warn(
610-
"DirectMARLEnvCfg.num_observations is deprecated. Use DirectMARLEnvCfg.observation_spaces instead."
611-
)
612-
if isinstance(self.cfg.observation_spaces, type(MISSING)):
613-
self.cfg.observation_spaces = self.cfg.num_observations
614-
if self.cfg.num_states is not None:
615-
omni.log.warn("DirectMARLEnvCfg.num_states is deprecated. Use DirectMARLEnvCfg.state_space instead.")
616-
if isinstance(self.cfg.state_space, type(MISSING)):
617-
self.cfg.state_space = self.cfg.num_states
618-
619-
# set up observation and action spaces
620577
self.observation_spaces = {
621-
agent: spec_to_gym_space(self.cfg.observation_spaces[agent]) for agent in self.cfg.possible_agents
578+
agent_name: spec_to_gym_space(space) \
579+
for agent_name, space in self.cfg.observation_spaces.items()
622580
}
623581
self.action_spaces = {
624-
agent: spec_to_gym_space(self.cfg.action_spaces[agent]) for agent in self.cfg.possible_agents
582+
agent_name: spec_to_gym_space(space) \
583+
for agent_name, space in self.cfg.action_spaces.items()
584+
}
585+
self.state_spaces = {
586+
agent_name: spec_to_gym_space(space) \
587+
for agent_name, space in self.cfg.state_space.items()
625588
}
626-
627-
# set up state space
628-
if not self.cfg.state_space:
629-
self.state_space = None
630-
if isinstance(self.cfg.state_space, int) and self.cfg.state_space < 0:
631-
self.state_space = gym.spaces.flatten_space(
632-
gym.spaces.Tuple([self.observation_spaces[agent] for agent in self.cfg.possible_agents])
633-
)
634-
else:
635-
self.state_space = spec_to_gym_space(self.cfg.state_space)
636-
637589
# instantiate actions (needed for tasks for which the observations computation is dependent on the actions)
638590
self.actions = {
639591
agent: sample_space(self.action_spaces[agent], self.sim.device, batch_size=self.num_envs, fill_value=0)
@@ -713,7 +665,7 @@ def _get_observations(self) -> dict[AgentID, ObsType]:
713665
raise NotImplementedError(f"Please implement the '_get_observations' method for {self.__class__.__name__}.")
714666

715667
@abstractmethod
716-
def _get_states(self) -> StateType:
668+
def _get_states(self) -> dict[AgentID, torch.Tensor]:
717669
"""Compute and return the states for the environment.
718670
719671
This method is only called (and therefore has to be implemented) when the :attr:`DirectMARLEnvCfg.state_space`

0 commit comments

Comments
 (0)