Skip to content

Commit 92b6254

Browse files
imstevenpmworkxiangyang-95Copilotpre-commit-ci[bot]jadechoghari
authored
feat(utils): add support for Intel XPU backend (#2233)
* feat: add support for Intel XPU backend in device selection * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Lim Xiang Yang <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: update is_amp_available to include xpu as a valid device * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Lim Xiang Yang <[email protected]> * Update src/lerobot/utils/utils.py Co-authored-by: Copilot <[email protected]> Signed-off-by: Lim Xiang Yang <[email protected]> * fix: remove unused return and add comments on fp64 fallback handling * fix(utils): return dtype in case xpu has fp64 --------- Signed-off-by: Lim Xiang Yang <[email protected]> Co-authored-by: Lim, Xiang Yang <[email protected]> Co-authored-by: Lim Xiang Yang <[email protected]> Co-authored-by: Copilot <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Jade Choghari <[email protected]>
1 parent 79137f5 commit 92b6254

File tree

2 files changed

+26
-3
lines changed

2 files changed

+26
-3
lines changed

src/lerobot/utils/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device:
4545
elif torch.backends.mps.is_available():
4646
logging.info("Metal backend detected, using mps.")
4747
return torch.device("mps")
48+
elif torch.xpu.is_available():
49+
logging.info("Intel XPU backend detected, using xpu.")
50+
return torch.device("xpu")
4851
else:
4952
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
5053
return torch.device("cpu")
@@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
6164
case "mps":
6265
assert torch.backends.mps.is_available()
6366
device = torch.device("mps")
67+
case "xpu":
68+
assert torch.xpu.is_available()
69+
device = torch.device("xpu")
6470
case "cpu":
6571
device = torch.device("cpu")
6672
if log:
@@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
8187
device = device.type
8288
if device == "mps" and dtype == torch.float64:
8389
return torch.float32
90+
if device == "xpu" and dtype == torch.float64:
91+
if hasattr(torch.xpu, "get_device_capability"):
92+
device_capability = torch.xpu.get_device_capability()
93+
# NOTE: Some Intel XPU devices do not support double precision (FP64).
94+
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
95+
# when available; if False, we fall back to float32 for compatibility.
96+
if not device_capability.get("has_fp64", False):
97+
logging.warning(f"Device {device} does not support float64, using float32 instead.")
98+
return torch.float32
99+
else:
100+
logging.warning(
101+
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
102+
)
103+
return torch.float32
104+
return dtype
84105
else:
85106
return dtype
86107

@@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool:
91112
return torch.cuda.is_available()
92113
elif try_device == "mps":
93114
return torch.backends.mps.is_available()
115+
elif try_device == "xpu":
116+
return torch.xpu.is_available()
94117
elif try_device == "cpu":
95118
return True
96119
else:
97-
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
120+
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")
98121

99122

100123
def is_amp_available(device: str):
101-
if device in ["cuda", "cpu"]:
124+
if device in ["cuda", "xpu", "cpu"]:
102125
return True
103126
elif device == "mps":
104127
return False

tests/async_inference/test_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
389389
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
390390
for key, value in observation.items():
391391
if isinstance(value, torch.Tensor):
392-
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
392+
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"
393393

394394

395395
def test_raw_observation_to_observation_deterministic():

0 commit comments

Comments
 (0)