diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 5b6741d74efc..24abd5aff40a 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -40,10 +40,23 @@ class PPTestOptions(NamedTuple): @dataclass class PPTestSettings: parallel_setups: List[ParallelSetup] + # NOTE: the length of distributed_backends and + # vllm_major_versions should be the same, and they + # are first zipped together to iterate over all + # test settings. distributed_backends: List[str] + # vllm major version: "0" for V0, "1" for V1 + vllm_major_versions: List[str] task: TaskOption test_options: PPTestOptions + def __post_init__(self): + if len(self.distributed_backends) != len(self.vllm_major_versions): + raise ValueError( + f"Length mismatch: distributed_backends " + f"({len(self.distributed_backends)}) != " + f"vllm_major_versions ({len(self.vllm_major_versions)})") + @staticmethod def detailed( *, @@ -79,7 +92,9 @@ def detailed( eager_mode=True, chunked_prefill=False), ], - distributed_backends=["mp", "ray"], + # only ray is supported for V1 + distributed_backends=["mp", "ray", "ray"], + vllm_major_versions=["0", "0", "1"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, trust_remote_code=trust_remote_code, @@ -108,6 +123,7 @@ def fast( chunked_prefill=False), ], distributed_backends=["mp"], + vllm_major_versions=["0"], task=task, test_options=PPTestOptions(multi_node_only=multi_node_only, trust_remote_code=trust_remote_code, @@ -120,8 +136,9 @@ def iter_params(self, model_name: str): opts = self.test_options for parallel_setup in self.parallel_setups: - for distributed_backend in self.distributed_backends: - yield (model_name, parallel_setup, distributed_backend, + for backend, vllm_major_version in zip(self.distributed_backends, + self.vllm_major_versions): + yield (model_name, parallel_setup, backend, vllm_major_version, self.task, opts) @@ -244,6 +261,7 @@ def _compare_tp( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available: int, @@ -296,10 +314,13 @@ def _compare_tp( if hf_overrides: common_args.extend(["--hf-overrides", hf_overrides]) - if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 - and chunked_prefill): - # Test Ray ADAG for a subset of the tests + specific_case = tp_size == 2 and pp_size == 2 and chunked_prefill + if distributed_backend == "ray" and (vllm_major_version == "1" + or specific_case): + # For V1, test Ray ADAG for all the tests + # For V0, test Ray ADAG for a subset of the tests pp_env = { + "VLLM_USE_V1": vllm_major_version, "VLLM_USE_RAY_COMPILED_DAG": "1", "VLLM_USE_RAY_SPMD_WORKER": "1", "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", @@ -348,8 +369,8 @@ def _compare_tp( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in TEXT_GENERATION_MODELS.items() for params in settings.iter_params(model_name) @@ -361,6 +382,7 @@ def test_tp_language_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -368,6 +390,7 @@ def test_tp_language_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, @@ -375,8 +398,8 @@ def test_tp_language_generation( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in EMBEDDING_MODELS.items() for params in settings.iter_params(model_name) @@ -388,6 +411,7 @@ def test_tp_language_embedding( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -395,6 +419,7 @@ def test_tp_language_embedding( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, @@ -402,8 +427,8 @@ def test_tp_language_embedding( @pytest.mark.parametrize( - ("model_name", "parallel_setup", "distributed_backend", "task", - "test_options"), + ("model_name", "parallel_setup", "distributed_backend", + "vllm_major_version", "task", "test_options"), [ params for model_name, settings in MULTIMODAL_MODELS.items() for params in settings.iter_params(model_name) @@ -415,6 +440,7 @@ def test_tp_multimodal_generation( model_name: str, parallel_setup: ParallelSetup, distributed_backend: str, + vllm_major_version: str, task: TaskOption, test_options: PPTestOptions, num_gpus_available, @@ -422,6 +448,7 @@ def test_tp_multimodal_generation( _compare_tp(model_name, parallel_setup, distributed_backend, + vllm_major_version, task, test_options, num_gpus_available, diff --git a/vllm/executor/ray_utils.py b/vllm/executor/ray_utils.py index 7b30155971a6..1300205ba647 100644 --- a/vllm/executor/ray_utils.py +++ b/vllm/executor/ray_utils.py @@ -35,7 +35,7 @@ class RayWorkerWrapper(WorkerWrapperBase): """Ray wrapper for vllm.worker.Worker, allowing Worker to be - lazliy initialized after Ray sets CUDA_VISIBLE_DEVICES.""" + lazily initialized after Ray sets CUDA_VISIBLE_DEVICES.""" def __init__(self, *args, **kwargs) -> None: super().__init__(*args, **kwargs) @@ -118,7 +118,14 @@ def execute_model( ) -> "ModelRunnerOutput": self.setup_device_if_necessary() assert self.worker is not None, "Worker is not initialized" - output = self.worker.model_runner.execute_model(scheduler_output) + if isinstance(scheduler_output, tuple): + scheduler_output, intermediate_tensors = scheduler_output + else: + scheduler_output, intermediate_tensors = scheduler_output, None + output = self.worker.model_runner.execute_model( + scheduler_output, intermediate_tensors) + if isinstance(output, IntermediateTensors): + output = scheduler_output, output return output def override_env_vars(self, vars: Dict[str, str]): diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 6888f1a3e182..81e118999cf6 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -424,7 +424,8 @@ def is_kv_cache_type_uniform(kv_cache_spec: KVCacheSpec) -> bool: def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> KVCacheConfig: + available_memory: int, + num_layers: int) -> KVCacheConfig: """ Generates the KV cache configuration for a model with one type of KV cache. Divide the available memory equally among all layers. @@ -433,6 +434,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, vllm_config: The global VllmConfig kv_cache_spec: The kv cache spec of the model available_memory: Memory available for KV cache in bytes. + num_layers: The number of layers in the model. Returns: The generated KVCacheConfig @@ -442,7 +444,7 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, assert len(page_sizes) == 1 page_size = page_sizes.pop() - num_blocks = int(available_memory // page_size // len(kv_cache_spec)) + num_blocks = int(available_memory // page_size // num_layers) num_blocks = max(num_blocks, 0) if vllm_config.cache_config.num_gpu_blocks_override is not None: @@ -472,25 +474,36 @@ def _get_kv_cache_config_uniform_type(vllm_config: VllmConfig, return kv_cache_config -def get_kv_cache_config(vllm_config: VllmConfig, kv_cache_spec: KVCacheSpec, - available_memory: int) -> KVCacheConfig: +def get_kv_cache_configs(vllm_config: VllmConfig, + kv_cache_specs: List[KVCacheSpec], + available_memory: int) -> List[KVCacheConfig]: """ Generates the KV cache configuration for a model TODO: support hybrid models with more than one type of KV cache. Args: vllm_config: The global VllmConfig - kv_cache_spec: The kv cache spec of the model + kv_cache_specs: The kv cache specs of the model available_memory: Memory available for KV cache in bytes. Returns: - The generated KVCacheConfig + The generated KVCacheConfigs """ - check_enough_kv_cache_memory(vllm_config, kv_cache_spec, available_memory) - if is_kv_cache_type_uniform(kv_cache_spec): - # KV cache of all layers are the same, which is true for most models. - # Allocate the same amount of memory for each layer. - return _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, - available_memory) - else: - raise NotImplementedError + # Use the max number of layers to conservatively determine + # the number of blocks. + num_layers = max(len(kv_cache_spec) for kv_cache_spec in kv_cache_specs) + kv_cache_configs = [] + for kv_cache_spec in kv_cache_specs: + check_enough_kv_cache_memory(vllm_config, kv_cache_spec, + available_memory) + if is_kv_cache_type_uniform(kv_cache_spec): + # KV cache of all layers are the same, which is true for + # most models. Allocate the same amount of memory for + # each layer. + kv_cache_configs.append( + _get_kv_cache_config_uniform_type(vllm_config, kv_cache_spec, + available_memory, + num_layers)) + else: + raise NotImplementedError + return kv_cache_configs diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index f3d40aa1e9cb..a341a6075460 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -17,7 +17,7 @@ from vllm.transformers_utils.config import ( maybe_register_config_serialize_by_value) from vllm.utils import get_exception_traceback, zmq_socket_ctx -from vllm.v1.core.kv_cache_utils import get_kv_cache_config +from vllm.v1.core.kv_cache_utils import get_kv_cache_configs from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutputs, EngineCoreProfile, EngineCoreRequest, EngineCoreRequestType, @@ -71,20 +71,25 @@ def _initialize_kv_caches(self, start = time.time() # Get all kv cache needed by the model - kv_cache_spec = self.model_executor.get_kv_cache_spec() + kv_cache_specs = self.model_executor.get_kv_cache_specs() # Profiles the peak memory usage of the model to determine how much # memory can be allocated for kv cache. - availble_gpu_memory = self.model_executor.determine_available_memory() + available_gpu_memory = self.model_executor.determine_available_memory() # Get the kv cache tensor size - kv_cache_config = get_kv_cache_config(vllm_config, kv_cache_spec, - availble_gpu_memory) - num_gpu_blocks = kv_cache_config.num_blocks + kv_cache_configs = get_kv_cache_configs(vllm_config, kv_cache_specs, + available_gpu_memory) + num_gpu_blocks_set = set(config.num_blocks + for config in kv_cache_configs) + assert len(num_gpu_blocks_set) == 1, ( + f"num_gpu_blocks need to be the same across workers, " + f"but they are different: {num_gpu_blocks_set}") + num_gpu_blocks = num_gpu_blocks_set.pop() num_cpu_blocks = 0 # Initialize kv cache and warmup the execution - self.model_executor.initialize(kv_cache_config) + self.model_executor.initialize(kv_cache_configs) elapsed = time.time() - start logger.info(("init engine (profile, create kv cache, " diff --git a/vllm/v1/executor/abstract.py b/vllm/v1/executor/abstract.py index 093be09ae11b..d1ffc891ad69 100644 --- a/vllm/v1/executor/abstract.py +++ b/vllm/v1/executor/abstract.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 -from typing import Type +from typing import List, Type from vllm.config import VllmConfig from vllm.executor.executor_base import ExecutorBase @@ -48,12 +48,12 @@ def get_class(vllm_config: VllmConfig) -> Type["Executor"]: f"{distributed_executor_backend}") return executor_class - def initialize(self, kv_cache_config: KVCacheConfig) -> None: + def initialize(self, kv_cache_configs: List[KVCacheConfig]) -> None: """ Initialize the KV caches and begin the model execution loop of the underlying workers. """ - self.collective_rpc("initialize_cache", args=(kv_cache_config, )) + self.collective_rpc("initialize_cache", args=(kv_cache_configs, )) self.collective_rpc("compile_or_warm_up_model") def determine_available_memory(self) -> int: # in bytes @@ -63,11 +63,9 @@ def determine_available_memory(self) -> int: # in bytes # operators can be applied to all workers. return min(output) - def get_kv_cache_spec(self) -> KVCacheSpec: + def get_kv_cache_specs(self) -> List[KVCacheSpec]: output = self.collective_rpc("get_kv_cache_spec") - for x in output: - assert x == output[0] - return output[0] + return output def execute_model( self, diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index fdbca70bda71..3a8ed5e1b47a 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -12,7 +12,7 @@ from vllm.attention.backends.abstract import AttentionType from vllm.attention.layer import Attention from vllm.config import CompilationLevel, VllmConfig -from vllm.distributed.parallel_state import graph_capture +from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger @@ -21,6 +21,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.utils import group_mm_inputs_by_modality from vllm.sampling_params import SamplingType +from vllm.sequence import IntermediateTensors from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, LayerBlockType, cdiv, is_pin_memory_available) from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, @@ -773,6 +774,7 @@ def get_model(self) -> nn.Module: def execute_model( self, scheduler_output: "SchedulerOutput", + intermediate_tensors: Optional[IntermediateTensors] = None, ) -> ModelRunnerOutput: batch_changed = self._update_states(scheduler_output) @@ -831,8 +833,11 @@ def execute_model( positions=positions, kv_caches=self.kv_caches, attn_metadata=None, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states[:num_scheduled_tokens] sample_hidden_states = hidden_states[logits_indices] logits = self.model.compute_logits(sample_hidden_states, None) @@ -1007,12 +1012,19 @@ def _dummy_run( positions = self.mrope_positions[:, :num_tokens] else: positions = self.positions[:num_tokens] + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=num_tokens, + dtype=self.model_config.dtype, + device=self.device) with set_forward_context(None, self.vllm_config): hidden_states = model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=None, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -1142,6 +1154,8 @@ def profile_run(self) -> None: # Trigger compilation for general shape. hidden_states = self._dummy_run(self.max_num_tokens, dummy_kv_caches) + if not get_pp_group().is_last_rank: + return hidden_states hidden_states = hidden_states[logit_indices] logits = self.model.compute_logits(hidden_states, None) # TODO(woosuk): Consider the memory usage of the sampler. diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index 0adb69073397..d7574a7d1f8c 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -2,7 +2,7 @@ """A GPU worker class.""" import gc import os -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, List, Optional import torch import torch.distributed @@ -195,8 +195,9 @@ def determine_available_memory(self) -> int: def get_kv_cache_spec(self) -> KVCacheSpec: return self.model_runner.get_kv_cache_spec() - def initialize_cache(self, kv_cache_config: KVCacheConfig) -> None: + def initialize_cache(self, kv_cache_configs: List[KVCacheConfig]) -> None: """Allocate GPU KV cache with the specified kv_cache_config.""" + kv_cache_config = kv_cache_configs[self.rank] if self.vllm_config.model_config.enable_sleep_mode: allocator = CuMemAllocator.get_instance() context = allocator.use_memory_pool(tag="kv_cache")