Skip to content
Merged
27 changes: 25 additions & 2 deletions src/lerobot/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ def auto_select_torch_device() -> torch.device:
elif torch.backends.mps.is_available():
logging.info("Metal backend detected, using mps.")
return torch.device("mps")
elif torch.xpu.is_available():
logging.info("Intel XPU backend detected, using xpu.")
return torch.device("xpu")
else:
logging.warning("No accelerated backend detected. Using default cpu, this will be slow.")
return torch.device("cpu")
Expand All @@ -61,6 +64,9 @@ def get_safe_torch_device(try_device: str, log: bool = False) -> torch.device:
case "mps":
assert torch.backends.mps.is_available()
device = torch.device("mps")
case "xpu":
assert torch.xpu.is_available()
device = torch.device("xpu")
case "cpu":
device = torch.device("cpu")
if log:
Expand All @@ -81,6 +87,21 @@ def get_safe_dtype(dtype: torch.dtype, device: str | torch.device):
device = device.type
if device == "mps" and dtype == torch.float64:
return torch.float32
if device == "xpu" and dtype == torch.float64:
if hasattr(torch.xpu, "get_device_capability"):
device_capability = torch.xpu.get_device_capability()
# NOTE: Some Intel XPU devices do not support double precision (FP64).
# The `has_fp64` flag is returned by `torch.xpu.get_device_capability()`
# when available; if False, we fall back to float32 for compatibility.
if not device_capability.get("has_fp64", False):
logging.warning(f"Device {device} does not support float64, using float32 instead.")
return torch.float32
else:
logging.warning(
f"Device {device} capability check failed. Assuming no support for float64, using float32 instead."
)
return torch.float32
return dtype
else:
return dtype

Expand All @@ -91,14 +112,16 @@ def is_torch_device_available(try_device: str) -> bool:
return torch.cuda.is_available()
elif try_device == "mps":
return torch.backends.mps.is_available()
elif try_device == "xpu":
return torch.xpu.is_available()
elif try_device == "cpu":
return True
else:
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps or cpu.")
raise ValueError(f"Unknown device {try_device}. Supported devices are: cuda, mps, xpu or cpu.")


def is_amp_available(device: str):
if device in ["cuda", "cpu"]:
if device in ["cuda", "xpu", "cpu"]:
return True
elif device == "mps":
return False
Expand Down
2 changes: 1 addition & 1 deletion tests/async_inference/test_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ def test_raw_observation_to_observation_device_handling():
# Check that all expected keys produce tensors (device placement handled by preprocessor later)
for key, value in observation.items():
if isinstance(value, torch.Tensor):
assert value.device.type in ["cpu", "cuda", "mps"], f"Tensor {key} on unexpected device"
assert value.device.type in ["cpu", "cuda", "mps", "xpu"], f"Tensor {key} on unexpected device"


def test_raw_observation_to_observation_deterministic():
Expand Down