@@ -741,13 +741,15 @@ def block_predicate_cache_write_output_buf() -> None:
741741
742742########## Testcases for cache_read ##########
743743
744+ use_block_name = tvm .testing .parameter (by_dict = {"block_obj" : False , "block_name" : True })
744745
745- def test_cache_read_elementwise ():
746+
747+ def test_cache_read_elementwise (use_block_name ):
746748 sch = tir .Schedule (elementwise , debug_mask = "all" )
747749 block_b = sch .get_block ("B" )
748750 block_c = sch .get_block ("C" )
749- cached_a = sch .cache_read (block_b , 0 , "global" )
750- cached_b = sch .cache_read (block_c , 0 , "local" )
751+ cached_a = sch .cache_read ("B" if use_block_name else block_b , 0 , "global" )
752+ cached_b = sch .cache_read ("C" if use_block_name else block_c , 0 , "local" )
751753 assert sch .get (cached_a ) == sch .get (sch .get_block ("A_global" ))
752754 assert sch .get (cached_b ) == sch .get (sch .get_block ("B_local" ))
753755 assert sch .get (block_b ) == sch .get (sch .get_block ("B" ))
@@ -756,87 +758,87 @@ def test_cache_read_elementwise():
756758 verify_trace_roundtrip (sch = sch , mod = elementwise )
757759
758760
759- def test_cache_read_under_scope ():
761+ def test_cache_read_under_scope (use_block_name ):
760762 sch = tir .Schedule (access_under_scope , debug_mask = "all" )
761- block_b = sch .get_block ("B" )
762- block_c = sch .get_block ("C" )
763+ block_b = "B" if use_block_name else sch .get_block ("B" )
764+ block_c = "C" if use_block_name else sch .get_block ("C" )
763765 sch .cache_read (block_b , 0 , "local" )
764766 sch .cache_read (block_c , 0 , "global" )
765767 tvm .ir .assert_structural_equal (cache_read_under_scope , sch .mod ["main" ])
766768 verify_trace_roundtrip (sch = sch , mod = access_under_scope )
767769
768770
769- def test_cache_read_opaque_access ():
771+ def test_cache_read_opaque_access (use_block_name ):
770772 sch = tir .Schedule (opaque_access , debug_mask = "all" )
771- block = sch .get_block ("load_store" )
773+ block = "load_store" if use_block_name else sch .get_block ("load_store" )
772774 sch .cache_read (block , 0 , "global" )
773775 tvm .ir .assert_structural_equal (cache_read_opaque_access , sch .mod ["main" ])
774776 verify_trace_roundtrip (sch = sch , mod = opaque_access )
775777
776778
777- def test_cache_read_location ():
779+ def test_cache_read_location (use_block_name ):
778780 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
779- block_b = sch .get_block ("B" )
781+ block_b = "B" if use_block_name else sch .get_block ("B" )
780782 sch .cache_read (block_b , 0 , "global" )
781783 tvm .ir .assert_structural_equal (cache_read_multi_consumer , sch .mod ["main" ])
782784 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
783785
784786
785- def test_continuous_cache_read ():
787+ def test_continuous_cache_read (use_block_name ):
786788 sch = tir .Schedule (elementwise , debug_mask = "all" )
787- block_c = sch .get_block ("C" )
789+ block_c = "C" if use_block_name else sch .get_block ("C" )
788790 sch .cache_read (block_c , 0 , "shared" )
789791 sch .cache_read (block_c , 0 , "local" )
790792 tvm .ir .assert_structural_equal (continuous_cache_read , sch .mod ["main" ])
791793 verify_trace_roundtrip (sch = sch , mod = elementwise )
792794
793795
794- def test_cache_read_with_block_predicate ():
796+ def test_cache_read_with_block_predicate (use_block_name ):
795797 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
796- block = sch .get_block ("consumer" )
798+ block = "consumer" if use_block_name else sch .get_block ("consumer" )
797799 sch .cache_read (block , 0 , "shared" )
798800 tvm .ir .assert_structural_equal (block_predicate_cache_read , sch .mod ["main" ])
799801 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
800802
801803
802- def test_cache_read_non_int32_shape ():
804+ def test_cache_read_non_int32_shape (use_block_name ):
803805 sch = tir .Schedule (elementwise_shape_int64 , debug_mask = "all" )
804- block_b = sch .get_block ("B" )
806+ block_b = "B" if use_block_name else sch .get_block ("B" )
805807 sch .cache_read (block_b , 0 , "global" )
806808 tvm .ir .assert_structural_equal (cache_read_shape_int64 , sch .mod ["main" ])
807809 verify_trace_roundtrip (sch = sch , mod = elementwise_shape_int64 )
808810
809811
810- def test_cache_read_fail_multi_producer ():
812+ def test_cache_read_fail_multi_producer (use_block_name ):
811813 sch = tir .Schedule (func_multi_producer , debug_mask = "all" )
812- block_b = sch .get_block ("B" )
814+ block_b = "B" if use_block_name else sch .get_block ("B" )
813815 with pytest .raises (tvm .tir .ScheduleError ):
814816 sch .cache_read (block_b , 0 , "global" )
815817
816818
817- def test_cache_read_fail_index_out_of_bound ():
819+ def test_cache_read_fail_index_out_of_bound (use_block_name ):
818820 sch = tir .Schedule (elementwise , debug_mask = "all" )
819- block_b = sch .get_block ("B" )
821+ block_b = "B" if use_block_name else sch .get_block ("B" )
820822 with pytest .raises (tvm .tir .ScheduleError ):
821823 sch .cache_read (block_b , 1 , "global" )
822824
823825
824- def test_cache_read_fail_invalid_storage_scope ():
826+ def test_cache_read_fail_invalid_storage_scope (use_block_name ):
825827 sch = tir .Schedule (elementwise , debug_mask = "all" )
826- block_b = sch .get_block ("B" )
828+ block_b = "B" if use_block_name else sch .get_block ("B" )
827829 with pytest .raises (tvm .tir .ScheduleError ):
828830 sch .cache_read (block_b , 0 , "test_scope" )
829831
830832
831833########## Testcases for cache_write ##########
832834
833835
834- def test_cache_write_elementwise ():
836+ def test_cache_write_elementwise (use_block_name ):
835837 sch = tir .Schedule (elementwise , debug_mask = "all" )
836838 block_b = sch .get_block ("B" )
837839 block_c = sch .get_block ("C" )
838- cached_b = sch .cache_write (block_b , 0 , "local" )
839- cached_c = sch .cache_write (block_c , 0 , "global" )
840+ cached_b = sch .cache_write ("B" if use_block_name else block_b , 0 , "local" )
841+ cached_c = sch .cache_write ("C" if use_block_name else block_c , 0 , "global" )
840842 assert sch .get (cached_b ) == sch .get (sch .get_block ("B_local" ))
841843 assert sch .get (cached_c ) == sch .get (sch .get_block ("C_global" ))
842844 assert sch .get (block_b ) == sch .get (sch .get_block ("B" ))
@@ -845,10 +847,10 @@ def test_cache_write_elementwise():
845847 verify_trace_roundtrip (sch = sch , mod = elementwise )
846848
847849
848- def test_cache_write_under_scope ():
850+ def test_cache_write_under_scope (use_block_name ):
849851 sch = tir .Schedule (access_under_scope , debug_mask = "all" )
850- block_a = sch .get_block ("A" )
851- block_b = sch .get_block ("B" )
852+ block_a = "A" if use_block_name else sch .get_block ("A" )
853+ block_b = "B" if use_block_name else sch .get_block ("B" )
852854 block_scope = sch .get_block ("scope" )
853855 sch .cache_write (block_a , 0 , "local" )
854856 sch .cache_write (block_b , 0 , "global" )
@@ -857,70 +859,70 @@ def test_cache_write_under_scope():
857859 verify_trace_roundtrip (sch = sch , mod = access_under_scope )
858860
859861
860- def test_cache_write_opaque_access ():
862+ def test_cache_write_opaque_access (use_block_name ):
861863 sch = tir .Schedule (opaque_access , debug_mask = "all" )
862- block_store = sch .get_block ("load_store" )
863- block_opaque = sch .get_block ("opaque" )
864- block_match_buffer = sch .get_block ("match_buffer" )
864+ block_store = "load_store" if use_block_name else sch .get_block ("load_store" )
865+ block_opaque = "opaque" if use_block_name else sch .get_block ("opaque" )
866+ block_match_buffer = "match_buffer" if use_block_name else sch .get_block ("match_buffer" )
865867 sch .cache_write (block_store , 0 , "global" )
866868 sch .cache_write (block_opaque , 0 , "global" )
867869 sch .cache_write (block_match_buffer , 0 , "global" )
868870 tvm .ir .assert_structural_equal (cache_write_opaque_access , sch .mod ["main" ])
869871 verify_trace_roundtrip (sch = sch , mod = opaque_access )
870872
871873
872- def test_cache_write_location ():
874+ def test_cache_write_location (use_block_name ):
873875 sch = tir .Schedule (func_multi_consumer , debug_mask = "all" )
874- block_a = sch .get_block ("A" )
876+ block_a = "A" if use_block_name else sch .get_block ("A" )
875877 sch .cache_write (block_a , 0 , "global" )
876878 tvm .ir .assert_structural_equal (cache_write_multi_consumer , sch .mod ["main" ])
877879 verify_trace_roundtrip (sch = sch , mod = func_multi_consumer )
878880
879881
880- def test_continuous_cache_write ():
882+ def test_continuous_cache_write (use_block_name ):
881883 sch = tir .Schedule (elementwise , debug_mask = "all" )
882- block_b = sch .get_block ("B" )
884+ block_b = "B" if use_block_name else sch .get_block ("B" )
883885 sch .cache_write (block_b , 0 , "shared" )
884886 sch .cache_write (block_b , 0 , "local" )
885887 tvm .ir .assert_structural_equal (continuous_cache_write , sch .mod ["main" ])
886888 verify_trace_roundtrip (sch = sch , mod = elementwise )
887889
888890
889- def test_cache_write_with_block_predicate ():
891+ def test_cache_write_with_block_predicate (use_block_name ):
890892 # cache write for intermediate buffer
891893 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
892- block = sch .get_block ("producer" )
894+ block = "producer" if use_block_name else sch .get_block ("producer" )
893895 sch .cache_write (block , 0 , "shared" )
894896 tvm .ir .assert_structural_equal (block_predicate_cache_write_intermediate_buf , sch .mod ["main" ])
895897 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
896898 # cache write for external buffer
897899 sch = tir .Schedule (func_with_block_predicate , debug_mask = "all" )
898- block = sch .get_block ("consumer" )
900+ block = "consumer" if use_block_name else sch .get_block ("consumer" )
899901 sch .cache_write (block , 0 , "shared" )
900902 tvm .ir .assert_structural_equal (block_predicate_cache_write_output_buf , sch .mod ["main" ])
901903 verify_trace_roundtrip (sch = sch , mod = func_with_block_predicate )
902904
903905
904- def test_cache_write_fail_multi_producer ():
906+ def test_cache_write_fail_multi_producer (use_block_name ):
905907 sch = tir .Schedule (func_multi_producer , debug_mask = "all" )
906- block_a0 = sch .get_block ("A0" )
907- block_a1 = sch .get_block ("A1" )
908+ block_a0 = "A0" if use_block_name else sch .get_block ("A0" )
909+ block_a1 = "A1" if use_block_name else sch .get_block ("A1" )
908910 with pytest .raises (tvm .tir .ScheduleError ):
909911 sch .cache_write (block_a0 , 0 , "global" )
910912 with pytest .raises (tvm .tir .ScheduleError ):
911913 sch .cache_write (block_a1 , 0 , "global" )
912914
913915
914- def test_cache_write_fail_index_out_of_bound ():
916+ def test_cache_write_fail_index_out_of_bound (use_block_name ):
915917 sch = tir .Schedule (elementwise , debug_mask = "all" )
916- block_b = sch .get_block ("B" )
918+ block_b = "B" if use_block_name else sch .get_block ("B" )
917919 with pytest .raises (tvm .tir .ScheduleError ):
918920 sch .cache_write (block_b , 1 , "global" )
919921
920922
921- def test_cache_write_fail_invalid_storage_scope ():
923+ def test_cache_write_fail_invalid_storage_scope (use_block_name ):
922924 sch = tir .Schedule (elementwise , debug_mask = "all" )
923- block_b = sch .get_block ("B" )
925+ block_b = "B" if use_block_name else sch .get_block ("B" )
924926 with pytest .raises (tvm .tir .ScheduleError ):
925927 sch .cache_write (block_b , 0 , "test_scope" )
926928
0 commit comments