-
Notifications
You must be signed in to change notification settings - Fork 3.7k
[microNPU] Add support for TFLite FULLY_CONNECTED #10345
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
18 commits
Select commit
Hold shift + click to select a range
8c0ea73
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm dab5c5e
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm b79fec2
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm ef1d576
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm d80302a
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm 7b177bd
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm a4741b9
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm 83d9ee1
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm 97cd5ce
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm 3bc81a5
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm ae6827c
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm 18bd546
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm efbe30e
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm fcca2d2
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm 06aa4d9
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm 04fd825
Merge branch 'apache:main' into pers-FCsupport
dchauhan-arm 82dc516
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm f924021
[microNPU] Add support for TFLite FULLY_CONNECTED
dchauhan-arm File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2421,5 +2421,108 @@ def verify(ext_func): | |
| verify(mod["tvmgen_default_ethos_u_main_0"]) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("ifm_shape", [(1, 14), (1, 151)]) | ||
| @pytest.mark.parametrize("ofm_channels", [32, 64]) | ||
| @pytest.mark.parametrize("use_bias", [True, False]) | ||
| @pytest.mark.parametrize("activation_function", ["RELU", "NONE"]) | ||
| def test_tflite_fully_connected( | ||
| ifm_shape, | ||
| ofm_channels, | ||
| use_bias, | ||
| activation_function, | ||
| ): | ||
| dtype = "int8" | ||
|
|
||
| def create_tflite_graph(): | ||
| class Model(tf.Module): | ||
| @tf.function | ||
| def fully_connected(self, x): | ||
| bias_shape = ofm_channels | ||
| bias = tf.constant(np.random.uniform(size=bias_shape), dtype=tf.float32) | ||
| w = tf.constant( | ||
| np.random.uniform(size=[ifm_shape[1], ofm_channels]), | ||
| dtype=tf.float32, | ||
| ) | ||
| x = tf.matmul(x, w) | ||
| if use_bias: | ||
| x = tf.nn.bias_add(x, bias) | ||
| if activation_function: | ||
| x = tf.nn.relu(x) | ||
| return x | ||
|
|
||
| model = Model() | ||
| concrete_func = model.fully_connected.get_concrete_function( | ||
| tf.TensorSpec(ifm_shape, dtype=tf.float32) | ||
| ) | ||
| # Convert the model | ||
| def representative_dataset(): | ||
| for _ in range(100): | ||
| data = np.random.rand(*tuple(ifm_shape)) | ||
| yield [data.astype(np.float32)] | ||
|
|
||
| converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func]) | ||
| converter.optimizations = [tf.lite.Optimize.DEFAULT] | ||
| converter.representative_dataset = representative_dataset | ||
| converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] | ||
| converter.inference_input_type = tf.int8 | ||
| converter.inference_output_type = tf.int8 | ||
| tflite_model = converter.convert() | ||
| return tflite_model | ||
|
|
||
| def verify(ext_func): | ||
| op = ext_func.body.args[0] | ||
| ofm_channels = op.attrs.ofm_channels | ||
|
|
||
| # check IFM | ||
| ifm = op.args[0].checked_type | ||
| assert list(ifm.shape) == [1, 1] + list(ifm_shape) | ||
| assert str(ifm.dtype) == dtype | ||
|
|
||
| # check OFM | ||
| ofm = op.checked_type | ||
| assert list(ofm.shape) == [1, 1, 1, ofm_channels] | ||
| assert str(ofm.dtype) == dtype | ||
|
|
||
| # check weights | ||
| weights_ohwi = op.args[1].data.asnumpy() | ||
| assert str(weights_ohwi.dtype) == dtype | ||
| assert list(weights_ohwi.shape) == [ofm_channels, 1, 1, ifm_shape[1]] | ||
|
|
||
| # Check that scale_bias matches weight tensor | ||
| assert list(op.args[2].checked_type.shape)[0] == ofm_channels | ||
|
|
||
| assert list(op.attrs.padding) == [0, 0, 0, 0] | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: might also be worth checking the op name is an NPU convolution here as well |
||
| assert list(op.attrs.strides) == [1, 1] | ||
| assert list(op.attrs.dilation) == [1, 1] | ||
| if activation_function == "RELU": | ||
| assert str(op.attrs.activation) == "CLIP" | ||
|
|
||
| fc_pattern_table = [ | ||
| ( | ||
| ethosu.FullyConnectedParams.composite_name, | ||
| ethosu.qnn_fc_pattern(), | ||
| lambda pat: ethosu.FullyConnectedParams(pat).is_valid(), | ||
| ) | ||
| ] | ||
|
|
||
| tflite_graph = create_tflite_graph() | ||
| tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0) | ||
|
|
||
| mod, fc_params = relay.frontend.from_tflite( | ||
| tflite_model, | ||
| shape_dict={"input": ifm_shape}, | ||
| dtype_dict={"input": dtype}, | ||
| ) | ||
|
|
||
| mod["main"] = bind_params_by_name(mod["main"], fc_params) | ||
| mod = partition_ethosu_by_table(mod, fc_pattern_table) | ||
|
|
||
| mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite( | ||
| legalize.FullyConnectedRewriter(), mod["tvmgen_default_ethos_u_main_0"] | ||
| ) | ||
|
|
||
| verify(mod["tvmgen_default_ethos_u_main_0"]) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| pytest.main([__file__]) | ||
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect there isn't a test case that exercises this case since on line 1700 this pass runs after the no op legalizer, so the last reshape won't have a following identity op and will fall over in TE