Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions vllm/multimodal/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from abc import ABC, abstractmethod
from collections import UserDict, defaultdict
from collections.abc import Mapping, Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import partial
from itertools import accumulate
from typing import (
Expand Down Expand Up @@ -167,11 +167,42 @@ class PlaceholderRange:
between `offset` and `offset + length` to assign embeddings to.
"""

def get_num_embeds(self) -> int:
num_embeds: int = field(init=False)
"""
The number of positions that actually result in an output from the encoder.
"""

def __post_init__(self):
if self.is_embed is None:
return self.length
object.__setattr__(self, "num_embeds", self.length)
else:
num_embeds = int(self.is_embed.sum().item())
object.__setattr__(self, "num_embeds", num_embeds)

# Remove leading & tailing False in `is_embed` for easier scheduling
if num_embeds > 0:
true_indices = torch.nonzero(self.is_embed, as_tuple=True)[0]
first_true_index = true_indices[0].item()
last_true_index = true_indices[-1].item()

return int(self.is_embed.sum().item())
start_trim_count = first_true_index
new_length = last_true_index - first_true_index + 1

object.__setattr__(self, "offset", self.offset + start_trim_count)
object.__setattr__(self, "length", new_length)

object.__setattr__(
self,
"is_embed",
self.is_embed[first_true_index : last_true_index + 1],
)
else:
# Seems impossible?
object.__setattr__(self, "length", 0)
object.__setattr__(self, "is_embed", self.is_embed[0:0])

def get_num_embeds(self) -> int:
return self.num_embeds

def __eq__(self, other: object) -> bool:
if not isinstance(other, self.__class__):
Expand Down
2 changes: 1 addition & 1 deletion vllm/v1/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def get_finished_reason(self) -> FinishReason | None:

def get_num_encoder_tokens(self, input_id: int) -> int:
assert input_id < len(self.mm_features)
num_tokens = self.mm_features[input_id].mm_position.length
num_tokens = self.mm_features[input_id].mm_position.num_embeds
return num_tokens

def record_event(
Expand Down
37 changes: 24 additions & 13 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,7 @@
MultiModalBudget,
add_kv_sharing_layers_to_kv_cache_groups,
bind_kv_cache,
gather_mm_placeholders,
sanity_check_mm_encoder_outputs,
scatter_mm_placeholders,
)

if TYPE_CHECKING:
Expand Down Expand Up @@ -1774,10 +1772,7 @@ def _execute_mm_encoder(self, scheduler_output: "SchedulerOutput"):

# Cache the encoder outputs by mm_hash
for (mm_hash, pos_info), output in zip(mm_hashes_pos, encoder_outputs):
self.encoder_cache[mm_hash] = scatter_mm_placeholders(
output,
is_embed=pos_info.is_embed,
)
self.encoder_cache[mm_hash] = output

def _gather_mm_embeddings(
self,
Expand Down Expand Up @@ -1828,20 +1823,36 @@ def _gather_mm_embeddings(
encoder_output = self.encoder_cache.get(mm_hash, None)
assert encoder_output is not None, f"Encoder cache miss for {mm_hash}."

if (is_embed := pos_info.is_embed) is not None:
is_embed = pos_info.is_embed

# retrieve `encoder_output` slice based on `is_embed` mask
encoder_output_slice_start = start_idx
encoder_output_slice_end = end_idx
if is_embed is not None:
num_encoder_output_before_start = is_embed[:start_idx].sum().item()
num_encoder_output_selected = (
is_embed[start_idx:end_idx].sum().item()
)

encoder_output_slice_start = num_encoder_output_before_start
encoder_output_slice_end = (
num_encoder_output_before_start + num_encoder_output_selected
)

mm_embeds_item = encoder_output[
encoder_output_slice_start:encoder_output_slice_end
]
mm_embeds_req.append(mm_embeds_item)

# append `is_mm_embed` mask
if is_embed is not None:
is_embed = is_embed[start_idx:end_idx]

req_start_pos = req_start_idx + start_pos - num_computed_tokens
is_mm_embed[req_start_pos + start_idx : req_start_pos + end_idx] = (
True if is_embed is None else is_embed
)

mm_embeds_item = gather_mm_placeholders(
encoder_output[start_idx:end_idx],
is_embed=is_embed,
)
mm_embeds_req.append(mm_embeds_item)

if self.is_multimodal_pruning_enabled and self.uses_mrope:
assert req_state.mrope_positions is not None
should_sync_mrope_positions = True
Expand Down