Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,7 @@ class IterMapRewriter : public ExprMutator {
if (expr->IsInstance<IterMapExprNode>()) {
ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in "
<< "IterMapRewriter using DirectMutate. "
<< "Indirect return occurred in " << tvm::PrettyPrint(input_expr);
<< "Indirect return occurred in " << input_expr;
}
return expr;
}
Expand All @@ -324,8 +324,11 @@ class IterMapRewriter : public ExprMutator {
PrimExpr VisitExpr_(const FloorModNode* op) final;

private:
// Preprocessing common to both FloorDiv and FloorMod
IterSumExpr PreprocessDividend(IterMapExpr dividend);
/* \brief Preprocessing common to both FloorDiv and FloorMod
*
* \param dividend The dividend to be manipulated.
*/
IterSumExpr PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend);

// Create an iterator that represents the expression (split+base), with
// padding such that the iterator's extents are evenly divisible by
Expand Down Expand Up @@ -1238,14 +1241,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
}
}

IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend) {
IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) {
if (dividend->IsInstance<IterSplitExprNode>()) {
auto split = Downcast<IterSplitExpr>(dividend);
return IterSumExpr({split}, make_zero(split.dtype()));
} else if (dividend->IsInstance<IterSumExprNode>()) {
auto opt_fused = TryFuseIters(Downcast<IterSumExpr>(dividend));
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(dividend)
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
return IterSumExpr();
}
Expand Down Expand Up @@ -1495,7 +1498,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
return GetRef<PrimExpr>(op);
}

IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a));
IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a);
if (!preprocessed.defined()) {
return GetRef<PrimExpr>(op);
}
Expand Down Expand Up @@ -1580,7 +1583,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
return GetRef<PrimExpr>(op);
}

IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a));
IterSumExpr preprocessed = PreprocessDividend(Downcast<IterMapExpr>(a), op->a);
if (!preprocessed.defined()) {
return GetRef<PrimExpr>(op);
}
Expand Down