Skip to content

Commit 47a2c38

Browse files
committed
Fix zephyr/test_zephyr_armv7m
1 parent d13e2b6 commit 47a2c38

File tree

3 files changed

+125
-137
lines changed

3 files changed

+125
-137
lines changed

tests/micro/zephyr/test_utils.py

Lines changed: 80 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import pathlib
2121
import tarfile
2222
import tempfile
23-
from typing import Union
23+
import logging
2424

2525
import numpy as np
2626

@@ -31,12 +31,19 @@
3131

3232
import tvm.micro
3333
from tvm.micro import export_model_library_format
34-
from tvm.micro.testing import mlf_extract_workspace_size_bytes
34+
from tvm.micro.model_library_format import generate_c_interface_header
35+
from tvm.micro.testing import (
36+
mlf_extract_workspace_size_bytes,
37+
aot_transport_init_wait,
38+
aot_transport_find_message,
39+
)
3540

3641
TEMPLATE_PROJECT_DIR = pathlib.Path(tvm.micro.get_microtvm_template_projects("zephyr"))
3742

3843
BOARDS = TEMPLATE_PROJECT_DIR / "boards.json"
3944

45+
_LOG = logging.getLogger(__name__)
46+
4047

4148
def zephyr_boards() -> dict:
4249
"""Returns a dict mapping board to target model"""
@@ -68,29 +75,32 @@ def has_fpu(board: str):
6875
return board in fpu_boards
6976

7077

71-
def build_project(temp_dir, zephyr_board, west_cmd, mod, build_config, extra_files_tar=None):
78+
def build_project(
79+
temp_dir, zephyr_board, west_cmd, mod, build_config, simd=False, extra_files_tar=None
80+
):
7281
project_dir = temp_dir / "project"
7382

7483
with tempfile.TemporaryDirectory() as tar_temp_dir:
7584
model_tar_path = pathlib.Path(tar_temp_dir) / "model.tar"
7685
export_model_library_format(mod, model_tar_path)
7786

7887
workspace_size = mlf_extract_workspace_size_bytes(model_tar_path)
88+
project_options = {
89+
"extra_files_tar": extra_files_tar,
90+
"project_type": "aot_demo",
91+
"west_cmd": west_cmd,
92+
"verbose": bool(build_config.get("debug")),
93+
"zephyr_board": zephyr_board,
94+
"compile_definitions": [
95+
# TODO(mehrdadh): It fails without offset.
96+
f"-DWORKSPACE_SIZE={workspace_size + 128}",
97+
],
98+
}
99+
if simd:
100+
project_options["config_main_stack_size"] = 1536
101+
79102
project = tvm.micro.project.generate_project_from_mlf(
80-
str(TEMPLATE_PROJECT_DIR),
81-
project_dir,
82-
model_tar_path,
83-
{
84-
"extra_files_tar": extra_files_tar,
85-
"project_type": "aot_demo",
86-
"west_cmd": west_cmd,
87-
"verbose": bool(build_config.get("debug")),
88-
"zephyr_board": zephyr_board,
89-
"compile_definitions": [
90-
# TODO(mehrdadh): It fails without offset.
91-
f"-DWORKSPACE_SIZE={workspace_size + 128}",
92-
],
93-
},
103+
str(TEMPLATE_PROJECT_DIR), project_dir, model_tar_path, project_options
94104
)
95105
project.build()
96106
return project, project_dir
@@ -167,3 +177,56 @@ def loadCMSIS(temp_dir):
167177
urlretrieve(file_url, f"{temp_path}/{file_name}")
168178
except HTTPError as e:
169179
print(f"Failed to download {file_url}: {e}")
180+
181+
182+
def run_model(project):
183+
project.flash()
184+
185+
with project.transport() as transport:
186+
aot_transport_init_wait(transport)
187+
transport.write(b"infer%", timeout_sec=5)
188+
result_line = aot_transport_find_message(transport, "result", timeout_sec=60)
189+
190+
result_line = result_line.strip("\n")
191+
result_line = result_line.split(":")
192+
result = int(result_line[1])
193+
time = int(result_line[2])
194+
_LOG.info(f"Result: {result}\ttime: {time} ms")
195+
196+
return result, time
197+
198+
199+
def generate_project(
200+
temp_dir, board, west_cmd, lowered, build_config, sample, output_shape, output_type, load_cmsis
201+
):
202+
with tempfile.NamedTemporaryFile() as tar_temp_file:
203+
with tarfile.open(tar_temp_file.name, "w:gz") as tf:
204+
with tempfile.TemporaryDirectory() as tar_temp_dir:
205+
model_files_path = os.path.join(tar_temp_dir, "include")
206+
os.mkdir(model_files_path)
207+
if load_cmsis:
208+
loadCMSIS(model_files_path)
209+
tf.add(
210+
model_files_path, arcname=os.path.relpath(model_files_path, tar_temp_dir)
211+
)
212+
header_path = generate_c_interface_header(
213+
lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
214+
)
215+
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
216+
217+
create_header_file("input_data", sample, "include", tf)
218+
create_header_file(
219+
"output_data", np.zeros(shape=output_shape, dtype=output_type), "include", tf
220+
)
221+
222+
project, project_dir = build_project(
223+
temp_dir,
224+
board,
225+
west_cmd,
226+
lowered,
227+
build_config,
228+
simd=load_cmsis,
229+
extra_files_tar=tar_temp_file.name,
230+
)
231+
232+
return project, project_dir

