Skip to content

Commit b874a81

Browse files
committed
add_ort_benchmark_UT
1 parent 5601a84 commit b874a81

File tree

1 file changed

+50
-9
lines changed

1 file changed

+50
-9
lines changed

test/test_adaptor_onnxrt.py

Lines changed: 50 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,38 @@ def build_static_yaml():
3939
timeout: 0
4040
random_seed: 9527
4141
"""
42-
with open("static_yaml.yaml", "w", encoding="utf-8") as f:
42+
with open("static.yaml", "w", encoding="utf-8") as f:
43+
f.write(fake_yaml)
44+
45+
def build_benchmark_yaml():
46+
fake_yaml = """
47+
model:
48+
name: imagenet
49+
framework: onnxrt_qlinearops
50+
51+
evaluation:
52+
performance:
53+
warmup: 1
54+
iteration: 10
55+
configs:
56+
num_of_instance: 1
57+
dataloader:
58+
batch_size: 1
59+
dataset:
60+
ImageFolder:
61+
root: /path/to/evaluation/dataset/
62+
accuracy:
63+
metric:
64+
topk: 1
65+
66+
tuning:
67+
accuracy_criterion:
68+
relative: 0.01
69+
exit_policy:
70+
timeout: 0
71+
random_seed: 9527
72+
"""
73+
with open("benchmark.yaml", "w", encoding="utf-8") as f:
4374
f.write(fake_yaml)
4475

4576
def build_dynamic_yaml():
@@ -65,7 +96,7 @@ def build_dynamic_yaml():
6596
timeout: 0
6697
random_seed: 9527
6798
"""
68-
with open("dynamic_yaml.yaml", "w", encoding="utf-8") as f:
99+
with open("dynamic.yaml", "w", encoding="utf-8") as f:
69100
f.write(fake_yaml)
70101

71102
def build_non_MSE_yaml():
@@ -101,7 +132,7 @@ def build_non_MSE_yaml():
101132
timeout: 0
102133
random_seed: 9527
103134
"""
104-
with open("non_MSE_yaml.yaml", "w", encoding="utf-8") as f:
135+
with open("non_MSE.yaml", "w", encoding="utf-8") as f:
105136
f.write(fake_yaml)
106137

107138
def eval_func(model):
@@ -139,16 +170,18 @@ def setUpClass(self):
139170
build_static_yaml()
140171
build_dynamic_yaml()
141172
build_non_MSE_yaml()
173+
build_benchmark_yaml()
142174
export_onnx_model(self.mb_v2_model, self.mb_v2_export_path)
143175
self.mb_v2_model = onnx.load(self.mb_v2_export_path)
144176
export_onnx_model(self.rn50_model, self.rn50_export_path)
145177
self.rn50_model = onnx.load(self.rn50_export_path)
146178

147179
@classmethod
148180
def tearDownClass(self):
149-
os.remove("static_yaml.yaml")
150-
os.remove("dynamic_yaml.yaml")
151-
os.remove("non_MSE_yaml.yaml")
181+
os.remove("static.yaml")
182+
os.remove("dynamic.yaml")
183+
os.remove("non_MSE.yaml")
184+
os.remove("benchmark.yaml")
152185
os.remove(self.mb_v2_export_path)
153186
os.remove(self.rn50_export_path)
154187
shutil.rmtree("./saved", ignore_errors=True)
@@ -167,23 +200,31 @@ def test_adaptor(self):
167200
adaptor = FRAMEWORKS[framework](framework_specific_info)
168201
adaptor.inspect_tensor(self.rn50_model, self.cv_dataloader, ["Conv"])
169202

170-
def test_quantizate(self):
203+
def test_adaptor(self):
171204
from lpot.experimental import Quantization, common
172-
for fake_yaml in ["static_yaml.yaml", "dynamic_yaml.yaml"]:
205+
for fake_yaml in ["static.yaml", "dynamic.yaml"]:
173206
quantizer = Quantization(fake_yaml)
174207
quantizer.calib_dataloader = self.cv_dataloader
175208
quantizer.eval_dataloader = self.cv_dataloader
176209
quantizer.model = common.Model(self.rn50_model)
177210
q_model = quantizer()
178211
eval_func(q_model)
179-
for fake_yaml in ["non_MSE_yaml.yaml"]:
212+
for fake_yaml in ["non_MSE.yaml"]:
180213
quantizer = Quantization(fake_yaml)
181214
quantizer.calib_dataloader = self.cv_dataloader
182215
quantizer.eval_dataloader = self.cv_dataloader
183216
quantizer.model = common.Model(self.mb_v2_model)
184217
q_model = quantizer()
185218
eval_func(q_model)
186219

220+
from lpot.experimental import Benchmark, common
221+
for mode in ["performance", "accuracy"]:
222+
fake_yaml = "benchmark.yaml"
223+
evaluator = Benchmark(fake_yaml)
224+
evaluator.b_dataloader = self.cv_dataloader
225+
evaluator.model = common.Model(self.rn50_model)
226+
evaluator(mode)
227+
187228

188229
if __name__ == "__main__":
189230
unittest.main()

0 commit comments

Comments
 (0)