diff --git a/src/ATen/native/xpu/sycl/Shape.cpp b/src/ATen/native/xpu/sycl/Shape.cpp index 12bd0ba66d..c8e4236401 100644 --- a/src/ATen/native/xpu/sycl/Shape.cpp +++ b/src/ATen/native/xpu/sycl/Shape.cpp @@ -394,6 +394,7 @@ void cat_out_kernel( kHalf, kBool, kBFloat16, + AT_EXPAND(AT_FLOAT8_TYPES), AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES)); } else { offset = 0; diff --git a/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp b/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp index 562f994a62..1de7da7674 100644 --- a/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp +++ b/src/ATen/native/xpu/sycl/TensorCompareKernels.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -78,10 +79,16 @@ struct ClampScalarFunctor { }; void where_kernel(TensorIterator& iter) { - AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4( - kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] { - gpu_kernel(iter, WhereFunctor()); - }); + AT_DISPATCH_V2( + iter.dtype(), + "where_xpu", + [&] { gpu_kernel(iter, WhereFunctor()); }, + kComplexHalf, + kHalf, + kBFloat16, + kBool, + AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX), + AT_EXPAND(AT_FLOAT8_TYPES)); } void isposinf_kernel(TensorIteratorBase& iter) { diff --git a/test/regressions/test_cat.py b/test/regressions/test_cat.py index ac5d60808c..e2e4218e56 100644 --- a/test/regressions/test_cat.py +++ b/test/regressions/test_cat.py @@ -4,6 +4,67 @@ class TestTorchMethod(TestCase): + # Define float8 dtypes for the focused test + FLOAT8_DTYPES = ( + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2, + torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, + ) + + def _create_input_tensors(self, shape, dtype, memory_format=None): + # Always generate random data using a CPU-compatible dtype (float32) + # to avoid the "not implemented" error for float8 on CPU. + tensor = torch.randn(shape, dtype=torch.float32) + + # Convert to the target testing dtype + tensor = tensor.to(dtype) + + # Apply memory format if specified + if memory_format is not None: + tensor = tensor.to(memory_format=memory_format) + + return tensor + + def _test_cat_float8_core(self, tensors, dim, dtype): + """Core function to test torch.cat for float8, using tolerances.""" + + # --- CPU Reference Calculation (High Precision) --- + # Convert inputs to float32 on CPU for golden reference calculation + ref_tensors = [t.cpu().to(torch.float32) for t in tensors] + + # Calculate CPU reference result + res_cpu = torch.cat(ref_tensors, dim=dim) + + # --- XPU Calculation --- + # Convert inputs to XPU + xpu_tensors = [t.xpu() for t in tensors] + res_xpu = torch.cat(xpu_tensors, dim=dim) + + # Float8 is lossy, use higher tolerance (rtol=1e-2, atol=1e-2) + rtol = 1e-2 + atol = 1e-2 + + # Convert XPU result to float32 on CPU before comparison to match res_cpu's dtype. + res_xpu_f32_on_cpu = res_xpu.cpu().to(torch.float32) + + self.assertEqual(res_cpu, res_xpu_f32_on_cpu, rtol=rtol, atol=atol) + + def test_cat_float8_simple(self): + """Test torch.cat correctness across float8 dtypes using simple tensors.""" + for dtype in self.FLOAT8_DTYPES: + with self.subTest(dtype=dtype): + # Use simple 3D shape (2, 4, 3) and concatenate along dim 1 + user_cpu1 = self._create_input_tensors([2, 4, 3], dtype=dtype) + user_cpu2 = self._create_input_tensors([2, 2, 3], dtype=dtype) + user_cpu3 = self._create_input_tensors([2, 6, 3], dtype=dtype) + + tensors = (user_cpu1, user_cpu2, user_cpu3) + dim = 1 + + self._test_cat_float8_core(tensors, dim, dtype) + def test_cat_8d(self, dtype=torch.float): input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype) input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype) diff --git a/test/regressions/test_where.py b/test/regressions/test_where.py new file mode 100644 index 0000000000..4cf4b79394 --- /dev/null +++ b/test/regressions/test_where.py @@ -0,0 +1,92 @@ +# Owner(s): ["module: intel"] +import torch +from torch.testing._internal.common_utils import TestCase + + +class TestTorchWhereMethod(TestCase): + # Define float8 dtypes + FLOAT8_DTYPES = ( + torch.float8_e5m2, + torch.float8_e4m3fn, + torch.float8_e4m3fnuz, + torch.float8_e5m2fnuz, + torch.float8_e8m0fnu, + ) + + # Define the set of all dtypes to be tested + TEST_DTYPES = ( + torch.float32, + torch.float64, + torch.half, + torch.bfloat16, + ) + FLOAT8_DTYPES + + def _test_where_fn(self, dtype): + """Core function to test torch.where(condition, x, y) correctness.""" + + # 1. Input Tensors (x and y) + x = torch.tensor([[10.0, 20.0], [30.0, 40.0]], dtype=dtype) + y = torch.tensor([[-1.0, -2.0], [-3.0, -4.0]], dtype=dtype) + # Condition must be bool + condition = torch.tensor([[True, False], [False, True]], dtype=torch.bool) + + # --- 1. CPU Reference Calculation and Tolerance Setting --- + + if dtype in self.FLOAT8_DTYPES: + # FP8: Use float32 as reference type for comparison + x_ref = x.cpu().to(torch.float32) + y_ref = y.cpu().to(torch.float32) + rtol = 1e-2 + atol = 1e-2 + else: + # Non-FP8: Use original dtype as reference type + x_ref = x.cpu() + y_ref = y.cpu() + rtol = 1e-5 + atol = 1e-5 + + condition_ref = condition.cpu() + res_ref = torch.where(condition_ref, x_ref, y_ref) + + # --- 2. XPU Operation (Default) --- + x_xpu = x.xpu() + y_xpu = y.xpu() + condition_xpu = condition.xpu() + + res_xpu = torch.where(condition_xpu, x_xpu, y_xpu) + + # Prepare XPU result for comparison (must match res_ref dtype) + if dtype in self.FLOAT8_DTYPES: + # FP8: Convert XPU result to float32 + res_xpu_to_compare = res_xpu.cpu().to(torch.float32) + else: + # Non-FP8: Pull to CPU, keeping original dtype + res_xpu_to_compare = res_xpu.cpu() + + # Compare: res_ref vs res_xpu_to_compare + self.assertEqual(res_ref, res_xpu_to_compare, rtol=rtol, atol=atol) + + # --- 3. Test the version with out= argument --- + + # Create output tensor on XPU + res_xpu_out = torch.empty_like(res_xpu, dtype=dtype).xpu() + torch.where(condition_xpu, x_xpu, y_xpu, out=res_xpu_out) + + # Prepare XPU 'out' result for comparison + if dtype in self.FLOAT8_DTYPES: + # FP8: Convert XPU result to float32 + res_xpu_out_to_compare = res_xpu_out.cpu().to(torch.float32) + else: + # Non-FP8: Pull to CPU, keeping original dtype + res_xpu_out_to_compare = res_xpu_out.cpu() + + # Compare: res_ref vs res_xpu_out_to_compare + self.assertEqual(res_ref, res_xpu_out_to_compare, rtol=rtol, atol=atol) + + def test_where(self): + """Test torch.where() correctness across all supported dtypes, including float8.""" + for dtype in self.TEST_DTYPES: + # Use string conversion for better subTest reporting + dtype_name = str(dtype).split(".")[-1] + with self.subTest(dtype=dtype_name): + self._test_where_fn(dtype)