Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 30 additions & 5 deletions python/tvm/relay/frontend/tensorflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .. import ir_pass
from .. import expr as _expr
from .. import op as _op
from ..expr_functor import ExprMutator

__all__ = ['from_tensorflow']

Expand Down Expand Up @@ -1414,6 +1415,27 @@ def _get_abs_layer_name(node):
# 1.x.
_control_flow_nodes = ['Merge', 'Switch', 'NextIteration', 'Exit', 'Enter', 'LoopCond']

class RewriteSubgraph(ExprMutator):
"""
A helper class to rewrite expr in while loop function to variable

Parameters
----------
rewrite_map : Dict[expr, expr]
A dictionay contains a set of expr to var mapping.
"""
def __init__(self, rewrite_map):
ExprMutator.__init__(self)
self.rewrite_map = rewrite_map

def visit(self, expr):
if expr in self.rewrite_map:
return self.rewrite_map[expr]
return super().visit(expr)

def rewrite_subgraph(expr, rewrites):
return RewriteSubgraph(rewrites).visit(expr)

def _in_while_loop(control_flow_node_map, op_name):
"""
Check if a given control flow operator is part of a while loop execution
Expand Down Expand Up @@ -1594,14 +1616,17 @@ def _while_loop(self):
loop_vars = []
bind_map = {}
for i, var in enumerate(self.loop_vars):
assert isinstance(var, _expr.Var), repr(var)
v = tvm.relay.var("loop_var" + str(i),
type_annotation=var.type_annotation)
if not isinstance(var, _expr.Var):
var_type = ir_pass.infer_type(var).checked_type
else:
var_type = var.type_annotation

v = tvm.relay.var("loop_var" + str(i), type_annotation=var_type)
loop_vars.append(v)
bind_map[var] = v
Copy link
Member

@jroesch jroesch Apr 24, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think instead of calling bind which has specific semantics we should just add code like this:

class RewriteSubgraph(ExprMutator):
    def __init__(self, rewrite_map): 
        self.rewrite_map = rewrite_map

    def visit(self, expr):
       if expr in self.rewrite_map:
            return self.rewrite_map[expr]
       else: 
            return super().visit(expr)

def rewrite_subgraph(expr, rewrites):
     return RewriteSubgraph(rewrites).visit(expr)

then replace the call with bind to this, this will handle the general case and not just calls.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Do you think it should be a common function or just be kept in TF frontend?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it is reasonable to just put it in the TF pass for now. We could also put in the frontend/common file where utilities are kept.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've modified it according to request.


self.cond = tvm.relay.bind(self.cond, bind_map)
self.body = [tvm.relay.bind(b, bind_map) for b in self.body]
self.cond = rewrite_subgraph(self.cond, bind_map)
self.body = [rewrite_subgraph(b, bind_map) for b in self.body]

cond = tvm.relay.op.min(self.cond)

Expand Down
18 changes: 18 additions & 0 deletions tests/python/frontend/tensorflow/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,23 @@ def b(i): return tf.add(i, 1)
check_equal(graph, tf_out)


def test_callnode_loop_vars():
graph = tf.Graph()
with graph.as_default():
i = tf.add(tf.constant(0), 1)

def c(i): return tf.less(i, 10)

def b(i): return tf.add(i, 1)

r = tf.while_loop(c, b, [i])

with tf.Session() as sess:
tf_out = sess.run(r)

check_equal(graph, tf_out)


def test_loop_2_vars():
graph = tf.Graph()
with graph.as_default():
Expand Down Expand Up @@ -288,6 +305,7 @@ def condition(x):
test_loop_3_vars()
test_loop_conditions()
test_loop_bodies()
test_callnode_loop_vars()

# tf.cond
test_vanilla_cond()
Expand Down