Skip to content

Commit 9f8bef1

Browse files
committed
Add support and testing for tf.assert (as no-op) and tf.no_op to TF Relay frontend.
1 parent 6f9d028 commit 9f8bef1

File tree

3 files changed

+132
-2
lines changed

3 files changed

+132
-2
lines changed

python/tvm/relay/frontend/tensorflow.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,24 @@ def _impl(inputs, attr, params):
436436
return AttrCvt(op_name="copy", ignores=['message'])(inputs, attr)
437437
return _impl
438438

439+
def _assert():
440+
# ToDo: In general people want asserts to be gone from TensorFlow graphs
441+
# when they are optimizing them, so converting it to a no-op is
442+
# reasonable. However, it would be nice to have the option to keep them
443+
# once Relay gets a Halt or Assert op.
444+
return _no_op()
445+
446+
def _no_op():
447+
def _impl(inputs, attr, params):
448+
# ToDo: This should really be an op that returns nothing, which could
449+
# be represented as an empty tuple. It turns out that TVM
450+
# infrastructure doesn't like running functions that return None and
451+
# also don't like running functions that return an empty tuple. So it
452+
# doesn't work, but it should be made to work and then this could be
453+
# improved. In the mean time, it is hard to imagine a case where it
454+
# matters in any real way that a no-op is converted to a constant 0.
455+
return tvm.relay.const(0)
456+
return _impl
439457

440458
def _matmul():
441459
def _impl(inputs, attr, params):
@@ -1319,6 +1337,7 @@ def _impl(inputs, attr, params):
13191337
'All' : _reduce('all'),
13201338
'ArgMax' : _argx(_op.argmax, 'argmax'),
13211339
'ArgMin' : _argx(_op.argmin, 'argmin'),
1340+
'Assert' : _assert(),
13221341
'AvgPool' : _pooling('avg_pool'),
13231342
'BatchMatMul' : _batch_matmul(),
13241343
'BatchMatMulV2' : _batch_matmul(),
@@ -1377,6 +1396,7 @@ def _impl(inputs, attr, params):
13771396
'Mod' : _elemwise('mod'),
13781397
'Mul' : _elemwise('multiply'),
13791398
'Neg' : AttrCvt('negative'),
1399+
'NoOp' : _no_op(),
13801400
'NotEqual' : _broadcast('not_equal'),
13811401
'OneHot' : _one_hot(),
13821402
'Pack' : _pack(),
@@ -2189,8 +2209,11 @@ def _parse_param(self, key, value, name, shape):
21892209
if np_array.dtype == np.dtype(object):
21902210
# Object types are generally tensorflow DT_STRING (DecodeJpeg op).
21912211
# Just leave it as placeholder.
2192-
self._nodes[name] = [_expr.var(name, shape=shape[name], dtype='uint8')]
2193-
2212+
if shape:
2213+
var_shape = shape[name]
2214+
else:
2215+
var_shape = value.tensor.tensor_shape
2216+
self._nodes[name] = [_expr.var(name, shape=var_shape, dtype='uint8')]
21942217
return
21952218

21962219
array_ndim = len(np_array.shape)
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Unit tests for converting TensorFlow debugging ops to Relay."""
18+
import tensorflow as tf
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.relay.frontend.tensorflow import from_tensorflow
22+
23+
def run_relay(graph):
24+
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
25+
ex = relay.create_executor('debug', mod=mod)
26+
return ex.evaluate()(**params)
27+
28+
def test_assert_true():
29+
g = tf.Graph()
30+
with g.as_default():
31+
x = tf.placeholder(tf.float32, shape=())
32+
assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"])
33+
34+
with tf.Session() as sess:
35+
x_value = np.random.rand()
36+
assert sess.run(assert_op, feed_dict={x: x_value}) is None
37+
38+
# In TVM, tf.assert is converted to a no-op which is actually a 0,
39+
# though it should probably be none or an empty tuple.
40+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
41+
42+
def test_assert_false():
43+
g = tf.Graph()
44+
with g.as_default():
45+
assert_op = tf.Assert(tf.constant(False), ["it failed"])
46+
47+
with tf.Session() as sess:
48+
try:
49+
print(sess.run(assert_op))
50+
assert False # TF should have thrown an exception
51+
except tf.errors.InvalidArgumentError as e:
52+
assert "it failed" in e.message
53+
54+
# In TVM, tf.assert is converted to a no-op which is actually a 0,
55+
# though it should probably be none or an empty tuple. For the same
56+
# reason, there will not be an error here, even though the assertion
57+
# argument is false.
58+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
59+
60+
61+
if __name__ == "__main__":
62+
test_assert_true()
63+
test_assert_false()
64+
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Unit tests for converting TensorFlow debugging ops to Relay."""
18+
import tensorflow as tf
19+
import numpy as np
20+
from tvm import relay
21+
from tvm.relay.frontend.tensorflow import from_tensorflow
22+
23+
def run_relay(graph):
24+
mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True))
25+
ex = relay.create_executor('debug', mod=mod)
26+
return ex.evaluate()(**params)
27+
28+
def test_no_op():
29+
g = tf.Graph()
30+
with g.as_default():
31+
no_op = tf.no_op()
32+
with tf.Session() as sess:
33+
# In TF, the type of a no-op is None.
34+
assert sess.run(no_op) is None
35+
36+
# In TVM, no-op is currently translated to 0, though it should
37+
# probably be none or an empty tuple.
38+
np.testing.assert_allclose(0, run_relay(g).asnumpy())
39+
40+
41+
if __name__ == "__main__":
42+
test_no_op()
43+

0 commit comments

Comments
 (0)