diff --git a/test/xpu/test_nn_xpu.py b/test/xpu/test_nn_xpu.py index 4ff4bcef2..72d2647af 100644 --- a/test/xpu/test_nn_xpu.py +++ b/test/xpu/test_nn_xpu.py @@ -16,6 +16,7 @@ from torch.testing._internal.common_device_type import ( dtypes, instantiate_device_type_tests, + largeTensorTest, ) from torch.testing._internal.common_dtype import get_all_math_dtypes, integral_types from torch.testing._internal.common_utils import ( @@ -3786,6 +3787,39 @@ def test_cross_entropy_loss_2d_out_of_bounds_class_index(self): ) +@dtypes(torch.float, torch.half) +@largeTensorTest("20GB") +@largeTensorTest("64GB", "cpu") +def _test_warp_softmax_64bit_indexing(self, device, dtype): + def run_test(*shape): + x = torch.randn(shape, device=device, dtype=torch.float16, requires_grad=True) + y = F.log_softmax(x, dim=-1, dtype=dtype) + y.backward(y) + with torch.no_grad(): + xx = x.cpu().requires_grad_() + yy = F.log_softmax(xx.float(), dim=-1).to(dtype) + yy.backward(yy) + # workaround to reduce memory usage vs. self.assertEqual, see #84944 + rtol, atol = torch.testing._comparison.get_tolerances( + dtype, rtol=None, atol=None + ) + self.assertTrue(torch.allclose(y.cpu(), yy, rtol=rtol, atol=atol)) + # x is half + rtol, _ = torch.testing._comparison.get_tolerances( + torch.half, rtol=None, atol=None + ) + self.assertTrue(torch.allclose(x.grad.cpu(), xx.grad, rtol=rtol, atol=1e-3)) + + run_test( + 1100000000, 2 + ) # Illegal memory access https://github.com/pytorch/pytorch/issues/52715 + run_test( + 2200000000, 1 + ) # invalid configuration argument https://github.com/pytorch/pytorch/issues/52716 + + +TestNNDeviceType.test_warp_softmax_64bit_indexing = _test_warp_softmax_64bit_indexing + TestNNDeviceType.test_cross_entropy_loss_2d_out_of_bounds_class_index = ( _test_cross_entropy_loss_2d_out_of_bounds_class_index )