Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/ATen/native/xpu/sycl/Shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -394,6 +394,7 @@ void cat_out_kernel(
kHalf,
kBool,
kBFloat16,
AT_EXPAND(AT_FLOAT8_TYPES),
AT_EXPAND(AT_BAREBONES_UNSIGNED_TYPES));
} else {
offset = 0;
Expand Down
15 changes: 11 additions & 4 deletions src/ATen/native/xpu/sycl/TensorCompareKernels.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <ATen/Dispatch.h>
#include <ATen/Dispatch_v2.h>
#include <ATen/NumericUtils.h>
#include <ATen/native/TensorCompare.h>
#include <ATen/native/TensorIterator.h>
Expand Down Expand Up @@ -78,10 +79,16 @@ struct ClampScalarFunctor {
};

void where_kernel(TensorIterator& iter) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBFloat16, kBool, iter.dtype(), "where_xpu", [&] {
gpu_kernel(iter, WhereFunctor<scalar_t>());
});
AT_DISPATCH_V2(
iter.dtype(),
"where_xpu",
[&] { gpu_kernel(iter, WhereFunctor<scalar_t>()); },
kComplexHalf,
kHalf,
kBFloat16,
kBool,
AT_EXPAND(AT_ALL_TYPES_AND_COMPLEX),
AT_EXPAND(AT_FLOAT8_TYPES));
}

void isposinf_kernel(TensorIteratorBase& iter) {
Expand Down
61 changes: 61 additions & 0 deletions test/regressions/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,67 @@


class TestTorchMethod(TestCase):
# Define float8 dtypes for the focused test
FLOAT8_DTYPES = (
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
)

def _create_input_tensors(self, shape, dtype, memory_format=None):
# Always generate random data using a CPU-compatible dtype (float32)
# to avoid the "not implemented" error for float8 on CPU.
tensor = torch.randn(shape, dtype=torch.float32)

# Convert to the target testing dtype
tensor = tensor.to(dtype)

# Apply memory format if specified
if memory_format is not None:
tensor = tensor.to(memory_format=memory_format)

return tensor

def _test_cat_float8_core(self, tensors, dim, dtype):
"""Core function to test torch.cat for float8, using tolerances."""

# --- CPU Reference Calculation (High Precision) ---
# Convert inputs to float32 on CPU for golden reference calculation
ref_tensors = [t.cpu().to(torch.float32) for t in tensors]

# Calculate CPU reference result
res_cpu = torch.cat(ref_tensors, dim=dim)

# --- XPU Calculation ---
# Convert inputs to XPU
xpu_tensors = [t.xpu() for t in tensors]
res_xpu = torch.cat(xpu_tensors, dim=dim)

# Float8 is lossy, use higher tolerance (rtol=1e-2, atol=1e-2)
rtol = 1e-2
atol = 1e-2

# Convert XPU result to float32 on CPU before comparison to match res_cpu's dtype.
res_xpu_f32_on_cpu = res_xpu.cpu().to(torch.float32)

self.assertEqual(res_cpu, res_xpu_f32_on_cpu, rtol=rtol, atol=atol)

def test_cat_float8_simple(self):
"""Test torch.cat correctness across float8 dtypes using simple tensors."""
for dtype in self.FLOAT8_DTYPES:
with self.subTest(dtype=dtype):
# Use simple 3D shape (2, 4, 3) and concatenate along dim 1
user_cpu1 = self._create_input_tensors([2, 4, 3], dtype=dtype)
user_cpu2 = self._create_input_tensors([2, 2, 3], dtype=dtype)
user_cpu3 = self._create_input_tensors([2, 6, 3], dtype=dtype)

tensors = (user_cpu1, user_cpu2, user_cpu3)
dim = 1

self._test_cat_float8_core(tensors, dim, dtype)

def test_cat_8d(self, dtype=torch.float):
input1 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)
input2 = torch.randn([256, 8, 8, 3, 3, 3, 3], dtype=dtype)
Expand Down
92 changes: 92 additions & 0 deletions test/regressions/test_where.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# Owner(s): ["module: intel"]
import torch
from torch.testing._internal.common_utils import TestCase


class TestTorchWhereMethod(TestCase):
# Define float8 dtypes
FLOAT8_DTYPES = (
torch.float8_e5m2,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
)

# Define the set of all dtypes to be tested
TEST_DTYPES = (
torch.float32,
torch.float64,
torch.half,
torch.bfloat16,
) + FLOAT8_DTYPES

def _test_where_fn(self, dtype):
"""Core function to test torch.where(condition, x, y) correctness."""

# 1. Input Tensors (x and y)
x = torch.tensor([[10.0, 20.0], [30.0, 40.0]], dtype=dtype)
y = torch.tensor([[-1.0, -2.0], [-3.0, -4.0]], dtype=dtype)
# Condition must be bool
condition = torch.tensor([[True, False], [False, True]], dtype=torch.bool)

# --- 1. CPU Reference Calculation and Tolerance Setting ---

if dtype in self.FLOAT8_DTYPES:
# FP8: Use float32 as reference type for comparison
x_ref = x.cpu().to(torch.float32)
y_ref = y.cpu().to(torch.float32)
rtol = 1e-2
atol = 1e-2
else:
# Non-FP8: Use original dtype as reference type
x_ref = x.cpu()
y_ref = y.cpu()
rtol = 1e-5
atol = 1e-5

condition_ref = condition.cpu()
res_ref = torch.where(condition_ref, x_ref, y_ref)

# --- 2. XPU Operation (Default) ---
x_xpu = x.xpu()
y_xpu = y.xpu()
condition_xpu = condition.xpu()

res_xpu = torch.where(condition_xpu, x_xpu, y_xpu)

# Prepare XPU result for comparison (must match res_ref dtype)
if dtype in self.FLOAT8_DTYPES:
# FP8: Convert XPU result to float32
res_xpu_to_compare = res_xpu.cpu().to(torch.float32)
else:
# Non-FP8: Pull to CPU, keeping original dtype
res_xpu_to_compare = res_xpu.cpu()

# Compare: res_ref vs res_xpu_to_compare
self.assertEqual(res_ref, res_xpu_to_compare, rtol=rtol, atol=atol)

# --- 3. Test the version with out= argument ---

# Create output tensor on XPU
res_xpu_out = torch.empty_like(res_xpu, dtype=dtype).xpu()
torch.where(condition_xpu, x_xpu, y_xpu, out=res_xpu_out)

# Prepare XPU 'out' result for comparison
if dtype in self.FLOAT8_DTYPES:
# FP8: Convert XPU result to float32
res_xpu_out_to_compare = res_xpu_out.cpu().to(torch.float32)
else:
# Non-FP8: Pull to CPU, keeping original dtype
res_xpu_out_to_compare = res_xpu_out.cpu()

# Compare: res_ref vs res_xpu_out_to_compare
self.assertEqual(res_ref, res_xpu_out_to_compare, rtol=rtol, atol=atol)

def test_where(self):
"""Test torch.where() correctness across all supported dtypes, including float8."""
for dtype in self.TEST_DTYPES:
# Use string conversion for better subTest reporting
dtype_name = str(dtype).split(".")[-1]
with self.subTest(dtype=dtype_name):
self._test_where_fn(dtype)
Loading