Skip to content

Commit 63c722b

Browse files
committed
try fix
1 parent 9132cb6 commit 63c722b

File tree

3 files changed

+12
-45
lines changed

3 files changed

+12
-45
lines changed

src/relay/backend/build_module.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ class RelayBuildModule : public runtime::ModuleNode {
300300
}
301301
});
302302
pass_seqs.push_back(transform::EliminateCommonSubexpr(fskip));
303+
pass_seqs.push_back(transform::InferType());
303304
pass_seqs.push_back(transform::SimplifyExpr());
304305
pass_seqs.push_back(transform::CombineParallelConv2D(3));
305306
pass_seqs.push_back(transform::CombineParallelDense(3));

src/relay/transforms/simplify_expr.cc

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,14 @@ class SimplifyReshape : public DFPatternRewrite {
7676
};
7777

7878
/*!
79-
* \brief SimplifyCastLike matches the pattern of cast data to the same dtype.
79+
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
8080
*/
81-
class SimplifyCastLike : public DFPatternRewrite {
81+
class SimplifyCast : public DFPatternRewrite {
8282
public:
83-
SimplifyCastLike() {
83+
SimplifyCast() {
8484
data_pat_ = IsWildcard();
8585
like_pat_ = IsWildcard();
86-
pattern_ = IsOp("cast_like")({data_pat_, like_pat_});
86+
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
8787
}
8888

8989
Expr Callback(const Expr& pre, const Expr& post,
@@ -102,32 +102,6 @@ class SimplifyCastLike : public DFPatternRewrite {
102102
DFPattern like_pat_;
103103
};
104104

105-
/*!
106-
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
107-
*/
108-
class SimplifyCast : public DFPatternRewrite {
109-
public:
110-
SimplifyCast() {
111-
data_pat_ = IsWildcard();
112-
pattern_ = IsOp("cast")({data_pat_});
113-
}
114-
115-
Expr Callback(const Expr& pre, const Expr& post,
116-
const Map<DFPattern, Array<Expr>>& node_map) const override {
117-
const CallNode* call = pre.as<CallNode>();
118-
auto attrs = call->attrs.as<CastAttrs>();
119-
auto data = node_map[data_pat_][0];
120-
const TensorTypeNode* data_ty = data->checked_type().as<TensorTypeNode>();
121-
if (attrs->dtype == data_ty->dtype) {
122-
return data;
123-
}
124-
return post;
125-
}
126-
127-
protected:
128-
DFPattern data_pat_;
129-
};
130-
131105
/*!
132106
* \brief SimplifyTranspose matches the pattern of consecutive transpose op,
133107
* and merges or cancels them.
@@ -510,7 +484,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
510484
composer.AddRewrite<EliminateIdentityRewrite>();
511485
composer.AddRewrite<SimplifyReshape>();
512486
composer.AddRewrite<SimplifyTranspose>();
513-
composer.AddRewrite<SimplifyCastLike>();
514487
composer.AddRewrite<SimplifyCast>();
515488
composer.AddRewrite<FullElementwise>();
516489
return RewritePatterns(composer.MakeCallbacks(), expr, mod);

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,25 +236,18 @@ def check(x, y=None, do_nothing=False):
236236
check(id_op(const, x), id_op(op_like(x), x))
237237

238238

239-
def test_simplify_cast_like():
240-
dtype = "int32"
241-
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
242-
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
243-
expr = relay.cast_like(data, dtype_like)
244-
245-
expected = run_infer_type(data)
246-
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
247-
assert tvm.ir.structural_equal(actual, expected)
248-
249-
250239
def test_simplify_cast():
251240
dtype = "int32"
252241
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
253-
expr = relay.cast(data, dtype)
242+
expr1 = relay.cast(data, dtype)
243+
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
244+
expr2 = relay.cast_like(data, dtype_like)
254245

255246
expected = run_infer_type(data)
256-
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
257-
assert tvm.ir.structural_equal(actual, expected)
247+
actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr())
248+
assert tvm.ir.structural_equal(actual1, expected)
249+
actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
250+
assert tvm.ir.structural_equal(actual2, expected)
258251

259252

260253
def test_concretize_reshape_like():

0 commit comments

Comments
 (0)