@@ -33,8 +33,8 @@ namespace relax {
3333
3434using Expr = RelayExpr;
3535using ExprNode = RelayExprNode;
36- using relay::Id;
3736using relay::Call;
37+ using relay::Id;
3838using relay::Tuple;
3939using relay::TupleGetItem;
4040
@@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode {
5353 }
5454
5555 bool SEqualReduce (const ShapeExprNode* other, SEqualReducer equal) const {
56- return equal (values, other->values ) &&
57- equal (checked_type_, other->checked_type_ ) &&
56+ return equal (values, other->values ) && equal (checked_type_, other->checked_type_ ) &&
5857 equal (shape_, other->shape_ );
5958 }
6059
@@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode {
7271
7372class ShapeExpr : public Expr {
7473 public:
75- TVM_DLL ShapeExpr (Array<PrimExpr> values);
74+ TVM_DLL explicit ShapeExpr (Array<PrimExpr> values, Span span = Span() );
7675 TVM_DEFINE_OBJECT_REF_METHODS (ShapeExpr, Expr, ShapeExprNode);
7776};
7877
79-
8078/* ! \brief The variable class for all Relax bindings. */
8179class VarNode : public ExprNode {
8280 public:
83- /* ! \brief The identifier of the variable, is used for comparing stable equality across transformations. */
81+ /* ! \brief The identifier of the variable, which is used for comparing stable equality across
82+ * transformations. */
8483 Id vid;
8584 /* ! \brief The type annotation, used by binding sites and parameter declarations. */
8685 runtime::Optional<Type> type_annotation;
@@ -97,11 +96,9 @@ class VarNode : public ExprNode {
9796 }
9897
9998 bool SEqualReduce (const VarNode* other, SEqualReducer equal) const {
100- return equal (vid, other->vid ) &&
101- equal (type_annotation, other->type_annotation ) &&
99+ return equal (vid, other->vid ) && equal (type_annotation, other->type_annotation ) &&
102100 // Do we use the analysis information in equality?
103- equal (checked_type_, other->checked_type_ ) &&
104- equal (shape_, other->shape_ );
101+ equal (checked_type_, other->checked_type_ ) && equal (shape_, other->shape_ );
105102 }
106103
107104 void SHashReduce (SHashReducer hash_reduce) const {
@@ -120,16 +117,12 @@ class VarNode : public ExprNode {
120117
121118class Var : public Expr {
122119 public:
123- TVM_DLL Var (String name_hint,
124- runtime::Optional<Expr> shape_annotation,
125- runtime::Optional<Type> type_annotation,
126- Span span = Span())
127- : Var(Id(name_hint), shape_annotation, type_annotation, span) {}
128-
129- TVM_DLL Var (Id vid,
130- runtime::Optional<Expr> shape_annotation,
131- runtime::Optional<Type> type_annotation,
132- Span span = Span());
120+ TVM_DLL explicit Var (String name_hint, runtime::Optional<Expr> shape_annotation,
121+ runtime::Optional<Type> type_annotation, Span span = Span())
122+ : Var(Id(name_hint), shape_annotation, type_annotation, span) {}
123+
124+ TVM_DLL explicit Var (Id vid, runtime::Optional<Expr> shape_annotation,
125+ runtime::Optional<Type> type_annotation, Span span = Span());
133126 TVM_DEFINE_OBJECT_REF_METHODS (Var, Expr, VarNode);
134127};
135128
@@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode {
147140 }
148141
149142 bool SEqualReduce (const DataflowVarNode* other, SEqualReducer equal) const {
150- return equal (vid, other->vid ) &&
151- equal (type_annotation, other->type_annotation ) &&
152- equal (shape_, other->shape_ ) &&
153- equal (checked_type_, other->checked_type_ );
143+ return equal (vid, other->vid ) && equal (type_annotation, other->type_annotation ) &&
144+ equal (shape_, other->shape_ ) && equal (checked_type_, other->checked_type_ );
154145 }
155146
156147 void SHashReduce (SHashReducer hash_reduce) const {
@@ -168,15 +159,22 @@ class DataflowVarNode : public VarNode {
168159
169160class DataflowVar : public Var {
170161 public:
171- using Var::Var; // inherit constructors from Var
162+ TVM_DLL explicit DataflowVar (String name_hint, runtime::Optional<Expr> shape_annotation,
163+ runtime::Optional<Type> type_annotation, Span span = Span())
164+ : DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {}
165+
166+ TVM_DLL explicit DataflowVar (Id vid, runtime::Optional<Expr> shape_annotation,
167+ runtime::Optional<Type> type_annotation, Span span = Span());
168+
172169 TVM_DEFINE_OBJECT_REF_METHODS (DataflowVar, Var, DataflowVarNode);
173170};
174171
175-
176172/* ! \brief The base class of a variable binding in Relax. */
177173class BindingNode : public Object {
178174 public:
179- void VisitAttrs (AttrVisitor* v) {}
175+ mutable Span span;
176+
177+ void VisitAttrs (AttrVisitor* v) { v->Visit (" span" , &span); }
180178 bool SEqualReduce (const BindingNode* other, SEqualReducer equal) const { return true ; }
181179 void SHashReduce (SHashReducer hash_reduce) const {}
182180
@@ -188,10 +186,10 @@ class BindingNode : public Object {
188186
189187class Binding : public ObjectRef {
190188 public:
189+ TVM_DLL explicit Binding (Span span);
191190 TVM_DEFINE_OBJECT_REF_METHODS (Binding, ObjectRef, BindingNode);
192191};
193192
194-
195193/* ! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */
196194class MatchShape ;
197195class MatchShapeNode : public BindingNode {
@@ -202,6 +200,7 @@ class MatchShapeNode : public BindingNode {
202200 void VisitAttrs (AttrVisitor* v) {
203201 v->Visit (" pattern" , &pattern);
204202 v->Visit (" value" , &value);
203+ v->Visit (" span" , &span);
205204 }
206205
207206 bool SEqualReduce (const MatchShapeNode* other, SEqualReducer equal) const {
@@ -221,7 +220,7 @@ class MatchShapeNode : public BindingNode {
221220
222221class MatchShape : public Binding {
223222 public:
224- TVM_DLL MatchShape (Array<PrimExpr> pattern, Expr value);
223+ TVM_DLL explicit MatchShape (Array<PrimExpr> pattern, Expr value, Span span = Span() );
225224 TVM_DEFINE_OBJECT_REF_METHODS (MatchShape, Binding, MatchShapeNode);
226225};
227226
@@ -234,6 +233,7 @@ class VarBindingNode : public BindingNode {
234233 void VisitAttrs (AttrVisitor* v) {
235234 v->Visit (" var" , &var);
236235 v->Visit (" value" , &value);
236+ v->Visit (" span" , &span);
237237 }
238238
239239 bool SEqualReduce (const VarBindingNode* other, SEqualReducer equal) const {
@@ -251,23 +251,28 @@ class VarBindingNode : public BindingNode {
251251
252252class VarBinding : public Binding {
253253 public:
254- TVM_DLL VarBinding (Var var, Expr value);
254+ TVM_DLL explicit VarBinding (Var var, Expr value, Span span = Span() );
255255 TVM_DEFINE_OBJECT_REF_METHODS (VarBinding, Binding, VarBindingNode);
256256};
257257
258-
259258class BindingBlock ;
260259
261260class BindingBlockNode : public Object {
262261 public:
262+ mutable Span span;
263263 Array<Binding> bindings;
264+
264265 void VisitAttrs (AttrVisitor* v) {
266+ v->Visit (" span" , &span);
265267 v->Visit (" bindings" , &bindings);
266268 }
269+
267270 bool SEqualReduce (const BindingBlockNode* other, SEqualReducer equal) const {
268271 return equal (bindings, other->bindings );
269272 }
273+
270274 void SHashReduce (SHashReducer hash_reduce) const { hash_reduce (bindings); }
275+
271276 static constexpr const char * _type_key = " relax.expr.BindingBlock" ;
272277 static constexpr const bool _type_has_method_sequal_reduce = true ;
273278 static constexpr const bool _type_has_method_shash_reduce = true ;
@@ -276,21 +281,17 @@ class BindingBlockNode : public Object {
276281
277282class BindingBlock : public ObjectRef {
278283 public:
279- TVM_DLL BindingBlock (Array<Binding> bindings);
284+ TVM_DLL explicit BindingBlock (Array<Binding> bindings, Span span = Span() );
280285 TVM_DEFINE_OBJECT_REF_METHODS (BindingBlock, ObjectRef, BindingBlockNode);
281286};
282287
283-
284288class DataflowBlock ;
285289class DataflowBlockNode : public BindingBlockNode {
286290 public:
287- void VisitAttrs (AttrVisitor* v) {
288- v->Visit (" bindings" , &bindings);
289- }
290291 bool SEqualReduce (const DataflowBlockNode* other, SEqualReducer equal) const {
291292 return equal (bindings, other->bindings );
292293 }
293- void SHashReduce (SHashReducer hash_reduce) const { hash_reduce (bindings); }
294+
294295 static constexpr const char * _type_key = " relax.expr.DataflowBlock" ;
295296 static constexpr const bool _type_has_method_sequal_reduce = true ;
296297 static constexpr const bool _type_has_method_shash_reduce = true ;
@@ -299,7 +300,7 @@ class DataflowBlockNode : public BindingBlockNode {
299300
300301class DataflowBlock : public BindingBlock {
301302 public:
302- TVM_DLL DataflowBlock (Array<Binding> bindings);
303+ TVM_DLL explicit DataflowBlock (Array<Binding> bindings, Span span = Span() );
303304 TVM_DEFINE_OBJECT_REF_METHODS (DataflowBlock, BindingBlock, DataflowBlockNode);
304305};
305306
@@ -340,11 +341,10 @@ class SeqExprNode : public ExprNode {
340341
341342class SeqExpr : public Expr {
342343 public:
343- TVM_DLL SeqExpr (Array<BindingBlock> blocks, Expr body);
344+ TVM_DLL explicit SeqExpr (Array<BindingBlock> blocks, Expr body, Span span = Span() );
344345 TVM_DEFINE_OBJECT_REF_METHODS (SeqExpr, Expr, SeqExprNode);
345346};
346347
347-
348348/* ! \brief A Relax function, eventually to replace the current Relay function definition. */
349349class FunctionNode : public BaseFuncNode {
350350 public:
@@ -372,8 +372,7 @@ class FunctionNode : public BaseFuncNode {
372372
373373 bool SEqualReduce (const FunctionNode* other, SEqualReducer equal) const {
374374 equal->MarkGraphNode ();
375- return equal.DefEqual (params, other->params ) &&
376- equal (body, other->body ) &&
375+ return equal.DefEqual (params, other->params ) && equal (body, other->body ) &&
377376 equal (ret_type, other->ret_type ) && equal (checked_type_, other->checked_type_ ) &&
378377 equal (shape_, other->shape_ );
379378 }
@@ -396,12 +395,11 @@ class FunctionNode : public BaseFuncNode {
396395
397396class Function : public Expr {
398397 public:
399- TVM_DLL Function (runtime::Optional<GlobalVar> name, Array<Var> params,
400- Expr body, Type ret_type);
398+ TVM_DLL explicit Function (runtime::Optional<GlobalVar> name, Array<Var> params, Expr body ,
399+ Type ret_type, Span span = Span() );
401400 TVM_DEFINE_OBJECT_REF_METHODS (Function, Expr, FunctionNode);
402401};
403402
404-
405403/* ! \brief The extern function, which can represent packed function. */
406404class ExternFuncNode : public BaseFuncNode {
407405 public:
@@ -410,15 +408,14 @@ class ExternFuncNode : public BaseFuncNode {
410408
411409 void VisitAttrs (AttrVisitor* v) {
412410 v->Visit (" global_symbol" , &global_symbol);
411+ v->Visit (" span" , &span);
413412 }
414413
415414 bool SEqualReduce (const ExternFuncNode* other, SEqualReducer equal) const {
416415 return equal (global_symbol, other->global_symbol );
417416 }
418417
419- void SHashReduce (SHashReducer hash_reduce) const {
420- hash_reduce (global_symbol);
421- }
418+ void SHashReduce (SHashReducer hash_reduce) const { hash_reduce (global_symbol); }
422419
423420 static constexpr const char * _type_key = " relax.expr.ExternFunc" ;
424421 static constexpr const bool _type_has_method_sequal_reduce = true ;
@@ -428,7 +425,7 @@ class ExternFuncNode : public BaseFuncNode {
428425
429426class ExternFunc : public Expr {
430427 public:
431- TVM_DLL ExternFunc (String global_symbol);
428+ TVM_DLL ExternFunc (String global_symbol, Span span = Span() );
432429 TVM_DEFINE_OBJECT_REF_METHODS (ExternFunc, Expr, ExternFuncNode);
433430};
434431
0 commit comments