Skip to content

Commit 3e7a2ad

Browse files
support overlapped itersum (#12039)
1 parent 7bf5fa4 commit 3e7a2ad

8 files changed

+176
-58
lines changed

src/arith/iter_affine_map.cc

Lines changed: 66 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -177,8 +177,12 @@ class IterMapRewriter : public ExprMutator {
177177
using Parent = ExprMutator;
178178

179179
explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
180-
bool simplify_trivial_iterators, Array<String>* errors)
181-
: analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
180+
IterMapLevel check_level, bool simplify_trivial_iterators,
181+
Array<String>* errors)
182+
: analyzer_(analyzer),
183+
check_level_(check_level),
184+
errors_(*errors),
185+
padding_predicate_(const_false()) {
182186
for (auto kv : input_iters) {
183187
const Var& var = kv.first;
184188
const Range& vrng = kv.second;
@@ -419,6 +423,8 @@ class IterMapRewriter : public ExprMutator {
419423

420424
// Internal analyzer
421425
Analyzer* analyzer_;
426+
// Iter map check level
427+
IterMapLevel check_level_;
422428
// Error messages for each unresolved expression.
423429
Array<String>& errors_;
424430
// The var map
@@ -651,7 +657,7 @@ class IterMapRewriter : public ExprMutator {
651657
if (predicate_induced_max.defined())
652658
predicate_induced_max = predicate_induced_max.value() - base;
653659
}
654-
Optional<IterSumExpr> opt = TryFuseIters(expr);
660+
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
655661
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
656662
// scale should be 1
657663
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
@@ -702,7 +708,7 @@ class IterMapRewriter : public ExprMutator {
702708
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
703709
// We are normalizing a regular iter
704710
if (expr->args.size() < 1) return expr;
705-
Optional<IterSumExpr> opt = TryFuseIters(expr);
711+
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
706712
if (opt.defined()) {
707713
return opt.value();
708714
} else {
@@ -735,9 +741,10 @@ class IterMapRewriter : public ExprMutator {
735741
* return a corresponding IterSumExpr with extra offset if needed.
736742
* Try to normalize IterSum into a fused IterMark
737743
* \param expr The input sum.
744+
* \param check_level The check level if iter mapping.
738745
* \return The sum with the fused IterMark and extra offset if succeed.
739746
*/
740-
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
747+
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
741748
// select the iterators in order
742749
std::vector<bool> visited(expr->args.size(), false);
743750
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
@@ -758,14 +765,42 @@ class IterMapRewriter : public ExprMutator {
758765
}
759766
// check if it can be remapped into a fused pattern.
760767
PrimExpr expected_extra_base = 0;
768+
PrimExpr tail_extent = 0;
761769
PrimExpr expected_scale = base_scale.value();
762770
for (size_t i = 0; i < expr->args.size();) {
763-
// find j such that expr->args[j] has expected scale
764-
size_t j = i == 0 ? base_index : 0;
765-
for (; j < expr->args.size(); ++j) {
766-
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
771+
// find position such that expr->args[j] match expected scale
772+
int j = i == 0 ? base_index : expr->args.size() - 1;
773+
774+
size_t matched_pos = expr->args.size();
775+
PrimExpr matched_scale{nullptr};
776+
bool is_exact_match{false};
777+
778+
for (; j >= 0; --j) {
779+
if (visited[j]) {
780+
continue;
781+
}
782+
const PrimExpr& cur_scale = expr->args[j]->scale;
783+
784+
// for bijective mapping, the matched scale must equal to expected scale
785+
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
786+
matched_pos = j;
787+
matched_scale = cur_scale;
788+
is_exact_match = true;
789+
break;
790+
}
791+
if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
792+
// find the closest scale which is less or equal to expected scale
793+
if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
794+
analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
795+
if (matched_pos == expr->args.size() ||
796+
analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
797+
matched_pos = j;
798+
matched_scale = cur_scale;
799+
}
800+
}
801+
}
767802
}
768-
if (j == expr->args.size()) {
803+
if (matched_pos == expr->args.size()) {
769804
return NullOpt;
770805
}
771806
// look for the longest constrained iter started from expr->args[j]
@@ -775,8 +810,8 @@ class IterMapRewriter : public ExprMutator {
775810
// otherwise we expect the scale of i to be 2*5=10
776811
Optional<IterSumExpr> constraint_to_match;
777812
for (const IterSumExpr& iter : constrained_iters_flattened_) {
778-
if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
779-
// find a predicate started from expr->args[j]
813+
if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
814+
// find a predicate started from match position
780815
if (!constraint_to_match ||
781816
constraint_to_match.value()->args.size() < iter->args.size()) {
782817
constraint_to_match = iter;
@@ -793,7 +828,7 @@ class IterMapRewriter : public ExprMutator {
793828
size_t k = 0;
794829
for (; k < expr->args.size(); ++k) {
795830
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
796-
if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
831+
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
797832
break;
798833
}
799834
}
@@ -806,20 +841,25 @@ class IterMapRewriter : public ExprMutator {
806841
auto iter = sum_fuse_map_.find(constraint_to_match.value());
807842
ICHECK(iter != sum_fuse_map_.end());
808843
const IterMarkWithOffset& iter_matched = iter->second;
809-
grouped_iters.emplace_back(iter_matched.mark, expected_scale);
810-
expected_extra_base += iter_matched.offset * expected_scale;
811-
expected_scale *= iter_matched.mark->extent;
844+
grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
845+
expected_extra_base += iter_matched.offset * matched_scale;
846+
if (!is_exact_match) {
847+
tail_extent += expected_scale - matched_scale;
848+
}
849+
expected_scale = matched_scale * iter_matched.mark->extent;
812850
// move forward
813851
i += constraint_to_match.value()->args.size();
814852
} else {
815853
// constraint_to_match not found, skip this iterator
816-
visited[j] = true;
817-
IterSplitExpr arg = expr->args[j];
818-
arg.CopyOnWrite()->scale =
819-
analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
854+
visited[matched_pos] = true;
855+
IterSplitExpr arg = expr->args[matched_pos];
856+
arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
820857
flattened_iters.push_back(arg);
821858
grouped_iters.push_back(arg);
822-
expected_scale *= expr->args[j]->extent;
859+
if (!is_exact_match) {
860+
tail_extent += expected_scale - matched_scale;
861+
}
862+
expected_scale = matched_scale * expr->args[matched_pos]->extent;
823863
++i;
824864
}
825865
}
@@ -843,7 +883,8 @@ class IterMapRewriter : public ExprMutator {
843883
expr->base + expected_extra_base);
844884
} else {
845885
// new iter, form a new mark
846-
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
886+
IterMark mark =
887+
IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
847888
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
848889
flattened_map_[structured_form] = flattened_form;
849890
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
@@ -1086,8 +1127,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
10861127
constraints.begin(), constraints.end(),
10871128
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });
10881129

1089-
IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
1090-
&result->errors);
1130+
IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
1131+
simplify_trivial_iterators, &result->errors);
10911132
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
10921133
for (const IterConstraint& constraint : constraints) {
10931134
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
@@ -1281,7 +1322,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
12811322
} else if (sum->args.size() == 1) {
12821323
return sum;
12831324
}
1284-
auto opt_fused = TryFuseIters(sum);
1325+
auto opt_fused = TryFuseIters(sum, check_level_);
12851326
if (!opt_fused) {
12861327
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
12871328
<< ", can't be written as a single fused IterSum";

tests/python/unittest/test_arith_intset.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,6 @@ def do_test_point_access(point, predicates, var_dom, expect):
323323

324324

325325
def test_region_lower_bound_unfusable():
326-
# This test is designed to trigger an error in DetectIterMap,
327-
# resulting from a numerator which required multiple input
328-
# variables. The bug resulted in an exception being thrown,
329-
# rather than a return value of None.
330326
var_dom = {
331327
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
332328
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
@@ -336,7 +332,8 @@ def test_region_lower_bound_unfusable():
336332
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
337333
]
338334
result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
339-
assert result is None
335+
assert result[0].min_value == 0
336+
assert result[0].max_value == 5
340337

341338

342339
def test_union_lower_bound():

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def assert_iter_sum_pattern(
6161
)
6262
indices = res.indices
6363
assert len(indices) == len(keys), res.errors
64-
print(indices)
6564
for i, input_iter in enumerate(keys):
6665
spec = expect_dict[input_iter]
6766
(
@@ -446,6 +445,13 @@ def test_predicate():
446445
predicate=xo * 129 + xi < 128,
447446
)
448447

448+
# strided iteration predicate
449+
assert_iter_sum_pattern(
450+
{xo * 16 + xi * 4: (10, 0, 4)},
451+
var_dom([(xo, 3), (xi, 4)]),
452+
predicate=xo * 4 + xi < 10,
453+
)
454+
449455

450456
def convert_division(divisions):
451457
if divisions is None or len(divisions) == 0:
@@ -1010,5 +1016,55 @@ def test_padding():
10101016
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))
10111017

10121018

1019+
def test_overlapped_fuse():
1020+
x = tvm.tir.Var("x", "int32")
1021+
y = tvm.tir.Var("y", "int32")
1022+
z = tvm.tir.Var("z", "int32")
1023+
a = tvm.tir.Var("x", "int32")
1024+
b = tvm.tir.Var("y", "int32")
1025+
1026+
# non-bijective fuse of two
1027+
assert_iter_sum_pattern(
1028+
{
1029+
x * 7 + y: (22, 0, 1),
1030+
},
1031+
var_dom([(x, 3), (y, 8)]),
1032+
check_level="surjective",
1033+
)
1034+
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")
1035+
1036+
# non-bijective fuse of three
1037+
assert_iter_sum_pattern(
1038+
{
1039+
x * 18 + y * 7 + z: (40, 0, 1),
1040+
},
1041+
var_dom([(x, 2), (y, 3), (z, 8)]),
1042+
check_level="surjective",
1043+
)
1044+
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")
1045+
1046+
# negative scale fusion is not allowed
1047+
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
1048+
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
1049+
1050+
# with predicate
1051+
assert_iter_sum_pattern(
1052+
{
1053+
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
1054+
},
1055+
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
1056+
predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
1057+
check_level="surjective",
1058+
)
1059+
1060+
# stride=1 kernel
1061+
assert_iter_sum_pattern(
1062+
{x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
1063+
)
1064+
1065+
# do not allow both strided and overlapped
1066+
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")
1067+
1068+
10131069
if __name__ == "__main__":
10141070
tvm.testing.main()

tests/python/unittest/test_meta_schedule_space_cpu.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
4848
for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8):
4949
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
5050
with T.block("conv1d_nlc"):
51-
n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
52-
l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
53-
co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3)
51+
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
52+
l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3)
53+
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2)
5454
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
55-
rc = T.axis.reduce(64, i4_0 + i4_1)
55+
rc = T.axis.reduce(64, i4_1 + i4_0)
5656
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
5757
T.writes(conv1d_nlc_global[n, l, co])
5858
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -89,11 +89,11 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
8989
PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
9090
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
9191
with T.block("conv1d_nlc"):
92-
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
93-
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
94-
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
92+
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
93+
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
94+
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
9595
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
96-
rc = T.axis.reduce(64, i4_0 + i4_1)
96+
rc = T.axis.reduce(64, i4_1 + i4_0)
9797
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
9898
T.writes(conv1d_nlc_global[n, l, co])
9999
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
@@ -107,7 +107,7 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
107107
T.reads(conv1d_nlc_global[v0, v1, v2])
108108
T.writes(conv1d_nlc[v0, v1, v2])
109109
conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]
110-
110+
111111
@T.prim_func
112112
def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
113113
# function attr dict
@@ -119,11 +119,11 @@ def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
119119
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
120120
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
121121
with T.block("conv1d_nlc"):
122-
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
123-
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
124-
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
122+
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
123+
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
124+
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
125125
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
126-
rc = T.axis.reduce(64, i4_0 + i4_1)
126+
rc = T.axis.reduce(64, i4_1 + i4_0)
127127
T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co])
128128
T.writes(conv1d_nlc[n, l, co])
129129
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})

tests/python/unittest/test_meta_schedule_space_cuda.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
4747
for ax0_ax1_ax2_fused in T.serial(260):
4848
with T.block("PadInput_shared"):
4949
v0 = T.axis.spatial(1, 0)
50-
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
50+
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
5151
v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
5252
T.reads(inputs[v0, v1 - 1, v2])
5353
T.writes(PadInput_shared[v0, v1, v2])
@@ -64,11 +64,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
6464
weight_shared[v0, v1, v2] = weight[v0, v1, v2]
6565
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
6666
with T.block("conv1d_nlc"):
67-
n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
68-
l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
69-
co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
70-
rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
71-
rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
67+
n = T.axis.spatial(1, i0_4 + i0_3)
68+
l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
69+
co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
70+
rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2)
71+
rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2)
7272
T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
7373
T.writes(conv1d_nlc_local[n, l, co])
7474
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})

0 commit comments

Comments
 (0)