Skip to content

Commit ceca6af

Browse files
mgoinlulmer
authored andcommitted
[V1][TPU] TPU multimodal model support for ragged attention (vllm-project#14158)
Signed-off-by: Michael Goin <[email protected]> Signed-off-by: Louis Ulmer <[email protected]>
1 parent 2b559ee commit ceca6af

File tree

2 files changed

+194
-30
lines changed

2 files changed

+194
-30
lines changed

vllm/v1/worker/tpu_model_runner.py

Lines changed: 193 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,18 @@
1515
from vllm.attention.layer import Attention
1616
from vllm.config import VllmConfig
1717
from vllm.forward_context import get_forward_context, set_forward_context
18+
from vllm.inputs import INPUT_REGISTRY
1819
from vllm.logger import init_logger
1920
from vllm.model_executor.model_loader import get_model
21+
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
22+
from vllm.multimodal.utils import group_mm_inputs_by_modality
2023
from vllm.sampling_params import SamplingType
2124
from vllm.utils import LayerBlockType, cdiv, is_pin_memory_available
2225
from vllm.v1.attention.backends.pallas import (NUM_KV_PAGES_PER_BLOCK,
2326
NUM_QUERIES_PER_BLOCK,
2427
PallasAttentionBackend,
2528
PallasMetadata)
29+
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
2630
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
2731
KVCacheSpec)
2832
from vllm.v1.outputs import LogprobsTensors, ModelRunnerOutput
@@ -72,8 +76,10 @@ def __init__(
7276
self.block_size = cache_config.block_size
7377
self.max_model_len = model_config.max_model_len
7478
self.max_num_blocks_per_req = cdiv(self.max_model_len, self.block_size)
75-
self.max_num_tokens = scheduler_config.max_num_batched_tokens
76-
self.max_num_reqs = scheduler_config.max_num_seqs
79+
self.max_num_tokens = _get_padded_number(
80+
scheduler_config.max_num_batched_tokens, NUM_QUERIES_PER_BLOCK)
81+
self.max_num_reqs = _get_padded_number(scheduler_config.max_num_seqs,
82+
NUM_QUERIES_PER_BLOCK)
7783

7884
# Model-related.
7985
self.num_attn_layers = model_config.get_num_layers_by_block_type(
@@ -84,25 +90,38 @@ def __init__(
8490
self.head_size = model_config.get_head_size()
8591
self.hidden_size = model_config.get_hidden_size()
8692

93+
# Multi-modal data support
94+
self.input_registry = INPUT_REGISTRY
95+
self.mm_registry = MULTIMODAL_REGISTRY
96+
self.uses_mrope = model_config.uses_mrope
97+
# TODO: Support M-RoPE (e.g, Qwen2-VL)
98+
assert not self.uses_mrope, "TPU does not support M-RoPE yet."
99+
100+
encoder_compute_budget, encoder_cache_size = compute_encoder_budget(
101+
model_config=model_config,
102+
scheduler_config=scheduler_config,
103+
)
104+
self.max_num_encoder_input_tokens = encoder_compute_budget
105+
self.encoder_cache_size = encoder_cache_size
106+
107+
# Lazy initialization
108+
# self.model: nn.Module # Set after load_model
109+
self.kv_caches: list[torch.Tensor] = []
110+
# req_id -> (input_id -> encoder_output)
111+
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
112+
113+
# Request states.
114+
self.requests: dict[str, CachedRequestState] = {}
87115
# Persistent batch.
88116
self.input_batch = InputBatch(
89117
max_num_reqs=self.max_num_reqs,
90118
max_model_len=self.max_model_len,
91119
max_num_blocks_per_req=self.max_num_blocks_per_req,
92120
device=self.device,
93121
pin_memory=self.pin_memory,
94-
vocab_size=self.model_config.get_vocab_size(),
122+
vocab_size=model_config.get_vocab_size(),
95123
)
96124

97-
# Request states.
98-
self.requests: dict[str, CachedRequestState] = {}
99-
100-
# req_id -> (input_id -> encoder_output)
101-
self.encoder_cache: dict[str, dict[int, torch.Tensor]] = {}
102-
103-
# KV caches for forward pass
104-
self.kv_caches: list[tuple[torch.Tensor, torch.Tensor]] = []
105-
106125
# Cached torch/numpy tensor
107126
# The pytorch tensor and numpy array share the same buffer.
108127
# Sometimes the numpy op is faster so we create both.
@@ -164,6 +183,7 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
164183
# Remove finished requests from the cached states.
165184
for req_id in scheduler_output.finished_req_ids:
166185
self.requests.pop(req_id, None)
186+
self.encoder_cache.pop(req_id, None)
167187

168188
# Remove the finished requests from the persistent batch.
169189
# NOTE(woosuk): There could be an edge case where finished_req_ids and
@@ -177,6 +197,14 @@ def _update_states(self, scheduler_output: "SchedulerOutput") -> bool:
177197
if req_index is not None:
178198
removed_req_indices.append(req_index)
179199

200+
# Free the cached encoder outputs.
201+
for req_id, input_id in scheduler_output.free_encoder_input_ids:
202+
encoder_outputs = self.encoder_cache.get(req_id)
203+
if encoder_outputs is not None:
204+
encoder_outputs.pop(input_id, None)
205+
if not encoder_outputs:
206+
self.encoder_cache.pop(req_id, None)
207+
180208
# Remove the unscheduled requests from the persistent batch.
181209
# NOTE(woosuk): The unscheduled requests are either preempted requests
182210
# or running requests that are not scheduled in this step. We remove
@@ -426,6 +454,92 @@ def _prepare_inputs(self, scheduler_output: "SchedulerOutput"):
426454
logits_indices = query_start_loc[1:] - 1
427455
return attn_metadata, logits_indices
428456

457+
def _execute_encoder(self, scheduler_output: "SchedulerOutput"):
458+
scheduled_encoder_inputs = scheduler_output.scheduled_encoder_inputs
459+
if not scheduled_encoder_inputs:
460+
return
461+
462+
# Batch the multi-modal inputs.
463+
mm_inputs: list[MultiModalKwargs] = []
464+
req_input_ids: list[tuple[str, int]] = []
465+
for req_id, encoder_input_ids in scheduled_encoder_inputs.items():
466+
req_state = self.requests[req_id]
467+
for input_id in encoder_input_ids:
468+
mm_inputs.append(req_state.mm_inputs[input_id])
469+
req_input_ids.append((req_id, input_id))
470+
471+
# Batch mm inputs as much as we can: if a request in the batch has
472+
# multiple modalities or a different modality than the previous one,
473+
# we process it separately to preserve item order.
474+
# FIXME(ywang96): This is a hacky way to deal with multiple modalities
475+
# in the same batch while still being able to benefit from batching
476+
# multimodal inputs. The proper solution should be reordering the
477+
# encoder outputs.
478+
grouped_mm_inputs_list = group_mm_inputs_by_modality(mm_inputs)
479+
480+
encoder_outputs = []
481+
for grouped_mm_inputs in grouped_mm_inputs_list:
482+
batched_mm_inputs = MultiModalKwargs.batch(grouped_mm_inputs)
483+
batched_mm_inputs = MultiModalKwargs.as_kwargs(batched_mm_inputs,
484+
device=self.device)
485+
486+
# Run the encoder.
487+
# `curr_group_outputs` is either of the following:
488+
# 1. A tensor of shape (num_items, feature_size, hidden_size)
489+
# in case feature_size is fixed across all multimodal items.
490+
# 2. A list or tuple (length: num_items) of tensors, each of shape
491+
# (feature_size, hidden_size) in case the feature size is dynamic
492+
# depending on the input multimodal items.
493+
curr_group_outputs = self.model.get_multimodal_embeddings(
494+
**batched_mm_inputs)
495+
496+
for output in curr_group_outputs:
497+
encoder_outputs.append(output)
498+
499+
# Cache the encoder outputs.
500+
for (req_id, input_id), output in zip(req_input_ids, encoder_outputs):
501+
if req_id not in self.encoder_cache:
502+
self.encoder_cache[req_id] = {}
503+
self.encoder_cache[req_id][input_id] = output
504+
505+
def _gather_encoder_outputs(
506+
self,
507+
scheduler_output: "SchedulerOutput",
508+
) -> list[torch.Tensor]:
509+
encoder_outputs: list[torch.Tensor] = []
510+
for req_id in self.input_batch.req_ids:
511+
num_scheduled_tokens = scheduler_output.num_scheduled_tokens[
512+
req_id]
513+
req_state = self.requests[req_id]
514+
num_computed_tokens = req_state.num_computed_tokens
515+
mm_positions = req_state.mm_positions
516+
for i, pos_info in enumerate(mm_positions):
517+
start_pos = pos_info["offset"]
518+
num_encoder_tokens = pos_info["length"]
519+
520+
# The encoder output is needed if the two ranges overlap:
521+
# [num_computed_tokens,
522+
# num_computed_tokens + num_scheduled_tokens) and
523+
# [start_pos, start_pos + num_encoder_tokens)
524+
if start_pos >= num_computed_tokens + num_scheduled_tokens:
525+
# The encoder output is not needed in this step.
526+
break
527+
if start_pos + num_encoder_tokens <= num_computed_tokens:
528+
# The encoder output is already processed and stored
529+
# in the decoder's KV cache.
530+
continue
531+
532+
start_idx = max(num_computed_tokens - start_pos, 0)
533+
end_idx = min(
534+
num_computed_tokens - start_pos + num_scheduled_tokens,
535+
num_encoder_tokens)
536+
assert start_idx < end_idx
537+
assert req_id in self.encoder_cache
538+
assert i in self.encoder_cache[req_id]
539+
encoder_output = self.encoder_cache[req_id][i]
540+
encoder_outputs.append(encoder_output[start_idx:end_idx])
541+
return encoder_outputs
542+
429543
@torch.no_grad()
430544
def execute_model(
431545
self,
@@ -434,16 +548,42 @@ def execute_model(
434548
# Update cached state
435549
self._update_states(scheduler_output)
436550

551+
if self.is_multimodal_model:
552+
# Run the multimodal encoder if any.
553+
self._execute_encoder(scheduler_output)
554+
encoder_outputs = self._gather_encoder_outputs(scheduler_output)
555+
else:
556+
encoder_outputs = []
557+
437558
# Prepare inputs
438559
attn_metadata, logits_indices = self._prepare_inputs(scheduler_output)
439560
total_num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens
440561

562+
if self.is_multimodal_model:
563+
# NOTE(woosuk): To unify token ids and soft tokens (vision
564+
# embeddings), we always use embeddings (rather than token ids)
565+
# as input to the multimodal model, even when the input is text.
566+
if encoder_outputs:
567+
inputs_embeds = self.model.get_input_embeddings(
568+
self.input_ids, encoder_outputs)
569+
else:
570+
inputs_embeds = self.model.get_input_embeddings(self.input_ids)
571+
input_ids = None
572+
else:
573+
# For text-only models, we use token ids as input.
574+
# While it is possible to use embeddings as input just like the
575+
# multimodal models, it is not desirable for performance since
576+
# then the embedding layer is not included in the CUDA graph.
577+
input_ids = self.input_ids
578+
inputs_embeds = None
579+
441580
# Run the decoder
442581
with set_forward_context(attn_metadata, self.vllm_config):
443582
hidden_states = self.model(
444-
token_ids=self.input_ids,
445-
position_ids=self.position_ids,
583+
input_ids=input_ids,
584+
positions=self.position_ids,
446585
kv_caches=self.kv_caches,
586+
inputs_embeds=inputs_embeds,
447587
)
448588
hidden_states = hidden_states[:total_num_scheduled_tokens]
449589
num_reqs = self.input_batch.num_reqs
@@ -538,14 +678,21 @@ def load_model(self) -> None:
538678
fullgraph=True,
539679
dynamic=False)
540680

541-
def dummy_run(
681+
def _dummy_run(
542682
self,
543683
kv_caches,
544684
num_tokens: int,
545685
) -> None:
546-
input_ids = torch.zeros(num_tokens,
547-
dtype=torch.int32,
548-
device=self.device)
686+
if self.is_multimodal_model:
687+
input_ids = None
688+
inputs_embeds = torch.zeros((num_tokens, self.hidden_size),
689+
dtype=self.dtype,
690+
device=self.device)
691+
else:
692+
input_ids = torch.zeros((num_tokens),
693+
dtype=torch.int32,
694+
device=self.device)
695+
inputs_embeds = None
549696
position_ids = torch.zeros(num_tokens,
550697
dtype=torch.int32,
551698
device=self.device)
@@ -571,7 +718,10 @@ def dummy_run(
571718
num_seqs=num_tokens,
572719
)
573720

574-
torch._dynamo.mark_dynamic(input_ids, 0)
721+
if self.is_multimodal_model:
722+
torch._dynamo.mark_dynamic(inputs_embeds, 0)
723+
else:
724+
torch._dynamo.mark_dynamic(input_ids, 0)
575725
torch._dynamo.mark_dynamic(position_ids, 0)
576726
torch._dynamo.mark_dynamic(attn_metadata.slot_mapping, 0)
577727
torch._dynamo.mark_dynamic(attn_metadata.block_tables, 0)
@@ -580,7 +730,12 @@ def dummy_run(
580730

581731
with set_forward_context(attn_metadata, self.vllm_config, 0):
582732
assert self.model is not None
583-
self.model(input_ids, position_ids, kv_caches)
733+
self.model(
734+
input_ids=input_ids,
735+
positions=position_ids,
736+
kv_caches=kv_caches,
737+
inputs_embeds=inputs_embeds,
738+
)
584739

585740
def capture_model(self) -> None:
586741
"""Compile the model."""
@@ -590,11 +745,11 @@ def capture_model(self) -> None:
590745
start = time.perf_counter()
591746
num_tokens = 16
592747
while True:
593-
self.dummy_run(self.kv_caches, num_tokens)
748+
self._dummy_run(self.kv_caches, num_tokens)
594749
logger.info(" -- num_tokens: %d", num_tokens)
595750
xm.mark_step()
596751
xm.wait_device_ops()
597-
if num_tokens >= self.scheduler_config.max_num_batched_tokens:
752+
if num_tokens >= self.max_num_tokens:
598753
break
599754
num_tokens *= 2
600755
end = time.perf_counter()
@@ -647,17 +802,20 @@ def __init__(self, model: nn.Module):
647802

648803
def forward(
649804
self,
650-
token_ids: torch.Tensor,
651-
position_ids: torch.Tensor,
805+
input_ids: torch.Tensor,
806+
positions: torch.Tensor,
652807
kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
808+
inputs_embeds: Optional[torch.Tensor] = None,
653809
) -> torch.Tensor:
654810
"""Executes the forward pass of the model and samples the next token.
655811
656812
Args:
657-
token_ids: The input token IDs of shape [num_tokens].
658-
position_ids: The input position IDs of shape [num_tokens].
813+
input_ids: The input token IDs of shape [num_tokens].
814+
positions: The input position IDs of shape [num_tokens].
659815
kv_caches: The key and value caches. They can be None during the
660816
memory profiling at initialization.
817+
inputs_embeds: The input embeddings of shape [num_tokens,
818+
hidden_size]. It is used for multimodal models.
661819
"""
662820
# Skip this in memory profiling at initialization.
663821
if kv_caches[0][0].numel() > 0:
@@ -684,9 +842,9 @@ def forward(
684842

685843
assert self.model is not None
686844
hidden_states = self.model(
687-
token_ids,
688-
position_ids,
689-
kv_caches,
845+
input_ids=input_ids,
846+
positions=positions,
847+
inputs_embeds=inputs_embeds,
690848
)
691849

692850
return hidden_states
@@ -699,6 +857,12 @@ def compute_logits(
699857
logits = self.model.compute_logits(hidden_states, sampling_metadata)
700858
return logits
701859

860+
def get_multimodal_embeddings(self, *args, **kwargs):
861+
return self.model.get_multimodal_embeddings(*args, **kwargs)
862+
863+
def get_input_embeddings(self, *args, **kwargs):
864+
return self.model.get_input_embeddings(*args, **kwargs)
865+
702866

703867
def _get_padded_number(n: int, multiple: int) -> int:
704868
return ((n + multiple - 1) // multiple) * multiple

vllm/v1/worker/tpu_worker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def determine_available_memory(self) -> int:
134134
self.vllm_config.compilation_config.static_forward_context,
135135
runner_kv_caches)
136136

137-
self.model_runner.dummy_run(
137+
self.model_runner._dummy_run(
138138
runner_kv_caches,
139139
num_tokens=self.scheduler_config.max_num_batched_tokens,
140140
)

0 commit comments

Comments
 (0)