2929#include < tvm/support/logging.h>
3030
3131#include " ../op/tensor/transform.h"
32+ #include " pattern_utils.h"
3233
3334namespace tvm {
3435namespace relay {
3536
37+ class SimplifyPattern {
38+ public:
39+ virtual Expr callback (const Expr& pre , const Expr& post ,
40+ const Map<DFPattern, Array<Expr>>& node_map) const = 0;
41+
42+ DFPattern pattern () const { return pattern_; }
43+
44+ protected:
45+ /* ! \brief Pattern for rewriting */
46+ DFPattern pattern_;
47+ };
48+
3649/* !
3750 * \brief SimplifyReshape matches the pattern of consecutive reshape or reverse_reshape ops,
3851 * and merges into one reshape op.
3952 */
40- class SimplifyReshape {
53+ class SimplifyReshape : public SimplifyPattern {
4154 public:
4255 SimplifyReshape () {
43- x_ = WildcardPattern (make_object<WildcardPatternNode>() );
56+ x_ = IsWildcard ( );
4457 auto reshape1 = IsOp (" reshape" ) || IsOp (" contrib_reverse_reshape" );
4558 auto reshape2 = IsOp (" reshape" ) || IsOp (" contrib_reverse_reshape" );
4659 pattern_ = reshape1 ({reshape2 ({x_})});
4760 }
4861
49- Expr callback (const Expr& pre , const Expr& post , const Map<DFPattern, Array<Expr>>& node_map) {
62+ Expr callback (const Expr& pre , const Expr& post ,
63+ const Map<DFPattern, Array<Expr>>& node_map) const override {
5064 auto x = node_map[x_][0 ];
5165 bool const_shape = true ;
5266 Array<Integer> newshape;
@@ -63,13 +77,82 @@ class SimplifyReshape {
6377 return post ;
6478 }
6579
66- DFPattern pattern () const { return pattern_; }
67-
6880 private:
6981 /* ! \brief Pattern input */
7082 DFPattern x_;
71- /* ! \brief Pattern for consecutive reshape or reverse_reshape ops */
72- DFPattern pattern_;
83+ };
84+
85+ /* !
86+ * \brief FullArgwhere finds full followed by argwhere and turns it into an Arange op
87+ */
88+ class FullElementwise : public SimplifyPattern {
89+ public:
90+ FullElementwise () {
91+ x_ = IsWildcard ();
92+ data_ = IsWildcard ();
93+ value_ = IsConstant ();
94+
95+ full_ = IsOp (" full" )({value_}) || IsOp (" full_like" )({data_, value_});
96+ ones_ = IsOp (" ones" )({}) || IsOp (" ones_like" )({data_});
97+ zeros_ = IsOp (" zeros" )({}) || IsOp (" zeros_like" )({data_});
98+
99+ Map<String, ObjectRef> attrs;
100+ attrs.Set (" TOpPattern" , Integer (static_cast <int >(kBroadcast )));
101+ DFPattern op = IsWildcard ().HasAttr (attrs);
102+ DFPattern full = full_ || ones_ || zeros_;
103+ pattern_ = op ({full, x_}) || op ({x_, full});
104+ }
105+
106+ Expr callback (const Expr& pre , const Expr& post ,
107+ const Map<DFPattern, Array<Expr>>& node_map) const override {
108+ const CallNode* call = pre .as <CallNode>();
109+ ICHECK (call);
110+ Type pre_type = pre ->checked_type_ ;
111+ ICHECK (pre_type.as <TensorTypeNode>());
112+ auto dtype = pre_type.as <TensorTypeNode>()->dtype ;
113+ auto x = node_map[x_][0 ];
114+ bool is_left = post .as <CallNode>()->args [1 ] == x;
115+ Type x_type;
116+ if (is_left) {
117+ x_type = call->args [1 ]->checked_type_ ;
118+ } else {
119+ x_type = call->args [0 ]->checked_type_ ;
120+ }
121+
122+ if (StructuralEqual ()(x_type, pre_type)) {
123+ Expr value;
124+ if (node_map.count (full_)) {
125+ value = node_map[value_][0 ];
126+ ICHECK (IsConstScalar (value));
127+ } else if (node_map.count (ones_)) {
128+ value = MakeConstantScalar (dtype, 1 );
129+ } else if (node_map.count (zeros_)) {
130+ value = MakeConstantScalar (dtype, 0 );
131+ } else {
132+ ICHECK (false ) << " Didn't find a full op while matching full + elementwise" ;
133+ }
134+ if (is_left) {
135+ return Call (call->op , {value, x}, call->attrs , call->type_args , call->span );
136+ } else {
137+ return Call (call->op , {x, value}, call->attrs , call->type_args , call->span );
138+ }
139+ }
140+ return post ;
141+ }
142+
143+ private:
144+ /* ! \brief binary argument */
145+ DFPattern x_;
146+ /* ! \brief data ops get shape from */
147+ DFPattern data_;
148+ /* ! \brief constant input */
149+ DFPattern value_;
150+ /* ! \brief full op */
151+ DFPattern full_;
152+ /* ! \brief ones op */
153+ DFPattern ones_;
154+ /* ! \brief zeros op */
155+ DFPattern zeros_;
73156};
74157
75158/* !
@@ -78,22 +161,24 @@ class SimplifyReshape {
78161class ExprSimplifier {
79162 public:
80163 explicit ExprSimplifier (IRModule mod) : mod_(mod) {
81- auto reshape_func = [this ](TVMArgs args, TVMRetValue* rv) {
164+ CreateCallback (SimplifyReshape ());
165+ CreateCallback (FullElementwise ());
166+ }
167+ template <typename T>
168+ void CreateCallback (const T& pattern) {
169+ auto func = [pattern](TVMArgs args, TVMRetValue* rv) {
82170 Expr pre = args[0 ];
83171 Expr post = args[1 ];
84172 Map<DFPattern, Array<Expr>> node_map = args[2 ];
85- *rv = simplify_reshape_ .callback (pre , post , node_map);
173+ *rv = pattern .callback (pre , post , node_map);
86174 };
87- callbacks_.push_back (
88- DFPatternCallback (simplify_reshape_.pattern (), PackedFunc (reshape_func), true ));
175+ callbacks_.push_back (DFPatternCallback (pattern.pattern (), PackedFunc (func), true ));
89176 }
90177
91178 Expr Simplify (const Expr& expr) { return RewritePatterns (callbacks_, expr, mod_); }
92179
93180 private:
94181 IRModule mod_;
95- /* ! \brief Simplify reshape pattern */
96- SimplifyReshape simplify_reshape_;
97182 /* ! \brief Callbacks for expr simplification */
98183 Array<DFPatternCallback> callbacks_;
99184};
0 commit comments