diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index dec2e0acab6b..d368c0c4dde0 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from collections import UserDict, defaultdict from collections.abc import Mapping, Sequence -from dataclasses import dataclass +from dataclasses import dataclass, field from functools import partial from itertools import accumulate from typing import ( @@ -167,11 +167,42 @@ class PlaceholderRange: between `offset` and `offset + length` to assign embeddings to. """ - def get_num_embeds(self) -> int: + num_embeds: int = field(init=False) + """ + The number of positions that actually result in an output from the encoder. + """ + + def __post_init__(self): if self.is_embed is None: - return self.length + object.__setattr__(self, "num_embeds", self.length) + else: + num_embeds = int(self.is_embed.sum().item()) + object.__setattr__(self, "num_embeds", num_embeds) + + # Remove leading & tailing False in `is_embed` for easier scheduling + if num_embeds > 0: + true_indices = torch.nonzero(self.is_embed, as_tuple=True)[0] + first_true_index = true_indices[0].item() + last_true_index = true_indices[-1].item() - return int(self.is_embed.sum().item()) + start_trim_count = first_true_index + new_length = last_true_index - first_true_index + 1 + + object.__setattr__(self, "offset", self.offset + start_trim_count) + object.__setattr__(self, "length", new_length) + + object.__setattr__( + self, + "is_embed", + self.is_embed[first_true_index : last_true_index + 1], + ) + else: + # Seems impossible? + object.__setattr__(self, "length", 0) + object.__setattr__(self, "is_embed", self.is_embed[0:0]) + + def get_num_embeds(self) -> int: + return self.num_embeds def __eq__(self, other: object) -> bool: if not isinstance(other, self.__class__): diff --git a/vllm/v1/request.py b/vllm/v1/request.py index 864b0eb7fa41..c4ab5f6e8f7b 100644 --- a/vllm/v1/request.py +++ b/vllm/v1/request.py @@ -192,7 +192,7 @@ def get_finished_reason(self) -> FinishReason | None: def get_num_encoder_tokens(self, input_id: int) -> int: assert input_id < len(self.mm_features) - num_tokens = self.mm_features[input_id].mm_position.length + num_tokens = self.mm_features[input_id].mm_position.num_embeds return num_tokens def record_event( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 9e394dbb592e..66d7a74cf4d7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -148,9 +148,7 @@ MultiModalBudget, add_kv_sharing_layers_to_kv_cache_groups, bind_kv_cache, - gather_mm_placeholders, sanity_check_mm_encoder_outputs, - scatter_mm_placeholders, ) if TYPE_CHECKING: @@ -1774,10 +1772,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"): # Cache the encoder outputs by mm_hash for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs): - self.encoder_cache[mm_hash] = scatter_mm_placeholders( - output, - is_embed=pos_info.is_embed, - ) + self.encoder_cache[mm_hash] = output def _gather_mm_embeddings( self, @@ -1828,7 +1823,29 @@ def _gather_mm_embeddings( encoder_output = self.encoder_cache.get(mm_hash, None) assert encoder_output is not None, f"Encoder cache miss for {mm_hash}." - if (is_embed := pos_info.is_embed) is not None: + is_embed = pos_info.is_embed + + # retrieve `encoder_output` slice based on `is_embed` mask + encoder_output_slice_start = start_idx + encoder_output_slice_end = end_idx + if is_embed is not None: + num_encoder_output_before_start = is_embed[:start_idx].sum().item() + num_encoder_output_selected = ( + is_embed[start_idx:end_idx].sum().item() + ) + + encoder_output_slice_start = num_encoder_output_before_start + encoder_output_slice_end = ( + num_encoder_output_before_start + num_encoder_output_selected + ) + + mm_embeds_item = encoder_output[ + encoder_output_slice_start:encoder_output_slice_end + ] + mm_embeds_req.append(mm_embeds_item) + + # append `is_mm_embed` mask + if is_embed is not None: is_embed = is_embed[start_idx:end_idx] req_start_pos = req_start_idx + start_pos - num_computed_tokens @@ -1836,12 +1853,6 @@ def _gather_mm_embeddings( True if is_embed is None else is_embed ) - mm_embeds_item = gather_mm_placeholders( - encoder_output[start_idx:end_idx], - is_embed=is_embed, - ) - mm_embeds_req.append(mm_embeds_item) - if self.is_multimodal_pruning_enabled and self.uses_mrope: assert req_state.mrope_positions is not None should_sync_mrope_positions = True