Skip to content

Commit 00395ae

Browse files
authored
[Relax][Bugfix] Provide the full Expr to pattern-match rewriter (#16828)
* [Relax][Bugfix] Provide the full Expr to pattern-match rewriter This resolves a bug that was introduced in #16732. If a rewriter function returned a no-op, and the pattern-match continued, then the `matches` provided to the rewriter function in subsequent calls would contain a variable to which the matched expression was bound, not the matched expression itself. (e.g. For a match of `C = R.add(A,B)`, passing `C` to the rewriter instead of `R.add(A,B)`.) This bug was caused by incorrect re-wrapping of `OrPattern` in `ExprPatternRewriter`. Prior to #16732, all pattern-match results were populated by `ExtractMatchExpr`, and contained the result after applying `TryGetValOfVar`. When re-wrapping the result of an `OrPattern`, #16732 populated the additional matches with the result before applying `TryGetValOfVar`. This commit fixes the bug by applying `TryGetValOfVar`. * Update with PR link of bugfix
1 parent 3f615dc commit 00395ae

File tree

2 files changed

+44
-2
lines changed

2 files changed

+44
-2
lines changed

src/relax/ir/dataflow_matcher.cc

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1190,8 +1190,17 @@ class ExprPatternRewriter : ExprMutator {
11901190

11911191
if (auto opt_matches = ExtractMatchedExpr(pattern, expr, bindings_)) {
11921192
auto matches = opt_matches.value();
1193-
for (const auto& pat : *matches_top_level) {
1194-
matches.Set(pat, expr);
1193+
1194+
// Append any additional matches that from the unwrapped
1195+
// `OrPattern`. When matching against `pat = pat_lhs |
1196+
// pat_rhs`, we call `ExtractMatchedExpr` on `pat_lhs` and
1197+
// `pat_rhs` separately. The top-level `pat` is never seen by
1198+
// `ExtractMatchedExpr`, and must be re-added afterward.
1199+
if (matches_top_level->size()) {
1200+
auto matched_expr = TryGetValOfVar(expr, bindings_);
1201+
for (const auto& pat : *matches_top_level) {
1202+
matches.Set(pat, matched_expr);
1203+
}
11951204
}
11961205

11971206
Expr rewritten_expr = rewriter_func_(expr, matches);

tests/python/relax/test_dataflow_pattern.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1952,5 +1952,38 @@ def expected():
19521952
tvm.ir.assert_structural_equal(expected, after)
19531953

19541954

1955+
def test_backtrack_for_no_op_rewriter_does_not_match_on_var():
1956+
"""The matches should always contain the bound value
1957+
1958+
This is a regression test. In versions from
1959+
https://github.com/apache/tvm/pull/16732 to
1960+
https://github.com/apache/tvm/pull/16828, the `rewrite_call`
1961+
function could erroneously call the rewriter with `expr` and
1962+
`matches[pat]` set to a variable (`C`) instead of the value to
1963+
which it is bound (`R.add(A,B)`).
1964+
"""
1965+
pat_a = is_op("relax.add")(wildcard(), wildcard())
1966+
pat_b = is_op("relax.add")(wildcard(), wildcard())
1967+
pat = pat_a | pat_b
1968+
1969+
def rewriter(expr, matches):
1970+
assert isinstance(matches[pat], rx.Call)
1971+
return expr
1972+
1973+
@R.function(private=True)
1974+
def before():
1975+
with R.dataflow():
1976+
A = R.ones([64, 128], "int32")
1977+
B = R.zeros([64, 128], "int32")
1978+
C = R.add(A, B)
1979+
1980+
R.output(C)
1981+
return C
1982+
1983+
expected = before
1984+
after = rewrite_call(pat, rewriter, before)
1985+
tvm.ir.assert_structural_equal(expected, after)
1986+
1987+
19551988
if __name__ == "__main__":
19561989
tvm.testing.main()

0 commit comments

Comments
 (0)