diff --git a/src/relay/ir/dataflow_matcher.cc b/src/relay/ir/dataflow_matcher.cc index 3e86e1c8eaf9..9d117adbbcaf 100644 --- a/src/relay/ir/dataflow_matcher.cc +++ b/src/relay/ir/dataflow_matcher.cc @@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const Array& callbacks, const E std::unordered_map done; do { last = post; + // We don't have to call InferType if previous pass has not modified anything + // We can just take previous typed state of the expression + bool types_invalidated = true; for (auto callback : callbacks) { if (!done[callback]) { auto before = post; + auto post_typed = post; callback_ = callback; - if (callback_->require_type) { - post = InferTypeWithModule(post, mod_); + if (callback_->require_type && types_invalidated) { + post_typed = InferTypeWithModule(post, mod_); } auto grouper = PatternGrouper(); - groups_ = grouper.GroupMatches(callback_->pattern, post); + groups_ = grouper.GroupMatches(callback_->pattern, post_typed); gid_assignments_ = grouper.GetGIDAssignments(); memo_.clear(); VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre); - post = this->VisitExpr(post); + post = this->VisitExpr(post_typed); VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post); count++; - if (callback_->rewrite_once) { - bool current_equal = (*structural_equal)(before, post, false, true); - if (!current_equal) { + bool current_equal = (*structural_equal)(before, post, false, true); + if (callback_->require_type && current_equal) { + types_invalidated = false; + post = post_typed; + } else { + types_invalidated = true; + if (callback_->rewrite_once) { done[callback] = true; } }