|
23 | 23 | from __future__ import print_function |
24 | 24 | from functools import partial |
25 | 25 | from distutils.version import LooseVersion |
26 | | - |
| 26 | +import platform |
27 | 27 | import os |
28 | 28 | import tempfile |
29 | 29 | import typing |
@@ -1092,35 +1092,56 @@ def test_forward_quantized_convolution(): |
1092 | 1092 | ) |
1093 | 1093 |
|
1094 | 1094 | _test_tflite2_quantized_convolution( |
1095 | | - (1, 16, 10, 10), |
1096 | | - (3, 3), |
1097 | | - 2, |
| 1095 | + (2, 32, 28, 28), |
| 1096 | + (1, 1), |
| 1097 | + 16, |
1098 | 1098 | data_format="NCWH", |
1099 | 1099 | int_quant_dtype=int_quant_dtype, |
1100 | | - groups=2, |
| 1100 | + groups=8, |
1101 | 1101 | ) |
1102 | 1102 |
|
| 1103 | + if platform.machine() == "aarch64": |
| 1104 | + pytest.skip( |
| 1105 | + reason=( |
| 1106 | + "Grouped convolution type inference error for `arm_cpu`. " |
| 1107 | + "See https://github.com/apache/tvm/issues/16532" |
| 1108 | + ) |
| 1109 | + ) |
| 1110 | + |
1103 | 1111 | _test_tflite2_quantized_convolution( |
1104 | | - (2, 32, 28, 28), |
1105 | | - (1, 1), |
1106 | | - 16, |
| 1112 | + (1, 16, 10, 10), |
| 1113 | + (3, 3), |
| 1114 | + 2, |
1107 | 1115 | data_format="NCWH", |
1108 | 1116 | int_quant_dtype=int_quant_dtype, |
1109 | | - groups=8, |
| 1117 | + groups=2, |
1110 | 1118 | ) |
1111 | 1119 |
|
1112 | 1120 |
|
1113 | 1121 | def test_forward_quantized_depthwise_convolution(): |
| 1122 | + """Test qnn.conv2d depthwise compiled with TVM against TFLite reference.""" |
1114 | 1123 | for int_quant_dtype in [tf.int8, tf.int16]: |
1115 | | - _test_tflite2_quantized_depthwise_convolution( |
1116 | | - [1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], "SAME", "NHWC", 1, int_quant_dtype |
1117 | | - ) |
1118 | 1124 | _test_tflite2_quantized_depthwise_convolution( |
1119 | 1125 | [1, 17, 17, 12], [3, 3, 12, 1], [1, 1], [2, 2], "VALID", "NHWC", 1, int_quant_dtype |
1120 | 1126 | ) |
1121 | 1127 | _test_tflite2_quantized_depthwise_convolution( |
1122 | 1128 | [1, 24, 24, 3], [7, 7, 3, 8], [1, 1], [2, 2], "SAME", "NHWC", 8, int_quant_dtype |
1123 | 1129 | ) |
| 1130 | + _test_tflite2_quantized_depthwise_convolution( |
| 1131 | + [1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], "SAME", "NHWC", 1, tf.int8 |
| 1132 | + ) |
| 1133 | + |
| 1134 | + if platform.machine() == "aarch64": |
| 1135 | + pytest.skip( |
| 1136 | + reason=( |
| 1137 | + "Tensor intrinsic data type mismatch error. " |
| 1138 | + "See https://github.com/apache/tvm/issues/16533" |
| 1139 | + ) |
| 1140 | + ) |
| 1141 | + |
| 1142 | + _test_tflite2_quantized_depthwise_convolution( |
| 1143 | + [1, 8, 8, 128], [1, 1, 128, 1], [1, 1], [1, 1], "SAME", "NHWC", 1, tf.int16 |
| 1144 | + ) |
1124 | 1145 |
|
1125 | 1146 |
|
1126 | 1147 | def _test_tflite2_quantized_depthwise_convolution( |
@@ -5090,6 +5111,10 @@ def test_forward_qnn_mobilenet_v3_net(): |
5090 | 5111 | tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) |
5091 | 5112 |
|
5092 | 5113 |
|
| 5114 | +@pytest.mark.skipif( |
| 5115 | + platform.machine() == "aarch64", |
| 5116 | + reason="Fails with an output mismatch. See https://github.com/apache/tvm/issues/16534", |
| 5117 | +) |
5093 | 5118 | def test_forward_tflite2_qnn_resnet50(): |
5094 | 5119 | """Test the Quantized TFLite version 2.1.0 Resnet50 model.""" |
5095 | 5120 | if package_version.parse(tf.VERSION) >= package_version.parse("2.1.0"): |
@@ -5186,6 +5211,11 @@ def test_forward_tflite_float16(): |
5186 | 5211 | tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) |
5187 | 5212 |
|
5188 | 5213 |
|
| 5214 | +@pytest.mark.skipif( |
| 5215 | + platform.machine() == "aarch64", |
| 5216 | + reason="Fails during leagalization due to int16 datatype. " |
| 5217 | + "See https://github.com/apache/tvm/issues/16535", |
| 5218 | +) |
5189 | 5219 | def test_forward_mobilenet_int16(): |
5190 | 5220 | """Test int16 quantized model""" |
5191 | 5221 | # MobilenetV2 |
@@ -5228,6 +5258,11 @@ def representative_dataset(): |
5228 | 5258 | tvm.testing.assert_allclose(tvm_sorted_labels, tflite_sorted_labels) |
5229 | 5259 |
|
5230 | 5260 |
|
| 5261 | +@pytest.mark.skipif( |
| 5262 | + platform.machine() == "aarch64", |
| 5263 | + reason="Fails during leagalization due to int16 datatype. " |
| 5264 | + "See https://github.com/apache/tvm/issues/16535", |
| 5265 | +) |
5231 | 5266 | def test_forward_ds_cnn_int16(): |
5232 | 5267 | """Test DS_CNN int16 quantized model""" |
5233 | 5268 | tflite_model_file = download_testdata( |
|
0 commit comments