Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import dataclasses
import os
from collections import defaultdict
from contextlib import contextmanager
from dataclasses import dataclass
Expand Down Expand Up @@ -50,8 +49,7 @@
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
ModelInputForGPUWithSamplingMetadata)

FLASHINFER_KV_CACHE_LAYOUT: str = os.getenv("FLASHINFER_KV_CACHE_LAYOUT",
"NHD").upper()
FLASHINFER_KV_CACHE_LAYOUT: str = envs.VLLM_KV_CACHE_LAYOUT or "NHD"


class FlashInferBackend(AttentionBackend):
Expand Down
9 changes: 5 additions & 4 deletions vllm/distributed/kv_transfer/kv_connector/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
"""
KV cache helper for store.
"""

import torch

import vllm.envs as envs
Expand Down Expand Up @@ -94,15 +93,17 @@ def put_kv_to_cache(self, model_executable: torch.nn.Module, keys, values,


def get_kv_connector_cache_layout():
# NOTE (NickLucche) When running disaggregated PD with NIXL, HND layout is
# used for faster transfer.
vllm_config = get_current_vllm_config()
kv_config = vllm_config.kv_transfer_config
if vllm_config.model_config is None:
logger.warning("Unable to detect current VLLM config. " \
if vllm_config.model_config is None or kv_config is None:
logger.warning_once("Unable to detect current VLLM config. " \
"Defaulting to NHD kv cache layout.")
else:
use_mla = vllm_config.model_config.use_mla
if not use_mla and kv_config.kv_connector == "NixlConnector":
logger.info("NixlConnector detected. Setting KV cache " \
logger.info_once("NixlConnector detected. Setting KV cache " \
"layout to HND for better xfer performance.")
return "HND"
return "NHD"
11 changes: 11 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@
VLLM_TOOL_PARSE_REGEX_TIMEOUT_SECONDS: int = 1
VLLM_SLEEP_WHEN_IDLE: bool = False
VLLM_MQ_MAX_CHUNK_BYTES_MB: int = 16
VLLM_KV_CACHE_LAYOUT: Optional[str] = None


def get_default_cache_root():
Expand Down Expand Up @@ -879,6 +880,16 @@ def get_vllm_port() -> Optional[int]:
# processes via zmq.
"VLLM_MQ_MAX_CHUNK_BYTES_MB":
lambda: int(os.getenv("VLLM_MQ_MAX_CHUNK_BYTES_MB", "16")),

# KV Cache layout used throughout vllm.
# Some common values are:
# - NHD
# - HND
# Where N=num_blocks, H=num_heads and D=head_size. The default value will
# leave the layout choice to the backend. Mind that backends may only
# implement and support a subset of all possible layouts.
"VLLM_KV_CACHE_LAYOUT":
lambda: os.getenv("VLLM_KV_CACHE_LAYOUT", None)
}

# --8<-- [end:env-vars-definition]
Expand Down
12 changes: 5 additions & 7 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

Expand Down Expand Up @@ -73,16 +72,15 @@ def get_kv_cache_shape(

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# NOTE When running disaggregated PD with NIXL, HND layout is used for
# faster transfer. `stride_order` indicates the permutation that gets
# `stride_order` indicates the permutation that gets
# us from `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_connector_cache_layout()
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError("Unknown cache layout format %s.", cache_layout)
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order


Expand Down
27 changes: 21 additions & 6 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
from vllm.logger import init_logger
from vllm.v1.attention.backends.flash_attn import use_cascade_attention
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
CommonAttentionMetadata,
get_kv_cache_layout)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.v1.worker.block_table import BlockTable

Expand Down Expand Up @@ -66,6 +67,19 @@ def get_kv_cache_shape(
) -> tuple[int, ...]:
return (num_blocks, 2, block_size, num_kv_heads, head_size)

@staticmethod
def get_kv_cache_stride_order() -> tuple[int, ...]:
# `stride_order` indicates the permutation that gets us from
# `get_kv_cache_shape` to the actual memory layout we want.
cache_layout = get_kv_cache_layout()
if cache_layout == "NHD":
stride_order = (0, 1, 2, 3, 4)
elif cache_layout == "HND":
stride_order = (0, 1, 3, 2, 4)
else:
raise ValueError(f"Unknown cache layout format {cache_layout}.")
return stride_order


@dataclass
class PerLayerParameters:
Expand Down Expand Up @@ -290,7 +304,7 @@ def _get_workspace_buffer(self):
def _get_prefill_wrapper(self):
if self._prefill_wrapper is None:
self._prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
self._get_workspace_buffer(), "NHD")
self._get_workspace_buffer(), get_kv_cache_layout())
return self._prefill_wrapper

def _get_decode_wrapper(self):
Expand All @@ -303,14 +317,14 @@ def _get_decode_wrapper(self):
num_qo_heads // num_kv_heads > 4)
self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self._get_workspace_buffer(),
"NHD",
get_kv_cache_layout(),
use_tensor_cores=use_tensor_cores)
return self._decode_wrapper

def _get_cascade_wrapper(self):
if self._cascade_wrapper is None:
self._cascade_wrapper = MultiLevelCascadeAttentionWrapper(
2, self._get_workspace_buffer(), "NHD")
2, self._get_workspace_buffer(), get_kv_cache_layout())
return self._cascade_wrapper

def _plan(self, attn_metadata: FlashInferMetadata):
Expand Down Expand Up @@ -621,6 +635,7 @@ def forward(
num_decode_tokens = attn_metadata.num_decode_tokens
num_prefill_tokens = attn_metadata.num_prefill_tokens

stride_order = FlashInferBackend.get_kv_cache_stride_order()
# Regular attention (common case).
# Decodes are at the front and prefills are at the back,
# according to reorder_batch()
Expand All @@ -635,7 +650,7 @@ def forward(
assert prefill_wrapper._sm_scale == self.scale
prefill_wrapper.run(
prefill_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[num_decode_tokens:],
Expand All @@ -651,7 +666,7 @@ def forward(
assert decode_wrapper._sm_scale == self.scale
decode_wrapper.run(
decode_query,
kv_cache,
kv_cache.permute(*stride_order),
k_scale=layer._k_scale_float,
v_scale=layer._v_scale_float,
out=output[:num_decode_tokens],
Expand Down
21 changes: 21 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import abc
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar
Expand All @@ -12,6 +13,13 @@
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch

import vllm.envs as envs
from vllm.distributed.kv_transfer.kv_connector.utils import (
get_kv_connector_cache_layout)
from vllm.logger import init_logger

logger = init_logger(__name__)


@dataclass
class CommonAttentionMetadata:
Expand Down Expand Up @@ -119,3 +127,16 @@ def validate_kv_sharing_target(current_layer_name, target_layer_name,
raise ValueError(
error_msg +
f"must be the same type as the current layer ({expected}).")


@functools.lru_cache
def get_kv_cache_layout():
# Override with format specified by the user.
cache_layout = envs.VLLM_KV_CACHE_LAYOUT
if cache_layout is None:
cache_layout = get_kv_connector_cache_layout()
else:
logger.info_once("`FLASHINFER_KV_CACHE_LAYOUT` environment variable " \
"detected. Setting KV cache layout to %s.", cache_layout)

return cache_layout