@@ -308,7 +308,7 @@ class IterMapRewriter : public ExprMutator {
308308 if (expr->IsInstance <IterMapExprNode>()) {
309309 ErrorLogger (this ) << " IterMapExpr or subclasses should only result from calls in "
310310 << " IterMapRewriter using DirectMutate. "
311- << " Indirect return occurred in " << tvm::PrettyPrint ( input_expr) ;
311+ << " Indirect return occurred in " << input_expr;
312312 }
313313 return expr;
314314 }
@@ -324,8 +324,11 @@ class IterMapRewriter : public ExprMutator {
324324 PrimExpr VisitExpr_ (const FloorModNode* op) final ;
325325
326326 private:
327- // Preprocessing common to both FloorDiv and FloorMod
328- IterSumExpr PreprocessDividend (IterMapExpr dividend);
327+ /* \brief Preprocessing common to both FloorDiv and FloorMod
328+ *
329+ * \param dividend The dividend to be manipulated.
330+ */
331+ IterSumExpr PreprocessDividend (IterMapExpr dividend, PrimExpr original_dividend);
329332
330333 // Create an iterator that represents the expression (split+base), with
331334 // padding such that the iterator's extents are evenly divisible by
@@ -1238,14 +1241,14 @@ PrimExpr IterMapRewriter::VisitExpr_(const MulNode* op) {
12381241 }
12391242}
12401243
1241- IterSumExpr IterMapRewriter::PreprocessDividend (IterMapExpr dividend) {
1244+ IterSumExpr IterMapRewriter::PreprocessDividend (IterMapExpr dividend, PrimExpr original_dividend ) {
12421245 if (dividend->IsInstance <IterSplitExprNode>()) {
12431246 auto split = Downcast<IterSplitExpr>(dividend);
12441247 return IterSumExpr ({split}, make_zero (split.dtype ()));
12451248 } else if (dividend->IsInstance <IterSumExprNode>()) {
12461249 auto opt_fused = TryFuseIters (Downcast<IterSumExpr>(dividend));
12471250 if (!opt_fused) {
1248- ErrorLogger (this ) << " Dividend " << tvm::PrettyPrint (dividend )
1251+ ErrorLogger (this ) << " Dividend " << tvm::PrettyPrint (original_dividend )
12491252 << " , can't be written as a single fused IterSum" ;
12501253 return IterSumExpr ();
12511254 }
@@ -1495,7 +1498,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorDivNode* op) {
14951498 return GetRef<PrimExpr>(op);
14961499 }
14971500
1498- IterSumExpr preprocessed = PreprocessDividend (Downcast<IterMapExpr>(a));
1501+ IterSumExpr preprocessed = PreprocessDividend (Downcast<IterMapExpr>(a), op-> a );
14991502 if (!preprocessed.defined ()) {
15001503 return GetRef<PrimExpr>(op);
15011504 }
@@ -1580,7 +1583,7 @@ PrimExpr IterMapRewriter::VisitExpr_(const FloorModNode* op) {
15801583 return GetRef<PrimExpr>(op);
15811584 }
15821585
1583- IterSumExpr preprocessed = PreprocessDividend (Downcast<IterMapExpr>(a));
1586+ IterSumExpr preprocessed = PreprocessDividend (Downcast<IterMapExpr>(a), op-> a );
15841587 if (!preprocessed.defined ()) {
15851588 return GetRef<PrimExpr>(op);
15861589 }
0 commit comments