Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch

from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.utils import bind_kv_cache


def test_bind_kv_cache():
Expand Down
48 changes: 0 additions & 48 deletions vllm/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import multiprocessing
import time
import weakref
from collections import defaultdict
from collections.abc import Sequence
from multiprocessing import connection
from multiprocessing.process import BaseProcess
Expand All @@ -14,14 +13,12 @@
import torch

from vllm.logger import init_logger
from vllm.model_executor.models.utils import extract_layer_index
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import (get_open_port, get_open_zmq_ipc_path, get_tcp_uri,
kill_process_tree)

if TYPE_CHECKING:
from vllm.attention.layer import Attention
from vllm.v1.engine.coordinator import DPCoordinator
from vllm.v1.engine.utils import (CoreEngineActorManager,
CoreEngineProcManager)
Expand Down Expand Up @@ -275,51 +272,6 @@ def shutdown(procs: list[BaseProcess]):
kill_process_tree(pid)


def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.

This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.

Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0

# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)

for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])

# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]


def copy_slice(from_tensor: torch.Tensor, to_tensor: torch.Tensor,
length: int) -> torch.Tensor:
"""
Expand Down
4 changes: 2 additions & 2 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,13 @@
from vllm.v1.spec_decode.medusa import MedusaProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.block_table import BlockTable
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin

from ..sample.logits_processor import LogitsProcessorManager
from .utils import (gather_mm_placeholders, initialize_kv_cache_for_kv_sharing,
from .utils import (bind_kv_cache, gather_mm_placeholders,
initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs, scatter_mm_placeholders)

if TYPE_CHECKING:
Expand Down
3 changes: 1 addition & 2 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@
LogprobsTensors, ModelRunnerOutput)
from vllm.v1.sample.tpu.metadata import TPUSupportedSamplingMetadata
from vllm.v1.sample.tpu.sampler import Sampler as TPUSampler
from vllm.v1.utils import bind_kv_cache
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
from vllm.v1.worker.tpu_input_batch import CachedRequestState, InputBatch

from .utils import (initialize_kv_cache_for_kv_sharing,
from .utils import (bind_kv_cache, initialize_kv_cache_for_kv_sharing,
sanity_check_mm_encoder_outputs)

if TYPE_CHECKING:
Expand Down
3 changes: 2 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from vllm.v1.kv_cache_interface import (AttentionSpec, KVCacheConfig,
KVCacheSpec)
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.utils import bind_kv_cache, report_usage_stats
from vllm.v1.utils import report_usage_stats
from vllm.v1.worker.tpu_model_runner import TPUModelRunner
from vllm.v1.worker.utils import bind_kv_cache

logger = init_logger(__name__)

Expand Down
52 changes: 51 additions & 1 deletion vllm/v1/worker/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
from collections import defaultdict
from typing import TYPE_CHECKING, Optional

import torch

from vllm.model_executor.models.interfaces import MultiModalEmbeddings
from vllm.model_executor.models.utils import extract_layer_index
from vllm.v1.kv_cache_interface import KVCacheGroupSpec

if TYPE_CHECKING:
from vllm.attention.layer import Attention


def sanity_check_mm_encoder_outputs(
mm_embeddings: MultiModalEmbeddings,
Expand Down Expand Up @@ -110,3 +115,48 @@ def initialize_kv_cache_for_kv_sharing(
kv_caches[layer_name] = kv_caches[target_layer_name]
group_idx = layer_to_kv_cache_group_idx[target_layer_name]
kv_cache_groups[group_idx].layer_names.append(layer_name)


def bind_kv_cache(
kv_caches: dict[str, torch.Tensor],
forward_context: dict[str, "Attention"],
runner_kv_caches: list[torch.Tensor],
) -> None:
"""
Bind the allocated KV cache to both ModelRunner and forward context so
that the KV cache can be used in the forward pass.

This function:
1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
kv_caches.
2) Associates each attention layer in the `forward_context` with its
corresponding KV cache in kv_caches.

Args:
kv_caches: The allocated kv_caches with layer names as keys.
forward_context: The global forward context containing all Attention
layers with layer names as keys.
runner_kv_caches: The kv_cache declared by ModelRunner.
"""
# Bind kv_caches to ModelRunner
assert len(runner_kv_caches) == 0

# Convert kv_caches dict to a list of tensors in the order of layer_index.
index2name = defaultdict(list)
for layer_name in kv_caches:
index2name[extract_layer_index(layer_name)].append(layer_name)

for layer_index in sorted(index2name.keys()):
layer_names = index2name[layer_index]
if len(layer_names) > 1:
# One typical case is encoder-decoder model, e.g., bart.
# The cross attention and self attention in the same decoder layer
# has different layer_name but the same layer_index.
raise NotImplementedError
layer_name = layer_names[0]
runner_kv_caches.append(kv_caches[layer_name])

# Bind kv_caches to forward context
for layer_name, kv_cache in kv_caches.items():
# NOTE: Use list because of v0 PP virtual engine.
forward_context[layer_name].kv_cache = [kv_cache]
Comment on lines +120 to +162
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider making runner_kv_caches an optional parameter to avoid unnecessary computation when the caller doesn't need it. This improves the function's cohesion and efficiency.

def bind_kv_cache(
    kv_caches: dict[str, torch.Tensor],
    forward_context: dict[str, "Attention"],
    runner_kv_caches: Optional[list[torch.Tensor]] = None,
) -> None:
    """
    Bind the allocated KV cache to both ModelRunner and forward context so
    that the KV cache can be used in the forward pass.

    This function:
      1) Fills the ModelRunner's kv cache list (`runner_kv_caches`) with
         kv_caches.
      2) Associates each attention layer in the `forward_context` with its
         corresponding KV cache in kv_caches.

    Args:
        kv_caches: The allocated kv_caches with layer names as keys.
        forward_context: The global forward context containing all Attention
        layers with layer names as keys.
        runner_kv_caches: The kv_cache declared by ModelRunner.
    """
    if runner_kv_caches is not None:
        # Bind kv_caches to ModelRunner
        assert len(runner_kv_caches) == 0

        # Convert kv_caches dict to a list of tensors in the order of layer_index.
        index2name = defaultdict(list)
        for layer_name in kv_caches:
            index2name[extract_layer_index(layer_name)].append(layer_name)

        for layer_index in sorted(index2name.keys()):
            layer_names = index2name[layer_index]
            if len(layer_names) > 1:
                # One typical case is encoder-decoder model, e.g., bart.
                # The cross attention and self attention in the same decoder layer
                # has different layer_name but the same layer_index.
                raise NotImplementedError
            layer_name = layer_names[0]
            runner_kv_caches.append(kv_caches[layer_name])

    # Bind kv_caches to forward context
    for layer_name, kv_cache in kv_caches.items():
        # NOTE: Use list because of v0 PP virtual engine.
        forward_context[layer_name].kv_cache = [kv_cache]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This suggested change is out of the PR scope, we can do it in the next PR.