From 4ec9e1620fa76497c7941bf60dd6450811937ed0 Mon Sep 17 00:00:00 2001 From: Alex Morehead Date: Wed, 10 Sep 2025 14:27:57 -0700 Subject: [PATCH] Update test_utils.py --- torchtnt/utils/test_utils.py | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/torchtnt/utils/test_utils.py b/torchtnt/utils/test_utils.py index 638fe120f2..a9c05c03c9 100644 --- a/torchtnt/utils/test_utils.py +++ b/torchtnt/utils/test_utils.py @@ -93,11 +93,16 @@ def captured_output() -> Generator[Tuple[TextIO, TextIO], None, None]: """Decorator for tests to ensure running on a GPU.""" -skip_if_not_gpu: Callable[..., Callable[..., object]] = unittest.skipUnless( - torch.cuda.is_available(), "Skipping test since GPU is not available" -) - -"""Decorator for tests to ensure running when distributed is available.""" -skip_if_not_distributed: Callable[..., Callable[..., object]] = unittest.skipUnless( - torch.distributed.is_available(), "Skipping test since distributed is not available" -) +def skip_if_not_gpu(func: Callable) -> Callable: + """Decorator that checks for GPU availability at decoration time.""" + return unittest.skipUnless( + torch.cuda.is_available(), + "Skipping test since GPU is not available" + )(func) + +def skip_if_not_distributed(func: Callable) -> Callable: + """Decorator that checks for distributed availability at decoration time.""" + return unittest.skipUnless( + torch.distributed.is_available(), + "Skipping test since distributed is not available" + )(func)