@@ -81,9 +81,15 @@ def forward(self, x):
8181 z = torch .add (y , z )
8282 return z
8383
84- def _test_conv1d (self , module , inputs , conv_count , quantized = False ):
84+ def _test_conv1d (
85+ self , module , inputs , conv_count , quantized = False , dynamic_shape = None
86+ ):
8587 (
86- (Tester (module , inputs ).quantize () if quantized else Tester (module , inputs ))
88+ (
89+ Tester (module , inputs , dynamic_shape ).quantize ()
90+ if quantized
91+ else Tester (module , inputs )
92+ )
8793 .export ()
8894 .check_count ({"torch.ops.aten.convolution.default" : conv_count })
8995 .to_edge ()
@@ -101,21 +107,41 @@ def _test_conv1d(self, module, inputs, conv_count, quantized=False):
101107 )
102108
103109 def test_fp16_conv1d (self ):
104- inputs = (torch .randn (1 , 2 , 4 ).to (torch .float16 ),)
105- self ._test_conv1d (self .Conv1d (dtype = torch .float16 ), inputs , conv_count = 1 )
110+ inputs = (torch .randn (2 , 2 , 4 ).to (torch .float16 ),)
111+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
112+ self ._test_conv1d (
113+ self .Conv1d (dtype = torch .float16 ),
114+ inputs ,
115+ conv_count = 1 ,
116+ dynamic_shape = dynamic_shapes ,
117+ )
106118
107119 def test_fp32_conv1d (self ):
108- inputs = (torch .randn (1 , 2 , 4 ),)
109- self ._test_conv1d (self .Conv1d (), inputs , 1 )
120+ inputs = (torch .randn (2 , 2 , 4 ),)
121+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
122+ self ._test_conv1d (self .Conv1d (), inputs , 1 , dynamic_shape = dynamic_shapes )
110123
111124 def test_fp32_conv1d_batchnorm_seq (self ):
112- inputs = (torch .randn (1 , 2 , 4 ),)
113- self ._test_conv1d (self .Conv1dBatchNormSequential (), inputs , 2 )
125+ inputs = (torch .randn (2 , 2 , 4 ),)
126+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
127+ self ._test_conv1d (
128+ self .Conv1dBatchNormSequential (), inputs , 2 , dynamic_shape = dynamic_shapes
129+ )
114130
115131 def test_qs8_conv1d (self ):
116- inputs = (torch .randn (1 , 2 , 4 ),)
117- self ._test_conv1d (self .Conv1d (), inputs , 1 , quantized = True )
132+ inputs = (torch .randn (2 , 2 , 4 ),)
133+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
134+ self ._test_conv1d (
135+ self .Conv1d (), inputs , 1 , quantized = True , dynamic_shape = dynamic_shapes
136+ )
118137
119138 def test_qs8_conv1d_batchnorm_seq (self ):
120- inputs = (torch .randn (1 , 2 , 4 ),)
121- self ._test_conv1d (self .Conv1dBatchNormSequential (), inputs , 2 , quantized = True )
139+ inputs = (torch .randn (2 , 2 , 4 ),)
140+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , min = 2 , max = 10 )},)
141+ self ._test_conv1d (
142+ self .Conv1dBatchNormSequential (),
143+ inputs ,
144+ 2 ,
145+ quantized = True ,
146+ dynamic_shape = dynamic_shapes ,
147+ )
0 commit comments