@@ -76,14 +76,14 @@ class SimplifyReshape : public DFPatternRewrite {
7676};
7777
7878/* !
79- * \brief SimplifyCastLike matches the pattern of cast data to the same dtype.
79+ * \brief SimplifyCast matches the pattern of cast data to the same dtype.
8080 */
81- class SimplifyCastLike : public DFPatternRewrite {
81+ class SimplifyCast : public DFPatternRewrite {
8282 public:
83- SimplifyCastLike () {
83+ SimplifyCast () {
8484 data_pat_ = IsWildcard ();
8585 like_pat_ = IsWildcard ();
86- pattern_ = IsOp (" cast_like" )({data_pat_, like_pat_});
86+ pattern_ = IsOp (" cast_like" )({data_pat_, like_pat_}) || IsOp ( " cast " )({data_pat_}) ;
8787 }
8888
8989 Expr Callback (const Expr& pre , const Expr& post ,
@@ -102,32 +102,6 @@ class SimplifyCastLike : public DFPatternRewrite {
102102 DFPattern like_pat_;
103103};
104104
105- /* !
106- * \brief SimplifyCast matches the pattern of cast data to the same dtype.
107- */
108- class SimplifyCast : public DFPatternRewrite {
109- public:
110- SimplifyCast () {
111- data_pat_ = IsWildcard ();
112- pattern_ = IsOp (" cast" )({data_pat_});
113- }
114-
115- Expr Callback (const Expr& pre , const Expr& post ,
116- const Map<DFPattern, Array<Expr>>& node_map) const override {
117- const CallNode* call = pre .as <CallNode>();
118- auto attrs = call->attrs .as <CastAttrs>();
119- auto data = node_map[data_pat_][0 ];
120- const TensorTypeNode* data_ty = data->checked_type ().as <TensorTypeNode>();
121- if (attrs->dtype == data_ty->dtype ) {
122- return data;
123- }
124- return post ;
125- }
126-
127- protected:
128- DFPattern data_pat_;
129- };
130-
131105/* !
132106 * \brief SimplifyTranspose matches the pattern of consecutive transpose op,
133107 * and merges or cancels them.
@@ -510,7 +484,6 @@ Expr SimplifyExpr(const Expr& expr, const IRModule& mod) {
510484 composer.AddRewrite <EliminateIdentityRewrite>();
511485 composer.AddRewrite <SimplifyReshape>();
512486 composer.AddRewrite <SimplifyTranspose>();
513- composer.AddRewrite <SimplifyCastLike>();
514487 composer.AddRewrite <SimplifyCast>();
515488 composer.AddRewrite <FullElementwise>();
516489 return RewritePatterns (composer.MakeCallbacks (), expr, mod);
0 commit comments