Skip to content

Commit 279c240

Browse files
committed
[UnitTests] Expose TVM pytest helpers as plugin
Previously, pytest helper utilities such as automatic parametrization of `target`/`dev`, or `tvm.testing.parameter` were only available for tests within the `${TVM_HOME}/tests` directory. This PR extracts the helper utilities into an importable plugin, which can be used in external tests (e.g. one-off debugging).
1 parent ade2d4d commit 279c240

File tree

5 files changed

+87
-43
lines changed

5 files changed

+87
-43
lines changed

conftest.py

Lines changed: 1 addition & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,5 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
import pytest
18-
from pytest import ExitCode
1917

20-
import tvm
21-
import tvm.testing
22-
23-
24-
def pytest_configure(config):
25-
print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
26-
print("pytest marker:", config.option.markexpr)
27-
28-
29-
@pytest.fixture
30-
def dev(target):
31-
return tvm.device(target)
32-
33-
34-
def pytest_generate_tests(metafunc):
35-
tvm.testing._auto_parametrize_target(metafunc)
36-
tvm.testing._parametrize_correlated_parameters(metafunc)
37-
38-
39-
def pytest_collection_modifyitems(config, items):
40-
tvm.testing._count_num_fixture_uses(items)
41-
tvm.testing._remove_global_fixture_definitions(items)
42-
43-
44-
def pytest_sessionfinish(session, exitstatus):
45-
# Don't exit with an error if we select a subset of tests that doesn't
46-
# include anything
47-
if session.config.option.markexpr != "":
48-
if exitstatus == ExitCode.NO_TESTS_COLLECTED:
49-
session.exitstatus = ExitCode.OK
18+
pytest_plugins = ["tvm.testing.plugin"]

pytest.ini renamed to python/tvm/testing/__init__.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,6 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
[pytest]
18-
markers =
19-
gpu: mark a test as requiring a gpu
20-
tensorcore: mark a test as requiring a tensorcore
21-
cuda: mark a test as requiring cuda
22-
opencl: mark a test as requiring opencl
23-
rocm: mark a test as requiring rocm
24-
vulkan: mark a test as requiring vulkan
25-
metal: mark a test as requiring metal
26-
llvm: mark a test as requiring llvm
17+
"""Namespace for tvm testing utilities"""
18+
19+
from .testing import *

python/tvm/testing/plugin.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""Pytest plugin for using tvm testing extensions.
19+
20+
TVM provides utilities for testing across all supported targets, and
21+
to more easily parametrize across many inputs. For more information
22+
on usage of these features, see documentation in the tvm.testing
23+
module.
24+
25+
These are enabled by default in all pytests provided by tvm, but may
26+
be useful externally for one-off testing. To enable, add the
27+
following line to the test script, or to the conftest.py in the same
28+
directory as the test scripts.
29+
30+
pytest_plugins = ['tvm.testing.plugin']
31+
32+
"""
33+
34+
import pytest
35+
36+
import tvm.testing.testing
37+
38+
39+
def pytest_configure(config):
40+
"""Runs at pytest configure time, defines marks to be used later."""
41+
markers = {
42+
"gpu": "mark a test as requiring a gpu",
43+
"tensorcore": "mark a test as requiring a tensorcore",
44+
"cuda": "mark a test as requiring cuda",
45+
"opencl": "mark a test as requiring opencl",
46+
"rocm": "mark a test as requiring rocm",
47+
"vulkan": "mark a test as requiring vulkan",
48+
"metal": "mark a test as requiring metal",
49+
"llvm": "mark a test as requiring llvm",
50+
}
51+
for markername, desc in markers.items():
52+
config.addinivalue_line("markers", "{}: {}".format(markername, desc))
53+
54+
print("enabled targets:", "; ".join(map(lambda x: x[0], tvm.testing.enabled_targets())))
55+
print("pytest marker:", config.option.markexpr)
56+
57+
58+
def pytest_generate_tests(metafunc):
59+
"""Called once per unit test, modifies/parametrizes it as needed."""
60+
tvm.testing.testing._auto_parametrize_target(metafunc)
61+
tvm.testing.testing._parametrize_correlated_parameters(metafunc)
62+
63+
64+
def pytest_collection_modifyitems(config, items):
65+
"""Called after all tests are chosen, currently used for bookkeeping."""
66+
# pylint: disable=unused-argument
67+
tvm.testing.testing._count_num_fixture_uses(items)
68+
tvm.testing.testing._remove_global_fixture_definitions(items)
69+
70+
71+
@pytest.fixture
72+
def dev(target):
73+
"""Give access to the device to tests that need it."""
74+
return tvm.device(target)
75+
76+
77+
def pytest_sessionfinish(session, exitstatus):
78+
# Don't exit with an error if we select a subset of tests that doesn't
79+
# include anything
80+
if session.config.option.markexpr != "":
81+
if exitstatus == pytest.ExitCode.NO_TESTS_COLLECTED:
82+
session.exitstatus = pytest.ExitCode.OK
File renamed without changes.

tests/python/unittest/test_tvm_testing_features.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def test_num_uses_cached(self):
183183
class TestAutomaticMarks:
184184
@staticmethod
185185
def check_marks(request, target):
186-
parameter = tvm.testing._pytest_target_params([target])[0]
186+
parameter = tvm.testing.testing._pytest_target_params([target])[0]
187187
required_marks = [decorator.mark for decorator in parameter.marks]
188188
applied_marks = list(request.node.iter_markers())
189189

0 commit comments

Comments
 (0)