From e188540f2e3a7492cafad35c79a55505b5ab4016 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 01:46:37 -0800 Subject: [PATCH 1/2] [V1][PP] Cache Intermediate Tensors Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 40 ++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 12b7ce18fbc2..f154421143e1 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -2,7 +2,7 @@ import gc import time -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, cast +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union, cast import numpy as np import torch @@ -91,6 +91,10 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() + # Parallelism related. + self.tp_size = parallel_config.tensor_parallel_size + self.pp_size = parallel_config.pipeline_parallel_size + # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY @@ -149,6 +153,7 @@ def __init__( self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) + # self.intermediate_tensors # Set after load_model # Only relevant for models using M-RoPE (e.g, Qwen2-VL) if self.uses_mrope: @@ -869,7 +874,7 @@ def execute_model( self, scheduler_output: "SchedulerOutput", intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> ModelRunnerOutput: + ) -> Union[ModelRunnerOutput, torch.Tensor]: batch_changed = self._update_states(scheduler_output) if self.is_multimodal_model: @@ -919,6 +924,14 @@ def execute_model( else: positions = self.positions[:num_input_tokens] + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + intermediate_tensors = IntermediateTensors({ + k: v[:num_input_tokens] + for k, v in self.intermediate_tensors.items() + }) + # Run the decoder. # Use persistent buffers for CUDA graphs. with set_forward_context(attn_metadata, self.vllm_config): @@ -931,7 +944,9 @@ def execute_model( inputs_embeds=inputs_embeds, ) if not get_pp_group().is_last_rank: + # For mid-pipeline stages, return the hidden states. return hidden_states + hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1118,12 +1133,21 @@ def _dummy_run( positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions[:num_tokens] - intermediate_tensors = None - if not get_pp_group().is_first_rank: - intermediate_tensors = self.model.make_empty_intermediate_tensors( - batch_size=num_tokens, - dtype=self.model_config.dtype, - device=self.device) + + if get_pp_group().is_first_rank: + intermediate_tensors = None + else: + if not hasattr(self, "intermediate_tensors"): + self.intermediate_tensors = ( + self.model.make_empty_intermediate_tensors( + batch_size=self.max_num_tokens, + dtype=self.model_config.dtype, + device=self.device)) + intermediate_tensors = IntermediateTensors({ + k: v[:num_tokens] + for k, v in self.intermediate_tensors.items() + }) + with set_forward_context(None, self.vllm_config): hidden_states = model( input_ids=input_ids, From 766541d582284f888a93fa3a5ceac11fcb230d44 Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Sun, 16 Feb 2025 01:47:48 -0800 Subject: [PATCH 2/2] minor Signed-off-by: Woosuk Kwon --- vllm/v1/worker/gpu_model_runner.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index f154421143e1..d3995b619d31 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -91,10 +91,6 @@ def __init__( self.head_size = model_config.get_head_size() self.hidden_size = model_config.get_hidden_size() - # Parallelism related. - self.tp_size = parallel_config.tensor_parallel_size - self.pp_size = parallel_config.pipeline_parallel_size - # Multi-modal data support self.input_registry = INPUT_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY