Skip to content

Commit 050b23f

Browse files
authored
[Relay]Disable InferType if it was done and no changes after previous pass (#17585)
Disable InferType if it was done and no changes after previous pass This optimizatin allows to speedup PatternRewriter transformations by reusing of preious type inferred expression instead of perform InferType multiple times
1 parent 8083520 commit 050b23f

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

src/relay/ir/dataflow_matcher.cc

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,24 +851,32 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
851851
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done;
852852
do {
853853
last = post;
854+
// We don't have to call InferType if previous pass has not modified anything
855+
// We can just take previous typed state of the expression
856+
bool types_invalidated = true;
854857
for (auto callback : callbacks) {
855858
if (!done[callback]) {
856859
auto before = post;
860+
auto post_typed = post;
857861
callback_ = callback;
858-
if (callback_->require_type) {
859-
post = InferTypeWithModule(post, mod_);
862+
if (callback_->require_type && types_invalidated) {
863+
post_typed = InferTypeWithModule(post, mod_);
860864
}
861865
auto grouper = PatternGrouper();
862-
groups_ = grouper.GroupMatches(callback_->pattern, post);
866+
groups_ = grouper.GroupMatches(callback_->pattern, post_typed);
863867
gid_assignments_ = grouper.GetGIDAssignments();
864868
memo_.clear();
865869
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
866-
post = this->VisitExpr(post);
870+
post = this->VisitExpr(post_typed);
867871
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
868872
count++;
869-
if (callback_->rewrite_once) {
870-
bool current_equal = (*structural_equal)(before, post, false, true);
871-
if (!current_equal) {
873+
bool current_equal = (*structural_equal)(before, post, false, true);
874+
if (callback_->require_type && current_equal) {
875+
types_invalidated = false;
876+
post = post_typed;
877+
} else {
878+
types_invalidated = true;
879+
if (callback_->rewrite_once) {
872880
done[callback] = true;
873881
}
874882
}

0 commit comments

Comments
 (0)