Skip to content

Commit 0cb658c

Browse files
committed
try fix
1 parent 9132cb6 commit 0cb658c

File tree

2 files changed

+14
-48
lines changed

2 files changed

+14
-48
lines changed

src/relay/transforms/simplify_expr.cc

Lines changed: 7 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -75,57 +75,31 @@ class SimplifyReshape : public DFPatternRewrite {
7575
DFPattern x_;
7676
};
7777

78-
/*!
79-
* \brief SimplifyCastLike matches the pattern of cast data to the same dtype.
80-
*/
81-
class SimplifyCastLike : public DFPatternRewrite {
82-
public:
83-
SimplifyCastLike() {
84-
data_pat_ = IsWildcard();
85-
like_pat_ = IsWildcard();
86-
pattern_ = IsOp("cast_like")({data_pat_, like_pat_});
87-
}
88-
89-
Expr Callback(const Expr& pre, const Expr& post,
90-
const Map<DFPattern, Array<Expr>>& node_map) const override {
91-
auto data = node_map[data_pat_][0];
92-
const TensorTypeNode* data_ty = data->checked_type().as<TensorTypeNode>();
93-
const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
94-
if (like_ty->dtype == data_ty->dtype) {
95-
return data;
96-
}
97-
return post;
98-
}
99-
100-
protected:
101-
DFPattern data_pat_;
102-
DFPattern like_pat_;
103-
};
104-
10578
/*!
10679
* \brief SimplifyCast matches the pattern of cast data to the same dtype.
10780
*/
10881
class SimplifyCast : public DFPatternRewrite {
10982
public:
11083
SimplifyCast() {
11184
data_pat_ = IsWildcard();
112-
pattern_ = IsOp("cast")({data_pat_});
85+
like_pat_ = IsWildcard();
86+
pattern_ = IsOp("cast_like")({data_pat_, like_pat_}) || IsOp("cast")({data_pat_});
11387
}
11488

11589
Expr Callback(const Expr& pre, const Expr& post,
11690
const Map<DFPattern, Array<Expr>>& node_map) const override {
11791
const CallNode* call = pre.as<CallNode>();
118-
auto attrs = call->attrs.as<CastAttrs>();
119-
auto data = node_map[data_pat_][0];
120-
const TensorTypeNode* data_ty = data->checked_type().as<TensorTypeNode>();
121-
if (attrs->dtype == data_ty->dtype) {
122-
return data;
92+
const TensorTypeNode* data_ty = call->args[0]->checked_type().as<TensorTypeNode>();
93+
const TensorTypeNode* like_ty = pre->checked_type().as<TensorTypeNode>();
94+
if (like_ty->dtype == data_ty->dtype) {
95+
return node_map[data_pat_][0];
12396
}
12497
return post;
12598
}
12699

127100
protected:
128101
DFPattern data_pat_;
102+
DFPattern like_pat_;
129103
};
130104

131105
/*!
@@ -510,7 +484,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
510484
composer.AddRewrite<EliminateIdentityRewrite>();
511485
composer.AddRewrite<SimplifyReshape>();
512486
composer.AddRewrite<SimplifyTranspose>();
513-
composer.AddRewrite<SimplifyCastLike>();
514487
composer.AddRewrite<SimplifyCast>();
515488
composer.AddRewrite<FullElementwise>();
516489
return RewritePatterns(composer.MakeCallbacks(), expr, mod);

tests/python/relay/test_pass_simplify_expr.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -236,25 +236,18 @@ def check(x, y=None, do_nothing=False):
236236
check(id_op(const, x), id_op(op_like(x), x))
237237

238238

239-
def test_simplify_cast_like():
240-
dtype = "int32"
241-
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
242-
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
243-
expr = relay.cast_like(data, dtype_like)
244-
245-
expected = run_infer_type(data)
246-
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
247-
assert tvm.ir.structural_equal(actual, expected)
248-
249-
250239
def test_simplify_cast():
251240
dtype = "int32"
252241
data = relay.var("data", shape=(3, 4, 5), dtype=dtype)
253-
expr = relay.cast(data, dtype)
242+
expr1 = relay.cast(data, dtype)
243+
dtype_like = relay.var("dtype_like", shape=(2, 2, 2), dtype=dtype)
244+
expr2 = relay.cast_like(data, dtype_like)
254245

255246
expected = run_infer_type(data)
256-
actual = run_opt_pass(expr, relay.transform.SimplifyExpr())
257-
assert tvm.ir.structural_equal(actual, expected)
247+
actual1 = run_opt_pass(expr1, relay.transform.SimplifyExpr())
248+
assert tvm.ir.structural_equal(actual1, expected)
249+
actual2 = run_opt_pass(expr2, relay.transform.SimplifyExpr())
250+
assert tvm.ir.structural_equal(actual2, expected)
258251

259252

260253
def test_concretize_reshape_like():

0 commit comments

Comments
 (0)