From c204de86b6150f70eb7ff543cec3d66a2df0ff32 Mon Sep 17 00:00:00 2001 From: swyo Date: Mon, 17 Jun 2024 13:15:27 +0900 Subject: [PATCH 1/3] fix example for mps; issue #19981 --- examples/fabric/reinforcement_learning/train_fabric.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index 1f3f83f3f2025..e865010d91bd9 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -146,7 +146,7 @@ def main(args: argparse.Namespace): # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device).view(-1) + rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32 if device.type == 'mps' else None).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: From 74f8a351751fcfc7c602f89de11698d0ccbfebc7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Jun 2024 04:20:55 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- examples/fabric/reinforcement_learning/train_fabric.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index e865010d91bd9..d07784d1d6962 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -146,7 +146,9 @@ def main(args: argparse.Namespace): # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32 if device.type == 'mps' else None).view(-1) + rewards[step] = torch.tensor( + reward, device=device, dtype=torch.float32 if device.type == "mps" else None + ).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: From 26006e7acf1ceaa5cf032426b7e5b847f76ea2df Mon Sep 17 00:00:00 2001 From: swyo Date: Fri, 21 Jun 2024 17:15:25 +0900 Subject: [PATCH 3/3] apply feedback: type casting unconditionally --- examples/fabric/reinforcement_learning/train_fabric.py | 2 +- .../fabric/reinforcement_learning/train_fabric_decoupled.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/fabric/reinforcement_learning/train_fabric.py b/examples/fabric/reinforcement_learning/train_fabric.py index e865010d91bd9..74b9b378371d3 100644 --- a/examples/fabric/reinforcement_learning/train_fabric.py +++ b/examples/fabric/reinforcement_learning/train_fabric.py @@ -146,7 +146,7 @@ def main(args: argparse.Namespace): # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32 if device.type == 'mps' else None).view(-1) + rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: diff --git a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py index bbc09c977efcf..3849ae0f96a3c 100644 --- a/examples/fabric/reinforcement_learning/train_fabric_decoupled.py +++ b/examples/fabric/reinforcement_learning/train_fabric_decoupled.py @@ -135,7 +135,7 @@ def player(args, world_collective: TorchCollective, player_trainer_collective: T # Single environment step next_obs, reward, done, truncated, info = envs.step(action.cpu().numpy()) done = torch.logical_or(torch.tensor(done), torch.tensor(truncated)) - rewards[step] = torch.tensor(reward, device=device).view(-1) + rewards[step] = torch.tensor(reward, device=device, dtype=torch.float32).view(-1) next_obs, next_done = torch.tensor(next_obs, device=device), done.to(device) if "final_info" in info: