Skip to content

Commit f975992

Browse files
authored
[UnitTest] Parametrized test_arith_iter_affine_map::test_padding (#13774)
Parametrization helped in the debugging of #13530, but is not otherwise related to that PR.
1 parent 287597b commit f975992

File tree

1 file changed

+70
-90
lines changed

1 file changed

+70
-90
lines changed

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 70 additions & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -946,103 +946,83 @@ def test_free_variables():
946946
)
947947

948948

949-
def test_padding():
949+
class TestPadding:
950950
x = tvm.tir.Var("x", "int32")
951951
y = tvm.tir.Var("y", "int32")
952952
fld = tvm.tir.floordiv
953953
flm = tvm.tir.floormod
954954

955-
# left padding only, offset divisible
956-
sum = 64 + y
957-
dom_map = var_dom([(y, 192)])
958-
assert_iter_sum_pattern(
959-
{fld(sum, 32): (6, 2, 1), flm(sum, 32): (32, 0, 1)},
960-
dom_map,
961-
check_level="bijective",
962-
)
963-
964-
# left padding only, offset non-divisible
965-
sum = 80 + y
966-
dom_map = var_dom([(y, 176)])
967-
assert_iter_sum_pattern(
968-
{fld(sum, 32): (6, 2, 1)},
969-
dom_map,
970-
)
971-
assert_iter_sum_pattern(
972-
{flm(fld(sum, 2), 16): (16, 0, 1), flm(sum, 2): (2, 0, 1)},
973-
dom_map,
974-
)
975-
assert_iter_sum_failure({fld(sum, 32), flm(sum, 32)}, dom_map)
976-
assert_iter_sum_failure({fld(sum, 32), fld(sum, 4)}, dom_map)
977-
978-
# right padding only, offset divisible
979-
sum = x * 32 + y * 8
980-
dom_map = var_dom([(x, 5), (y, 4)])
981-
assert_iter_sum_pattern(
982-
{fld(sum, 16): (10, 0, 1), flm(sum, 16): (2, 0, 8)},
983-
dom_map,
984-
)
985-
assert_iter_sum_failure({fld(sum, 5)}, dom_map)
986-
987-
# right padding only, offset non-divisible
988-
dom_map = var_dom([(x, 26)])
989-
assert_iter_sum_pattern(
990-
{fld(x, 15): (2, 0, 1)},
991-
dom_map,
992-
)
993-
assert_iter_sum_pattern(
994-
{flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)},
995-
dom_map,
996-
)
997-
998-
# padding constants on both side
999-
sum = x + 71
1000-
dom_map = var_dom([(x, 45)])
1001-
assert_iter_sum_pattern({fld(sum, 32): (2, 2, 1)}, dom_map)
1002-
assert_iter_sum_pattern(
1003-
{flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)},
1004-
dom_map,
1005-
)
1006-
1007-
# padding for free iteration part
1008-
sum = x * 360 + y
1009-
dom_map = var_dom([(y, 360)])
1010-
assert_iter_sum_pattern({fld(sum, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}, dom_map)
1011-
assert_iter_sum_pattern({flm(x * 360 + y, 16): (16, 0, 1)}, dom_map)
1012-
1013-
# multiple split with same mark offset, could
1014-
# be surjective on missing (padded // LCM)
1015-
assert_iter_sum_pattern(
1016-
{
1017-
flm(x + 10, 3): (3, 0),
1018-
flm(fld(x + 10, 3), 4): (4, 0),
1019-
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
1020-
},
1021-
var_dom([(x, 240)]),
1022-
)
1023-
assert_iter_sum_failure(
1024-
{
1025-
flm(x + 10, 3),
1026-
flm(fld(x + 10, 3), 4),
1027-
flm(fld(fld(x + 10, 3), 4), 5),
1028-
fld(fld(fld(x + 10, 3), 4), 5),
1029-
},
1030-
var_dom([(x, 240)]),
1031-
)
1032-
1033-
# different offsets on splits
1034-
assert_iter_sum_pattern(
1035-
{
1036-
flm(x + 1, 3): (3, 0),
1037-
flm(fld(x + 10, 3) + 2, 4): (4, 0),
1038-
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
1039-
},
1040-
var_dom([(x, 240)]),
955+
positive_test_case = tvm.testing.parameter(
956+
# left padding only, offset divisible
957+
({y: 192}, {fld(64 + y, 32): (6, 2, 1), flm(64 + y, 32): (32, 0, 1)}, "bijective"),
958+
# left padding only, offset non-divisible
959+
({y: 176}, {fld(80 + y, 32): (6, 2, 1)}),
960+
({y: 176}, {flm(fld(80 + y, 2), 16): (16, 0, 1), flm(80 + y, 2): (2, 0, 1)}),
961+
# right padding only, offset divisible
962+
({x: 5, y: 4}, {fld(x * 32 + y * 8, 16): (10, 0, 1), flm(x * 32 + y * 8, 16): (2, 0, 8)}),
963+
# right padding only, offset non-divisible
964+
({x: 26}, {fld(x, 15): (2, 0, 1)}),
965+
({x: 26}, {flm(fld(x, 3), 5): (5, 0, 1), flm(x, 3): (3, 0, 1)}),
966+
# padding constants on both side
967+
({x: 45}, {fld(x + 71, 32): (2, 2, 1)}),
968+
({x: 45}, {flm(fld(x, 4), 8): (8, 0, 1), flm(x, 4): (4, 0, 1)}),
969+
# padding for free iteration part
970+
({y: 360}, {fld(x * 360 + y, 16): (23, fld(x * 360 - flm(x, 2) * 8, 16), 1)}),
971+
({y: 360}, {flm(x * 360 + y, 16): (16, 0, 1)}),
972+
# multiple split with same mark offset, could
973+
# be surjective on missing (padded // LCM)
974+
(
975+
{x: 240},
976+
{
977+
flm(x + 10, 3): (3, 0),
978+
flm(fld(x + 10, 3), 4): (4, 0),
979+
flm(fld(fld(x + 10, 3), 4), 5): (5, 0),
980+
},
981+
),
982+
# different offsets on splits
983+
(
984+
{x: 240},
985+
{
986+
flm(x + 1, 3): (3, 0),
987+
flm(fld(x + 10, 3) + 2, 4): (4, 0),
988+
flm(fld(fld(x + 10, 3), 4) + 3, 5): (5, 0),
989+
},
990+
),
1041991
)
1042992

1043-
# original extent is smaller than the divident
1044-
# it is not surjective wrt to the region [0, 16)
1045-
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
993+
negative_test_case = tvm.testing.parameter(
994+
# left padding only, offset non-divisible
995+
({y: 176}, {fld(80 + y, 32), flm(80 + y, 32)}),
996+
({y: 176}, {fld(80 + y, 32), fld(80 + y, 4)}),
997+
# right padding only, offset divisible
998+
({x: 5, y: 4}, {fld(x * 32 + y * 8, 5)}),
999+
# multiple split with same mark offset, could
1000+
# be surjective on missing (padded // LCM)
1001+
(
1002+
{x: 240},
1003+
{
1004+
flm(x + 10, 3),
1005+
flm(fld(x + 10, 3), 4),
1006+
flm(fld(fld(x + 10, 3), 4), 5),
1007+
fld(fld(fld(x + 10, 3), 4), 5),
1008+
},
1009+
),
1010+
# original extent is smaller than the divident
1011+
# it is not surjective wrt to the region [0, 16)
1012+
({x: 3}, {flm(x, 16)}),
1013+
)
1014+
1015+
def test_padding(self, positive_test_case):
1016+
iter_extent, mapped_iterators, *args = positive_test_case
1017+
check_level = args[0] if args else "surjective"
1018+
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
1019+
assert_iter_sum_pattern(mapped_iterators, dom_map, check_level=check_level)
1020+
1021+
def test_padding_error(self, negative_test_case):
1022+
iter_extent, mapped_iterators, *args = negative_test_case
1023+
check_level = args[0] if args else "surjective"
1024+
dom_map = {var: tvm.ir.Range(0, ext) for var, ext in iter_extent.items()}
1025+
assert_iter_sum_failure(mapped_iterators, dom_map, check_level=check_level)
10461026

10471027

10481028
def test_overlapped_fuse():

0 commit comments

Comments
 (0)