Skip to content

Commit 55071db

Browse files
committed
EEMBC emulation test harness
1 parent 05ff8d1 commit 55071db

File tree

1 file changed

+125
-3
lines changed

1 file changed

+125
-3
lines changed

tests/micro/common/test_mlperftiny.py

Lines changed: 125 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import os
18+
import re
19+
import logging
1820
from urllib.parse import urlparse
21+
import struct
1922

2023
import pytest
2124
import tensorflow as tf
@@ -37,19 +40,22 @@
3740
create_header_file,
3841
mlf_extract_workspace_size_bytes,
3942
)
43+
from tvm.micro.testing.utils import aot_transport_find_message
4044

4145
MLPERF_TINY_MODELS = {
4246
"kws": {
4347
"name": "Keyword Spotting",
4448
"index": 1,
4549
"url": "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/keyword_spotting/trained_models/kws_ref_model.tflite",
4650
"sample": "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy",
51+
"sample_label": 6,
4752
},
4853
"vww": {
4954
"name": "Visual Wake Words",
5055
"index": 2,
5156
"url": "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite",
5257
"sample": "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/visual_wake_word_int8_1.npy",
58+
"sample_label": 1,
5359
},
5460
# Note: The reason we use quantized model with float32 I/O is
5561
# that TVM does not handle the int8 I/O correctly and accuracy
@@ -67,9 +73,14 @@
6773
"index": 4,
6874
"url": "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/image_classification/trained_models/pretrainedResnet_quant.tflite",
6975
"sample": "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy",
76+
"sample_label": 0,
7077
},
7178
}
7279

80+
EEMBC_READY_MSG = "m-ready"
81+
EEMBC_RESULT_MSG = "m-results"
82+
EEMBC_NAME_MSG = "m-name"
83+
7384

7485
def mlperftiny_get_module(model_name: str):
7586
model_url = MLPERF_TINY_MODELS[model_name]["url"]
@@ -114,11 +125,17 @@ def mlperftiny_get_module(model_name: str):
114125
return relay_mod, params, model_info
115126

116127

117-
def get_test_data(model_name: str) -> list:
128+
def get_test_data(model_name: str, project_type: str) -> list:
118129
sample_url = MLPERF_TINY_MODELS[model_name]["sample"]
119130
url = urlparse(sample_url)
120131
sample_path = download_testdata(sample_url, os.path.basename(url.path), module="data")
121-
return [np.load(sample_path)]
132+
sample = np.load(sample_path)
133+
if project_type == "mlperftiny" and model_name != "ad":
134+
sample = sample.astype(np.uint8)
135+
sample_label = None
136+
if "sample_label" in MLPERF_TINY_MODELS[model_name].keys():
137+
sample_label = MLPERF_TINY_MODELS[model_name]["sample_label"]
138+
return [sample], [sample_label]
122139

123140

124141
def predict_ad_labels_aot(session, aot_executor, input_data, runs_per_sample=1):
@@ -148,6 +165,91 @@ def predict_ad_labels_aot(session, aot_executor, input_data, runs_per_sample=1):
148165
yield np.mean(errors), np.median(slice_runtimes)
149166

150167

