@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
1515 vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
1616 vit = vit .eval ()
1717 model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
18+ dynamic_shapes = (
19+ {
20+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
21+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
22+ },
23+ )
24+
25+ class DynamicViT (torch .nn .Module ):
26+ def __init__ (self ):
27+ super ().__init__ ()
28+ self .vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
29+ self .vit = self .vit .eval ()
30+
31+ def forward (self , x ):
32+ x = torch .nn .functional .interpolate (
33+ x ,
34+ size = (224 , 224 ),
35+ mode = "bilinear" ,
36+ align_corners = True ,
37+ antialias = False ,
38+ )
39+ return self .vit (x )
40+
1841 all_operators = {
1942 "executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
2043 "executorch_exir_dialects_edge__ops_aten_cat_default" ,
@@ -34,7 +57,7 @@ class TestViT(unittest.TestCase):
3457 "executorch_exir_dialects_edge__ops_aten_bmm_default" ,
3558 }
3659
37- def test_fp32_vit (self ):
60+ def _test_exported_vit (self , tester , check_nots = [] ):
3861 lowerable_xnn_operators = self .all_operators - {
3962 "executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
4063 "executorch_exir_dialects_edge__ops_aten_gelu_default" ,
@@ -48,14 +71,33 @@ def test_fp32_vit(self):
4871 "executorch_exir_dialects_edge__ops_aten_bmm_default" ,
4972 }
5073 (
51- Tester (self .vit , self .model_inputs )
52- .export ()
74+ tester .export ()
5375 .to_edge ()
5476 .check (list (self .all_operators ))
5577 .partition ()
5678 .check (["torch.ops.higher_order.executorch_call_delegate" ])
5779 .check_not (list (lowerable_xnn_operators ))
80+ .check_not (check_nots )
5881 .to_executorch ()
5982 .serialize ()
6083 .run_method_and_compare_outputs ()
6184 )
85+
86+ def test_fp32_vit (self ):
87+ self ._test_exported_vit (Tester (self .vit , self .model_inputs ))
88+
89+ def test_dynamic_vit (self ):
90+ bilinear_ops = {
91+ "executorch_exir_dialects_edge__ops_aten_sub_Tensor" ,
92+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" ,
93+ "executorch_exir_dialects_edge__ops_aten_index_Tensor" ,
94+ "executorch_exir_dialects_edge__ops_aten_arange_start_step" ,
95+ "executorch_exir_dialects_edge__ops_aten__to_copy_default" ,
96+ "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
97+ "executorch_exir_dialects_edge__ops_aten_clamp_default" ,
98+ }
99+
100+ self ._test_exported_vit (
101+ Tester (self .DynamicViT (), self .model_inputs , self .dynamic_shapes ),
102+ bilinear_ops ,
103+ )
0 commit comments