diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index 77970d1b667f..f492c743173c 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -249,7 +249,10 @@ def visit_constructor(self, con): return con def visit_match(self, m): - return Match(self.visit(m.data), [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses]) + return Match( + self.visit(m.data), + [Clause(c.lhs, self.visit(c.rhs)) for c in m.clauses], + complete=m.complete) def visit_ref_create(self, r): return RefCreate(self.visit(r.value)) diff --git a/tests/python/relay/test_expr_functor.py b/tests/python/relay/test_expr_functor.py index 4b0adff57ff9..5c923655d7b7 100644 --- a/tests/python/relay/test_expr_functor.py +++ b/tests/python/relay/test_expr_functor.py @@ -125,6 +125,16 @@ def test_match(): p = relay.prelude.Prelude() check_visit(p.mod[p.map]) + +def test_match_completeness(): + p = relay.prelude.Prelude() + for completeness in [True, False]: + match_expr = relay.adt.Match(p.nil, [], complete=completeness) + result_expr = ExprMutator().visit(match_expr) + # ensure the mutator doesn't mangle the completeness flag + assert result_expr.complete == completeness + + if __name__ == "__main__": test_constant() test_tuple() @@ -139,3 +149,4 @@ def test_match(): test_ref_write() test_memo() test_match() + test_match_completeness()