Skip to content

Commit f16286e

Browse files
authored
[Hexagon] Improved ergonomics of HexagonLauncher in unit tests. (#10581)
* [Hexagon] Improved ergonomics of HexagonLauncher in unit tests. The goal of this commit is to reduce/eliminate common code required through unit tests that interact with Hexagon hardware. - New testing fixtures in `tests/python/contrib/test_hexagon`. A test running on hexagon hardware should only need to use the `hexagon_session` fixture. - `rpc_server_port`: Iterates through port numbers, selecting an unused port for each unit test. Avoids needing to explicitly specify unique ports for each unit test. - `tvm_tracker`: Starts a tracker on use, exits after test. Avoids needing to manually start a tracker prior to running the unit test. - `hexagon_launcher`: Starts a `HexagonLauncher` server on use, stops server after test. Avoids needing to call `start_server()` and `stop_server()` in each test. - `hexagon_session`: Starts a hexagon session using `hexagon_laucnehr.start_session()`, exits after test. - Added `Session.upload` function, which delegates to `HexagonLauncher.upload`. Avoids needing to interact with both the launcher and the session. - Allowed `tvm.IRModule` as argument passed to `Session.load_module`, which will automatically save/upload the module, then load it. Avoids needing to handle save/upload of temporary files in each unit test. * Added default port for tracker if not already set. * Pass through None from hexagon_launcher to hexagon_session. * Updated launcher to use external tracker if specified. * Avoid setting up the local tracker unless required. * Declare previous_port as global, instead of list. * Corrected type hints. * Docstring updates
1 parent 14084f4 commit f16286e

File tree

5 files changed

+290
-223
lines changed

5 files changed

+290
-223
lines changed

python/tvm/contrib/hexagon/build.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,28 +195,37 @@ def start_session(self) -> Session:
195195
"timeout": 0,
196196
"key": self.HEXAGON_REMOTE_DEVICE_KEY,
197197
}
198-
return Session(hexagon_remote_kw)
198+
return Session(self, hexagon_remote_kw)
199199

200-
def load_module(self, module_name: Union[str, pathlib.Path], session: Session):
200+
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module], session: Session):
201201
"""Load TVM module.
202202
203203
Parameters
204204
----------
205-
module_name : str or pathlib.Path
206-
Name of the module to load. It must be either a bare file name
207-
(without any path components), or a full path in the remote
208-
system. If it is a file name, the file must be placed in the
209-
remote workspace.
205+
module : Union[str, pathlib.Path, tvm.runtime.Module]
206+
207+
The module to load. If `module` is a
208+
`tvm.runtime.Module`, it will be uploaded to the remote
209+
session and loaded.
210+
211+
If the object passed is a string or pathlib.Path, it must
212+
be either a bare file name (without any path components),
213+
or a full path in the remote system. If it is a file name,
214+
the file must already have been uploaded to the remote,
215+
and be placed in the remote workspace.
216+
210217
session : Session
218+
211219
Remote session. The session must be established (via __enter__)
212220
prior to calling this function.
213221
214222
Returns
215223
-------
216224
TVMModule :
217225
TVM module object.
226+
218227
"""
219-
return session.load_module(module_name)
228+
return session.load_module(module)
220229

