Skip to content

Commit ceb8e22

Browse files
authored
[Relax] Improve CanonicalizeBindings in DataflowVar edge case (#16783)
* [Relax] Improve CanonicalizeBindings in DataflowVar edge case If there is a trivial binding of `Var = DataflowVar`, but the non-dataflow variable is never used outside the dataflow block in which is is declared, then we should keep the name of the upstream `DataflowVar`, as it is more likely to be the human-readable name (e.g. a function parameter). * Update comment for used/not used Var * ci bump
1 parent 2f88977 commit ceb8e22

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

src/relax/transform/canonicalize_bindings.cc

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,21 @@ class CanonicalizePlanner : public ExprVisitor {
9191
bound_to = opt.value();
9292
}
9393

94-
if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>()) {
94+
if (bound_var.as<DataflowVarNode>() || !bound_to.as<DataflowVarNode>() ||
95+
!visitor.used_outside_home_dataflow_.count(bound_var)) {
9596
// Case 1: Var = Var
9697
// Case 2: DataflowVar = Var
9798
// Case 3: DataflowVar = DataflowVar
99+
// Case 4a: Var = DataflowVar, where the Var is not used
100+
// outside the DataflowBlock containing the binding
98101
//
99-
// For these three cases, the trivial binding can be
100-
// unwrapped, using the bound variable directly at the point
101-
// of use.
102+
// For these four cases, the trivial binding can be unwrapped,
103+
// using the bound variable directly at the point of use.
102104
plan.replace_usage.Set(bound_var->vid, bound_to);
103105
plan.bindings_to_remove.insert(bound_var->vid);
104106
} else {
105-
// Case 4: Var = DataflowVar
107+
// Case 4b: Var = DataflowVar, where the Var is used somewhere
108+
// outside the DataflowBlock containing the binding
106109
//
107110
// Replacing a Var with a DataflowVar could result in illegal
108111
// use of a DataflowVar outside of a DataflowBlock. Instead,

tests/python/relax/test_transform_canonicalize_bindings.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,5 +977,39 @@ def main():
977977
verify(TestChainAssignments, Expected)
978978

979979

980+
def test_trivial_binding_of_replaced_non_dataflow_var():
981+
@I.ir_module
982+
class Before:
983+
@R.function
984+
def main(param_tuple: R.Tuple([R.Tensor])):
985+
with R.dataflow():
986+
A = param_tuple[0]
987+
B = A
988+
C = R.add(A, B)
989+
R.output(A, B, C)
990+
return C
991+
992+
@I.ir_module
993+
class Expected:
994+
@R.function
995+
def main(param_tuple: R.Tuple([R.Tensor])):
996+
with R.dataflow():
997+
A = param_tuple[0]
998+
C = R.add(A, A)
999+
R.output(C)
1000+
return C
1001+
1002+
After = CanonicalizeBindings()(Before)
1003+
tvm.ir.assert_structural_equal(After, Expected)
1004+
1005+
def _get_binding_names(mod):
1006+
return [binding.var.name_hint for binding in mod["main"].body.blocks[0].bindings]
1007+
1008+
expected_names = _get_binding_names(Expected)
1009+
after_names = _get_binding_names(After)
1010+
1011+
assert after_names == expected_names
1012+
1013+
9801014
if __name__ == "__main__":
9811015
tvm.testing.main()

0 commit comments

Comments
 (0)