diff --git a/torchtnt/utils/test_utils.py b/torchtnt/utils/test_utils.py index 638fe120f2..a9c05c03c9 100644 --- a/torchtnt/utils/test_utils.py +++ b/torchtnt/utils/test_utils.py @@ -93,11 +93,16 @@ def captured_output() -> Generator[Tuple[TextIO, TextIO], None, None]: """Decorator for tests to ensure running on a GPU.""" -skip_if_not_gpu: Callable[..., Callable[..., object]] = unittest.skipUnless( - torch.cuda.is_available(), "Skipping test since GPU is not available" -) - -"""Decorator for tests to ensure running when distributed is available.""" -skip_if_not_distributed: Callable[..., Callable[..., object]] = unittest.skipUnless( - torch.distributed.is_available(), "Skipping test since distributed is not available" -) +def skip_if_not_gpu(func: Callable) -> Callable: + """Decorator that checks for GPU availability at decoration time.""" + return unittest.skipUnless( + torch.cuda.is_available(), + "Skipping test since GPU is not available" + )(func) + +def skip_if_not_distributed(func: Callable) -> Callable: + """Decorator that checks for distributed availability at decoration time.""" + return unittest.skipUnless( + torch.distributed.is_available(), + "Skipping test since distributed is not available" + )(func)