diff --git a/tests/unittest/conftest.py b/tests/unittest/conftest.py index 259244222e2..58255435330 100644 --- a/tests/unittest/conftest.py +++ b/tests/unittest/conftest.py @@ -14,6 +14,7 @@ # limitations under the License. # # Force resource release after test import pytest +import torch import tqdm @@ -79,3 +80,12 @@ def pytest_sessionstart(session): # To counter TransformerEngine v2.3's lazy_compile deferral, # which will cause Pytest thinks there's a thread leakage. import torch._inductor.async_compile # noqa: F401 + + +@pytest.fixture(autouse=True) +def torch_empty_cache() -> None: + """ + Automatically empty the torch CUDA cache before each test, to reduce risk of OOM errors. + """ + if torch.cuda.is_available(): + torch.cuda.empty_cache()