@@ -946,103 +946,83 @@ def test_free_variables():
946946 )
947947
948948
949- def test_padding () :
949+ class TestPadding :
950950 x = tvm .tir .Var ("x" , "int32" )
951951 y = tvm .tir .Var ("y" , "int32" )
952952 fld = tvm .tir .floordiv
953953 flm = tvm .tir .floormod
954954
955- # left padding only, offset divisible
956- sum = 64 + y
957- dom_map = var_dom ([(y , 192 )])
958- assert_iter_sum_pattern (
959- {fld (sum , 32 ): (6 , 2 , 1 ), flm (sum , 32 ): (32 , 0 , 1 )},
960- dom_map ,
961- check_level = "bijective" ,
962- )
963-
964- # left padding only, offset non-divisible
965- sum = 80 + y
966- dom_map = var_dom ([(y , 176 )])
967- assert_iter_sum_pattern (
968- {fld (sum , 32 ): (6 , 2 , 1 )},
969- dom_map ,
970- )
971- assert_iter_sum_pattern (
972- {flm (fld (sum , 2 ), 16 ): (16 , 0 , 1 ), flm (sum , 2 ): (2 , 0 , 1 )},
973- dom_map ,
974- )
975- assert_iter_sum_failure ({fld (sum , 32 ), flm (sum , 32 )}, dom_map )
976- assert_iter_sum_failure ({fld (sum , 32 ), fld (sum , 4 )}, dom_map )
977-
978- # right padding only, offset divisible
979- sum = x * 32 + y * 8
980- dom_map = var_dom ([(x , 5 ), (y , 4 )])
981- assert_iter_sum_pattern (
982- {fld (sum , 16 ): (10 , 0 , 1 ), flm (sum , 16 ): (2 , 0 , 8 )},
983- dom_map ,
984- )
985- assert_iter_sum_failure ({fld (sum , 5 )}, dom_map )
986-
987- # right padding only, offset non-divisible
988- dom_map = var_dom ([(x , 26 )])
989- assert_iter_sum_pattern (
990- {fld (x , 15 ): (2 , 0 , 1 )},
991- dom_map ,
992- )
993- assert_iter_sum_pattern (
994- {flm (fld (x , 3 ), 5 ): (5 , 0 , 1 ), flm (x , 3 ): (3 , 0 , 1 )},
995- dom_map ,
996- )
997-
998- # padding constants on both side
999- sum = x + 71
1000- dom_map = var_dom ([(x , 45 )])
1001- assert_iter_sum_pattern ({fld (sum , 32 ): (2 , 2 , 1 )}, dom_map )
1002- assert_iter_sum_pattern (
1003- {flm (fld (x , 4 ), 8 ): (8 , 0 , 1 ), flm (x , 4 ): (4 , 0 , 1 )},
1004- dom_map ,
1005- )
1006-
1007- # padding for free iteration part
1008- sum = x * 360 + y
1009- dom_map = var_dom ([(y , 360 )])
1010- assert_iter_sum_pattern ({fld (sum , 16 ): (23 , fld (x * 360 - flm (x , 2 ) * 8 , 16 ), 1 )}, dom_map )
1011- assert_iter_sum_pattern ({flm (x * 360 + y , 16 ): (16 , 0 , 1 )}, dom_map )
1012-
1013- # multiple split with same mark offset, could
1014- # be surjective on missing (padded // LCM)
1015- assert_iter_sum_pattern (
1016- {
1017- flm (x + 10 , 3 ): (3 , 0 ),
1018- flm (fld (x + 10 , 3 ), 4 ): (4 , 0 ),
1019- flm (fld (fld (x + 10 , 3 ), 4 ), 5 ): (5 , 0 ),
1020- },
1021- var_dom ([(x , 240 )]),
1022- )
1023- assert_iter_sum_failure (
1024- {
1025- flm (x + 10 , 3 ),
1026- flm (fld (x + 10 , 3 ), 4 ),
1027- flm (fld (fld (x + 10 , 3 ), 4 ), 5 ),
1028- fld (fld (fld (x + 10 , 3 ), 4 ), 5 ),
1029- },
1030- var_dom ([(x , 240 )]),
1031- )
1032-
1033- # different offsets on splits
1034- assert_iter_sum_pattern (
1035- {
1036- flm (x + 1 , 3 ): (3 , 0 ),
1037- flm (fld (x + 10 , 3 ) + 2 , 4 ): (4 , 0 ),
1038- flm (fld (fld (x + 10 , 3 ), 4 ) + 3 , 5 ): (5 , 0 ),
1039- },
1040- var_dom ([(x , 240 )]),
955+ positive_test_case = tvm .testing .parameter (
956+ # left padding only, offset divisible
957+ ({y : 192 }, {fld (64 + y , 32 ): (6 , 2 , 1 ), flm (64 + y , 32 ): (32 , 0 , 1 )}, "bijective" ),
958+ # left padding only, offset non-divisible
959+ ({y : 176 }, {fld (80 + y , 32 ): (6 , 2 , 1 )}),
960+ ({y : 176 }, {flm (fld (80 + y , 2 ), 16 ): (16 , 0 , 1 ), flm (80 + y , 2 ): (2 , 0 , 1 )}),
961+ # right padding only, offset divisible
962+ ({x : 5 , y : 4 }, {fld (x * 32 + y * 8 , 16 ): (10 , 0 , 1 ), flm (x * 32 + y * 8 , 16 ): (2 , 0 , 8 )}),
963+ # right padding only, offset non-divisible
964+ ({x : 26 }, {fld (x , 15 ): (2 , 0 , 1 )}),
965+ ({x : 26 }, {flm (fld (x , 3 ), 5 ): (5 , 0 , 1 ), flm (x , 3 ): (3 , 0 , 1 )}),
966+ # padding constants on both side
967+ ({x : 45 }, {fld (x + 71 , 32 ): (2 , 2 , 1 )}),
968+ ({x : 45 }, {flm (fld (x , 4 ), 8 ): (8 , 0 , 1 ), flm (x , 4 ): (4 , 0 , 1 )}),
969+ # padding for free iteration part
970+ ({y : 360 }, {fld (x * 360 + y , 16 ): (23 , fld (x * 360 - flm (x , 2 ) * 8 , 16 ), 1 )}),
971+ ({y : 360 }, {flm (x * 360 + y , 16 ): (16 , 0 , 1 )}),
972+ # multiple split with same mark offset, could
973+ # be surjective on missing (padded // LCM)
974+ (
975+ {x : 240 },
976+ {
977+ flm (x + 10 , 3 ): (3 , 0 ),
978+ flm (fld (x + 10 , 3 ), 4 ): (4 , 0 ),
979+ flm (fld (fld (x + 10 , 3 ), 4 ), 5 ): (5 , 0 ),
980+ },
981+ ),
982+ # different offsets on splits
983+ (
984+ {x : 240 },
985+ {
986+ flm (x + 1 , 3 ): (3 , 0 ),
987+ flm (fld (x + 10 , 3 ) + 2 , 4 ): (4 , 0 ),
988+ flm (fld (fld (x + 10 , 3 ), 4 ) + 3 , 5 ): (5 , 0 ),
989+ },
990+ ),
1041991 )
1042992
1043- # original extent is smaller than the divident
1044- # it is not surjective wrt to the region [0, 16)
1045- assert_iter_sum_failure ({flm (x , 16 )}, var_dom ([(x , 3 )]))
993+ negative_test_case = tvm .testing .parameter (
994+ # left padding only, offset non-divisible
995+ ({y : 176 }, {fld (80 + y , 32 ), flm (80 + y , 32 )}),
996+ ({y : 176 }, {fld (80 + y , 32 ), fld (80 + y , 4 )}),
997+ # right padding only, offset divisible
998+ ({x : 5 , y : 4 }, {fld (x * 32 + y * 8 , 5 )}),
999+ # multiple split with same mark offset, could
1000+ # be surjective on missing (padded // LCM)
1001+ (
1002+ {x : 240 },
1003+ {
1004+ flm (x + 10 , 3 ),
1005+ flm (fld (x + 10 , 3 ), 4 ),
1006+ flm (fld (fld (x + 10 , 3 ), 4 ), 5 ),
1007+ fld (fld (fld (x + 10 , 3 ), 4 ), 5 ),
1008+ },
1009+ ),
1010+ # original extent is smaller than the divident
1011+ # it is not surjective wrt to the region [0, 16)
1012+ ({x : 3 }, {flm (x , 16 )}),
1013+ )
1014+
1015+ def test_padding (self , positive_test_case ):
1016+ iter_extent , mapped_iterators , * args = positive_test_case
1017+ check_level = args [0 ] if args else "surjective"
1018+ dom_map = {var : tvm .ir .Range (0 , ext ) for var , ext in iter_extent .items ()}
1019+ assert_iter_sum_pattern (mapped_iterators , dom_map , check_level = check_level )
1020+
1021+ def test_padding_error (self , negative_test_case ):
1022+ iter_extent , mapped_iterators , * args = negative_test_case
1023+ check_level = args [0 ] if args else "surjective"
1024+ dom_map = {var : tvm .ir .Range (0 , ext ) for var , ext in iter_extent .items ()}
1025+ assert_iter_sum_failure (mapped_iterators , dom_map , check_level = check_level )
10461026
10471027
10481028def test_overlapped_fuse ():
0 commit comments