@@ -15,6 +15,29 @@ class TestViT(unittest.TestCase):
1515 vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
1616 vit = vit .eval ()
1717 model_inputs = (torch .ones (1 , 3 , 224 , 224 ),)
18+ dynamic_shapes = (
19+ {
20+ 2 : torch .export .Dim ("height" , min = 224 , max = 455 ),
21+ 3 : torch .export .Dim ("width" , min = 224 , max = 455 ),
22+ },
23+ )
24+
25+ class DynamicViT (torch .nn .Module ):
26+ def __init__ (self ):
27+ super ().__init__ ()
28+ self .vit = models .vision_transformer .vit_b_16 (weights = "IMAGENET1K_V1" )
29+ self .vit = self .vit .eval ()
30+
31+ def forward (self , x ):
32+ x = torch .nn .functional .interpolate (
33+ x ,
34+ size = (224 , 224 ),
35+ mode = "bilinear" ,
36+ align_corners = True ,
37+ antialias = False ,
38+ )
39+ return self .vit (x )
40+
1841 all_operators = {
1942 "executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
2043 "executorch_exir_dialects_edge__ops_aten_cat_default" ,
@@ -34,7 +57,8 @@ class TestViT(unittest.TestCase):
3457 "executorch_exir_dialects_edge__ops_aten_bmm_default" ,
3558 }
3659
37- def test_fp32_vit (self ):
60+ def _test_exported_vit (self , tester , check_nots = None ):
61+ check_nots = check_nots or []
3862 lowerable_xnn_operators = self .all_operators - {
3963 "executorch_exir_dialects_edge__ops_aten_expand_copy_default" ,
4064 "executorch_exir_dialects_edge__ops_aten_gelu_default" ,
@@ -48,14 +72,33 @@ def test_fp32_vit(self):
4872 "executorch_exir_dialects_edge__ops_aten_bmm_default" ,
4973 }
5074 (
51- Tester (self .vit , self .model_inputs )
52- .export ()
75+ tester .export ()
5376 .to_edge ()
5477 .check (list (self .all_operators ))
5578 .partition ()
5679 .check (["torch.ops.higher_order.executorch_call_delegate" ])
5780 .check_not (list (lowerable_xnn_operators ))
81+ .check_not (check_nots )
5882 .to_executorch ()
5983 .serialize ()
6084 .run_method_and_compare_outputs ()
6185 )
86+
87+ def test_fp32_vit (self ):
88+ self ._test_exported_vit (Tester (self .vit , self .model_inputs ))
89+
90+ def test_dynamic_vit (self ):
91+ bilinear_ops = {
92+ "executorch_exir_dialects_edge__ops_aten_sub_Tensor" ,
93+ "executorch_exir_dialects_edge__ops_aten_mul_Tensor" ,
94+ "executorch_exir_dialects_edge__ops_aten_index_Tensor" ,
95+ "executorch_exir_dialects_edge__ops_aten_arange_start_step" ,
96+ "executorch_exir_dialects_edge__ops_aten__to_copy_default" ,
97+ "executorch_exir_dialects_edge__ops_aten_add_Tensor" ,
98+ "executorch_exir_dialects_edge__ops_aten_clamp_default" ,
99+ }
100+
101+ self ._test_exported_vit (
102+ Tester (self .DynamicViT (), self .model_inputs , self .dynamic_shapes ),
103+ bilinear_ops ,
104+ )
0 commit comments