-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Add support for tf.assert (as no-op) and tf.no_op to TF Relay frontend. #4172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,93 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Unit tests for converting TensorFlow debugging ops to Relay.""" | ||
| import tensorflow as tf | ||
| import numpy as np | ||
| from tvm import relay | ||
| from tvm.relay.frontend.tensorflow import from_tensorflow | ||
|
|
||
| def run_relay(graph, *vars): | ||
| mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) | ||
| ex = relay.create_executor('debug', mod=mod) | ||
| return ex.evaluate()(*vars) | ||
|
|
||
| def test_assert_true(): | ||
| g = tf.Graph() | ||
| with g.as_default(): | ||
| x = tf.placeholder(tf.float32, shape=()) | ||
| assert_op = tf.Assert(tf.less_equal(x, x), ["it failed"]) | ||
|
|
||
| with tf.Session() as sess: | ||
| x_value = np.random.rand() | ||
| assert sess.run(assert_op, feed_dict={x: x_value}) is None | ||
|
|
||
| # In TVM, tf.assert is converted to a no-op which is actually a 0, | ||
| # though it should probably be none or an empty tuple. | ||
| # | ||
| # ToDo: It appears that the frontend converter gets confused here and | ||
| # entirely eliminates all operands from main(). Likely because x <= x | ||
| # is always true, so the placeholder can be eliminated. But TF doesn't | ||
| # do that, it's happening in Relay, and that optimization shouldn't | ||
| # affect the arity of the main function. We should have to pass in | ||
| # x_value here. | ||
| np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
|
||
| def test_assert_true_var_capture(): | ||
| g = tf.Graph() | ||
| with g.as_default(): | ||
| x = tf.placeholder(tf.float32, shape=()) | ||
|
|
||
| # It turns out that tf.assert() creates a large and complex subgraph if | ||
| # you capture a variable as part of the error message. So we need to | ||
| # test that, too. | ||
| assert_op = tf.Assert(tf.less_equal(x, x), ["it failed", x]) | ||
|
|
||
| with tf.Session() as sess: | ||
| x_value = np.random.rand() | ||
| assert sess.run(assert_op, feed_dict={x: x_value}) is None | ||
|
|
||
| # ToDo: The frontend converter gets confused here as well, thinking | ||
| # that it needs to be told what x is twice. It also notes the output of | ||
| # the graph as a boolean, which is not correct - as you can see above, | ||
| # TF believes that the value of this graph is None. In addition, the | ||
| # arity of the translated function should be 1, not 2. | ||
| np.testing.assert_allclose(True, run_relay(g, x_value, x_value).asnumpy()) | ||
|
|
||
| def test_assert_false(): | ||
| g = tf.Graph() | ||
| with g.as_default(): | ||
| assert_op = tf.Assert(tf.constant(False), ["it failed"]) | ||
|
|
||
| with tf.Session() as sess: | ||
| try: | ||
| print(sess.run(assert_op)) | ||
| assert False # TF should have thrown an exception | ||
| except tf.errors.InvalidArgumentError as e: | ||
| assert "it failed" in e.message | ||
|
|
||
| # In TVM, tf.assert is converted to a no-op which is actually a 0, | ||
| # though it should probably be none or an empty tuple. For the same | ||
| # reason, there should not be an error here, even though the assertion | ||
| # argument is false. | ||
| np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_assert_true() | ||
| test_assert_true_var_capture() | ||
| test_assert_false() | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| # Licensed to the Apache Software Foundation (ASF) under one | ||
| # or more contributor license agreements. See the NOTICE file | ||
| # distributed with this work for additional information | ||
| # regarding copyright ownership. The ASF licenses this file | ||
| # to you under the Apache License, Version 2.0 (the | ||
| # "License"); you may not use this file except in compliance | ||
| # with the License. You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, | ||
| # software distributed under the License is distributed on an | ||
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| # KIND, either express or implied. See the License for the | ||
| # specific language governing permissions and limitations | ||
| # under the License. | ||
| """Unit tests for converting TensorFlow debugging ops to Relay.""" | ||
| import tensorflow as tf | ||
| import numpy as np | ||
| from tvm import relay | ||
| from tvm.relay.frontend.tensorflow import from_tensorflow | ||
|
|
||
| def run_relay(graph): | ||
| mod, params = from_tensorflow(graph.as_graph_def(add_shapes=True)) | ||
| ex = relay.create_executor('debug', mod=mod) | ||
| return ex.evaluate()(**params) | ||
|
|
||
| def test_no_op(): | ||
| g = tf.Graph() | ||
| with g.as_default(): | ||
| no_op = tf.no_op() | ||
| with tf.Session() as sess: | ||
| # In TF, the type of a no-op is None. | ||
| assert sess.run(no_op) is None | ||
|
|
||
| # In TVM, no-op is currently translated to 0, though it should | ||
| # probably be none or an empty tuple. | ||
| np.testing.assert_allclose(0, run_relay(g).asnumpy()) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_no_op() | ||
|
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to understand the semantics of a no-op. Can there be the TF graph which looks something like this?
op1 --> NoOp --> op2
Where NoOp basically means connecting op1 and op2.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is possible with control dependencies. The reason no_op can appear when you call tf.assert() is that it is used as a control dependency. Quite unnecessarily, but it's there.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In that case, will this implementation make sense?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My assumption here is that inputs are the value inputs, not control edges. If that understanding is correct, then there would never be inputs to no_op, since no_op doesn't have any operands and doesn't yield a value in TF. I had thought that this part of the code is not concerned with translating control edges, which are a concern that can apply to many Graphdef ops and isn't a specific concern for a specific op. Though maybe that's not correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood. I have also seen similar behavior with TF no-op where it doesn't have any input operands. So, current implementation should be enough.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As background, you can see this situation in this Tensorboard screenshot:
https://pasteboard.co/IDcKMM4.png
Which is the graph generated from the second assert test in this PR.