@@ -16,6 +16,12 @@ class TestMobileNetV3(unittest.TestCase):
1616 mv3 = models .mobilenetv3 .mobilenet_v3_small (pretrained = True )
1717 mv3 = mv3 .eval ()
1818 model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
19+ dynamic_shapes = (
20+ {
21+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
22+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
23+ },
24+ )
1925
2026 all_operators = {
2127 "executorch_exir_dialects_edge__ops_aten__native_batch_norm_legit_no_training_default" ,
@@ -33,7 +39,7 @@ class TestMobileNetV3(unittest.TestCase):
3339
3440 def test_fp32_mv3 (self ):
3541 (
36- Tester (self .mv3 , self .model_inputs )
42+ Tester (self .mv3 , self .model_inputs , dynamic_shapes = self . dynamic_shapes )
3743 .export ()
3844 .to_edge ()
3945 .check (list (self .all_operators ))
@@ -42,7 +48,7 @@ def test_fp32_mv3(self):
4248 .check_not (list (self .all_operators ))
4349 .to_executorch ()
4450 .serialize ()
45- .run_method_and_compare_outputs ()
51+ .run_method_and_compare_outputs (num_runs = 5 )
4652 )
4753
4854 def test_qs8_mv3 (self ):
@@ -52,7 +58,7 @@ def test_qs8_mv3(self):
5258 ops_after_lowering = self .all_operators
5359
5460 (
55- Tester (self .mv3 , self .model_inputs )
61+ Tester (self .mv3 , self .model_inputs , dynamic_shapes = self . dynamic_shapes )
5662 .quantize (Quantize (calibrate = False ))
5763 .export ()
5864 .to_edge ()
@@ -62,5 +68,5 @@ def test_qs8_mv3(self):
6268 .check_not (list (ops_after_lowering ))
6369 .to_executorch ()
6470 .serialize ()
65- .run_method_and_compare_outputs ()
71+ .run_method_and_compare_outputs (num_runs = 5 )
6672 )
0 commit comments