@@ -75,57 +75,31 @@ class SimplifyReshape : public DFPatternRewrite {
7575 DFPattern x_;
7676};
7777
78- /* !
79- * \brief SimplifyCastLike matches the pattern of cast data to the same dtype.
80- */
81- class SimplifyCastLike : public DFPatternRewrite {
82- public:
83- SimplifyCastLike () {
84- data_pat_ = IsWildcard ();
85- like_pat_ = IsWildcard ();
86- pattern_ = IsOp (" cast_like" )({data_pat_, like_pat_});
87- }
88-
89- Expr Callback (const Expr& pre , const Expr& post ,
90- const Map<DFPattern, Array<Expr>>& node_map) const override {
91- auto data = node_map[data_pat_][0 ];
92- const TensorTypeNode* data_ty = data->checked_type ().as <TensorTypeNode>();
93- const TensorTypeNode* like_ty = pre ->checked_type ().as <TensorTypeNode>();
94- if (like_ty->dtype == data_ty->dtype ) {
95- return data;
96- }
97- return post ;
98- }
99-
100- protected:
101- DFPattern data_pat_;
102- DFPattern like_pat_;
103- };
104-
10578/* !
10679 * \brief SimplifyCast matches the pattern of cast data to the same dtype.
10780 */
10881class SimplifyCast : public DFPatternRewrite {
10982 public:
11083 SimplifyCast () {
11184 data_pat_ = IsWildcard ();
112- pattern_ = IsOp (" cast" )({data_pat_});
85+ like_pat_ = IsWildcard ();
86+ pattern_ = IsOp (" cast_like" )({data_pat_, like_pat_}) || IsOp (" cast" )({data_pat_});
11387 }
11488
11589 Expr Callback (const Expr& pre , const Expr& post ,
11690 const Map<DFPattern, Array<Expr>>& node_map) const override {
11791 const CallNode* call = pre .as <CallNode>();
118- auto attrs = call->attrs .as <CastAttrs>();
119- auto data = node_map[data_pat_][0 ];
120- const TensorTypeNode* data_ty = data->checked_type ().as <TensorTypeNode>();
121- if (attrs->dtype == data_ty->dtype ) {
122- return data;
92+ const TensorTypeNode* data_ty = call->args [0 ]->checked_type ().as <TensorTypeNode>();
93+ const TensorTypeNode* like_ty = pre ->checked_type ().as <TensorTypeNode>();
94+ if (like_ty->dtype == data_ty->dtype ) {
95+ return node_map[data_pat_][0 ];
12396 }
12497 return post ;
12598 }
12699
127100 protected:
128101 DFPattern data_pat_;
102+ DFPattern like_pat_;
129103};
130104
131105/* !
@@ -510,7 +484,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
510484 composer.AddRewrite <EliminateIdentityRewrite>();
511485 composer.AddRewrite <SimplifyReshape>();
512486 composer.AddRewrite <SimplifyTranspose>();
513- composer.AddRewrite <SimplifyCastLike>();
514487 composer.AddRewrite <SimplifyCast>();
515488 composer.AddRewrite <FullElementwise>();
516489 return RewritePatterns (composer.MakeCallbacks (), expr, mod);
0 commit comments