tests/micro/zephyr/test_zephyr_aot.py

Lines changed: 16 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from tvm.relay.backend import Executor, Runtime
3434

3535
from tvm.contrib.download import download_testdata
36-
from tvm.micro.model_library_format import generate_c_interface_header
3736
from tvm.micro.testing import aot_transport_init_wait, aot_transport_find_message
3837

3938
import test_utils
@@ -78,41 +77,19 @@ def test_tflite(temp_dir, board, west_cmd, tvm_debug):
7877
sample_path = download_testdata(sample_url, "keyword_spotting_int8_6.pyc.npy", module="data")
7978
sample = np.load(sample_path)
8079

81-
with tempfile.NamedTemporaryFile() as tar_temp_file:
82-
with tarfile.open(tar_temp_file.name, "w:gz") as tf:
83-
with tempfile.TemporaryDirectory() as tar_temp_dir:
84-
model_files_path = os.path.join(tar_temp_dir, "include")
85-
os.mkdir(model_files_path)
86-
header_path = generate_c_interface_header(
87-
lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
88-
)
89-
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
90-
91-
test_utils.create_header_file("input_data", sample, "include", tf)
92-
test_utils.create_header_file(
93-
"output_data", np.zeros(shape=output_shape, dtype="int8"), "include", tf
94-
)
95-
96-
project, _ = test_utils.build_project(
97-
temp_dir,
98-
board,
99-
west_cmd,
100-
lowered,
101-
build_config,
102-
extra_files_tar=tar_temp_file.name,
103-
)
80+
project, _ = test_utils.generate_project(
81+
temp_dir,
82+
board,
83+
west_cmd,
84+
lowered,
85+
build_config,
86+
sample,
87+
output_shape,
88+
"int8",
89+
load_cmsis=False,
90+
)
10491

105-
project.flash()
106-
with project.transport() as transport:
107-
aot_transport_init_wait(transport)
108-
transport.write(b"infer%", timeout_sec=5)
109-
result_line = aot_transport_find_message(transport, "result", timeout_sec=60)
110-
111-
result_line = result_line.strip("\n")
112-
result_line = result_line.split(":")
113-
result = int(result_line[1])
114-
time = int(result_line[2])
115-
logging.info(f"Result: {result}\ttime: {time} ms")
92+
result, time = test_utils.run_model(project)
11693
assert result == 6
11794

11895

@@ -140,31 +117,10 @@ def test_qemu_make_fail(temp_dir, board, west_cmd, tvm_debug):
140117
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
141118
lowered = relay.build(ir_mod, target, executor=executor, runtime=runtime)
142119

143-
# Generate input/output header files
144-
with tempfile.NamedTemporaryFile() as tar_temp_file:
145-
with tarfile.open(tar_temp_file.name, "w:gz") as tf:
146-
with tempfile.TemporaryDirectory() as tar_temp_dir:
147-
model_files_path = os.path.join(tar_temp_dir, "include")
148-
os.mkdir(model_files_path)
149-
header_path = generate_c_interface_header(
150-
lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path
151-
)
152-
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
153-
test_utils.create_header_file(
154-
"input_data", np.zeros(shape=shape, dtype=dtype), "include", tf
155-
)
156-
test_utils.create_header_file(
157-
"output_data", np.zeros(shape=shape, dtype=dtype), "include", tf
158-
)
159-
160-
project, project_dir = test_utils.build_project(
161-
temp_dir,
162-
board,
163-
west_cmd,
164-
lowered,
165-
build_config,
166-
extra_files_tar=tar_temp_file.name,
167-
)
120+
sample = np.zeros(shape=shape, dtype=dtype)
121+
project, project_dir = test_utils.generate_project(
122+
temp_dir, board, west_cmd, lowered, build_config, sample, shape, dtype, load_cmsis=False
123+
)
168124

