66#include < tvm/relay/expr_functor.h>
77#include < tvm/relay/op_attr_types.h>
88#include < tvm/relay/interpreter.h>
9+ #include < tvm/relay/attrs/transform.h>
910
1011namespace tvm {
1112namespace relay {
@@ -71,6 +72,7 @@ class ConstantFolder : public ExprMutator {
7172
7273 Expr VisitExpr_ (const CallNode* call) final {
7374 static auto op_stateful = Op::GetAttr<TOpIsStateful>(" TOpIsStateful" );
75+ auto origin_args = call->args ;
7476 Expr res = ExprMutator::VisitExpr_ (call);
7577 call = res.as <CallNode>();
7678 // We don't constant fold function with zero arguments.
@@ -81,6 +83,10 @@ class ConstantFolder : public ExprMutator {
8183 if (op == nullptr ) return res;
8284 // skip stateful ops.
8385 if (op_stateful.get (GetRef<Op>(op), false )) return res;
86+ // Try to evaluate shape_of op
87+ if (call->op .same_as (Op::Get (" shape_of" ))) {
88+ return EvaluateShapeOf (res, origin_args, call->attrs );
89+ }
8490 bool all_const_args = true ;
8591 for (Expr arg : call->args ) {
8692 if (!checker_.Check (arg)) {
@@ -132,6 +138,42 @@ class ConstantFolder : public ExprMutator {
132138 expr = InferType (expr, Module (nullptr ));
133139 return ValueToExpr (executor_ (expr));
134140 }
141+ // Evaluate shape_of op
142+ Expr EvaluateShapeOf (Expr expr, Array<Expr> args, Attrs attrs) {
143+ Expr input = args[0 ];
144+ const auto * param = attrs.as <ShapeOfAttrs>();
145+ CHECK (param != nullptr );
146+ tvm::Array<IndexExpr> ishape;
147+ if (const ConstantNode* op = input.as <ConstantNode>()) {
148+ ishape = op->tensor_type ()->shape ;
149+ } else if (input->checked_type_ .defined ()) {
150+ ishape = input->checked_type ().as <TensorTypeNode>()->shape ;
151+ } else {
152+ return expr;
153+ }
154+ // Get the constant shape
155+ DLContext ctx;
156+ ctx.device_type = kDLCPU ;
157+ ctx.device_id = 0 ;
158+ auto val = runtime::NDArray::Empty (
159+ {(int64_t )ishape.size ()}, Type2TVMType (Int (32 )), ctx);
160+ int32_t * dims = static_cast <int32_t *>(val->data );
161+ using ::tvm::ir::IntImm;
162+ for (size_t i = 0 ; i < ishape.size (); ++i) {
163+ if (const IntImm* dim = ishape[i].as <IntImm>()) {
164+ dims[i] = dim->value ;
165+ } else {
166+ return expr;
167+ }
168+ }
169+ Expr shape = ValueToExpr (TensorValueNode::make (val));
170+ // Cast the constant into correct dtype
171+ auto cast_attrs = make_node<CastAttrs>();
172+ cast_attrs->dtype = param->dtype ;
173+ static const Op& cast_op = Op::Get (" cast" );
174+ Expr ret = CallNode::make (cast_op, {shape}, Attrs (cast_attrs), {});
175+ return ConstEvaluate (ret);
176+ }
135177};
136178
137179
0 commit comments