Skip to content

Commit 8c53f62

Browse files
[TIR] Allow compute_at create block predicate for non-trivial bounds and support floordiv pattern (#9527)
* allow generate block predicate in compute_at schedule * revert #9880 and add more testcases
1 parent 8247724 commit 8c53f62

File tree

8 files changed

+374
-72
lines changed

8 files changed

+374
-72
lines changed

include/tvm/arith/int_set.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,11 @@ class IntSet : public ObjectRef {
9393
bool CanProveNonPositive() const;
9494
/*! \return Whether the set is proved to be larger than or equal to 0 */
9595
bool CanProveNonNegative() const;
96+
/*! \return Whether the set has upper bound. */
97+
bool HasUpperBound() const;
98+
/*! \return Whether the set has lower bound. */
99+
bool HasLowerBound() const;
100+
96101
/*!
97102
* \brief The single point value, call only if IsSinglePoint is true
98103
* \return The point value.

src/arith/int_set.cc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,20 @@ bool IntSet::CanProveNonNegative() const {
574574
return false;
575575
}
576576

577+
bool IntSet::HasLowerBound() const {
578+
if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
579+
return s_int->HasLowerBound();
580+
}
581+
return false;
582+
}
583+
584+
bool IntSet::HasUpperBound() const {
585+
if (const IntervalSetNode* s_int = (*this).as<IntervalSetNode>()) {
586+
return s_int->HasUpperBound();
587+
}
588+
return false;
589+
}
590+
577591
SignType IntSet::GetSignType() const {
578592
if (CanProvePositive()) {
579593
return kPositive;

src/arith/iter_affine_map.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,7 @@ class IterMapRewriter : public ExprMutator {
503503
if (predicate_induced_min.defined()) predicate_induced_min = predicate_induced_min - base;
504504
if (predicate_induced_max.defined()) predicate_induced_max = predicate_induced_max - base;
505505
}
506+
if (expr->args.size() < 1) return expr;
506507
Optional<IterSumExpr> opt = TryFuseIters(expr);
507508
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
508509
// scale should be 1

src/tir/analysis/block_access_region_detector.cc

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -145,22 +145,10 @@ void BlockReadWriteDetector::VisitExpr_(const LoadNode* op) {
145145
ExprVisitor::VisitExpr_(op);
146146
}
147147

148-
arith::IntSet BlockReadWriteDetector::RelaxAccessIndex(const PrimExpr& index) {
149-
arith::IntSet relaxed = arith::EvalSet(index, dom_map_);
150-
if (!hint_map_.empty()) {
151-
// take non-relaxed var bound hints into considerations
152-
// eg, if i * 4 + j with i >= 10 and j in [0, 4), only j in domain scope
153-
// then the index region can be relaxed to [i*4, i*4+4) ^ [40, inf)
154-
arith::IntSet hint_bound = arith::EvalSet(relaxed, hint_map_);
155-
relaxed = arith::Intersect({relaxed, hint_bound});
156-
}
157-
return relaxed;
158-
}
159-
160148
void BlockReadWriteDetector::VisitExpr_(const BufferLoadNode* op) {
161149
std::vector<arith::IntSet> relaxed_region;
162150
for (const PrimExpr& index : op->indices) {
163-
relaxed_region.push_back(RelaxAccessIndex(index));
151+
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
164152
}
165153
Update(&read_buffers_, &read_regions_, op->buffer, relaxed_region);
166154
ExprVisitor::VisitExpr_(op);
@@ -213,7 +201,7 @@ void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) {
213201
void BlockReadWriteDetector::VisitStmt_(const BufferStoreNode* op) {
214202
std::vector<arith::IntSet> relaxed_region;
215203
for (const PrimExpr& index : op->indices) {
216-
relaxed_region.push_back(RelaxAccessIndex(index));
204+
relaxed_region.push_back(arith::EvalSet(index, dom_map_));
217205
}
218206
Update(&writes_buffers_, &write_regions_, op->buffer, relaxed_region);
219207
StmtVisitor::VisitStmt_(op);

src/tir/schedule/primitive/compute_at.cc

Lines changed: 132 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,49 @@ int FindInsertionPoint(
163163
return split.first_consumer_position;
164164
}
165165

166+
/*!
167+
* \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound)
168+
* The bound region may not get directly intersected with dom region, instead we try to generate
169+
* extra predicates for non-trivial bound. The domain info class can also union with each other.
170+
*/
171+
struct BlockVarDomainInfo {
172+
arith::IntSet dom{arith::IntSet::Nothing()}; // dom is ensured to be bounded
173+
arith::IntSet bound{arith::IntSet::Nothing()};
174+
175+
/*! \brief Relaxed union operation */
176+
void Union(const BlockVarDomainInfo& other) {
177+
// just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1)
178+
dom = arith::Union({dom, other.dom});
179+
bound = arith::Union({bound, other.bound});
180+
}
181+
182+
/*! \brief Simplify domain info */
183+
void Simplify(arith::Analyzer* analyzer) {
184+
auto to_simplified = [analyzer](const arith::IntSet& set) {
185+
PrimExpr min = set.HasLowerBound() ? analyzer->Simplify(set.min()) : set.min();
186+
PrimExpr max = set.HasUpperBound() ? analyzer->Simplify(set.max()) : set.max();
187+
return arith::IntSet::Interval(min, max);
188+
};
189+
// if no dom specified, try use bound as dom
190+
if (dom.IsNothing()) {
191+
if (bound.HasLowerBound() && bound.HasUpperBound()) {
192+
bound = to_simplified(bound);
193+
std::swap(dom, bound);
194+
}
195+
return;
196+
}
197+
// simplify intsets
198+
dom = to_simplified(dom);
199+
bound = to_simplified(bound);
200+
// if can proof the dom is within bound, remove bound
201+
auto intersect = to_simplified(arith::Intersect({dom, bound}));
202+
if (analyzer->CanProveEqual(dom.min(), intersect.min()) &&
203+
analyzer->CanProveEqual(dom.max(), intersect.max())) {
204+
bound = arith::IntSet::Nothing();
205+
}
206+
}
207+
};
208+
166209
/*!
167210
* \brief A helper to reconstruct the block scope where the given block is moved under the given
168211
* loop, and the given block's induced loop nest is regenerated to satisfy the required region.
@@ -179,29 +222,44 @@ class ScopeReconstructor : private StmtMutator {
179222
* \param insert_position The position among the subtrees where the block and its induced loop
180223
* nest is inserted
181224
* \param iter_doms The domain of each block var
225+
* \param analyzer The arithmetic analyzer
182226
* \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1
183227
*/
184-
void MakeNewLoop(int insert_position, std::vector<Range> iter_doms, bool preserve_unit_loops) {
228+
void MakeNewLoop(int insert_position, std::vector<BlockVarDomainInfo> iter_doms,
229+
arith::Analyzer* analyzer, bool preserve_unit_loops) {
185230
int n_iters = iter_doms.size();
186231
Array<Var> loop_vars;
187232
Array<PrimExpr> loop_extents;
188233
Array<PrimExpr> iter_values;
189234
loop_vars.reserve(n_iters);
190235
loop_extents.reserve(n_iters);
191236
iter_values.reserve(n_iters);
237+
PrimExpr predicate = const_true();
192238
for (int i = 0; i < n_iters; ++i) {
193-
const Range& iter_dom = iter_doms[i];
239+
Range iter_dom = iter_doms[i].dom.CoverRange(block_->iter_vars[i]->dom);
194240
if (preserve_unit_loops || !is_one(iter_dom->extent)) {
195241
Var var("ax" + std::to_string(loop_vars.size()), DataType::Int(32));
196242
loop_vars.push_back(var);
197243
loop_extents.push_back(iter_dom->extent);
198244
iter_values.push_back(iter_dom->min + var);
245+
analyzer->Bind(var, Range::FromMinExtent(0, iter_dom->extent));
199246
} else {
200247
iter_values.push_back(iter_dom->min);
201248
}
249+
const arith::IntSet& pred_bound = iter_doms[i].bound;
250+
if (!pred_bound.IsNothing()) {
251+
if (pred_bound.HasLowerBound()) {
252+
PrimExpr lower_bound = iter_values[i] >= pred_bound.min();
253+
predicate = predicate && lower_bound;
254+
}
255+
if (pred_bound.HasUpperBound()) {
256+
PrimExpr upper_bound = iter_values[i] < pred_bound.max() + 1;
257+
predicate = predicate && upper_bound;
258+
}
259+
}
202260
}
203261
this->new_block_realize_ =
204-
BlockRealize(std::move(iter_values), const_true(), std::move(block_));
262+
BlockRealize(std::move(iter_values), analyzer->Simplify(predicate), std::move(block_));
205263
Stmt new_subtree = this->new_block_realize_;
206264
for (int i = static_cast<int>(loop_vars.size()) - 1; i >= 0; --i) {
207265
const Var& loop_var = loop_vars[i];
@@ -310,39 +368,74 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
310368
* domain
311369
* \param provided The provided integer set to cover the required domain
312370
* \param required The required domain to be covered
313-
* \param iter_doms The result iteration domains to be updated
314371
* \param analyzer The arithmetic analyzer
315372
*/
316-
void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
317-
std::unordered_map<const VarNode*, std::vector<arith::IntSet>>* iter_doms,
318-
arith::Analyzer* analyzer) {
373+
std::pair<Var, arith::IntSet> SolveBlockVarDomain(const arith::IntSet& provided,
374+
const arith::IntSet& required,
375+
arith::Analyzer* analyzer) {
319376
PrimExpr provided_min = analyzer->Simplify(provided.min());
320-
PrimExpr provided_extent = analyzer->Simplify(provided.max() - provided_min + 1);
377+
PrimExpr provided_max = analyzer->Simplify(provided.max());
321378
PrimExpr required_min = analyzer->Simplify(required.min());
322-
PrimExpr required_extent = analyzer->Simplify(required.max() - required_min + 1);
323-
PrimExpr dom_min{nullptr}, dom_extent{nullptr};
379+
PrimExpr required_max = analyzer->Simplify(required.max());
380+
PrimExpr dom_min{nullptr}, dom_max{nullptr};
324381
Var dom_var{ObjectPtr<VarNode>{nullptr}};
325382
arith::PVar<Var> p_v;
326383
arith::PVar<PrimExpr> p_e;
327384
if ((p_v * p_e).Match(provided_min) || (p_e * p_v).Match(provided_min)) {
328385
PrimExpr e = p_e.Eval();
329386
dom_var = p_v.Eval();
330387
dom_min = floordiv(required_min, e);
331-
dom_extent = analyzer->Simplify((required_extent + e - 1) / e);
332-
} else if (analyzer->CanProveEqual(provided_extent, 1) && p_v.Match(provided_min)) {
333-
dom_var = p_v.Eval();
334-
dom_min = required_min;
335-
dom_extent = required_extent;
336-
} else {
337-
ICHECK(false) << "ValueError: BufferRegion pattern match failed";
388+
dom_max = floordiv(required_max, e);
389+
} else if (analyzer->CanProveEqual(provided_min, provided_max)) {
390+
if (p_v.Match(provided_min)) {
391+
dom_var = p_v.Eval();
392+
dom_min = required_min;
393+
dom_max = required_max;
394+
} else {
395+
arith::PVar<PrimExpr> p_f;
396+
if ((floordiv(p_v, p_f)).Match(provided_min)) {
397+
// a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
398+
PrimExpr fac = p_f.Eval();
399+
if (analyzer->CanProveGreaterEqual(fac, 1)) {
400+
dom_var = p_v.Eval();
401+
dom_min = required_min * fac;
402+
dom_max = analyzer->Simplify(required_max * fac + fac - 1);
403+
}
404+
} else if ((floormod(p_v, p_f).Match(provided_min))) {
405+
// generally domain of (x % fac) enforce no constraints to domain of x
406+
dom_var = p_v.Eval();
407+
return std::make_pair(dom_var, arith::IntSet::Nothing());
408+
}
409+
}
338410
}
339-
auto it = iter_doms->find(dom_var.get());
411+
ICHECK(dom_var.defined()) << "ValueError: BufferRegion pattern match failed: " << provided_min;
412+
return std::make_pair(dom_var, arith::IntSet::Interval(dom_min, dom_max));
413+
}
414+
415+
/*!
416+
* \brief Calculate and update the iteration domain info to fully cover the required domain
417+
* \param provided The provided integer set to cover the required domain
418+
* \param required The required domain to be covered
419+
* \param required_bound The additional region bound of the required domain to be covered
420+
* \param iter_doms The result iteration domains to be updated
421+
* \param analyzer The arithmetic analyzer
422+
*/
423+
void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& required,
424+
const arith::IntSet& required_bound,
425+
std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
426+
arith::Analyzer* analyzer) {
427+
auto var_with_dom = SolveBlockVarDomain(provided, required, analyzer);
428+
auto var_with_bound = SolveBlockVarDomain(provided, required_bound, analyzer);
429+
const Var& var = var_with_dom.first;
430+
const auto& var_dom = var_with_dom.second;
431+
const auto& var_bound = var_with_bound.second;
432+
ICHECK(var.same_as(var_with_bound.first));
433+
auto it = iter_doms->find(var.get());
340434
if (it != iter_doms->end()) {
341-
std::vector<arith::IntSet>& doms = it->second;
342-
doms.push_back(arith::IntSet::FromMinExtent(dom_min, dom_extent));
435+
it->second.Union({var_dom, var_bound});
343436
} else {
344-
ICHECK(analyzer->CanProveEqual(provided_min, required_min));
345-
ICHECK(analyzer->CanProveEqual(provided_extent, required_extent));
437+
ICHECK(analyzer->CanProveEqual(provided.min(), required.min()));
438+
ICHECK(analyzer->CanProveEqual(provided.max(), required.max()));
346439
}
347440
}
348441

@@ -352,19 +445,19 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re
352445
* \param provided_regions The region provided by one iteration instance of the block vars
353446
* \param required_regions The region required to be covered
354447
* \param analyzer The arithmetic analyzer
355-
* \return A list of iteration domain corresponding to the given list of block vars
448+
* \return A list of iteration domain info corresponding to the given list of block vars
356449
*/
357-
std::vector<Range> CalculateBlockVarDomain(
450+
std::vector<BlockVarDomainInfo> CalculateBlockVarDomain(
358451
const Array<IterVar>& iter_vars,
359452
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions,
360453
std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions,
361454
arith::Analyzer* analyzer) {
362455
int n_iters = iter_vars.size();
363456
// Step 1. Construct the mapping from block var to their iteration domain (initialized to empty)
364-
std::unordered_map<const VarNode*, std::vector<arith::IntSet>> iter_doms;
457+
std::unordered_map<const VarNode*, BlockVarDomainInfo> iter_doms;
365458
iter_doms.reserve(n_iters);
366459
for (const IterVar& iter_var : iter_vars) {
367-
iter_doms[iter_var->var.get()] = {};
460+
iter_doms[iter_var->var.get()] = BlockVarDomainInfo();
368461
}
369462
// Step 2. For each buffer, update the domain according to the provided and required regions
370463
for (const auto& kv : provided_regions) {
@@ -384,23 +477,23 @@ std::vector<Range> CalculateBlockVarDomain(
384477
for (int i = 0; i < ndim; ++i) {
385478
arith::IntSet provided = provided_region[i];
386479
arith::IntSet required = required_region[i];
387-
required = arith::Intersect(
388-
{std::move(required), arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i])});
389-
UpdateBlockVarDomain(provided, required, &iter_doms, analyzer);
480+
arith::IntSet required_bound = arith::IntSet::FromMinExtent(Integer(0), buffer->shape[i]);
481+
UpdateBlockVarDomain(provided, required, required_bound, &iter_doms, analyzer);
390482
}
391483
}
392484
// Union the iter var domains, put them in the same order of block vars, and return
393-
std::vector<Range> result;
485+
std::vector<BlockVarDomainInfo> result;
394486
result.reserve(n_iters);
395487
for (const IterVar& iter_var : iter_vars) {
396-
const std::vector<arith::IntSet>& doms = iter_doms.at(iter_var->var.get());
397-
arith::IntSet dom = arith::IntSet::FromRange(iter_var->dom);
398-
if (!doms.empty()) {
399-
dom = arith::Intersect({std::move(dom), arith::Union(doms)});
488+
BlockVarDomainInfo& info = iter_doms.at(iter_var->var.get());
489+
if (info.bound.IsNothing()) {
490+
info.bound = arith::IntSet::FromRange(iter_var->dom);
491+
} else {
492+
info.bound = arith::Intersect({info.bound, arith::IntSet::FromRange(iter_var->dom)});
400493
}
401-
PrimExpr min = analyzer->Simplify(dom.min());
402-
PrimExpr extent = analyzer->Simplify(dom.max() - min + 1);
403-
result.push_back(Range::FromMinExtent(min, extent));
494+
info.Simplify(analyzer);
495+
ICHECK(!info.dom.IsNothing());
496+
result.push_back(info);
404497
}
405498
return result;
406499
}
@@ -498,14 +591,14 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
498591
/*consumer_srefs=*/std::move(consumer_srefs),
499592
/*provided_regions=*/&provided_regions, /*required_regions=*/&required_regions);
500593
// Step 5. Calculate the iteration domain for each block var
501-
std::vector<Range> iter_doms =
594+
std::vector<BlockVarDomainInfo> iter_doms =
502595
CalculateBlockVarDomain(/*iter_vars=*/block->iter_vars,
503596
/*provided_regions=*/std::move(provided_regions),
504597
/*required_regions=*/std::move(required_regions),
505598
/*analyzer=*/analyzer);
506599
// Step 6. Create the new scope according to the iteration domain
507600
reconstructor.MakeNewLoop(/*insert_position=*/insert_position, /*iter_doms=*/std::move(iter_doms),
508-
/*preserve_unit_loops=*/preserve_unit_loops);
601+
/*analyzer=*/analyzer, /*preserve_unit_loops=*/preserve_unit_loops);
509602
Block new_scope_root = Downcast<Block>(reconstructor(scope_root));
510603

511604
// Step 7. Do the actual replacement

tests/python/unittest/test_arith_iter_affine_map.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,14 @@ def test_predicate():
263263
)
264264
assert len(res) == 0
265265

266+
# irrelavant predicate
267+
res = tvm.arith.detect_iter_map(
268+
[i + j],
269+
var_dom([(i, 1)]),
270+
j <= 24,
271+
)
272+
assert_iter_sum_pattern(res[0], 1, j)
273+
266274
# constraint on nested fused iters
267275
res = tvm.arith.detect_iter_map(
268276
[i * 8 + j * 2 + k],

tests/python/unittest/test_tir_analysis_get_block_access_region.py

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -138,29 +138,15 @@ def access_of_padding_pattern() -> None:
138138
for i, j in T.grid(32, 32):
139139
with T.block("padding"):
140140
vi, vj = T.axis.remap("SS", [i, j])
141-
T.reads(
142-
[
143-
X[
144-
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
145-
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
146-
]
147-
]
148-
)
141+
T.reads([X[vi - 2, vj - 2]])
149142
T.writes([X_pad[vi, vj]])
150143
X_pad[vi, vj] = T.if_then_else(
151144
2 <= vi and vi < 30 and 2 <= vj and vj < 30, X[vi - 2, vj - 2], 0.0, dtype="float32"
152145
)
153146
with T.block("padding_reverse"):
154147
vi, vj = T.axis.remap("SS", [i, j])
155-
T.reads([X_pad[T.max(vi, 2) : T.min(vi, 29) + 1, T.max(vj, 2) : T.min(vj, 29) + 1]])
156-
T.writes(
157-
[
158-
Y[
159-
T.max(vi - 2, 0) : T.min(vi - 2, 27) + 1,
160-
T.max(vj - 2, 0) : T.min(vj - 2, 27) + 1,
161-
]
162-
]
163-
)
148+
T.reads([X_pad[vi, vj]])
149+
T.writes([Y[vi - 2, vj - 2]])
164150
if 2 <= vi and vi < 30 and 2 <= vj and vj < 30:
165151
Y[vi - 2, vj - 2] = X_pad[vi, vj]
166152

0 commit comments

Comments
 (0)