Skip to content

Commit 70c1f3f

Browse files
author
Li Xiaoquan
committed
[Relay][Tensorflow] Allow an op as loop var.
Allow binding a CallNode to a var to support op as loop var.
1 parent c64a33e commit 70c1f3f

File tree

4 files changed

+39
-9
lines changed

4 files changed

+39
-9
lines changed

include/tvm/relay/expr_functor.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,15 +238,15 @@ class ExprMutator
238238
void PostOrderVisit(const Expr& node, std::function<void(const Expr&)> fvisit);
239239

240240
/*
241-
* \brief Bind function parameters or free variables.
241+
* \brief Bind function parameters, free variables or CallNode.
242242
*
243243
* Parameter binding can only happen if expr is a Function.
244244
* binds cannot change internal arguments of internal functions.
245245
*
246-
* \param expr The function to be binded.
246+
* \param expr The function to be bound.
247247
* \param binds The map of arguments to
248248
*/
249-
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds);
249+
Expr Bind(const Expr& expr, const tvm::Map<Expr, Expr>& binds);
250250

251251
} // namespace relay
252252
} // namespace tvm

python/tvm/relay/frontend/tensorflow.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1605,9 +1605,12 @@ def _while_loop(self):
16051605
loop_vars = []
16061606
bind_map = {}
16071607
for i, var in enumerate(self.loop_vars):
1608-
assert isinstance(var, _expr.Var), repr(var)
1609-
v = tvm.relay.var("loop_var" + str(i),
1610-
type_annotation=var.type_annotation)
1608+
if not isinstance(var, _expr.Var):
1609+
var_type = ir_pass.infer_type(var).checked_type
1610+
else:
1611+
var_type = var.type_annotation
1612+
1613+
v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
16111614
loop_vars.append(v)
16121615
bind_map[var] = v
16131616

src/relay/ir/expr_functor.cc

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,7 @@ TVM_REGISTER_API("relay._ir_pass.post_order_visit")
355355
// Implement bind.
356356
class ExprBinder : public ExprMutator {
357357
public:
358-
explicit ExprBinder(const tvm::Map<Var, Expr>& args_map)
358+
explicit ExprBinder(const tvm::Map<Expr, Expr>& args_map)
359359
: args_map_(args_map) {
360360
}
361361

@@ -383,11 +383,21 @@ class ExprBinder : public ExprMutator {
383383
}
384384
}
385385

386+
Expr VisitExpr_(const CallNode* op) final {
387+
auto id = GetRef<Call>(op);
388+
auto it = args_map_.find(id);
389+
if (it != args_map_.end()) {
390+
return (*it).second;
391+
} else {
392+
return ExprMutator::VisitExpr_(op);
393+
}
394+
}
395+
386396
private:
387-
const tvm::Map<Var, Expr>& args_map_;
397+
const tvm::Map<Expr, Expr>& args_map_;
388398
};
389399

390-
Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& args_map) {
400+
Expr Bind(const Expr& expr, const tvm::Map<Expr, Expr>& args_map) {
391401
if (const FunctionNode* func = expr.as<FunctionNode>()) {
392402
Expr new_body = ExprBinder(args_map).Mutate(func->body);
393403
Array<Var> new_params;

tests/python/frontend/tensorflow/test_control_flow.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@ def b(i): return tf.add(i, 1)
5151
check_equal(graph, tf_out)
5252

5353

54+
def test_callnode_loop_vars():
55+
graph = tf.Graph()
56+
with graph.as_default():
57+
i = tf.add(tf.constant(0), 1)
58+
59+
def c(i): return tf.less(i, 10)
60+
61+
def b(i): return tf.add(i, 1)
62+
63+
r = tf.while_loop(c, b, [i])
64+
65+
with tf.Session() as sess:
66+
tf_out = sess.run(r)
67+
68+
check_equal(graph, tf_out)
69+
70+
5471
def test_loop_2_vars():
5572
graph = tf.Graph()
5673
with graph.as_default():

0 commit comments

Comments
 (0)