@@ -247,6 +247,10 @@ def test_mlp_contiguous_relu_compile_cutlass(self):
247247 @unittest .skipIf (IS_WINDOWS , "torch.compile not supported on windows" )
248248 @unittest .skipIf ("cusparselt" not in SEMI_STRUCTURED_SUPPORTED_BACKENDS , "cusparselt not supported on this machine" )
249249 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
250+ @unittest .skipIf (
251+ "RelWithAssert" in torch .__config__ .show (),
252+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
253+ )
250254 def test_sp24_compile (self ) -> None :
251255 x = torch .randn ([1024 , 512 ], device = "cuda" , dtype = torch .float16 , requires_grad = True )
252256
@@ -576,6 +580,10 @@ def setUp(self):
576580
577581 @training_dtypes
578582 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
583+ @unittest .skipIf (
584+ "RelWithAssert" in torch .__config__ .show (),
585+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
586+ )
579587 def test_prune_dense_static_sort (self , dtype ) -> None :
580588 # Ideally we would like to clone and compare, but that won't work because the sorting order will be different
581589 # instead we pass the pruned matrix to the CUDA implementation and preserve the sparsity pattern.
@@ -621,6 +629,10 @@ def test_prune_dense_static_sort(self, dtype) -> None:
621629 @training_dtypes
622630 @parametrize_backends
623631 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
632+ @unittest .skipIf (
633+ "RelWithAssert" in torch .__config__ .show (),
634+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
635+ )
624636 def test_pruning_algo_largest_abs_values_greedy (self , dtype , backend ) -> None :
625637 inp = torch .tensor (
626638 [[4 , 3 , 2 , 1 ], [- 1 , - 3 , 0.6 , 0.5 ], [1 , 2 , 3 , 4 ], [10 , 2 , - 1 , 5 ]],
@@ -658,6 +670,10 @@ def test_gemm(self, dtype) -> None:
658670 @training_dtypes
659671 @parametrize_backends
660672 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
673+ @unittest .skipIf (
674+ "RelWithAssert" in torch .__config__ .show (),
675+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
676+ )
661677 def test_pack_both_ways_meta_correctness (self , dtype , backend ) -> None :
662678 M , N = 128 , 256
663679 # Construct x to make sure we always have exactly 8 elements per 4x4 tile
@@ -692,6 +708,10 @@ def test_pack_both_ways_meta_correctness(self, dtype, backend) -> None:
692708
693709 @training_dtypes
694710 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
711+ @unittest .skipIf (
712+ "RelWithAssert" in torch .__config__ .show (),
713+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
714+ )
695715 def test_pack_both_ways_id (self , dtype ) -> None :
696716 N = 512
697717 torch .manual_seed (0 )
@@ -729,6 +749,10 @@ def test_pack_both_ways_id(self, dtype) -> None:
729749
730750 @training_dtypes
731751 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
752+ @unittest .skipIf (
753+ "RelWithAssert" in torch .__config__ .show (),
754+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
755+ )
732756 def test_pack_both_ways_edge_case1 (self , dtype ) -> None :
733757 # In this case, the heuristic will keep 7 values out of 16
734758 # instead of 8. let's see how the kernel handles this
@@ -754,6 +778,10 @@ def test_pack_both_ways_edge_case1(self, dtype) -> None:
754778
755779 @training_dtypes
756780 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
781+ @unittest .skipIf (
782+ "RelWithAssert" in torch .__config__ .show (),
783+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
784+ )
757785 def test_sp24_apply (self , dtype ) -> None :
758786 M , N = 256 , 1024
759787 x = torch .randn ([M , N ], dtype = dtype , device = "cuda" )
@@ -770,6 +798,10 @@ def test_sp24_apply(self, dtype) -> None:
770798
771799 @training_dtypes
772800 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
801+ @unittest .skipIf (
802+ "RelWithAssert" in torch .__config__ .show (),
803+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
804+ )
773805 def test_sp24_apply_dense (self , dtype ) -> None :
774806 M , N = 256 , 1024
775807 x = torch .randn ([M , N ], dtype = dtype , device = "cuda" )
@@ -808,6 +840,10 @@ def test_sp24_apply_dense(self, dtype) -> None:
808840
809841 @training_dtypes
810842 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
843+ @unittest .skipIf (
844+ "RelWithAssert" in torch .__config__ .show (),
845+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
846+ )
811847 def test_sp24_matmuls (self , dtype ) -> None :
812848 M , N , K = 64 , 256 , 1024
813849 a = torch .randn ([M , K ], device = "cuda" , dtype = dtype )
@@ -843,6 +879,10 @@ def test_sp24_matmuls(self, dtype) -> None:
843879 )
844880
845881 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
882+ @unittest .skipIf (
883+ "RelWithAssert" in torch .__config__ .show (),
884+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
885+ )
846886 def test_sp24_matmuls_mat_vec (self ) -> None :
847887 a = torch .randn ([64 , 128 ], device = "cuda" , dtype = torch .float16 )
848888 b = torch .randn ([128 ], device = "cuda" , dtype = torch .float16 )
@@ -853,6 +893,10 @@ def test_sp24_matmuls_mat_vec(self) -> None:
853893 torch .testing .assert_close (a_s @ b , (a * a_m ) @ b , ** atol_rtol_kw [a .dtype ])
854894
855895 @unittest .skipIf (TEST_WITH_ROCM , "Not supported on ROCm" )
896+ @unittest .skipIf (
897+ "RelWithAssert" in torch .__config__ .show (),
898+ "failing in debug build, see https://github.com/pytorch/pytorch/pull/165158 for context" ,
899+ )
856900 def test_sp24_matmuls_bmm (self ) -> None :
857901 a = torch .randn ([64 , 128 ], device = "cuda" , dtype = torch .float16 )
858902 b = torch .randn ([5 , 6 , 128 ], device = "cuda" , dtype = torch .float16 )
0 commit comments