@@ -39,7 +39,38 @@ def build_static_yaml():
3939 timeout: 0
4040 random_seed: 9527
4141 """
42- with open ("static_yaml.yaml" , "w" , encoding = "utf-8" ) as f :
42+ with open ("static.yaml" , "w" , encoding = "utf-8" ) as f :
43+ f .write (fake_yaml )
44+
45+ def build_benchmark_yaml ():
46+ fake_yaml = """
47+ model:
48+ name: imagenet
49+ framework: onnxrt_qlinearops
50+
51+ evaluation:
52+ performance:
53+ warmup: 1
54+ iteration: 10
55+ configs:
56+ num_of_instance: 1
57+ dataloader:
58+ batch_size: 1
59+ dataset:
60+ ImageFolder:
61+ root: /path/to/evaluation/dataset/
62+ accuracy:
63+ metric:
64+ topk: 1
65+
66+ tuning:
67+ accuracy_criterion:
68+ relative: 0.01
69+ exit_policy:
70+ timeout: 0
71+ random_seed: 9527
72+ """
73+ with open ("benchmark.yaml" , "w" , encoding = "utf-8" ) as f :
4374 f .write (fake_yaml )
4475
4576def build_dynamic_yaml ():
@@ -65,7 +96,7 @@ def build_dynamic_yaml():
6596 timeout: 0
6697 random_seed: 9527
6798 """
68- with open ("dynamic_yaml .yaml" , "w" , encoding = "utf-8" ) as f :
99+ with open ("dynamic .yaml" , "w" , encoding = "utf-8" ) as f :
69100 f .write (fake_yaml )
70101
71102def build_non_MSE_yaml ():
@@ -101,7 +132,7 @@ def build_non_MSE_yaml():
101132 timeout: 0
102133 random_seed: 9527
103134 """
104- with open ("non_MSE_yaml .yaml" , "w" , encoding = "utf-8" ) as f :
135+ with open ("non_MSE .yaml" , "w" , encoding = "utf-8" ) as f :
105136 f .write (fake_yaml )
106137
107138def eval_func (model ):
@@ -139,16 +170,18 @@ def setUpClass(self):
139170 build_static_yaml ()
140171 build_dynamic_yaml ()
141172 build_non_MSE_yaml ()
173+ build_benchmark_yaml ()
142174 export_onnx_model (self .mb_v2_model , self .mb_v2_export_path )
143175 self .mb_v2_model = onnx .load (self .mb_v2_export_path )
144176 export_onnx_model (self .rn50_model , self .rn50_export_path )
145177 self .rn50_model = onnx .load (self .rn50_export_path )
146178
147179 @classmethod
148180 def tearDownClass (self ):
149- os .remove ("static_yaml.yaml" )
150- os .remove ("dynamic_yaml.yaml" )
151- os .remove ("non_MSE_yaml.yaml" )
181+ os .remove ("static.yaml" )
182+ os .remove ("dynamic.yaml" )
183+ os .remove ("non_MSE.yaml" )
184+ os .remove ("benchmark.yaml" )
152185 os .remove (self .mb_v2_export_path )
153186 os .remove (self .rn50_export_path )
154187 shutil .rmtree ("./saved" , ignore_errors = True )
@@ -167,23 +200,31 @@ def test_adaptor(self):
167200 adaptor = FRAMEWORKS [framework ](framework_specific_info )
168201 adaptor .inspect_tensor (self .rn50_model , self .cv_dataloader , ["Conv" ])
169202
170- def test_quantizate (self ):
203+ def test_adaptor (self ):
171204 from lpot .experimental import Quantization , common
172- for fake_yaml in ["static_yaml .yaml" , "dynamic_yaml .yaml" ]:
205+ for fake_yaml in ["static .yaml" , "dynamic .yaml" ]:
173206 quantizer = Quantization (fake_yaml )
174207 quantizer .calib_dataloader = self .cv_dataloader
175208 quantizer .eval_dataloader = self .cv_dataloader
176209 quantizer .model = common .Model (self .rn50_model )
177210 q_model = quantizer ()
178211 eval_func (q_model )
179- for fake_yaml in ["non_MSE_yaml .yaml" ]:
212+ for fake_yaml in ["non_MSE .yaml" ]:
180213 quantizer = Quantization (fake_yaml )
181214 quantizer .calib_dataloader = self .cv_dataloader
182215 quantizer .eval_dataloader = self .cv_dataloader
183216 quantizer .model = common .Model (self .mb_v2_model )
184217 q_model = quantizer ()
185218 eval_func (q_model )
186219
220+ from lpot .experimental import Benchmark , common
221+ for mode in ["performance" , "accuracy" ]:
222+ fake_yaml = "benchmark.yaml"
223+ evaluator = Benchmark (fake_yaml )
224+ evaluator .b_dataloader = self .cv_dataloader
225+ evaluator .model = common .Model (self .rn50_model )
226+ evaluator (mode )
227+
187228
188229if __name__ == "__main__" :
189230 unittest .main ()
0 commit comments