@@ -1235,11 +1235,34 @@ def sample(
12351235 next_tokens = self .sampler (logits , sampling_metadata )
12361236 return next_tokens
12371237
1238+ def unpack_data (self ,
1239+ image_data : Union [List [torch .Tensor ], torch .Tensor ],
1240+ padding_value = 0 ) -> torch .Tensor :
1241+ if isinstance (image_data , torch .Tensor ):
1242+ # torch.Tensor
1243+ return image_data
1244+ else :
1245+ assert isinstance (
1246+ image_data [0 ],
1247+ torch .Tensor ), "Image data is not properly batched."
1248+ # List[torch.Tensor]
1249+ bsz = len (image_data )
1250+ max_length = max (t .size (0 ) for t in image_data )
1251+ trailing_dims = image_data [0 ].shape [1 :]
1252+ for data in image_data :
1253+ cur_trailing_dims = data .shape [1 :]
1254+ assert cur_trailing_dims == trailing_dims
1255+ output_tensor = torch .full ((bsz , max_length , * trailing_dims ),
1256+ padding_value ,
1257+ dtype = image_data [0 ].dtype ,
1258+ device = image_data [0 ].device )
1259+ for i , t in enumerate (image_data ):
1260+ output_tensor [i , :t .size (0 )] = t
1261+ return output_tensor
1262+
12381263 def _parse_and_validate_image_input (self , ** kwargs : object ):
12391264 # tensor with the same shape will be batched together by
12401265 # MultiModalKwargs.batch, so pixel_values here can be:
1241- # - List[List[torch.Tensor]]:
1242- # with shape (num_tiles, 3, image_res, image_res)
12431266 # - List[torch.Tensor]:
12441267 # with shape (num_image, num_tiles, 3, image_res, image_res)
12451268 # - torch.Tensor:
@@ -1274,10 +1297,9 @@ def _parse_and_validate_image_input(self, **kwargs: object):
12741297
12751298 return MllamaImagePixelInputs (
12761299 type = "pixel_values" ,
1277- data = pixel_values ,
1278- aspect_ratio_ids = aspect_ratio_ids ,
1279- aspect_ratio_mask = aspect_ratio_mask ,
1280- )
1300+ data = self .unpack_data (pixel_values ),
1301+ aspect_ratio_ids = self .unpack_data (aspect_ratio_ids ),
1302+ aspect_ratio_mask = self .unpack_data (aspect_ratio_mask ))
12811303
12821304 if image_embeds is not None :
12831305 raise NotImplementedError
0 commit comments