From 0c7b94309474150137c585516f665e59abe3a88f Mon Sep 17 00:00:00 2001 From: yan ma Date: Mon, 17 Mar 2025 00:04:30 +0800 Subject: [PATCH 1/4] fix mllama multi-image Signed-off-by: yan ma --- .../vision_language/test_mllama.py | 2 +- vllm/model_executor/models/mllama.py | 84 ++++++++++++++++++- 2 files changed, 81 insertions(+), 5 deletions(-) diff --git a/tests/models/encoder_decoder/vision_language/test_mllama.py b/tests/models/encoder_decoder/vision_language/test_mllama.py index 260d2c109387..c688655887e2 100644 --- a/tests/models/encoder_decoder/vision_language/test_mllama.py +++ b/tests/models/encoder_decoder/vision_language/test_mllama.py @@ -212,7 +212,7 @@ def _run_test( with vllm_runner(model, dtype=dtype, max_model_len=4096, - max_num_seqs=2, + max_num_seqs=3, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 68d5298dfc9b..cd4c7cc0b687 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1235,6 +1235,83 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens + def unpack_data(self, image_data: Union[List[List[torch.Tensor]], + List[torch.Tensor], torch.Tensor]): + if isinstance(image_data, torch.Tensor): + # torch.Tensor + return image_data + elif isinstance(image_data[0], torch.Tensor): + bsz = len(image_data) + # List[torch.Tensor] + if image_data[0].ndim == 1: + # input: [tensor([6, 6], device='cuda:0'), + # tensor([6], device='cuda:0')] + # output: tensor([[6, 6], [6, 0]], device='cuda:0') + max_num_elements = max(tensor.numel() for tensor in image_data) + output_tensor = torch.zeros(bsz, + max_num_elements, + device=image_data[0].device, + dtype=image_data[0].dtype) + for b in range(bsz): + original_data = image_data[b] + num_original = original_data.numel() + output_tensor[b, :num_original] = original_data + return output_tensor + + assert image_data[0].ndim == 2 + # input: [tensor([[1, 1, 1, 1], [1, 1, 1, 1]], device='cuda:0'), + # tensor([[1, 1, 1, 1]], device='cuda:0')] + # output: + # tensor([[[1, 1, 1, 1], + # [1, 1, 1, 1]], + # [[1, 1, 1, 1], + # [0, 0, 0, 0]]], device='cuda:0') + bsz = len(image_data) + max_num_elements = max(tensor.shape[0] for tensor in image_data) + output_tensor = torch.zeros(bsz, + max_num_elements, + image_data[0].shape[1], + device=image_data[0].device, + dtype=image_data[0].dtype) + for b in range(bsz): + original_data = image_data[b] + num_original = original_data.shape[0] + output_tensor[b, :num_original, :] = original_data + return output_tensor + else: + return image_data + + def unpack_pixel_values(self, pixel_values: Union[List[List[torch.Tensor]], + List[torch.Tensor], + torch.Tensor]): + if isinstance(pixel_values, torch.Tensor): + return pixel_values + else: + pixel_values_unpacked = [] + for b in range(len(pixel_values)): + pixel_values_unpacked_b = [] + for i in range(len(pixel_values[b])): + pixel_values_unpacked_b.append(pixel_values[b][i]) + pixel_values_unpacked.append(pixel_values_unpacked_b) + + max_num_images = max([len(x) for x in pixel_values_unpacked]) + max_num_chunks = max( + max([len(x) for x in y]) for y in pixel_values_unpacked) + bsz = len(pixel_values_unpacked) + out_images = torch.zeros(bsz, + max_num_images, + max_num_chunks, + 3, + self.image_size, + self.image_size, + device=pixel_values[0].device, + dtype=pixel_values[0].dtype) + for b in range(len(pixel_values_unpacked)): + for i in range(len(pixel_values_unpacked[b])): + img = pixel_values_unpacked[b][i] + out_images[b, i, :img.shape[0]] = img + return out_images + def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: @@ -1274,10 +1351,9 @@ def _parse_and_validate_image_input(self, **kwargs: object): return MllamaImagePixelInputs( type="pixel_values", - data=pixel_values, - aspect_ratio_ids=aspect_ratio_ids, - aspect_ratio_mask=aspect_ratio_mask, - ) + data=self.unpack_pixel_values(pixel_values), + aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), + aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) if image_embeds is not None: raise NotImplementedError From ac060661162aebb7fe5739090cc5b7d05f487484 Mon Sep 17 00:00:00 2001 From: yan ma Date: Fri, 21 Mar 2025 22:14:13 +0800 Subject: [PATCH 2/4] refine unpack Signed-off-by: yan ma --- vllm/model_executor/models/mllama.py | 92 ++++++---------------------- 1 file changed, 19 insertions(+), 73 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index cd4c7cc0b687..6a2e20840fcf 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -1235,88 +1235,34 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def unpack_data(self, image_data: Union[List[List[torch.Tensor]], - List[torch.Tensor], torch.Tensor]): + def unpack_data(self, + image_data: Union[List[torch.Tensor], torch.Tensor], + padding_value=0) -> torch.Tensor: if isinstance(image_data, torch.Tensor): # torch.Tensor return image_data - elif isinstance(image_data[0], torch.Tensor): - bsz = len(image_data) + else: + assert isinstance( + image_data[0], + torch.Tensor), "Image data is not properly batched." # List[torch.Tensor] - if image_data[0].ndim == 1: - # input: [tensor([6, 6], device='cuda:0'), - # tensor([6], device='cuda:0')] - # output: tensor([[6, 6], [6, 0]], device='cuda:0') - max_num_elements = max(tensor.numel() for tensor in image_data) - output_tensor = torch.zeros(bsz, - max_num_elements, - device=image_data[0].device, - dtype=image_data[0].dtype) - for b in range(bsz): - original_data = image_data[b] - num_original = original_data.numel() - output_tensor[b, :num_original] = original_data - return output_tensor - - assert image_data[0].ndim == 2 - # input: [tensor([[1, 1, 1, 1], [1, 1, 1, 1]], device='cuda:0'), - # tensor([[1, 1, 1, 1]], device='cuda:0')] - # output: - # tensor([[[1, 1, 1, 1], - # [1, 1, 1, 1]], - # [[1, 1, 1, 1], - # [0, 0, 0, 0]]], device='cuda:0') bsz = len(image_data) - max_num_elements = max(tensor.shape[0] for tensor in image_data) - output_tensor = torch.zeros(bsz, - max_num_elements, - image_data[0].shape[1], - device=image_data[0].device, - dtype=image_data[0].dtype) - for b in range(bsz): - original_data = image_data[b] - num_original = original_data.shape[0] - output_tensor[b, :num_original, :] = original_data + max_length = max(t.size(0) for t in image_data) + trailing_dims = image_data[0].shape[1:] + for data in image_data: + cur_trailing_dims = data.shape[1:] + assert cur_trailing_dims == trailing_dims + output_tensor = torch.full((bsz, max_length, *trailing_dims), + padding_value, + dtype=image_data[0].dtype, + device=image_data[0].device) + for i, t in enumerate(image_data): + output_tensor[i, :t.size(0)] = t return output_tensor - else: - return image_data - - def unpack_pixel_values(self, pixel_values: Union[List[List[torch.Tensor]], - List[torch.Tensor], - torch.Tensor]): - if isinstance(pixel_values, torch.Tensor): - return pixel_values - else: - pixel_values_unpacked = [] - for b in range(len(pixel_values)): - pixel_values_unpacked_b = [] - for i in range(len(pixel_values[b])): - pixel_values_unpacked_b.append(pixel_values[b][i]) - pixel_values_unpacked.append(pixel_values_unpacked_b) - - max_num_images = max([len(x) for x in pixel_values_unpacked]) - max_num_chunks = max( - max([len(x) for x in y]) for y in pixel_values_unpacked) - bsz = len(pixel_values_unpacked) - out_images = torch.zeros(bsz, - max_num_images, - max_num_chunks, - 3, - self.image_size, - self.image_size, - device=pixel_values[0].device, - dtype=pixel_values[0].dtype) - for b in range(len(pixel_values_unpacked)): - for i in range(len(pixel_values_unpacked[b])): - img = pixel_values_unpacked[b][i] - out_images[b, i, :img.shape[0]] = img - return out_images def _parse_and_validate_image_input(self, **kwargs: object): # tensor with the same shape will be batched together by # MultiModalKwargs.batch, so pixel_values here can be: - # - List[List[torch.Tensor]]: - # with shape (num_tiles, 3, image_res, image_res) # - List[torch.Tensor]: # with shape (num_image, num_tiles, 3, image_res, image_res) # - torch.Tensor: @@ -1351,7 +1297,7 @@ def _parse_and_validate_image_input(self, **kwargs: object): return MllamaImagePixelInputs( type="pixel_values", - data=self.unpack_pixel_values(pixel_values), + data=self.unpack_data(pixel_values), aspect_ratio_ids=self.unpack_data(aspect_ratio_ids), aspect_ratio_mask=self.unpack_data(aspect_ratio_mask)) From 8f9a1ce3f3ccedc799237379f9915666ce68043a Mon Sep 17 00:00:00 2001 From: yan ma Date: Sat, 29 Mar 2025 23:32:42 +0800 Subject: [PATCH 3/4] fix UT Signed-off-by: yan ma --- vllm/model_executor/models/mllama.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 6a2e20840fcf..b78596754a94 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -898,7 +898,8 @@ def forward( q = self.q_norm(q) if attention_mask is not None: - output = self._attention_with_mask(q, k, v, attention_mask, + output = self._attention_with_mask(q.contiguous(), k.contiguous(), + v.contiguous(), attention_mask, kv_range_for_decode) else: output = self.attn( From fd879710dcc17bbc17bf613f0224fbbc2c07b3fc Mon Sep 17 00:00:00 2001 From: yan ma Date: Mon, 31 Mar 2025 16:29:55 +0800 Subject: [PATCH 4/4] Revert "fix UT" This reverts commit 8f9a1ce3f3ccedc799237379f9915666ce68043a. Signed-off-by: yan ma --- vllm/model_executor/models/mllama.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index b78596754a94..6a2e20840fcf 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -898,8 +898,7 @@ def forward( q = self.q_norm(q) if attention_mask is not None: - output = self._attention_with_mask(q.contiguous(), k.contiguous(), - v.contiguous(), attention_mask, + output = self._attention_with_mask(q, k, v, attention_mask, kv_range_for_decode) else: output = self.attn(