Skip to content

Commit 83d8cdd

Browse files
skip MT tests in test_quant_api
1 parent 2ed51c5 commit 83d8cdd

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

test/quantization/test_quant_api.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -709,8 +709,9 @@ def reset_memory():
709709
self.assertLess(memory_streaming, memory_baseline)
710710

711711
class TestMultiTensorFlow(TestCase):
712-
713712

713+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
714+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
714715
def test_multitensor_add_tensors(self):
715716
from torchao.quantization.GPTQ_MT import MultiTensor
716717
tensor1 = torch.randn(3, 3)
@@ -721,6 +722,8 @@ def test_multitensor_add_tensors(self):
721722
self.assertTrue(torch.equal(mt.values[0], tensor1))
722723
self.assertTrue(torch.equal(mt.values[1], tensor2))
723724

725+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
726+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
724727
def test_multitensor_pad_unpad(self):
725728
from torchao.quantization.GPTQ_MT import MultiTensor
726729
tensor1 = torch.randn(3, 3)
@@ -729,7 +732,9 @@ def test_multitensor_pad_unpad(self):
729732
self.assertEqual(mt.count, 3)
730733
mt.unpad()
731734
self.assertEqual(mt.count, 1)
732-
735+
736+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
737+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
733738
def test_multitensor_inplace_operation(self):
734739
from torchao.quantization.GPTQ_MT import MultiTensor
735740
tensor1 = torch.ones(3, 3)

0 commit comments

Comments
 (0)