169125
file_path = (
170126
pathlib.Path(project_dir) / "build" / "zephyr" / "CMakeFiles" / "run.dir" / "build.make"

tests/micro/zephyr/test_zephyr_armv7m.py

Lines changed: 29 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
from json import load
1819
import logging
1920
import os
2021
import pathlib
@@ -32,8 +33,6 @@
3233
from tvm import relay
3334

3435
from tvm.contrib.download import download_testdata
35-
from tvm.micro.model_library_format import generate_c_interface_header
36-
from tvm.micro.testing import aot_transport_init_wait, aot_transport_find_message
3736
from tvm.relay.backend import Executor, Runtime
3837

3938
import test_utils
@@ -103,59 +102,6 @@ def _apply_desired_layout_no_simd(relay_mod):
103102
return seq(relay_mod)
104103

105104

106-
def _generate_project(temp_dir, board, west_cmd, lowered, build_config, sample, output_shape):
107-
108-
with tempfile.NamedTemporaryFile() as tar_temp_file:
109-
with tarfile.open(tar_temp_file.name, "w:gz") as tf:
110-
with tempfile.TemporaryDirectory() as tar_temp_dir:
111-
model_files_path = os.path.join(tar_temp_dir, "include")
112-
os.mkdir(model_files_path)
113-
test_utils.loadCMSIS(model_files_path)
114-
tf.add(model_files_path, arcname=os.path.relpath(model_files_path, tar_temp_dir))
115-
header_path = generate_c_interface_header(
116-
lowered.libmod_name, ["input_1"], ["output"], [], model_files_path
117-
)
118-
tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir))
119-
120-
test_utils.create_header_file("input_data", sample, "include", tf)
121-
test_utils.create_header_file(
122-
"output_data", np.zeros(shape=output_shape, dtype="float32"), "include", tf
123-
)
124-
125-
project, _ = test_utils.build_project(
126-
temp_dir,
127-
board,
128-
west_cmd,
129-
lowered,
130-
build_config,
131-
extra_files_tar=tar_temp_file.name,
132-
)
133-
134-
return project
135-
136-
137-
def _run_model(temp_dir, board, west_cmd, lowered, build_config, sample, output_shape):
138-
139-
project = _generate_project(
140-
temp_dir, board, west_cmd, lowered, build_config, sample, output_shape
141-
)
142-
143-
project.flash()
144-
145-
with project.transport() as transport:
146-
aot_transport_init_wait(transport)
147-
transport.write(b"infer%", timeout_sec=5)
148-
result_line = aot_transport_find_message(transport, "result", timeout_sec=60)
149-
150-
result_line = result_line.strip("\n")
151-
result_line = result_line.split(":")
152-
result = int(result_line[1])
153-
time = int(result_line[2])
154-
_LOG.info(f"Result: {result}\ttime: {time} ms")
155-
156-
return result, time
157-
158-
159105
@tvm.testing.requires_micro
160106
def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug):
161107
"""Testing a ARM v7m SIMD extension."""
@@ -165,6 +111,7 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug):
165111
"stm32f746xx_disco",
166112
"nucleo_f746zg",
167113
"nucleo_l4r5zi",
114+
"nrf5340dk_nrf5340_cpuapp",
168115
]:
169116
pytest.skip(msg="Platform does not support ARM v7m SIMD extenion.")
170117

@@ -196,16 +143,38 @@ def test_armv7m_intrinsic(temp_dir, board, west_cmd, tvm_debug):
196143
os.makedirs(temp_dir_no_simd, exist_ok=True)
197144

198145
with tvm.transform.PassContext(opt_level=3, config={"tir.disable_vectorize": True}):
199-
lowered_simd = relay.build(relay_mod_simd, target_simd, params=params)
146+
lowered_simd = relay.build(
147+
relay_mod_simd, target_simd, params=params, runtime=runtime, executor=executor
148+
)
200149
lowered_no_simd = relay.build(
201150
relay_mod_no_simd, target, params=params, runtime=runtime, executor=executor
202151
)
203-
result_simd, time_simd = _run_model(
204-
temp_dir_simd, board, west_cmd, lowered_simd, build_config, sample, output_shape
152+
153+
simd_project, _ = test_utils.generate_project(
154+
temp_dir_simd,
155+
board,
156+
west_cmd,
157+
lowered_simd,
158+
build_config,
159+
sample,
160+
output_shape,
161+
"float32",
162+
load_cmsis=True,
205163
)
206-
result_no_simd, time_no_simd = _run_model(
207-
temp_dir_no_simd, board, west_cmd, lowered_no_simd, build_config, sample, output_shape
164+
result_simd, time_simd = test_utils.run_model(simd_project)
165+
166+
no_simd_project, _ = test_utils.generate_project(
167+
temp_dir_no_simd,
168+
board,
169+
west_cmd,
170+
lowered_no_simd,
171+
build_config,
172+
sample,
173+
output_shape,
174+
"float32",
175+
load_cmsis=False,
208176
)
177+
result_no_simd, time_no_simd = test_utils.run_model(no_simd_project)
209178

210179
assert result_no_simd == result_simd == 2
211180

0 commit comments

Comments
 (0)