File tree Expand file tree Collapse file tree 4 files changed +39
-9
lines changed
python/tvm/relay/frontend
tests/python/frontend/tensorflow Expand file tree Collapse file tree 4 files changed +39
-9
lines changed Original file line number Diff line number Diff line change @@ -238,15 +238,15 @@ class ExprMutator
238238void PostOrderVisit (const Expr& node, std::function<void (const Expr&)> fvisit);
239239
240240/*
241- * \brief Bind function parameters or free variables.
241+ * \brief Bind function parameters, free variables or CallNode .
242242 *
243243 * Parameter binding can only happen if expr is a Function.
244244 * binds cannot change internal arguments of internal functions.
245245 *
246- * \param expr The function to be binded .
246+ * \param expr The function to be bound .
247247 * \param binds The map of arguments to
248248 */
249- Expr Bind (const Expr& expr, const tvm::Map<Var , Expr>& binds);
249+ Expr Bind (const Expr& expr, const tvm::Map<Expr , Expr>& binds);
250250
251251} // namespace relay
252252} // namespace tvm
Original file line number Diff line number Diff line change @@ -1605,9 +1605,12 @@ def _while_loop(self):
16051605 loop_vars = []
16061606 bind_map = {}
16071607 for i , var in enumerate (self .loop_vars ):
1608- assert isinstance (var , _expr .Var ), repr (var )
1609- v = tvm .relay .var ("loop_var" + str (i ),
1610- type_annotation = var .type_annotation )
1608+ if not isinstance (var , _expr .Var ):
1609+ var_type = ir_pass .infer_type (var ).checked_type
1610+ else :
1611+ var_type = var .type_annotation
1612+
1613+ v = tvm .relay .var ("loop_var" + str (i ), type_annotation = var_type )
16111614 loop_vars .append (v )
16121615 bind_map [var ] = v
16131616
Original file line number Diff line number Diff line change @@ -355,7 +355,7 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit")
355355// Implement bind.
356356class ExprBinder : public ExprMutator {
357357 public:
358- explicit ExprBinder (const tvm::Map<Var , Expr>& args_map)
358+ explicit ExprBinder (const tvm::Map<Expr , Expr>& args_map)
359359 : args_map_(args_map) {
360360 }
361361
@@ -383,11 +383,21 @@ class ExprBinder : public ExprMutator {
383383 }
384384 }
385385
386+ Expr VisitExpr_ (const CallNode* op) final {
387+ auto id = GetRef<Call>(op);
388+ auto it = args_map_.find (id);
389+ if (it != args_map_.end ()) {
390+ return (*it).second ;
391+ } else {
392+ return ExprMutator::VisitExpr_ (op);
393+ }
394+ }
395+
386396 private:
387- const tvm::Map<Var , Expr>& args_map_;
397+ const tvm::Map<Expr , Expr>& args_map_;
388398};
389399
390- Expr Bind (const Expr& expr, const tvm::Map<Var , Expr>& args_map) {
400+ Expr Bind (const Expr& expr, const tvm::Map<Expr , Expr>& args_map) {
391401 if (const FunctionNode* func = expr.as <FunctionNode>()) {
392402 Expr new_body = ExprBinder (args_map).Mutate (func->body );
393403 Array<Var> new_params;
Original file line number Diff line number Diff line change @@ -51,6 +51,23 @@ def b(i): return tf.add(i, 1)
5151 check_equal (graph , tf_out )
5252
5353
54+ def test_callnode_loop_vars ():
55+ graph = tf .Graph ()
56+ with graph .as_default ():
57+ i = tf .add (tf .constant (0 ), 1 )
58+
59+ def c (i ): return tf .less (i , 10 )
60+
61+ def b (i ): return tf .add (i , 1 )
62+
63+ r = tf .while_loop (c , b , [i ])
64+
65+ with tf .Session () as sess :
66+ tf_out = sess .run (r )
67+
68+ check_equal (graph , tf_out )
69+
70+
5471def test_loop_2_vars ():
5572 graph = tf .Graph ()
5673 with graph .as_default ():
You can’t perform that action at this time.
0 commit comments