diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index a012b6e80c08..587de531f28f 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -308,7 +308,7 @@ class IterMapRewriter : public ExprMutator { if (expr->IsInstance()) { ErrorLogger(this) << "IterMapExpr or subclasses should only result from calls in " << "IterMapRewriter using DirectMutate. " - << "Indirect return occurred in " << tvm::PrettyPrint(input_expr); + << "Indirect return occurred in " << input_expr; } return expr; } @@ -324,8 +324,11 @@ class IterMapRewriter : public ExprMutator { PrimExpr VisitExpr_(const FloorModNode* op) final; private: - // Preprocessing common to both FloorDiv and FloorMod - IterSumExpr PreprocessDividend(IterMapExpr dividend); + /* \brief Preprocessing common to both FloorDiv and FloorMod + * + * \param dividend The dividend to be manipulated. + */ + IterSumExpr PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend); // Create an iterator that represents the expression (split+base), with // padding such that the iterator's extents are evenly divisible by @@ -1238,14 +1241,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) { } } -IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend) { +IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr original_dividend) { if (dividend->IsInstance()) { auto split = Downcast(dividend); return IterSumExpr({split}, make_zero(split.dtype())); } else if (dividend->IsInstance()) { auto opt_fused = TryFuseIters(Downcast(dividend)); if (!opt_fused) { - ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(dividend) + ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend) << ", can't be written as a single fused IterSum"; return IterSumExpr(); } @@ -1495,7 +1498,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) { return GetRef(op); } - IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { return GetRef(op); } @@ -1580,7 +1583,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) { return GetRef(op); } - IterSumExpr preprocessed = PreprocessDividend(Downcast(a)); + IterSumExpr preprocessed = PreprocessDividend(Downcast(a), op->a); if (!preprocessed.defined()) { return GetRef(op); }