@@ -709,8 +709,9 @@ def reset_memory():
709
709
self .assertLess (memory_streaming , memory_baseline )
710
710
711
711
class TestMultiTensorFlow (TestCase ):
712
-
713
712
713
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
714
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
714
715
def test_multitensor_add_tensors (self ):
715
716
from torchao .quantization .GPTQ_MT import MultiTensor
716
717
tensor1 = torch .randn (3 , 3 )
@@ -721,6 +722,8 @@ def test_multitensor_add_tensors(self):
721
722
self .assertTrue (torch .equal (mt .values [0 ], tensor1 ))
722
723
self .assertTrue (torch .equal (mt .values [1 ], tensor2 ))
723
724
725
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
726
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
724
727
def test_multitensor_pad_unpad (self ):
725
728
from torchao .quantization .GPTQ_MT import MultiTensor
726
729
tensor1 = torch .randn (3 , 3 )
@@ -729,7 +732,9 @@ def test_multitensor_pad_unpad(self):
729
732
self .assertEqual (mt .count , 3 )
730
733
mt .unpad ()
731
734
self .assertEqual (mt .count , 1 )
732
-
735
+
736
+ @unittest .skipIf (not TORCH_VERSION_AT_LEAST_2_4 , "Test only enabled for 2.4+" )
737
+ @unittest .skipIf (not torch .cuda .is_available (), "Need CUDA available" )
733
738
def test_multitensor_inplace_operation (self ):
734
739
from torchao .quantization .GPTQ_MT import MultiTensor
735
740
tensor1 = torch .ones (3 , 3 )
0 commit comments