diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 260d2c109387..c688655887e2 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -212,7 +212,7 @@ def _run_test( with vllm_runner(model, dtype=dtype, max_model_len=4096, - max_num_seqs=2, + max_num_seqs=3, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 68d5298dfc9b..6a2e20840fcf 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1235,11 +1235,34 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def unpack_data(self, + image_data: Union[List[torch.Tensor], torch.Tensor], + padding_value=0) -> torch.Tensor: + if isinstance(image_data, torch.Tensor): + # torch.Tensor + return image_data + else: + assert isinstance( + image_data[0], + torch.Tensor), "Image data is not properly batched." + # List[torch.Tensor] + bsz = len(image_data) + max_length = max(t.size(0) for t in image_data) + trailing_dims = image_data[0].shape[1:] + for data in image_data: + cur_trailing_dims = data.shape[1:] + assert cur_trailing_dims == trailing_dims + output_tensor = torch.full((bsz, max_length, *trailing_dims), + padding_value, + dtype=image_data[0].dtype, + device=image_data[0].device) + for i, t in enumerate(image_data): + output_tensor[i, :t.size(0)] = t + return output_tensor + def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: - # - List[List[torch.Tensor]]: - # with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: # with shape (num_image, num_tiles, 3, image_res, image_res) # - torch.Tensor: @@ -1274,10 +1297,9 @@ def _parse_and_validate_image_input(self, **kwargs: object): return MllamaImagePixelInputs( type="pixel_values", - data=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - ) + data=self.unpack_data(pixel_values), + aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), + aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) if image_embeds is not None: raise NotImplementedError