@@ -163,6 +163,49 @@ int FindInsertionPoint(
163163 return split.first_consumer_position ;
164164}
165165
166+ /* !
167+ * \brief Represent the iteration domain to fully cover the required region of Intersect(dom, bound)
168+ * The bound region may not get directly intersected with dom region, instead we try to generate
169+ * extra predicates for non-trivial bound. The domain info class can also union with each other.
170+ */
171+ struct BlockVarDomainInfo {
172+ arith::IntSet dom{arith::IntSet::Nothing ()}; // dom is ensured to be bounded
173+ arith::IntSet bound{arith::IntSet::Nothing ()};
174+
175+ /* ! \brief Relaxed union operation */
176+ void Union (const BlockVarDomainInfo& other) {
177+ // just relax (d0 ^ b0) v (d1 ^ b1) to (d0 v d1) ^ (b0 v b1)
178+ dom = arith::Union ({dom, other.dom });
179+ bound = arith::Union ({bound, other.bound });
180+ }
181+
182+ /* ! \brief Simplify domain info */
183+ void Simplify (arith::Analyzer* analyzer) {
184+ auto to_simplified = [analyzer](const arith::IntSet& set) {
185+ PrimExpr min = set.HasLowerBound () ? analyzer->Simplify (set.min ()) : set.min ();
186+ PrimExpr max = set.HasUpperBound () ? analyzer->Simplify (set.max ()) : set.max ();
187+ return arith::IntSet::Interval (min, max);
188+ };
189+ // if no dom specified, try use bound as dom
190+ if (dom.IsNothing ()) {
191+ if (bound.HasLowerBound () && bound.HasUpperBound ()) {
192+ bound = to_simplified (bound);
193+ std::swap (dom, bound);
194+ }
195+ return ;
196+ }
197+ // simplify intsets
198+ dom = to_simplified (dom);
199+ bound = to_simplified (bound);
200+ // if can proof the dom is within bound, remove bound
201+ auto intersect = to_simplified (arith::Intersect ({dom, bound}));
202+ if (analyzer->CanProveEqual (dom.min (), intersect.min ()) &&
203+ analyzer->CanProveEqual (dom.max (), intersect.max ())) {
204+ bound = arith::IntSet::Nothing ();
205+ }
206+ }
207+ };
208+
166209/* !
167210 * \brief A helper to reconstruct the block scope where the given block is moved under the given
168211 * loop, and the given block's induced loop nest is regenerated to satisfy the required region.
@@ -179,29 +222,44 @@ class ScopeReconstructor : private StmtMutator {
179222 * \param insert_position The position among the subtrees where the block and its induced loop
180223 * nest is inserted
181224 * \param iter_doms The domain of each block var
225+ * \param analyzer The arithmetic analyzer
182226 * \param preserve_unit_loops Whether to generate unit loops where the loop extent is 1
183227 */
184- void MakeNewLoop (int insert_position, std::vector<Range> iter_doms, bool preserve_unit_loops) {
228+ void MakeNewLoop (int insert_position, std::vector<BlockVarDomainInfo> iter_doms,
229+ arith::Analyzer* analyzer, bool preserve_unit_loops) {
185230 int n_iters = iter_doms.size ();
186231 Array<Var> loop_vars;
187232 Array<PrimExpr> loop_extents;
188233 Array<PrimExpr> iter_values;
189234 loop_vars.reserve (n_iters);
190235 loop_extents.reserve (n_iters);
191236 iter_values.reserve (n_iters);
237+ PrimExpr predicate = const_true ();
192238 for (int i = 0 ; i < n_iters; ++i) {
193- const Range& iter_dom = iter_doms[i];
239+ Range iter_dom = iter_doms[i]. dom . CoverRange (block_-> iter_vars [i]-> dom ) ;
194240 if (preserve_unit_loops || !is_one (iter_dom->extent )) {
195241 Var var (" ax" + std::to_string (loop_vars.size ()), DataType::Int (32 ));
196242 loop_vars.push_back (var);
197243 loop_extents.push_back (iter_dom->extent );
198244 iter_values.push_back (iter_dom->min + var);
245+ analyzer->Bind (var, Range::FromMinExtent (0 , iter_dom->extent ));
199246 } else {
200247 iter_values.push_back (iter_dom->min );
201248 }
249+ const arith::IntSet& pred_bound = iter_doms[i].bound ;
250+ if (!pred_bound.IsNothing ()) {
251+ if (pred_bound.HasLowerBound ()) {
252+ PrimExpr lower_bound = iter_values[i] >= pred_bound.min ();
253+ predicate = predicate && lower_bound;
254+ }
255+ if (pred_bound.HasUpperBound ()) {
256+ PrimExpr upper_bound = iter_values[i] < pred_bound.max () + 1 ;
257+ predicate = predicate && upper_bound;
258+ }
259+ }
202260 }
203261 this ->new_block_realize_ =
204- BlockRealize (std::move (iter_values), const_true ( ), std::move (block_));
262+ BlockRealize (std::move (iter_values), analyzer-> Simplify (predicate ), std::move (block_));
205263 Stmt new_subtree = this ->new_block_realize_ ;
206264 for (int i = static_cast <int >(loop_vars.size ()) - 1 ; i >= 0 ; --i) {
207265 const Var& loop_var = loop_vars[i];
@@ -310,39 +368,74 @@ void RelaxBufferRegions(const Map<Var, PrimExpr>& binding,
310368 * domain
311369 * \param provided The provided integer set to cover the required domain
312370 * \param required The required domain to be covered
313- * \param iter_doms The result iteration domains to be updated
314371 * \param analyzer The arithmetic analyzer
315372 */
316- void UpdateBlockVarDomain ( const arith::IntSet& provided, const arith::IntSet& required ,
317- std::unordered_map< const VarNode*, std::vector< arith::IntSet>>* iter_doms ,
318- arith::Analyzer* analyzer) {
373+ std::pair<Var, arith::IntSet> SolveBlockVarDomain ( const arith::IntSet& provided ,
374+ const arith::IntSet& required ,
375+ arith::Analyzer* analyzer) {
319376 PrimExpr provided_min = analyzer->Simplify (provided.min ());
320- PrimExpr provided_extent = analyzer->Simplify (provided.max () - provided_min + 1 );
377+ PrimExpr provided_max = analyzer->Simplify (provided.max ());
321378 PrimExpr required_min = analyzer->Simplify (required.min ());
322- PrimExpr required_extent = analyzer->Simplify (required.max () - required_min + 1 );
323- PrimExpr dom_min{nullptr }, dom_extent {nullptr };
379+ PrimExpr required_max = analyzer->Simplify (required.max ());
380+ PrimExpr dom_min{nullptr }, dom_max {nullptr };
324381 Var dom_var{ObjectPtr<VarNode>{nullptr }};
325382 arith::PVar<Var> p_v;
326383 arith::PVar<PrimExpr> p_e;
327384 if ((p_v * p_e).Match (provided_min) || (p_e * p_v).Match (provided_min)) {
328385 PrimExpr e = p_e.Eval ();
329386 dom_var = p_v.Eval ();
330387 dom_min = floordiv (required_min, e);
331- dom_extent = analyzer->Simplify ((required_extent + e - 1 ) / e);
332- } else if (analyzer->CanProveEqual (provided_extent, 1 ) && p_v.Match (provided_min)) {
333- dom_var = p_v.Eval ();
334- dom_min = required_min;
335- dom_extent = required_extent;
336- } else {
337- ICHECK (false ) << " ValueError: BufferRegion pattern match failed" ;
388+ dom_max = floordiv (required_max, e);
389+ } else if (analyzer->CanProveEqual (provided_min, provided_max)) {
390+ if (p_v.Match (provided_min)) {
391+ dom_var = p_v.Eval ();
392+ dom_min = required_min;
393+ dom_max = required_max;
394+ } else {
395+ arith::PVar<PrimExpr> p_f;
396+ if ((floordiv (p_v, p_f)).Match (provided_min)) {
397+ // a <= (x // factor) <= b, fac > 0 ==> (a * fac) <= x <= (b * fac + fac - 1)
398+ PrimExpr fac = p_f.Eval ();
399+ if (analyzer->CanProveGreaterEqual (fac, 1 )) {
400+ dom_var = p_v.Eval ();
401+ dom_min = required_min * fac;
402+ dom_max = analyzer->Simplify (required_max * fac + fac - 1 );
403+ }
404+ } else if ((floormod (p_v, p_f).Match (provided_min))) {
405+ // generally domain of (x % fac) enforce no constraints to domain of x
406+ dom_var = p_v.Eval ();
407+ return std::make_pair (dom_var, arith::IntSet::Nothing ());
408+ }
409+ }
338410 }
339- auto it = iter_doms->find (dom_var.get ());
411+ ICHECK (dom_var.defined ()) << " ValueError: BufferRegion pattern match failed: " << provided_min;
412+ return std::make_pair (dom_var, arith::IntSet::Interval (dom_min, dom_max));
413+ }
414+
415+ /* !
416+ * \brief Calculate and update the iteration domain info to fully cover the required domain
417+ * \param provided The provided integer set to cover the required domain
418+ * \param required The required domain to be covered
419+ * \param required_bound The additional region bound of the required domain to be covered
420+ * \param iter_doms The result iteration domains to be updated
421+ * \param analyzer The arithmetic analyzer
422+ */
423+ void UpdateBlockVarDomain (const arith::IntSet& provided, const arith::IntSet& required,
424+ const arith::IntSet& required_bound,
425+ std::unordered_map<const VarNode*, BlockVarDomainInfo>* iter_doms,
426+ arith::Analyzer* analyzer) {
427+ auto var_with_dom = SolveBlockVarDomain (provided, required, analyzer);
428+ auto var_with_bound = SolveBlockVarDomain (provided, required_bound, analyzer);
429+ const Var& var = var_with_dom.first ;
430+ const auto & var_dom = var_with_dom.second ;
431+ const auto & var_bound = var_with_bound.second ;
432+ ICHECK (var.same_as (var_with_bound.first ));
433+ auto it = iter_doms->find (var.get ());
340434 if (it != iter_doms->end ()) {
341- std::vector<arith::IntSet>& doms = it->second ;
342- doms.push_back (arith::IntSet::FromMinExtent (dom_min, dom_extent));
435+ it->second .Union ({var_dom, var_bound});
343436 } else {
344- ICHECK (analyzer->CanProveEqual (provided_min, required_min ));
345- ICHECK (analyzer->CanProveEqual (provided_extent, required_extent ));
437+ ICHECK (analyzer->CanProveEqual (provided. min (), required. min () ));
438+ ICHECK (analyzer->CanProveEqual (provided. max (), required. max () ));
346439 }
347440}
348441
@@ -352,19 +445,19 @@ void UpdateBlockVarDomain(const arith::IntSet& provided, const arith::IntSet& re
352445 * \param provided_regions The region provided by one iteration instance of the block vars
353446 * \param required_regions The region required to be covered
354447 * \param analyzer The arithmetic analyzer
355- * \return A list of iteration domain corresponding to the given list of block vars
448+ * \return A list of iteration domain info corresponding to the given list of block vars
356449 */
357- std::vector<Range > CalculateBlockVarDomain (
450+ std::vector<BlockVarDomainInfo > CalculateBlockVarDomain (
358451 const Array<IterVar>& iter_vars,
359452 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> provided_regions,
360453 std::unordered_map<const BufferNode*, std::vector<NDIntSet>> required_regions,
361454 arith::Analyzer* analyzer) {
362455 int n_iters = iter_vars.size ();
363456 // Step 1. Construct the mapping from block var to their iteration domain (initialized to empty)
364- std::unordered_map<const VarNode*, std::vector<arith::IntSet> > iter_doms;
457+ std::unordered_map<const VarNode*, BlockVarDomainInfo > iter_doms;
365458 iter_doms.reserve (n_iters);
366459 for (const IterVar& iter_var : iter_vars) {
367- iter_doms[iter_var->var .get ()] = {} ;
460+ iter_doms[iter_var->var .get ()] = BlockVarDomainInfo () ;
368461 }
369462 // Step 2. For each buffer, update the domain according to the provided and required regions
370463 for (const auto & kv : provided_regions) {
@@ -384,23 +477,23 @@ std::vector<Range> CalculateBlockVarDomain(
384477 for (int i = 0 ; i < ndim; ++i) {
385478 arith::IntSet provided = provided_region[i];
386479 arith::IntSet required = required_region[i];
387- required = arith::Intersect (
388- {std::move (required), arith::IntSet::FromMinExtent (Integer (0 ), buffer->shape [i])});
389- UpdateBlockVarDomain (provided, required, &iter_doms, analyzer);
480+ arith::IntSet required_bound = arith::IntSet::FromMinExtent (Integer (0 ), buffer->shape [i]);
481+ UpdateBlockVarDomain (provided, required, required_bound, &iter_doms, analyzer);
390482 }
391483 }
392484 // Union the iter var domains, put them in the same order of block vars, and return
393- std::vector<Range > result;
485+ std::vector<BlockVarDomainInfo > result;
394486 result.reserve (n_iters);
395487 for (const IterVar& iter_var : iter_vars) {
396- const std::vector<arith::IntSet>& doms = iter_doms.at (iter_var->var .get ());
397- arith::IntSet dom = arith::IntSet::FromRange (iter_var->dom );
398- if (!doms.empty ()) {
399- dom = arith::Intersect ({std::move (dom), arith::Union (doms)});
488+ BlockVarDomainInfo& info = iter_doms.at (iter_var->var .get ());
489+ if (info.bound .IsNothing ()) {
490+ info.bound = arith::IntSet::FromRange (iter_var->dom );
491+ } else {
492+ info.bound = arith::Intersect ({info.bound , arith::IntSet::FromRange (iter_var->dom )});
400493 }
401- PrimExpr min = analyzer-> Simplify (dom. min () );
402- PrimExpr extent = analyzer-> Simplify ( dom.max () - min + 1 );
403- result.push_back (Range::FromMinExtent (min, extent) );
494+ info. Simplify (analyzer );
495+ ICHECK (!info. dom .IsNothing () );
496+ result.push_back (info );
404497 }
405498 return result;
406499}
@@ -498,14 +591,14 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
498591 /* consumer_srefs=*/ std::move (consumer_srefs),
499592 /* provided_regions=*/ &provided_regions, /* required_regions=*/ &required_regions);
500593 // Step 5. Calculate the iteration domain for each block var
501- std::vector<Range > iter_doms =
594+ std::vector<BlockVarDomainInfo > iter_doms =
502595 CalculateBlockVarDomain (/* iter_vars=*/ block->iter_vars ,
503596 /* provided_regions=*/ std::move (provided_regions),
504597 /* required_regions=*/ std::move (required_regions),
505598 /* analyzer=*/ analyzer);
506599 // Step 6. Create the new scope according to the iteration domain
507600 reconstructor.MakeNewLoop (/* insert_position=*/ insert_position, /* iter_doms=*/ std::move (iter_doms),
508- /* preserve_unit_loops=*/ preserve_unit_loops);
601+ /* analyzer= */ analyzer, /* preserve_unit_loops=*/ preserve_unit_loops);
509602 Block new_scope_root = Downcast<Block>(reconstructor (scope_root));
510603
511604 // Step 7. Do the actual replacement
0 commit comments