Skip to content

Commit 65eb3fe

Browse files
committed
[CI] Apply linting rules to AOT tests
This enables pylint against the AOT test cases. One issue I found was with the `tvm.testing.parameter` which breaks the naming convention rules in pylint (constants are upper case and function parameters are lower case). It may be worth a syntax similar to: ``` tvm.testing.parameter("enable_usmp", [True, False]) ```
1 parent 81b42e6 commit 65eb3fe

File tree

5 files changed

+244
-190
lines changed

5 files changed

+244
-190
lines changed

tests/lint/pylint.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,4 +21,5 @@ python3 -m pylint python/tvm --rcfile="$(dirname "$0")"/pylintrc
2121
python3 -m pylint vta/python/vta --rcfile="$(dirname "$0")"/pylintrc
2222
python3 -m pylint tests/python/unittest/test_tvmscript_type.py --rcfile="$(dirname "$0")"/pylintrc
2323
python3 -m pylint tests/python/contrib/test_cmsisnn --rcfile="$(dirname "$0")"/pylintrc
24+
python3 -m pylint tests/python/relay/aot/*.py --rcfile="$(dirname "$0")"/pylintrc
2425

tests/python/relay/aot/test_c_device_api.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,38 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17+
"""AOT with C Device API Tests"""
1718

18-
import sys
19+
import re
1920
from collections import OrderedDict
2021

2122
import numpy as np
2223
import pytest
23-
import re
24-
import tvm.testing
2524

25+
import tvm.testing
2626
from tvm import relay
2727
from tvm.ir.module import IRModule
2828
from tvm.testing.aot import AOTTestModel, generate_ref_data, compile_models
2929
from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER
3030

3131

32-
@pytest.fixture
33-
def device_api_main_func():
32+
@pytest.fixture(name="device_api_main_func")
33+
def fixture_device_api_main_func():
34+
"""Test function generator which generates C Device API calls"""
35+
3436
# Ideally we should have a sample Target registered here
3537
# but we're going to re-use this for now
3638
pytest.importorskip("ethosu.vela")
39+
40+
# pylint: disable=import-outside-toplevel
3741
import tensorflow as tf
3842
import tflite.Model
3943

4044
from tests.python.contrib.test_ethosu.infra import create_test_runner, generate_ref_data_tflite
4145
from tvm.relay.op.contrib.ethosu import partition_for_ethosu
4246

47+
# pylint: enable=import-outside-toplevel
48+
4349
tf.config.run_functions_eagerly(True)
4450

4551
class Model(tf.Module):
@@ -97,8 +103,9 @@ def compile_to_main_func(interface_api="c", use_unpacked_api=True):
97103
return compile_to_main_func
98104

99105

100-
@pytest.fixture
101-
def non_device_api_main_func():
106+
@pytest.fixture(name="non_device_api_main_func")
107+
def fixture_non_device_api_main_func():
108+
"""Test function generator which does not generate C Device API calls"""
102109
x = relay.var("x", shape=(10, 10))
103110
y = relay.var("y", shape=(1, 10))
104111
func = relay.Function([x, y], relay.multiply(x, y))
@@ -151,7 +158,7 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
151158
# We dont need to check exact input and output var names in this test.
152159
# Hence, using a regex to cover any legal I/O name.
153160
regex = re.compile(
154-
'tir\.tvm_check_return\(0, -1, tir\.call_extern\("tvmgen_default_ethos_u_main_0", \w+, \w+, device_context_ethos_u\)\)'
161+
r'tir\.tvm_check_return\(0, -1, tir\.call_extern\("tvmgen_default_ethos_u_main_0", \w+, \w+, device_context_ethos_u\)\)' # pylint: disable=line-too-long
155162
)
156163
assert regex.match(str(main_func.body[1][0][0][1]))
157164
# Close Device
@@ -171,7 +178,9 @@ def test_device_api_hooks_unpacked_api(device_api_main_func):
171178

172179

173180
@pytest.mark.skip(
174-
"Skipping this test as this is incorrectly using Arm(R) Ethos(TM)-U NPU with packed calling convention which is not supported by the NPU codegen's TIR to Runtime Hook. We need to use a different target to test this feature"
181+
"Skipping this test as this is incorrectly using Arm(R) Ethos(TM)-U NPU "
182+
+ "with packed calling convention which is not supported by the NPU codegen's "
183+
+ "TIR to Runtime Hook. We need to use a different target to test this feature"
175184
)
176185
def test_device_api_hooks_packed_api(device_api_main_func):
177186
"""Check for Device API hooks with packed internal calls"""
@@ -236,11 +245,12 @@ def test_without_device_api_packed_api(non_device_api_main_func):
236245
"""Test a graph without the Device API with the packed internal calls"""
237246

238247
main_func = non_device_api_main_func(interface_api="packed", use_unpacked_api=False)
248+
239249
assert str(main_func.body) == (
240250
'tir.tvm_call_cpacked("tvmgen_default_fused_multiply", '
241-
"tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
242-
"tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
243-
"tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), "
251+
"tir.tvm_stack_make_array(x_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
252+
"tir.tvm_stack_make_array(y_buffer_var, tir.tvm_stack_make_shape(1, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
253+
"tir.tvm_stack_make_array(output_buffer_var, tir.tvm_stack_make_shape(10, 10), tir.reinterpret((uint64)0), (uint32)2, float32(0), 0), " # pylint: disable=line-too-long
244254
"tir.reinterpret((uint64)0))\n"
245255
)
246256

tests/python/relay/aot/test_cpp_aot.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,9 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
17+
"""AOT with C++ Runtime Tests"""
1818

1919
import re
20-
import sys
2120
import textwrap
2221

2322
import numpy as np
@@ -28,13 +27,10 @@
2827
from tvm import relay
2928
from tvm.relay import backend, testing
3029
from tvm.testing.aot import generate_ref_data
31-
from tvm.micro.testing.aot_test_utils import AOT_DEFAULT_RUNNER
3230

3331

3432
def test_error_c_interface():
35-
interface_api = "c"
36-
use_unpacked_api = False
37-
test_runner = AOT_DEFAULT_RUNNER
33+
"""Checks that an error occurs when using the packed API in combination with C interface"""
3834

3935
two = relay.add(relay.const(1), relay.const(1))
4036
func = relay.Function([], two)
@@ -53,12 +49,11 @@ def test_error_c_interface():
5349
)
5450

5551

56-
enable_usmp = tvm.testing.parameter(True, False)
57-
target_kind = tvm.testing.parameter("c", "llvm")
58-
59-
52+
@pytest.mark.parametrize("enable_usmp", [True, False])
53+
@pytest.mark.parametrize("target_kind", ["c", "llvm"])
6054
def test_conv2d(enable_usmp, target_kind):
61-
RELAY_MODEL = textwrap.dedent(
55+
"""Tests compilation of convolutions"""
56+
relay_model = textwrap.dedent(
6257
"""\
6358
#[version = "0.0.5"]
6459
def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5), int8]) {
@@ -86,7 +81,7 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
8681
}
8782
"""
8883
)
89-
ir_mod = tvm.parser.fromtext(RELAY_MODEL)
84+
ir_mod = tvm.parser.fromtext(relay_model)
9085

9186
main_func = ir_mod["main"]
9287
shape_dict = {p.name_hint: p.checked_type.concrete_shape for p in main_func.params}
@@ -119,7 +114,10 @@ def @main(%data : Tensor[(1, 3, 64, 64), uint8], %weight : Tensor[(3, 3, 5, 5),
119114
assert (runner.get_output(0).asnumpy() == list(ref_outputs.values())[0]).all()
120115

121116

117+
@pytest.mark.parametrize("enable_usmp", [True, False])
118+
@pytest.mark.parametrize("target_kind", ["c", "llvm"])
122119
def test_mobilenet(enable_usmp, target_kind):
120+
"""Full network test with Mobilenet"""
123121
ir_mod, params = testing.mobilenet.get_workload(batch_size=1)
124122
data_shape = [int(x) for x in ir_mod["main"].checked_type.arg_types[0].shape]
125123
data = np.random.uniform(size=data_shape).astype("float32")
@@ -147,10 +145,11 @@ def test_mobilenet(enable_usmp, target_kind):
147145

148146

149147
def test_module_list():
150-
x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32"))
151-
expr = tvm.relay.add(x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32"))))
148+
"""Checks the correct list of module names is generated"""
149+
input_x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32"))
150+
expr = tvm.relay.add(input_x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32"))))
152151
mod = tvm.relay.build(
153-
tvm.IRModule.from_expr(tvm.relay.Function([x], expr)),
152+
tvm.IRModule.from_expr(tvm.relay.Function([input_x], expr)),
154153
target="c",
155154
executor=tvm.relay.backend.Executor("aot", {"interface-api": "packed"}),
156155
mod_name="unusual_module_name_fred",
@@ -177,6 +176,7 @@ def test_create_executor():
177176

178177

179178
def test_pass_wrong_device_arg():
179+
"""Ensure an error is generated if the incorrect number of devices are passed"""
180180
x = tvm.relay.var("x", tvm.relay.TensorType([1], dtype="float32"))
181181
expr = tvm.relay.add(x, tvm.relay.Constant(tvm.nd.array(np.array([1], dtype="float32"))))
182182
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
@@ -191,12 +191,12 @@ def test_pass_wrong_device_arg():
191191
mod.export_library(test_so_path, cc="gcc", options=["-std=c11"])
192192
loaded_mod = tvm.runtime.load_module(test_so_path)
193193

194-
with pytest.raises(tvm.TVMError) as cm:
194+
with pytest.raises(tvm.TVMError) as error:
195195
tvm.runtime.executor.AotModule(loaded_mod["default"](tvm.cpu(0), tvm.cpu(0)))
196196

197197
assert (
198198
"Check failed: devices_.size() == 1 (2 vs. 1) : Expect exactly 1 device passed."
199-
in str(cm.exception)
199+
in str(error.exception)
200200
)
201201
# TODO write asserts for # and type of device.
202202

0 commit comments

Comments
 (0)