diff --git a/source/isaaclab/isaaclab/envs/mdp/events.py b/source/isaaclab/isaaclab/envs/mdp/events.py index 17c5f582d1e..5a6ec632d08 100644 --- a/source/isaaclab/isaaclab/envs/mdp/events.py +++ b/source/isaaclab/isaaclab/envs/mdp/events.py @@ -596,14 +596,16 @@ def randomize(data: torch.Tensor, params: tuple[float, float]) -> torch.Tensor: actuator_indices = slice(None) if isinstance(actuator.joint_indices, slice): global_indices = slice(None) + elif isinstance(actuator.joint_indices, torch.Tensor): + global_indices = actuator.joint_indices.to(self.asset.device) else: - global_indices = torch.tensor(actuator.joint_indices, device=self.asset.device) + raise TypeError("Actuator joint indices must be a slice or a torch.Tensor.") elif isinstance(actuator.joint_indices, slice): # we take the joints defined in the asset config - global_indices = actuator_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device) + global_indices = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device) else: # we take the intersection of the actuator joints and the asset config joints - actuator_joint_indices = torch.tensor(actuator.joint_indices, device=self.asset.device) + actuator_joint_indices = actuator.joint_indices asset_joint_ids = torch.tensor(self.asset_cfg.joint_ids, device=self.asset.device) # the indices of the joints in the actuator that have to be randomized actuator_indices = torch.nonzero(torch.isin(actuator_joint_indices, asset_joint_ids)).view(-1)