Skip to content

Commit d4ca123

Browse files
sisleyliBin Li
andauthored
[BugFix] Support rewrite_once when the number of callbacks > 1 (#14344)
* [BugFix] Support rewrite_once when the number of callbacks > 1 * callbacks_map -> done, swapping false and true --------- Co-authored-by: Bin Li <[email protected]>
1 parent 50b3ae4 commit d4ca123

File tree

2 files changed

+94
-22
lines changed

2 files changed

+94
-22
lines changed

src/relay/ir/dataflow_matcher.cc

Lines changed: 24 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -796,24 +796,35 @@ Expr PatternRewriter::Rewrite(const Array<DFPatternCallback>& callbacks, const E
796796
bool equal = true;
797797
static auto* structural_equal = runtime::Registry::Get("node.StructuralEqual");
798798
ICHECK(structural_equal) << "node.StructuralEqual is not registered.";
799+
// Keep track of callbacks that have finished rewriting
800+
std::unordered_map<DFPatternCallback, bool, ObjectPtrHash, ObjectPtrEqual> done;
799801
do {
800802
last = post;
801803
for (auto callback : callbacks) {
802-
callback_ = callback;
803-
if (callback_->require_type) {
804-
post = InferTypeWithModule(post, mod_);
805-
}
806-
auto grouper = PatternGrouper();
807-
groups_ = grouper.GroupMatches(callback_->pattern, post);
808-
gid_assignments_ = grouper.GetGIDAssignments();
809-
memo_.clear();
810-
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
811-
post = this->VisitExpr(post);
812-
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
813-
count++;
804+
if (!done[callback]) {
805+
auto before = post;
806+
callback_ = callback;
807+
if (callback_->require_type) {
808+
post = InferTypeWithModule(post, mod_);
809+
}
810+
auto grouper = PatternGrouper();
811+
groups_ = grouper.GroupMatches(callback_->pattern, post);
812+
gid_assignments_ = grouper.GetGIDAssignments();
813+
memo_.clear();
814+
VLOG(1) << "pre rewritten:" << std::endl << PrettyPrint(pre);
815+
post = this->VisitExpr(post);
816+
VLOG(1) << "post rewritten:" << std::endl << PrettyPrint(post);
817+
count++;
818+
if (callback_->rewrite_once) {
819+
bool current_equal = (*structural_equal)(before, post, false, true);
820+
if (!current_equal) {
821+
done[callback] = true;
822+
}
823+
}
824+
}
814825
}
815826
equal = (*structural_equal)(last, post, false, true);
816-
} while (!equal && count < 100 && !callback_->rewrite_once);
827+
} while (!equal && count < 100);
817828
if (count >= 100) {
818829
LOG(FATAL) << "Observed 100 rewrite passes, possible conflicting passes?";
819830
}

tests/python/relay/test_dataflow_pattern.py

Lines changed: 70 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1804,22 +1804,83 @@ def callback(self, pre, post, node_map):
18041804
if new_args:
18051805
return relay.op.concatenate(relay.expr.Tuple(new_args), axis=0)
18061806
else:
1807-
return concat_args
1807+
return concat_args[0]
18081808

18091809
x = relay.var("x")
18101810
y = relay.var("y")
18111811
z = relay.var("z")
18121812
concat = relay.op.concatenate(relay.expr.Tuple([x, y, z]), axis=0)
18131813

1814-
# Let the rewriter run recursively
1815-
out = rewrite(ConcatRewriter(False), concat)
1816-
expected = relay.expr.Tuple([x])
1817-
assert tvm.ir.structural_equal(out, expected)
1814+
def test_one_callback():
1815+
# Let the rewriter run recursively
1816+
out = rewrite(ConcatRewriter(False), concat)
1817+
expected = x
1818+
assert tvm.ir.structural_equal(out, expected)
1819+
1820+
# Run the rewriter once
1821+
out = rewrite(ConcatRewriter(True), concat)
1822+
expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
1823+
assert tvm.ir.structural_equal(out, expected)
1824+
1825+
def test_multi_callbacks():
1826+
# This class recursively add a nn.relu operator after nn.softmax
1827+
class OneMoreReluRewriter(DFPatternCallback):
1828+
def __init__(self, rewrite_once):
1829+
super().__init__(rewrite_once=rewrite_once)
1830+
self.pattern = is_op("nn.softmax")(None)
1831+
1832+
def callback(self, pre, post, node_map):
1833+
return relay.nn.relu(post)
1834+
1835+
def before():
1836+
# Before:
1837+
# x y z
1838+
# | | |
1839+
# concat
1840+
# |
1841+
# softmax
1842+
return relay.nn.softmax(concat)
1843+
1844+
def once_concat():
1845+
# ConcatRewrite once, OneMoreReluRewrite once
1846+
# Expected:
1847+
# x y
1848+
# | |
1849+
# concat
1850+
# |
1851+
# softmax
1852+
# |
1853+
# relu
1854+
return relay.nn.relu(
1855+
relay.nn.softmax(relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0))
1856+
)
1857+
1858+
def recursive_concat():
1859+
# ConcatRewrite recursively, OneMoreReluRewrite once
1860+
# Expected:
1861+
# x
1862+
# |
1863+
# softmax
1864+
# |
1865+
# relu
1866+
return relay.nn.relu(relay.nn.softmax(x))
1867+
1868+
# Run ConcatRewriter once, OneMoreReluRewriter once
1869+
out = rewrite(
1870+
[OneMoreReluRewriter(True), ConcatRewriter(True)],
1871+
before(),
1872+
)
1873+
assert tvm.ir.structural_equal(out, once_concat())
1874+
1875+
# Run ConcatRewriter recursively, OneMoreReluRewriter once
1876+
out = rewrite(
1877+
[OneMoreReluRewriter(True), ConcatRewriter(False)],
1878+
before(),
1879+
)
1880+
assert tvm.ir.structural_equal(out, recursive_concat())
18181881

1819-
# Run the rewriter once
1820-
out = rewrite(ConcatRewriter(True), concat)
1821-
expected = relay.op.concatenate(relay.expr.Tuple([x, y]), axis=0)
1822-
assert tvm.ir.structural_equal(out, expected)
1882+
test_one_callback()
1883+
test_multi_callbacks()
18231884

18241885

18251886
def test_matched_outside_but_dominated():

0 commit comments

Comments
 (0)