Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _run_test(
with vllm_runner(model,
dtype=dtype,
max_model_len=4096,
max_num_seqs=2,
max_num_seqs=3,
tensor_parallel_size=tensor_parallel_size,
distributed_executor_backend=distributed_executor_backend,
limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT
Expand Down
34 changes: 28 additions & 6 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,11 +1235,34 @@ def sample(
next_tokens = self.sampler(logits, sampling_metadata)
return next_tokens

def unpack_data(self,
image_data: Union[List[torch.Tensor], torch.Tensor],
padding_value=0) -> torch.Tensor:
if isinstance(image_data, torch.Tensor):
# torch.Tensor
return image_data
else:
assert isinstance(
image_data[0],
torch.Tensor), "Image data is not properly batched."
# List[torch.Tensor]
bsz = len(image_data)
max_length = max(t.size(0) for t in image_data)
trailing_dims = image_data[0].shape[1:]
for data in image_data:
cur_trailing_dims = data.shape[1:]
assert cur_trailing_dims == trailing_dims
output_tensor = torch.full((bsz, max_length, *trailing_dims),
padding_value,
dtype=image_data[0].dtype,
device=image_data[0].device)
for i, t in enumerate(image_data):
output_tensor[i, :t.size(0)] = t
return output_tensor

def _parse_and_validate_image_input(self, **kwargs: object):
# tensor with the same shape will be batched together by
# MultiModalKwargs.batch, so pixel_values here can be:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DarkLight1337 @ywang96 I think it is a common problem that images have different sizes and we need to pad them from list of tensors with different shape to one tensor. (See these code comments for details). Is there any utility functions for this in vLLM?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is usually done as part of the HF processor. If the HF processor doesn't do this, you can apply it manually like in Pixtral-HF.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Given that, I think it is OK to implement the unpacking in mllama.py.

# - List[List[torch.Tensor]]:
# with shape (num_tiles, 3, image_res, image_res)
# - List[torch.Tensor]:
# with shape (num_image, num_tiles, 3, image_res, image_res)
# - torch.Tensor:
Expand Down Expand Up @@ -1274,10 +1297,9 @@ def _parse_and_validate_image_input(self, **kwargs: object):

return MllamaImagePixelInputs(
type="pixel_values",
data=pixel_values,
aspect_ratio_ids=aspect_ratio_ids,
aspect_ratio_mask=aspect_ratio_mask,
)
data=self.unpack_data(pixel_values),
aspect_ratio_ids=self.unpack_data(aspect_ratio_ids),
aspect_ratio_mask=self.unpack_data(aspect_ratio_mask))

if image_embeds is not None:
raise NotImplementedError
Expand Down