1- from typing import Iterable , List , NamedTuple , Optional , Sequence , Tuple
1+ from typing import Any , Iterable , List , NamedTuple , Optional , Sequence , Tuple
22
33import torch
44
@@ -18,6 +18,12 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
1818 # is the dynamic batch dimension. Otherwise, we use the additional
1919 # inputs to determine the batch dimension.
2020 if additional_inputs is None :
21+ batch_dims = None
22+ if not isinstance (inputs , torch .Tensor ) and len (inputs ) > 1 :
23+ bs = inputs [0 ].size (0 )
24+ batch_dims = None
25+ if not all (x .size (0 ) == bs for x in inputs ):
26+ batch_dims = InputTensorSpec .find_batch_size_dim (inputs )
2127 return InputTensorSpec .from_tensors_with_dynamic_batch_size (
2228 inputs ,
2329 (
@@ -26,6 +32,7 @@ def generate_input_specs(inputs, lower_setting, additional_inputs=None):
2632 lower_setting .max_batch_size ,
2733 ),
2834 lower_setting .opt_profile_replica ,
35+ batch_dims ,
2936 )
3037 else :
3138 batch_dims = []
@@ -147,25 +154,69 @@ def from_tensors_with_dynamic_batch_size(
147154 A list of InputTensorSpec named tuples with dynamic ranges.
148155 """
149156 if batch_dims is None :
150- batch_dims = [ 0 ] * len (tensors )
157+ batch_dims = cls . find_batch_size_dim (tensors )
151158
152159 input_specs = []
153160 batch_size = tensors [0 ].size (batch_dims [0 ])
154161
155162 for i , tensor in enumerate (tensors ):
156163 batch_dim = batch_dims [i ]
157- assert batch_size == tensor .size (
158- batch_dim
159- ), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
160- shape = list (tensor .shape )
161- shape [batch_dim ] = - 1
162- shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [0 :batch_dim ] + [bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
163- input_specs .append (
164- cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
165- )
164+ if batch_dim == - 1 :
165+ input_specs .append (cls .from_tensor (tensor ))
166+ else :
167+ shape = list (tensor .shape )
168+ assert batch_size == tensor .size (
169+ batch_dim
170+ ), f"The { i } th tensor (shape: { tensor .shape } ) doesn't have the correct batch size: { batch_size } ."
171+ shape [batch_dim ] = - 1
172+ shape_ranges : List [ShapeRange ] = [tuple (tuple (shape [0 :batch_dim ] + [bs ] + shape [batch_dim + 1 :]) for bs in batch_size_range )] * opt_profile_replica # type: ignore[list-item]
173+ input_specs .append (
174+ cls (tuple (shape ), tensor .dtype , tensor .device , shape_ranges )
175+ )
166176
167177 return input_specs
168178
179+ @classmethod
180+ # pyre-ignore [2]: Parameter `sample_input` must have a type other than `Any`
181+ def find_batch_size_dim (cls , inputs : Any ) -> []:
182+ if isinstance (inputs , torch .Tensor ) or len (inputs ) <= 1 :
183+ return [0 ]
184+ shapes = [i .shape for i in inputs ]
185+ frequency_map = {}
186+ first_dims = set ()
187+ for shape in shapes :
188+ if len (shape ) < 2 :
189+ # By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
190+ continue
191+ # Dedup shape value for single tensor
192+ first_dims .add (shape [0 ])
193+ shape = set (shape )
194+ for i in shape :
195+ frequency_map [i ] = frequency_map .get (i , 0 ) + 1
196+
197+ if len (first_dims ) == 1 :
198+ # first dim is the same in every input: we use it as batch_size
199+ batch_size = first_dims .pop ()
200+ elif frequency_map :
201+ # first dims are different: we use the most frequent dim as batch_size
202+ sorted_frequency = sorted (frequency_map .items (), key = lambda x : - x [1 ])
203+ batch_size = sorted_frequency [0 ][0 ]
204+ else :
205+ # no dims to sort: no batch_size
206+ batch_size = - 1
207+
208+ bs_dim = []
209+ for i in inputs :
210+ # Default batch size dim = -1, indicate no batch_size
211+ dim = - 1
212+ for index , val in enumerate (i .shape ):
213+ if val == batch_size :
214+ dim = index
215+ break
216+ bs_dim .append (dim )
217+
218+ return bs_dim
219+
169220 def to_random_tensor (self , id = 1 ):
170221 shape = tuple (self .shape )
171222 if len (get_dynamic_dims (shape )):
0 commit comments