1919#include < tvm/relax/expr.h>
2020
2121namespace tvm {
22+
23+ RelayExpr RelayExprNode::shape () const {
24+ if (this ->shape_ .defined ()) {
25+ return Downcast<RelayExpr>(this ->shape_ );
26+ }
27+ static const Op& op = Op::Get (" relax.shape_of" );
28+ RelayExpr self = GetRef<RelayExpr>(this );
29+ return relay::Call (op, {self}, {}, {});
30+ }
31+
32+ TVM_REGISTER_GLOBAL (" ir.RelayExprShape" )
33+ .set_body_typed([](RelayExpr expr) {
34+ return expr->shape ();
35+ });
36+
2237namespace relax {
2338
2439using tvm::runtime::Optional;
2540
26-
2741TVM_REGISTER_NODE_TYPE (ShapeExprNode);
2842
2943ShapeExpr::ShapeExpr (Array<PrimExpr> values) {
@@ -41,7 +55,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeExpr")
4155TVM_REGISTER_NODE_TYPE (VarNode);
4256
4357Var::Var (Id vid,
44- Optional<Array<PrimExpr> > shape_annotation,
58+ Optional<Expr > shape_annotation,
4559 Optional<Type> type_annotation,
4660 Span span) {
4761 ObjectPtr<VarNode> n = make_object<VarNode>();
@@ -54,7 +68,7 @@ Var::Var(Id vid,
5468
5569TVM_REGISTER_GLOBAL (" relax.Var" )
5670.set_body_typed([](String name_hint,
57- Optional<Array<PrimExpr> > shape_annotation,
71+ Optional<Expr > shape_annotation,
5872 Optional<Type> type_annotation) {
5973 return Var (name_hint, shape_annotation, type_annotation);
6074});
@@ -64,7 +78,7 @@ TVM_REGISTER_NODE_TYPE(DataflowVarNode);
6478
6579TVM_REGISTER_GLOBAL (" relax.DataflowVar" )
6680.set_body_typed([](String name_hint,
67- Optional<Array<PrimExpr> > shape_annotation,
81+ Optional<Expr > shape_annotation,
6882 Optional<Type> type_annotation) {
6983 return DataflowVar (name_hint, shape_annotation, type_annotation);
7084});
0 commit comments