221230
def get_graph_executor(
222231
self, graph_json: str, module_name: Union[str, pathlib.Path], session: Session

python/tvm/contrib/hexagon/session.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,10 @@
1919

2020
import os
2121
import pathlib
22+
import tempfile
2223
from typing import Union
24+
25+
import tvm
2326
from tvm import rpc as _rpc
2427

2528

@@ -28,19 +31,28 @@ class Session:
2831
2932
Parameters
3033
----------
34+
launcher : HexagonLauncherRPC
35+
The launcher from which this session was started.
36+
3137
remote_kw : dict
3238
Remote configs for RPC tracker.
3339
3440
session_name : str
3541
Hexagon RPC session name.
42+
43+
remote_stack_size_bytes : int
44+
The stack size of the remote device, to be passed to
45+
tvm.contrib.hexagon.create_hexagon_session.
3646
"""
3747

3848
def __init__(
3949
self,
50+
launcher: "HexagonLauncherRPC",
4051
remote_kw: dict,
4152
session_name: str = "hexagon-rpc",
4253
remote_stack_size_bytes: int = 128 * 1024,
4354
):
55+
self._launcher = launcher
4456
self._session_name = session_name
4557
self._remote_stack_size_bytes = remote_stack_size_bytes
4658
self._remote_kw = remote_kw
@@ -74,6 +86,53 @@ def __enter__(self):
7486
def __exit__(self, exc_type, exc_value, exc_traceback):
7587
pass
7688

77-
def load_module(self, path: Union[str, pathlib.Path]):
78-
assert isinstance(path, (str, pathlib.Path)), "Invalid path type:" + str(type(path))
79-
return self._rpc.get_function("tvm.hexagon.load_module")(str(path))
89+
def upload(self, local_path: Union[str, pathlib.Path], remote_filename: str):
90+
"""Upload a local file to the remote workspace.
91+
92+
Parameters
93+
----------
94+
local_path : str or pathlib.Path
95+
Path to the local file to be copied.
96+
remote_filename : str
97+
Name of the file in the remote workspace.
98+
"""
99+
self._launcher.upload(local_path, remote_filename)
100+
101+
def load_module(self, module: Union[str, pathlib.Path, tvm.runtime.Module]):
102+
"""Load TVM module.
103+
104+
Parameters
105+
----------
106+
module : Union[str, pathlib.Path, tvm.runtime.Module]
107+
108+
The module to load. If `module` is a
109+
`tvm.runtime.Module`, it will be uploaded to the remote
110+
session and loaded.
111+
112+
If the object passed is a string or pathlib.Path, it must
113+
be either a bare file name (without any path components),
114+
or a full path in the remote system. If it is a file name,
115+
the file must already have been uploaded to the remote,
116+
and be placed in the remote workspace.
117+
118+
session : Session
119+
120+
Remote session. The session must be established (via __enter__)
121+
prior to calling this function.
122+
123+
Returns
124+
-------
125+
TVMModule :
126+
TVM module object.
127+
"""
128+
if isinstance(module, tvm.runtime.Module):
129+
with tempfile.TemporaryDirectory() as temp_dir:
130+
temp_dir = pathlib.Path(temp_dir)
131+
binary_name = "test_binary.so"
132+
binary_path = temp_dir / binary_name
133+
module.save(str(binary_path))
134+
self.upload(binary_path, binary_name)
135+
module = binary_name
136+
137+
assert isinstance(module, (str, pathlib.Path)), "Invalid path type:" + str(type(module))
138+
return self._rpc.get_function("tvm.hexagon.load_module")(str(module))

tests/python/contrib/test_hexagon/conftest.py

Lines changed: 123 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,15 @@
1919
values from testing parameters """
2020

2121
import os
22+
import random
23+
import socket
24+
from typing import Optional
25+
2226
import pytest
2327

2428
import tvm
25-
from tvm import rpc
29+
import tvm.rpc.tracker
30+
from tvm.contrib.hexagon.build import HexagonLauncher
2631

2732
HEXAGON_TOOLCHAIN = "HEXAGON_TOOLCHAIN"
2833
TVM_TRACKER_HOST = "TVM_TRACKER_HOST"
@@ -59,27 +64,135 @@ def requires_hexagon_toolchain(*args):
5964

6065

6166
@tvm.testing.fixture
62-
def android_serial_number() -> str:
63-
return os.getenv(ANDROID_SERIAL_NUMBER, default=None)
67+
def android_serial_number() -> Optional[str]:
68+
serial = os.getenv(ANDROID_SERIAL_NUMBER, default="")
69+
# Setting ANDROID_SERIAL_NUMBER to an empty string should be
70+
# equivalent to having it unset.
71+
if not serial.strip():
72+
serial = None
73+
return serial
74+
75+
76+
# NOTE on server ports:
77+
# These tests use different port numbers for the RPC server (7070 + ...).
78+
# The reason is that an RPC session cannot be gracefully closed without
79+
# triggering TIME_WAIT state on the server socket. This prevents another
80+
# server to bind to the same port until the wait time elapses.
81+
82+
listen_port_min = 2000 # Well above the privileged ports (1024 or lower)
83+
listen_port_max = 9000 # Below the search range end (port_end=9199) of RPC server
84+
previous_port = None
85+
86+
87+
def get_free_port():
88+
# https://stackoverflow.com/a/52872579/2689797
89+
def is_port_in_use(port: int) -> bool:
90+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
91+
return s.connect_ex(("localhost", port)) == 0
92+
93+
global previous_port
94+
if previous_port is None:
95+
port = random.randint(listen_port_min, listen_port_max)
96+
else:
97+
port = previous_port + 1
98+
99+
while is_port_in_use(port):
100+
port = port + 1 if port < listen_port_max else listen_port_min
101+
102+
previous_port = port
103+
return port
64104

65105

66-
@tvm.testing.fixture
67-
def tvm_tracker_host() -> str:
68-
return os.getenv(TVM_TRACKER_HOST, default=None)
106+
@pytest.fixture(scope="session")
107+
def _tracker_info() -> (str, int):
108+
env_tracker_host = os.getenv(TVM_TRACKER_HOST, default="")
109+
env_tracker_port = os.getenv(TVM_TRACKER_PORT, default="")
110+
111+
if env_tracker_host or env_tracker_port:
112+
# A tracker is already running, and we should connect to it
113+
# when running tests.
114+
assert env_tracker_host, "TVM_TRACKER_PORT is defined, but TVM_TRACKER_HOST is not"
115+
assert env_tracker_port, "TVM_TRACKER_HOST is defined, but TVM_TRACKER_PORT is not"
116+
env_tracker_port = int(env_tracker_port)
117+
118+
try:
119+
tvm.rpc.connect_tracker(env_tracker_host, env_tracker_port)
120+
except RuntimeError as exc:
121+
message = (
122+
"Could not connect to external tracker "
123+
"specified by $TVM_TRACKER_HOST and $TVM_TRACKER_PORT "
124+
f"({env_tracker_host}:{env_tracker_port})"
125+
)
126+
raise RuntimeError(message) from exc
127+
128+
yield (env_tracker_host, env_tracker_port)
129+
130+
else:
131+
# No tracker is provided to the tests, so we should start one
132+
# for the tests to use.
133+
tracker = tvm.rpc.tracker.Tracker("127.0.0.1", get_free_port())
134+
try:
135+
yield (tracker.host, tracker.port)
136+
finally:
137+
tracker.terminate()
138+
139+
140+
@pytest.fixture(scope="session")
141+
def tvm_tracker_host(_tracker_info) -> str:
142+
host, port = _tracker_info
143+
return host
144+
145+
146+
@pytest.fixture(scope="session")
147+
def tvm_tracker_port(_tracker_info) -> int:
148+
host, port = _tracker_info
149+
return port
69150

70151

71152
@tvm.testing.fixture
72-
def tvm_tracker_port() -> int:
73-
port = os.getenv(TVM_TRACKER_PORT, default=None)
74-
port = int(port) if port else None
75-
return port
153+
def rpc_server_port() -> int:
154+
return get_free_port()
76155

77156

78157
@tvm.testing.fixture
79158
def adb_server_socket() -> str:
80159
return os.getenv(ADB_SERVER_SOCKET, default="tcp:5037")
81160

82161

162+
@tvm.testing.fixture
163+
def hexagon_launcher(request, android_serial_number, rpc_server_port, adb_server_socket):
164+
if android_serial_number is None:
165+
yield None
166+
else:
167+
# Requesting these fixtures sets up a local tracker, if one
168+
# hasn't been provided to us. Delaying the evaluation of
169+
# these fixtures avoids starting a tracker unless necessary.
170+
tvm_tracker_host = request.getfixturevalue("tvm_tracker_host")
171+
tvm_tracker_port = request.getfixturevalue("tvm_tracker_port")
172+
173+
rpc_info = {
174+
"rpc_tracker_host": tvm_tracker_host,
175+
"rpc_tracker_port": tvm_tracker_port,
176+
"rpc_server_port": rpc_server_port,
177+
"adb_server_socket": adb_server_socket,
178+
}
179+
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
180+
launcher.start_server()
181+
try:
182+
yield launcher
183+
finally:
184+
launcher.stop_server()
185+
186+
187+
@tvm.testing.fixture
188+
def hexagon_session(hexagon_launcher):
189+
if hexagon_launcher is None:
190+
yield None
191+
else:
192+
with hexagon_launcher.start_session() as session:
193+
yield session
194+
195+
83196
# If the execution aborts while an RPC server is running, the python
84197
# code that is supposed to shut it dowm will never execute. This will
85198
# keep pytest from terminating (indefinitely), so add a cleanup

tests/python/contrib/test_hexagon/test_cache_read_write.py

Lines changed: 16 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ def intrin_func(ins, outs):
6363

6464

6565
@requires_hexagon_toolchain
66-
def test_cache_read_write(
67-
android_serial_number, tvm_tracker_host, tvm_tracker_port, adb_server_socket
68-
):
66+
def test_cache_read_write(hexagon_session):
6967
size = 128
7068
outer_shape = (size,)
7169
factor = 16
@@ -105,37 +103,24 @@ def test_cache_read_write(
105103
func = tvm.build(
106104
s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
107105
)
108-
temp = utils.tempdir()
109-
dso_binary = "test_binary.so"
110-
dso_binary_path = temp.relpath(dso_binary)
111-
func.save(dso_binary_path)
112106

113-
if not android_serial_number:
107+
if hexagon_session is None:
114108
pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
115109

116-
rpc_info = {
117-
"rpc_tracker_host": tvm_tracker_host,
118-
"rpc_tracker_port": tvm_tracker_port,
119-
"rpc_server_port": 7070,
120-
"adb_server_socket": adb_server_socket,
121-
}
122-
launcher = HexagonLauncher(serial_number=android_serial_number, rpc_info=rpc_info)
123-
launcher.upload(dso_binary_path, dso_binary)
124-
launcher.start_server()
125-
126-
with launcher.start_session() as sess:
127-
mod = launcher.load_module(dso_binary, sess)
128-
xt = tvm.nd.array(
129-
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
130-
)
131-
yt = tvm.nd.array(
132-
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
133-
)
134-
zt = tvm.nd.array(
135-
np.random.randint(-128, high=127, size=size, dtype=x.dtype), device=sess.device
136-
)
137-
mod["dmacpy"](xt, yt, zt)
138-
launcher.stop_server()
110+
mod = hexagon_session.load_module(func)
111+
xt = tvm.nd.array(
112+
np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
113+
device=hexagon_session.device,
114+
)
115+
yt = tvm.nd.array(
116+
np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
117+
device=hexagon_session.device,
118+
)
119+
zt = tvm.nd.array(
120+
np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
121+
device=hexagon_session.device,
122+
)
123+
mod["dmacpy"](xt, yt, zt)
139124

140125
ref = xt.numpy() + yt.numpy()
141126
np.testing.assert_equal(zt.numpy(), ref)

0 commit comments

Comments
 (0)