1515# specific language governing permissions and limitations
1616# under the License.
1717import os
18+ import re
19+ import logging
1820from urllib .parse import urlparse
21+ import struct
1922
2023import pytest
2124import tensorflow as tf
3740 create_header_file ,
3841 mlf_extract_workspace_size_bytes ,
3942)
43+ from tvm .micro .testing .utils import aot_transport_find_message
4044
4145MLPERF_TINY_MODELS = {
4246 "kws" : {
4347 "name" : "Keyword Spotting" ,
4448 "index" : 1 ,
4549 "url" : "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/keyword_spotting/trained_models/kws_ref_model.tflite" ,
4650 "sample" : "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/keyword_spotting_int8_6.pyc.npy" ,
51+ "sample_label" : 6 ,
4752 },
4853 "vww" : {
4954 "name" : "Visual Wake Words" ,
5055 "index" : 2 ,
5156 "url" : "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/visual_wake_words/trained_models/vww_96_int8.tflite" ,
5257 "sample" : "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/visual_wake_word_int8_1.npy" ,
58+ "sample_label" : 1 ,
5359 },
5460 # Note: The reason we use quantized model with float32 I/O is
5561 # that TVM does not handle the int8 I/O correctly and accuracy
6773 "index" : 4 ,
6874 "url" : "https://github.com/mlcommons/tiny/raw/bceb91c5ad2e2deb295547d81505721d3a87d578/benchmark/training/image_classification/trained_models/pretrainedResnet_quant.tflite" ,
6975 "sample" : "https://github.com/tlc-pack/web-data/raw/main/testdata/microTVM/data/image_classification_int8_0.npy" ,
76+ "sample_label" : 0 ,
7077 },
7178}
7279
80+ EEMBC_READY_MSG = "m-ready"
81+ EEMBC_RESULT_MSG = "m-results"
82+ EEMBC_NAME_MSG = "m-name"
83+
7384
7485def mlperftiny_get_module (model_name : str ):
7586 model_url = MLPERF_TINY_MODELS [model_name ]["url" ]
@@ -114,11 +125,17 @@ def mlperftiny_get_module(model_name: str):
114125 return relay_mod , params , model_info
115126
116127
117- def get_test_data (model_name : str ) -> list :
128+ def get_test_data (model_name : str , project_type : str ) -> list :
118129 sample_url = MLPERF_TINY_MODELS [model_name ]["sample" ]
119130 url = urlparse (sample_url )
120131 sample_path = download_testdata (sample_url , os .path .basename (url .path ), module = "data" )
121- return [np .load (sample_path )]
132+ sample = np .load (sample_path )
133+ if project_type == "mlperftiny" and model_name != "ad" :
134+ sample = sample .astype (np .uint8 )
135+ sample_label = None
136+ if "sample_label" in MLPERF_TINY_MODELS [model_name ].keys ():
137+ sample_label = MLPERF_TINY_MODELS [model_name ]["sample_label" ]
138+ return [sample ], [sample_label ]
122139
123140
124141def predict_ad_labels_aot (session , aot_executor , input_data , runs_per_sample = 1 ):
@@ -148,6 +165,91 @@ def predict_ad_labels_aot(session, aot_executor, input_data, runs_per_sample=1):
148165 yield np .mean (errors ), np .median (slice_runtimes )
149166
150167
168+ def _eembc_get_name (device_transport ) -> str :
169+ """Get device name."""
170+ device_transport .write (b"name%" , timeout_sec = 5 )
171+ name_message = aot_transport_find_message (device_transport , EEMBC_NAME_MSG , timeout_sec = 5 )
172+ m = re .search (r"\[([A-Za-z0-9_]+)\]" , name_message )
173+ return m .group (1 )
174+
175+
176+ def _eembc_infer (transport , warmup : int , infer : int , timeout : int ):
177+ """Send EEMBC infer command."""
178+ cmd = f"infer { warmup } { infer } %" .encode ("UTF-8" )
179+ transport .write (cmd , timeout_sec = timeout )
180+
181+
182+ def _eembc_write_sample (device_transport , data : list , timeout : int ):
183+ """Write a sample with EEMBC compatible format."""
184+ cmd = f"db load { len (data )} %" .encode ("UTF-8" )
185+ logging .debug (f"transport write: { cmd } " )
186+ device_transport .write (cmd , timeout )
187+ aot_transport_find_message (device_transport , EEMBC_READY_MSG , timeout_sec = timeout )
188+ for item in data :
189+ if isinstance (item , float ):
190+ ba = bytearray (struct .pack ("<f" , item ))
191+ hex_array = ["%02x" % b for b in ba ]
192+ else :
193+ hex_val = format (item , "x" )
194+ # make sure hex value is in HH format
195+ if len (hex_val ) < 2 :
196+ hex_val = "0" + hex_val
197+ elif len (hex_val ) > 2 :
198+ raise ValueError (f"Hex value not in HH format: { hex_val } " )
199+ hex_array = [hex_val ]
200+
201+ for hex_val in hex_array :
202+ cmd = f"db { hex_val } %" .encode ("UTF-8" )
203+ logging .debug (f"transport write: { cmd } " )
204+ device_transport .write (cmd , timeout )
205+ aot_transport_find_message (device_transport , EEMBC_READY_MSG , timeout_sec = timeout )
206+
207+
208+ def _eembc_test_dataset (device_transport , dataset , timeout ):
209+ """Run test dataset compatible with EEMBC format."""
210+ num_correct = 0
211+ total = 0
212+ samples , labels = dataset
213+ i_counter = 0
214+ for sample in samples :
215+ label = labels [i_counter ]
216+ logging .info (f"Writing Sample { i_counter } " )
217+ _eembc_write_sample (device_transport , sample .flatten ().tolist (), timeout )
218+ _eembc_infer (device_transport , 1 , 0 , timeout )
219+ results = aot_transport_find_message (
220+ device_transport , EEMBC_RESULT_MSG , timeout_sec = timeout
221+ )
222+
223+ m = re .search (r"m\-results\-\[([A-Za-z0-9_,.]+)\]" , results )
224+ results = m .group (1 ).split ("," )
225+ results_val = [float (x ) for x in results ]
226+ results_val = np .array (results_val )
227+
228+ if np .argmax (results_val ) == label :
229+ num_correct += 1
230+ total += 1
231+ i_counter += 1
232+ return float (num_correct / total )
233+
234+
235+ def _eembc_test_dataset_ad (device_transport , dataset , timeout ):
236+ """Run test dataset compatible with EEMBC format for AD model."""
237+ samples , _ = dataset
238+ result_output = np .zeros (samples [0 ].shape [0 ])
239+
240+ for slice in range (0 , 40 ):
241+ _eembc_write_sample (device_transport , samples [0 ][slice , :].flatten ().tolist (), timeout )
242+ _eembc_infer (device_transport , 1 , 0 , timeout )
243+ results = aot_transport_find_message (
244+ device_transport , EEMBC_RESULT_MSG , timeout_sec = timeout
245+ )
246+ m = re .search (r"m\-results\-\[([A-Za-z0-9_,.]+)\]" , results )
247+ results = m .group (1 ).split ("," )
248+ results_val = [float (x ) for x in results ]
249+ result_output [slice ] = np .array (results_val )
250+ return np .average (result_output )
251+
252+
151253@pytest .mark .parametrize ("model_name" , ["kws" , "vww" , "ad" , "ic" ])
152254@pytest .mark .parametrize ("project_type" , ["host_driven" , "mlperftiny" ])
153255@tvm .testing .requires_micro
@@ -177,6 +279,7 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
177279 else :
178280 predictor = predict_labels_aot
179281
282+ samples , labels = get_test_data (model_name , project_type )
180283 if project_type == "host_driven" :
181284 with create_aot_session (
182285 platform ,
@@ -201,7 +304,7 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
201304 args = {
202305 "session" : session ,
203306 "aot_executor" : aot_executor ,
204- "input_data" : get_test_data ( model_name ) ,
307+ "input_data" : samples ,
205308 "runs_per_sample" : 10 ,
206309 }
207310 predicted_labels , runtimes = zip (* predictor (** args ))
@@ -267,6 +370,10 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
267370 for i in range (len (input_shape )):
268371 input_total_size *= input_shape [i ]
269372
373+ # float input
374+ if model_name == "ad" :
375+ input_total_size *= 4
376+
270377 template_project_path = pathlib .Path (tvm .micro .get_microtvm_template_projects (platform ))
271378 project_options .update (
272379 {
@@ -295,6 +402,21 @@ def test_mlperftiny_models(platform, board, workspace_dir, serial_number, model_
295402 template_project_path , workspace_dir / "project" , model_tar_path , project_options
296403 )
297404 project .build ()
405+ project .flash ()
406+ with project .transport () as transport :
407+ aot_transport_find_message (transport , EEMBC_READY_MSG , timeout_sec = 200 )
408+ print (f"Testing { model_name } on { _eembc_get_name (transport )} using EEMBC." )
409+ assert _eembc_get_name (transport ) == "microTVM"
410+ if model_name != "ad" :
411+ accuracy = _eembc_test_dataset (transport , [samples , labels ], 100 )
412+ print (f"Model { model_name } accuracy: { accuracy } " )
413+ else :
414+ mean_error = _eembc_test_dataset_ad (transport , [samples , None ], 100 )
415+ print (
416+ f"""Model { model_name } mean error: { mean_error } .
417+ Note that this is not the final accuracy number.
418+ To calculate that, you need to use sklearn.metrics.roc_auc_score function."""
419+ )
298420
299421
300422if __name__ == "__main__" :
0 commit comments