168+
def _eembc_get_name(device_transport) -> str:
169+
"""Get device name."""
170+
device_transport.write(b"name%", timeout_sec=5)
171+
name_message = aot_transport_find_message(device_transport, EEMBC_NAME_MSG, timeout_sec=5)
172+
m = re.search(r"\[([A-Za-z0-9_]+)\]", name_message)
173+
return m.group(1)
174+
175+
176+
def _eembc_infer(transport, warmup: int, infer: int, timeout: int):
177+
"""Send EEMBC infer command."""
178+
cmd = f"infer {warmup} {infer}%".encode("UTF-8")
179+
transport.write(cmd, timeout_sec=timeout)
180+
181+
182+
def _eembc_write_sample(device_transport, data: list, timeout: int):
183+
"""Write a sample with EEMBC compatible format."""
184+
cmd = f"db load {len(data)}%".encode("UTF-8")
185+
logging.debug(f"transport write: {cmd}")
186+
device_transport.write(cmd, timeout)
187+
aot_transport_find_message(device_transport, EEMBC_READY_MSG, timeout_sec=timeout)
188+
for item in data:
189+
if isinstance(item, float):
190+
ba = bytearray(struct.pack("<f", item))
191+
hex_array = ["%02x" % b for b in ba]
192+
else:
193+
hex_val = format(item, "x")
194+
# make sure hex value is in HH format
195+
if len(hex_val) < 2:
196+
hex_val = "0" + hex_val
197+
elif len(hex_val) > 2:
198+
raise ValueError(f"Hex value not in HH format: {hex_val}")
199+
hex_array = [hex_val]
200+
201+
for hex_val in hex_array:
202+
cmd = f"db {hex_val}%".encode("UTF-8")
203+
logging.debug(f"transport write: {cmd}")
204+
device_transport.write(cmd, timeout)
205+
aot_transport_find_message(device_transport, EEMBC_READY_MSG, timeout_sec=timeout)
206+
207+
208+
def _eembc_test_dataset(device_transport, dataset, timeout):
209+
"""Run test dataset compatible with EEMBC format."""
210+
num_correct = 0
211+
total = 0
212+
samples, labels = dataset
213+
i_counter = 0
214+
for sample in samples:
215+
label = labels[i_counter]
216+
logging.info(f"Writing Sample {i_counter}")
217+
_eembc_write_sample(device_transport, sample.flatten().tolist(), timeout)
218+
_eembc_infer(device_transport, 1, 0, timeout)
219+
results = aot_transport_find_message(
220+
device_transport, EEMBC_RESULT_MSG, timeout_sec=timeout
221+
)
222+
223+
m = re.search(r"m\-results\-\[([A-Za-z0-9_,.]+)\]", results)
224+
results = m.group(1).split(",")
225+
results_val = [float(x) for x in results]
226+
results_val = np.array(results_val)
227+
228+
if np.argmax(results_val) == label:
229+
num_correct += 1
230+
total += 1
231+
i_counter += 1
232+
return float(num_correct / total)
233+
234+
235+
def _eembc_test_dataset_ad(device_transport, dataset, timeout):
236+
"""Run test dataset compatible with EEMBC format for AD model."""
237+
samples, _ = dataset
238+
result_output = np.zeros(samples[0].shape[0])
239+
240+
for slice in range(0, 40):
241+
_eembc_write_sample(device_transport, samples[0][slice, :].flatten().tolist(), timeout)
242+
_eembc_infer(device_transport, 1, 0, timeout)
243+
results = aot_transport_find_message(
244+
device_transport, EEMBC_RESULT_MSG, timeout_sec=timeout
245+
)
246+
m = re.search(r"m\-results\-\[([A-Za-z0-9_,.]+)\]", results)
247+
results = m.group(1).split(",")
248+
results_val = [float(x) for x in results]
249+
result_output[slice] = np.array(results_val)
250+
return np.average(result_output)
251+
252+
151253
@pytest.mark.parametrize("model_name", ["kws", "vww", "ad", "ic"])
152254
@pytest.mark.parametrize("project_type", ["host_driven", "mlperftiny"])
153255
@tvm.testing.requires_micro
@@ -177,6 +279,7 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
177279
else:
178280
predictor = predict_labels_aot
179281

282+
samples, labels = get_test_data(model_name, project_type)
180283
if project_type == "host_driven":
181284
with create_aot_session(
182285
platform,
@@ -201,7 +304,7 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
201304
args = {
202305
"session": session,
203306
"aot_executor": aot_executor,
204-
"input_data": get_test_data(model_name),
307+
"input_data": samples,
205308
"runs_per_sample": 10,
206309
}
207310
predicted_labels, runtimes = zip(*predictor(**args))
@@ -267,6 +370,10 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
267370
for i in range(len(input_shape)):
268371
input_total_size *= input_shape[i]
269372

373+
# float input
374+
if model_name == "ad":
375+
input_total_size *= 4
376+
270377
template_project_path = pathlib.Path(tvm.micro.get_microtvm_template_projects(platform))
271378
project_options.update(
272379
{
@@ -295,6 +402,21 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
295402
template_project_path, workspace_dir / "project", model_tar_path, project_options
296403
)
297404
project.build()
405+
project.flash()
406+
with project.transport() as transport:
407+
aot_transport_find_message(transport, EEMBC_READY_MSG, timeout_sec=200)
408+
print(f"Testing {model_name} on {_eembc_get_name(transport)} using EEMBC.")
409+
assert _eembc_get_name(transport) == "microTVM"
410+
if model_name != "ad":
411+
accuracy = _eembc_test_dataset(transport, [samples, labels], 100)
412+
print(f"Model {model_name} accuracy: {accuracy}")
413+
else:
414+
mean_error = _eembc_test_dataset_ad(transport, [samples, None], 100)
415+
print(
416+
f"""Model {model_name} mean error: {mean_error}.
417+
Note that this is not the final accuracy number.
418+
To calculate that, you need to use sklearn.metrics.roc_auc_score function."""
419+
)
298420

299421

300422
if __name__ == "__main__":

0 commit comments

Comments
 (0)