2626class TestLinear (unittest .TestCase ):
2727 def test_fp16_linear (self ):
2828 for use_bias in (True , False ):
29- self ._test_linear (
30- lambda in_size , out_size : torch .nn .Linear (
31- in_size , out_size , bias = use_bias # noqa
32- ),
33- uses_bias = use_bias ,
34- dtype = torch .float16 ,
35- atol = 5e-2 ,
36- )
29+ for num_batch_dims in range (1 , 3 ):
30+ self ._test_linear (
31+ lambda in_size , out_size : torch .nn .Linear (
32+ in_size , out_size , bias = use_bias # noqa
33+ ),
34+ num_batch_dims = num_batch_dims ,
35+ uses_bias = use_bias ,
36+ dtype = torch .float16 ,
37+ atol = 5e-2 ,
38+ )
3739
3840 def test_fp32_linear (self ):
3941 for use_bias in (True , False ):
40- self ._test_linear (
41- lambda in_size , out_size : torch .nn .Linear (
42- in_size , out_size , bias = use_bias # noqa
43- ),
44- uses_bias = use_bias ,
45- )
42+ for num_batch_dims in range (1 , 3 ):
43+ self ._test_linear (
44+ lambda in_size , out_size : torch .nn .Linear (
45+ in_size , out_size , bias = use_bias # noqa
46+ ),
47+ uses_bias = use_bias ,
48+ num_batch_dims = num_batch_dims ,
49+ )
4650
4751 def test_fp32_addmm (self ):
4852 """
@@ -63,24 +67,71 @@ def forward(self, x):
6367 uses_bias = True ,
6468 )
6569
70+ def test_fp32_linear_fused_relu (self ):
71+ class LinearReluModule (torch .nn .Module ):
72+ def __init__ (self , in_size , out_size , use_bias ):
73+ super ().__init__ ()
74+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
75+
76+ def forward (self , x ):
77+ return torch .nn .functional .relu (self .linear (x ))
78+
79+ for use_bias in (True , False ):
80+ for num_batch_dims in range (1 , 3 ):
81+ self ._test_linear (
82+ lambda in_size , out_size : LinearReluModule (
83+ in_size ,
84+ out_size ,
85+ use_bias , # noqa
86+ ),
87+ uses_bias = use_bias ,
88+ num_batch_dims = num_batch_dims ,
89+ )
90+
91+ def test_qs8_linear_fused_relu (self ):
92+ class LinearReluModule (torch .nn .Module ):
93+ def __init__ (self , in_size , out_size , use_bias ):
94+ super ().__init__ ()
95+ self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
96+
97+ def forward (self , x ):
98+ return torch .nn .functional .relu (self .linear (x ))
99+
100+ for use_bias in (True , False ):
101+ for num_batch_dims in range (1 , 3 ):
102+ self ._test_linear (
103+ lambda in_size , out_size : LinearReluModule (
104+ in_size ,
105+ out_size ,
106+ use_bias , # noqa
107+ ),
108+ num_batch_dims = num_batch_dims ,
109+ uses_bias = use_bias ,
110+ quant = True ,
111+ )
112+
66113 def test_qs8_linear (self ):
67114 for use_bias in (True , False ):
68- self ._test_linear (
69- lambda in_size , out_size : torch .nn .Linear (
70- in_size , out_size , bias = use_bias # noqa
71- ),
72- uses_bias = use_bias ,
73- )
115+ for num_batch_dims in range (1 , 3 ):
116+ self ._test_linear (
117+ lambda in_size , out_size : torch .nn .Linear (
118+ in_size , out_size , bias = use_bias # noqa
119+ ),
120+ uses_bias = use_bias ,
121+ num_batch_dims = num_batch_dims ,
122+ )
74123
75124 @unittest .skip ("XNNPACK currently only supports per-channel dynamic quantization." )
76125 def test_qd8_per_tensor_linear (self ):
77126 for uses_bias in (False , True ):
78127 inputs = (torch .randn (2 , 4 ),)
79128 module = torch .nn .Linear (4 , 5 , bias = uses_bias )
129+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
80130
81131 self ._test_dqlinear (
82132 module ,
83133 inputs ,
134+ dynamic_shapes = dynamic_shapes ,
84135 is_per_channel = False ,
85136 uses_bias = uses_bias ,
86137 )
@@ -93,6 +144,7 @@ def test_qd8_per_channel_linear(self):
93144 self ._test_dqlinear (
94145 module ,
95146 inputs ,
147+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
96148 is_per_channel = True ,
97149 uses_bias = uses_bias ,
98150 )
@@ -114,7 +166,7 @@ def test_qd8_per_channel_4w_linear(self):
114166 qconfig = self ._get_4b_dqconfig ()
115167 input_channels = [2 , 63 ]
116168 output_channels = [1 , 8 , 127 ]
117- batches = [1 , 2 ]
169+ batches = [2 , 2 ]
118170 use_bias = [False , True ]
119171
120172 for bs , bias , ipc , opc in product (
@@ -129,13 +181,14 @@ def test_qd8_per_channel_4w_linear(self):
129181 self ._test_dqlinear (
130182 module ,
131183 inputs ,
184+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},),
132185 is_per_channel = True ,
133186 uses_bias = bias ,
134187 qconfig = qconfig ,
135188 )
136189
137190 def test_qd8_per_channel_linear_parallel (self ):
138- in_size = 1
191+ in_size = 2
139192 input_size = 4
140193 output_size = 5
141194
@@ -165,17 +218,39 @@ def forward(self, x, y):
165218 torch .rand (in_size , input_size , dtype = torch .float ),
166219 torch .rand (in_size , input_size , dtype = torch .float ),
167220 )
221+ batch_dim = torch .export .Dim ("batch" , max = 100 )
222+ dynamic_shapes = ({0 : batch_dim }, {0 : batch_dim })
168223
169224 self ._test_dqlinear (
170225 ParallelLinear (),
171226 inputs ,
227+ dynamic_shapes = dynamic_shapes ,
172228 linear_count = 2 ,
173229 is_per_channel = True ,
174230 uses_bias = True ,
175231 )
176232
233+ def test_qd8_per_channel_linear_with_two_batch (self ):
234+ in_size = 2
235+ input_size = 4
236+ output_size = 5
237+
238+ linear = torch .nn .Linear (input_size , output_size )
239+ inputs = (torch .randn (2 , in_size , input_size , dtype = torch .float ),)
240+ batch_dim = torch .export .Dim ("batch" , max = 100 )
241+ dynamic_shapes = ({0 : batch_dim , 1 : batch_dim },)
242+
243+ self ._test_dqlinear (
244+ linear ,
245+ inputs ,
246+ dynamic_shapes = dynamic_shapes ,
247+ linear_count = 1 ,
248+ is_per_channel = True ,
249+ uses_bias = True ,
250+ )
251+
177252 def test_qd8_per_channel_linear_sequential (self ):
178- in_size = 1
253+ in_size = 2
179254 input_size = 4
180255 intermediate_size = 5
181256 output_size = 3
@@ -203,17 +278,20 @@ def forward(self, x):
203278 return b
204279
205280 inputs = (torch .rand (in_size , input_size , dtype = torch .float ),)
281+ dynamic_shapes = ({0 : torch .export .Dim ("batch" , max = 100 )},)
206282
207283 self ._test_dqlinear (
208284 LinearSequential (),
209285 inputs ,
286+ dynamic_shapes = dynamic_shapes ,
210287 linear_count = 2 ,
211288 is_per_channel = True ,
212289 uses_bias = True ,
290+ atol = 1e-1 ,
213291 )
214292
215293 def test_qd8_per_channel_linear_parellel_and_sequential (self ):
216- in_size = 1
294+ in_size = 2
217295 input_size = 4
218296 intermediate_size = 5
219297 output_size = 3
@@ -252,50 +330,21 @@ def forward(self, x, y):
252330 torch .rand (in_size , input_size , dtype = torch .float ),
253331 torch .rand (in_size , input_size , dtype = torch .float ),
254332 )
333+ dynamic_shapes = (
334+ {0 : torch .export .Dim ("batch" , max = 100 )},
335+ {0 : torch .export .Dim ("batch2" , max = 100 )},
336+ )
255337
256338 self ._test_dqlinear (
257- LinearModule (), inputs , linear_count = 3 , is_per_channel = True , uses_bias = True
339+ LinearModule (),
340+ inputs ,
341+ dynamic_shapes = dynamic_shapes ,
342+ linear_count = 3 ,
343+ is_per_channel = True ,
344+ uses_bias = True ,
345+ atol = 1e-1 ,
258346 )
259347
260- def test_fp32_linear_fused_relu (self ):
261- class LinearReluModule (torch .nn .Module ):
262- def __init__ (self , in_size , out_size , use_bias ):
263- super ().__init__ ()
264- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
265-
266- def forward (self , x ):
267- return torch .nn .functional .relu (self .linear (x ))
268-
269- for use_bias in (True , False ):
270- self ._test_linear (
271- lambda in_size , out_size : LinearReluModule (
272- in_size ,
273- out_size ,
274- use_bias , # noqa
275- ),
276- uses_bias = use_bias ,
277- )
278-
279- def test_qs8_linear_fused_relu (self ):
280- class LinearReluModule (torch .nn .Module ):
281- def __init__ (self , in_size , out_size , use_bias ):
282- super ().__init__ ()
283- self .linear = torch .nn .Linear (in_size , out_size , bias = use_bias )
284-
285- def forward (self , x ):
286- return torch .nn .functional .relu (self .linear (x ))
287-
288- for use_bias in (True , False ):
289- self ._test_linear (
290- lambda in_size , out_size : LinearReluModule (
291- in_size ,
292- out_size ,
293- use_bias , # noqa
294- ),
295- uses_bias = use_bias ,
296- quant = True ,
297- )
298-
299348 class ManualDQLinear (torch .nn .Module ):
300349 def __init__ (
301350 self ,
@@ -676,6 +725,7 @@ def _test_linear(
676725 self ,
677726 make_module ,
678727 uses_bias ,
728+ num_batch_dims = 1 ,
679729 quant = False ,
680730 dtype : torch .dtype = torch .float ,
681731 atol = 1e-03 ,
@@ -692,7 +742,7 @@ def _test_linear(
692742 )
693743 )
694744
695- in_sizes = [1 , 4 , 4 ]
745+ in_sizes = [3 , 4 , 4 ]
696746 input_sizes = [4 , 37 , 17 ]
697747 output_sizes = [4 , 17 , 37 ]
698748
@@ -704,11 +754,19 @@ def _test_linear(
704754 in_size = int (in_sizes [i ])
705755 input_size = int (input_sizes [i ])
706756 output_size = int (output_sizes [i ])
757+ input_shape = [in_size ] * num_batch_dims + [input_size ]
758+ print (f"Testing input_shape { input_shape } with { output_size } out_channels" )
707759
708760 module = make_module (input_size , output_size ).eval ().to (dtype )
709- inputs = (torch .randn (in_size , input_size ).to (dtype ),)
761+ inputs = (torch .randn (input_shape ).to (dtype ),)
762+ dynamic_shape = {}
763+ for i in range (num_batch_dims ):
764+ dynamic_shape [i ] = torch .export .Dim (f"batch{ i } " , min = 2 , max = in_size )
765+
766+ dynamic_shape = (dynamic_shape ,)
767+ print (dynamic_shape )
710768
711- tester = Tester (module , inputs )
769+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shape )
712770
713771 if quant :
714772 tester .quantize ()
@@ -736,10 +794,12 @@ def _test_dqlinear(
736794 self ,
737795 module ,
738796 inputs ,
797+ dynamic_shapes ,
739798 linear_count = 1 ,
740799 is_per_channel = False ,
741800 uses_bias = False ,
742801 qconfig : Optional [QuantizationConfig ] = None ,
802+ atol = 5e-02 ,
743803 ):
744804 aten_op , edge_op = (
745805 (
@@ -758,13 +818,12 @@ def _test_dqlinear(
758818 is_dynamic = True ,
759819 )
760820
761- tester = Tester (module , inputs )
821+ tester = Tester (module , inputs , dynamic_shapes = dynamic_shapes )
762822 tester .quantize (Quantize (quantization_config = quant_config ))
763823
764824 tester .export ()
765825 tester .check_count ({aten_op : linear_count })
766826 tester .check (["torch.ops.quantized_decomposed" ])
767- tester .dump_artifact ()
768827 tester .to_edge ()
769828 tester .check_count ({edge_op : linear_count })
770829
@@ -776,4 +835,4 @@ def _test_dqlinear(
776835
777836 tester .to_executorch ()
778837 tester .serialize ()
779- tester .run_method_and_compare_outputs (atol = 5e-02 )
838+ tester .run_method_and_compare_outputs (atol = atol )
0 commit comments