Skip to content

Commit 6c3e79c

Browse files
altanhelectriclilies
authored andcommitted
Relax IR Parser (apache#6)
* Copy jared's frontend * Remove some extraneous code + add TODOs * Skeleton AST * Added more skeleton AST, worked on parsing shape annotations. Something is wrong with span_to_span * Fix spans * Type annotations parsing correctly * some match_shape support * More bug fixes! Some stuff parses. Importing into tests is messed up. We probably need to restructure this code as well. * refactor parser and fill out more stubs * some parser tests * yolo dataflow * checkpoint for rebase * hook up AST * add inline TIR parsing * some cleanup * support call_packed parsing to ExternFunc call * remove stub ops * improve docstrings * address nits * support coercing tuples to ShapeExpr when possible for call_dps Co-authored-by: electriclilies <[email protected]>
1 parent 6bac561 commit 6c3e79c

File tree

17 files changed

+1535
-228
lines changed

17 files changed

+1535
-228
lines changed

include/tvm/relax/expr.h

Lines changed: 46 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,8 @@ namespace relax {
3333

3434
using Expr = RelayExpr;
3535
using ExprNode = RelayExprNode;
36-
using relay::Id;
3736
using relay::Call;
37+
using relay::Id;
3838
using relay::Tuple;
3939
using relay::TupleGetItem;
4040

@@ -53,8 +53,7 @@ class ShapeExprNode : public ExprNode {
5353
}
5454

5555
bool SEqualReduce(const ShapeExprNode* other, SEqualReducer equal) const {
56-
return equal(values, other->values) &&
57-
equal(checked_type_, other->checked_type_) &&
56+
return equal(values, other->values) && equal(checked_type_, other->checked_type_) &&
5857
equal(shape_, other->shape_);
5958
}
6059

@@ -72,15 +71,15 @@ class ShapeExprNode : public ExprNode {
7271

7372
class ShapeExpr : public Expr {
7473
public:
75-
TVM_DLL ShapeExpr(Array<PrimExpr> values);
74+
TVM_DLL explicit ShapeExpr(Array<PrimExpr> values, Span span = Span());
7675
TVM_DEFINE_OBJECT_REF_METHODS(ShapeExpr, Expr, ShapeExprNode);
7776
};
7877

79-
8078
/*! \brief The variable class for all Relax bindings. */
8179
class VarNode : public ExprNode {
8280
public:
83-
/*! \brief The identifier of the variable, is used for comparing stable equality across transformations. */
81+
/*! \brief The identifier of the variable, which is used for comparing stable equality across
82+
* transformations. */
8483
Id vid;
8584
/*! \brief The type annotation, used by binding sites and parameter declarations. */
8685
runtime::Optional<Type> type_annotation;
@@ -97,11 +96,9 @@ class VarNode : public ExprNode {
9796
}
9897

9998
bool SEqualReduce(const VarNode* other, SEqualReducer equal) const {
100-
return equal(vid, other->vid) &&
101-
equal(type_annotation, other->type_annotation) &&
99+
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
102100
// Do we use the analysis information in equality?
103-
equal(checked_type_, other->checked_type_) &&
104-
equal(shape_, other->shape_);
101+
equal(checked_type_, other->checked_type_) && equal(shape_, other->shape_);
105102
}
106103

107104
void SHashReduce(SHashReducer hash_reduce) const {
@@ -120,16 +117,12 @@ class VarNode : public ExprNode {
120117

121118
class Var : public Expr {
122119
public:
123-
TVM_DLL Var(String name_hint,
124-
runtime::Optional<Expr> shape_annotation,
125-
runtime::Optional<Type> type_annotation,
126-
Span span = Span())
127-
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}
128-
129-
TVM_DLL Var(Id vid,
130-
runtime::Optional<Expr> shape_annotation,
131-
runtime::Optional<Type> type_annotation,
132-
Span span = Span());
120+
TVM_DLL explicit Var(String name_hint, runtime::Optional<Expr> shape_annotation,
121+
runtime::Optional<Type> type_annotation, Span span = Span())
122+
: Var(Id(name_hint), shape_annotation, type_annotation, span) {}
123+
124+
TVM_DLL explicit Var(Id vid, runtime::Optional<Expr> shape_annotation,
125+
runtime::Optional<Type> type_annotation, Span span = Span());
133126
TVM_DEFINE_OBJECT_REF_METHODS(Var, Expr, VarNode);
134127
};
135128

@@ -147,10 +140,8 @@ class DataflowVarNode : public VarNode {
147140
}
148141

149142
bool SEqualReduce(const DataflowVarNode* other, SEqualReducer equal) const {
150-
return equal(vid, other->vid) &&
151-
equal(type_annotation, other->type_annotation) &&
152-
equal(shape_, other->shape_) &&
153-
equal(checked_type_, other->checked_type_);
143+
return equal(vid, other->vid) && equal(type_annotation, other->type_annotation) &&
144+
equal(shape_, other->shape_) && equal(checked_type_, other->checked_type_);
154145
}
155146

156147
void SHashReduce(SHashReducer hash_reduce) const {
@@ -168,15 +159,22 @@ class DataflowVarNode : public VarNode {
168159

169160
class DataflowVar : public Var {
170161
public:
171-
using Var::Var; // inherit constructors from Var
162+
TVM_DLL explicit DataflowVar(String name_hint, runtime::Optional<Expr> shape_annotation,
163+
runtime::Optional<Type> type_annotation, Span span = Span())
164+
: DataflowVar(Id(name_hint), shape_annotation, type_annotation, span) {}
165+
166+
TVM_DLL explicit DataflowVar(Id vid, runtime::Optional<Expr> shape_annotation,
167+
runtime::Optional<Type> type_annotation, Span span = Span());
168+
172169
TVM_DEFINE_OBJECT_REF_METHODS(DataflowVar, Var, DataflowVarNode);
173170
};
174171

175-
176172
/*! \brief The base class of a variable binding in Relax. */
177173
class BindingNode : public Object {
178174
public:
179-
void VisitAttrs(AttrVisitor* v) {}
175+
mutable Span span;
176+
177+
void VisitAttrs(AttrVisitor* v) { v->Visit("span", &span); }
180178
bool SEqualReduce(const BindingNode* other, SEqualReducer equal) const { return true; }
181179
void SHashReduce(SHashReducer hash_reduce) const {}
182180

@@ -188,10 +186,10 @@ class BindingNode : public Object {
188186

189187
class Binding : public ObjectRef {
190188
public:
189+
TVM_DLL explicit Binding(Span span);
191190
TVM_DEFINE_OBJECT_REF_METHODS(Binding, ObjectRef, BindingNode);
192191
};
193192

194-
195193
/*! \brief Symbolic shape match, binds the variables of the LHS with the rhs. */
196194
class MatchShape;
197195
class MatchShapeNode : public BindingNode {
@@ -202,6 +200,7 @@ class MatchShapeNode : public BindingNode {
202200
void VisitAttrs(AttrVisitor* v) {
203201
v->Visit("pattern", &pattern);
204202
v->Visit("value", &value);
203+
v->Visit("span", &span);
205204
}
206205

207206
bool SEqualReduce(const MatchShapeNode* other, SEqualReducer equal) const {
@@ -221,7 +220,7 @@ class MatchShapeNode : public BindingNode {
221220

222221
class MatchShape : public Binding {
223222
public:
224-
TVM_DLL MatchShape(Array<PrimExpr> pattern, Expr value);
223+
TVM_DLL explicit MatchShape(Array<PrimExpr> pattern, Expr value, Span span = Span());
225224
TVM_DEFINE_OBJECT_REF_METHODS(MatchShape, Binding, MatchShapeNode);
226225
};
227226

@@ -234,6 +233,7 @@ class VarBindingNode : public BindingNode {
234233
void VisitAttrs(AttrVisitor* v) {
235234
v->Visit("var", &var);
236235
v->Visit("value", &value);
236+
v->Visit("span", &span);
237237
}
238238

239239
bool SEqualReduce(const VarBindingNode* other, SEqualReducer equal) const {
@@ -251,23 +251,28 @@ class VarBindingNode : public BindingNode {
251251

252252
class VarBinding : public Binding {
253253
public:
254-
TVM_DLL VarBinding(Var var, Expr value);
254+
TVM_DLL explicit VarBinding(Var var, Expr value, Span span = Span());
255255
TVM_DEFINE_OBJECT_REF_METHODS(VarBinding, Binding, VarBindingNode);
256256
};
257257

258-
259258
class BindingBlock;
260259

261260
class BindingBlockNode : public Object {
262261
public:
262+
mutable Span span;
263263
Array<Binding> bindings;
264+
264265
void VisitAttrs(AttrVisitor* v) {
266+
v->Visit("span", &span);
265267
v->Visit("bindings", &bindings);
266268
}
269+
267270
bool SEqualReduce(const BindingBlockNode* other, SEqualReducer equal) const {
268271
return equal(bindings, other->bindings);
269272
}
273+
270274
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
275+
271276
static constexpr const char* _type_key = "relax.expr.BindingBlock";
272277
static constexpr const bool _type_has_method_sequal_reduce = true;
273278
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -276,21 +281,17 @@ class BindingBlockNode : public Object {
276281

277282
class BindingBlock : public ObjectRef {
278283
public:
279-
TVM_DLL BindingBlock(Array<Binding> bindings);
284+
TVM_DLL explicit BindingBlock(Array<Binding> bindings, Span span = Span());
280285
TVM_DEFINE_OBJECT_REF_METHODS(BindingBlock, ObjectRef, BindingBlockNode);
281286
};
282287

283-
284288
class DataflowBlock;
285289
class DataflowBlockNode : public BindingBlockNode {
286290
public:
287-
void VisitAttrs(AttrVisitor* v) {
288-
v->Visit("bindings", &bindings);
289-
}
290291
bool SEqualReduce(const DataflowBlockNode* other, SEqualReducer equal) const {
291292
return equal(bindings, other->bindings);
292293
}
293-
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(bindings); }
294+
294295
static constexpr const char* _type_key = "relax.expr.DataflowBlock";
295296
static constexpr const bool _type_has_method_sequal_reduce = true;
296297
static constexpr const bool _type_has_method_shash_reduce = true;
@@ -299,7 +300,7 @@ class DataflowBlockNode : public BindingBlockNode {
299300

300301
class DataflowBlock : public BindingBlock {
301302
public:
302-
TVM_DLL DataflowBlock(Array<Binding> bindings);
303+
TVM_DLL explicit DataflowBlock(Array<Binding> bindings, Span span = Span());
303304
TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlock, BindingBlock, DataflowBlockNode);
304305
};
305306

@@ -340,11 +341,10 @@ class SeqExprNode : public ExprNode {
340341

341342
class SeqExpr : public Expr {
342343
public:
343-
TVM_DLL SeqExpr(Array<BindingBlock> blocks, Expr body);
344+
TVM_DLL explicit SeqExpr(Array<BindingBlock> blocks, Expr body, Span span = Span());
344345
TVM_DEFINE_OBJECT_REF_METHODS(SeqExpr, Expr, SeqExprNode);
345346
};
346347

347-
348348
/*! \brief A Relax function, eventually to replace the current Relay function definition. */
349349
class FunctionNode : public BaseFuncNode {
350350
public:
@@ -372,8 +372,7 @@ class FunctionNode : public BaseFuncNode {
372372

373373
bool SEqualReduce(const FunctionNode* other, SEqualReducer equal) const {
374374
equal->MarkGraphNode();
375-
return equal.DefEqual(params, other->params) &&
376-
equal(body, other->body) &&
375+
return equal.DefEqual(params, other->params) && equal(body, other->body) &&
377376
equal(ret_type, other->ret_type) && equal(checked_type_, other->checked_type_) &&
378377
equal(shape_, other->shape_);
379378
}
@@ -396,12 +395,11 @@ class FunctionNode : public BaseFuncNode {
396395

397396
class Function : public Expr {
398397
public:
399-
TVM_DLL Function(runtime::Optional<GlobalVar> name, Array<Var> params,
400-
Expr body, Type ret_type);
398+
TVM_DLL explicit Function(runtime::Optional<GlobalVar> name, Array<Var> params, Expr body,
399+
Type ret_type, Span span = Span());
401400
TVM_DEFINE_OBJECT_REF_METHODS(Function, Expr, FunctionNode);
402401
};
403402

404-
405403
/*! \brief The extern function, which can represent packed function. */
406404
class ExternFuncNode : public BaseFuncNode {
407405
public:
@@ -410,15 +408,14 @@ class ExternFuncNode : public BaseFuncNode {
410408

411409
void VisitAttrs(AttrVisitor* v) {
412410
v->Visit("global_symbol", &global_symbol);
411+
v->Visit("span", &span);
413412
}
414413

415414
bool SEqualReduce(const ExternFuncNode* other, SEqualReducer equal) const {
416415
return equal(global_symbol, other->global_symbol);
417416
}
418417

419-
void SHashReduce(SHashReducer hash_reduce) const {
420-
hash_reduce(global_symbol);
421-
}
418+
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(global_symbol); }
422419

423420
static constexpr const char* _type_key = "relax.expr.ExternFunc";
424421
static constexpr const bool _type_has_method_sequal_reduce = true;
@@ -428,7 +425,7 @@ class ExternFuncNode : public BaseFuncNode {
428425

429426
class ExternFunc : public Expr {
430427
public:
431-
TVM_DLL ExternFunc(String global_symbol);
428+
TVM_DLL ExternFunc(String global_symbol, Span span = Span());
432429
TVM_DEFINE_OBJECT_REF_METHODS(ExternFunc, Expr, ExternFuncNode);
433430
};
434431

include/tvm/relax/type.h

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ namespace relax {
3939

4040
class ShapeTypeNode : public TypeNode {
4141
public:
42-
void VisitAttrs(tvm::AttrVisitor* v) {}
42+
43+
void VisitAttrs(tvm::AttrVisitor* v) {
44+
v->Visit("span", &span);
45+
}
4346

4447
bool SEqualReduce(const ShapeTypeNode* other, SEqualReducer equal) const {
4548
return true;
@@ -53,16 +56,9 @@ class ShapeTypeNode : public TypeNode {
5356

5457
class ShapeType : public Type {
5558
public:
56-
explicit ShapeType();
57-
explicit ShapeType(runtime::ObjectPtr<runtime::Object> n) : Type(n) {}
58-
TVM_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(ShapeType);
59-
const ShapeTypeNode* operator->() const {
60-
return static_cast<const ShapeTypeNode*>(data_.get());
61-
}
62-
const ShapeTypeNode* get() const {
63-
return operator->();
64-
}
65-
using ContainerType = ShapeTypeNode;
59+
TVM_DLL ShapeType(Span span = Span());
60+
61+
TVM_DEFINE_OBJECT_REF_METHODS(ShapeType, Type, ShapeTypeNode);
6662
};
6763

6864
class DynTensorTypeNode : public BaseTensorTypeNode {
@@ -108,11 +104,34 @@ class DynTensorType : public Type {
108104
* \param shape The shape of the tensor.
109105
* \param dtype The runtime dtype of the tensor's elements.
110106
*/
111-
TVM_DLL DynTensorType(int rank, DataType dtype);
107+
TVM_DLL DynTensorType(int rank, DataType dtype, Span span = Span());
112108

113109
TVM_DEFINE_OBJECT_REF_METHODS(DynTensorType, Type, DynTensorTypeNode);
114110
};
115111

112+
class DimTypeNode : public TypeNode {
113+
public:
114+
void VisitAttrs(tvm::AttrVisitor* v) {
115+
v->Visit("span", &span);
116+
}
117+
118+
bool SEqualReduce(const DimTypeNode* other, SEqualReducer equal) const {
119+
return true;
120+
}
121+
122+
void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(0); }
123+
124+
static constexpr const char* _type_key = "relax.DimType";
125+
TVM_DECLARE_FINAL_OBJECT_INFO(DimTypeNode, TypeNode);
126+
};
127+
128+
class DimType : public Type {
129+
public:
130+
TVM_DLL DimType(Span span = Span());
131+
132+
TVM_DEFINE_OBJECT_REF_METHODS(DimType, Type, DimTypeNode);
133+
};
134+
116135
} // namespace relax
117136
} // namespace tvm
118137
#endif // TVM_RELAX_TYPE_H_

include/tvm/relay/expr.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ class TupleNode : public ExprNode {
131131
v->Visit("virtual_device_", &virtual_device_);
132132
v->Visit("span", &span);
133133
v->Visit("_checked_type_", &checked_type_);
134+
v->Visit("shape_", &shape_);
134135
}
135136

136137
bool SEqualReduce(const TupleNode* other, SEqualReducer equal) const {

python/tvm/ir/type.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import tvm
2020
import tvm._ffi
2121

22+
from . import Span
2223
from .base import Node
2324
from . import _ffi_api
2425

@@ -166,8 +167,8 @@ class TupleType(Type):
166167
The fields in the tuple
167168
"""
168169

169-
def __init__(self, fields):
170-
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields)
170+
def __init__(self, fields, span: Span = None):
171+
self.__init_handle_by_constructor__(_ffi_api.TupleType, fields, span)
171172

172173

173174
@tvm._ffi.register_object("TypeConstraint")

0 commit comments

Comments
 (0)