Skip to content
Merged
230 changes: 230 additions & 0 deletions tests/v1/test_deferred_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project

import pytest
import torch

from vllm.v1.kv_cache.deferred import DeferredWriteManager, ShouldFallback
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.worker.gpu_model_runner import GPUModelRunner


def _make_metadata(draft_token_ids: list[int], per_request: list[int]) -> SpecDecodeMetadata:
total = len(draft_token_ids)
cu = torch.tensor(per_request, dtype=torch.int32)
cu = torch.cumsum(cu, dim=0)
return SpecDecodeMetadata(
draft_token_ids=torch.tensor(draft_token_ids, dtype=torch.int32),
num_draft_tokens=list(per_request),
cu_num_draft_tokens=cu,
target_logits_indices=torch.zeros(total, dtype=torch.int32),
bonus_logits_indices=torch.zeros(len(per_request), dtype=torch.int32),
logits_indices=torch.zeros(total + len(per_request), dtype=torch.int32),
)


def test_deferred_manager_commit_partial_acceptance():
manager = DeferredWriteManager()
assert manager.begin_window([2])

writes: list[tuple[torch.Tensor, torch.Tensor]] = []

def writer(key, value, key_cache, value_cache, slot_mapping, *_):
writes.append((key.clone(), slot_mapping.clone()))

key = torch.arange(4, dtype=torch.float32).view(2, 1, 2)
value = torch.arange(4, dtype=torch.float32).view(2, 1, 2)
slot_mapping = torch.tensor([3, 7], dtype=torch.int32)
key_cache = torch.empty_like(key)
value_cache = torch.empty_like(value)

manager.stage_layer(
layer_id="layer0",
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping,
kv_cache_dtype="fp16",
k_scale=None,
v_scale=None,
writer=writer,
)

mask = torch.tensor([True, False])
manager.commit(mask)

assert len(writes) == 1
committed_key, committed_slots = writes[0]
assert committed_key.shape[0] == 1
assert committed_slots.tolist() == [3]
window_metrics = manager.pop_last_window_metrics()
assert window_metrics == {
"mode": "stage",
"committed": 1,
"rejected": 1,
"fallback": 0,
}


def test_deferred_manager_cancel_flush_writes_all():
manager = DeferredWriteManager()
assert manager.begin_window([1, 1])

writes: list[tuple[str, torch.Tensor]] = []

def writer(key, value, *_args): # pragma: no cover - signature compatibility
writes.append(("commit", key.clone()))

key = torch.randn(1, 1, 2)
value = torch.randn(1, 1, 2)
slot_mapping = torch.tensor([5], dtype=torch.int32)
key_cache = torch.empty_like(key)
value_cache = torch.empty_like(value)

manager.stage_layer(
layer_id="layer0",
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping,
kv_cache_dtype="fp16",
k_scale=None,
v_scale=None,
writer=writer,
)
manager.stage_layer(
layer_id="layer1",
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping,
kv_cache_dtype="fp16",
k_scale=None,
v_scale=None,
writer=writer,
)

manager.cancel_and_flush("test_cancel")
assert len(writes) == 2
assert all(tensor.shape[0] == 1 for _tag, tensor in writes)
window_metrics = manager.pop_last_window_metrics()
assert window_metrics is not None
assert window_metrics.get("fallback") == 1


def test_build_acceptance_mask_matches_expected():
metadata = _make_metadata([10, 11, 20], [2, 1])
sampled = torch.tensor(
[
[10, 99, 0], # second token rejected
[20, 0, 0],
],
dtype=torch.int32,
)

runner = GPUModelRunner.__new__(GPUModelRunner)
mask = runner._build_nwor_acceptance_mask(metadata, sampled)
expected = torch.tensor([True, False, True], dtype=torch.bool)
assert torch.equal(mask.cpu(), expected)


def test_nwor_disabled_env(monkeypatch):
monkeypatch.setenv("VLLM_DISABLE_NWOR", "1")

runner = GPUModelRunner.__new__(GPUModelRunner)
runner.speculative_config = object()
runner._deferred_write_manager = DeferredWriteManager()

metadata = _make_metadata([1, 2], [2])
runner._maybe_begin_nwor_window(metadata)

assert not runner._deferred_write_manager.window_active


def test_fp8_staging_slices_quant_scales():
manager = DeferredWriteManager()
assert manager.begin_window([2])

recorded: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]] = []

def writer(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, k_scale, v_scale):
recorded.append((key.clone(), value.clone(), slot_mapping.clone(), k_scale.clone() if k_scale is not None else None))

key = torch.arange(4, dtype=torch.float32).view(2, 1, 2)
value = torch.arange(4, dtype=torch.float32).view(2, 1, 2)
slot_mapping = torch.tensor([3, 7], dtype=torch.int32)
key_cache = torch.empty_like(key, dtype=torch.uint8)
value_cache = torch.empty_like(value, dtype=torch.uint8)
k_scale = torch.tensor([0.5, 0.7], dtype=torch.float32)
v_scale = torch.tensor([0.6, 0.9], dtype=torch.float32)

manager.stage_layer(
layer_id="layer0",
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping,
kv_cache_dtype="fp8",
k_scale=k_scale,
v_scale=v_scale,
writer=writer,
)

manager.commit(torch.tensor([True, False]))

assert len(recorded) == 1
committed_key, committed_value, slots, committed_k_scale = recorded[0]
assert committed_key.shape[0] == 1
assert torch.equal(slots, torch.tensor([3], dtype=torch.int32))
assert committed_k_scale is None or committed_k_scale.shape[0] == 1
window_metrics = manager.pop_last_window_metrics()
assert window_metrics == {
"mode": "stage",
"committed": 1,
"rejected": 1,
"fallback": 0,
}


def test_nwor_immediate_mode_skips_window():
manager = DeferredWriteManager(mode="immediate")
assert not manager.begin_window([2])
assert manager.get_mode() == "immediate"


def test_commit_failure_triggers_fallback_metrics():
manager = DeferredWriteManager()
assert manager.begin_window([1])

key = torch.randn(1, 1, 2)
value = torch.randn(1, 1, 2)
slot_mapping = torch.tensor([0], dtype=torch.int32)
key_cache = torch.empty_like(key)
value_cache = torch.empty_like(value)

def writer(*_args, **_kwargs):
raise RuntimeError("forced failure")

manager.stage_layer(
layer_id="layer0",
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=slot_mapping,
kv_cache_dtype="fp16",
k_scale=None,
v_scale=None,
writer=writer,
)

with pytest.raises(ShouldFallback):
manager.commit(torch.tensor([True]))

window_metrics = manager.pop_last_window_metrics()
assert window_metrics is not None
assert window_metrics.get("fallback") == 1
6 changes: 6 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@
VLLM_ALLREDUCE_USE_SYMM_MEM: bool = True
VLLM_TUNED_CONFIG_FOLDER: str | None = None
VLLM_DISABLE_PAD_FOR_CUDAGRAPH: bool = False
VLLM_DISABLE_NWOR: bool = False
VLLM_NWOR_MODE: str = "stage"
VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False
VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False
VLLM_NVTX_SCOPES_FOR_PROFILING: bool = False
Expand Down Expand Up @@ -1309,6 +1311,10 @@ def get_vllm_port() -> int | None:
"VLLM_DISABLE_PAD_FOR_CUDAGRAPH": lambda: bool(
int(os.getenv("VLLM_DISABLE_PAD_FOR_CUDAGRAPH", "0"))
),
# Disable No-Write-On-Reject staging for speculative decoding if set to 1.
"VLLM_DISABLE_NWOR": lambda: bool(int(os.getenv("VLLM_DISABLE_NWOR", "0"))),
# Select NWOR mode: "stage" (default) or "immediate" to bypass staging.
"VLLM_NWOR_MODE": lambda: os.getenv("VLLM_NWOR_MODE", "stage"),
# Used to force set up loopback IP
"VLLM_LOOPBACK_IP": lambda: os.getenv("VLLM_LOOPBACK_IP", ""),
# Used to set the process name prefix for vLLM processes.
Expand Down
6 changes: 6 additions & 0 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
logger.debug("Skipping unmatched weight %s", name)
break
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
Expand All @@ -130,6 +133,9 @@ def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
if get_pp_group().world_size == 1 and "embed_tokens." in name:
continue

if name not in params_dict:
logger.debug("Skipping unmatched weight %s", name)
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
Expand Down
26 changes: 17 additions & 9 deletions vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.utils import cdiv
from vllm.v1.kv_cache import record_or_write_kv_cache
from vllm.v1.attention.backends.utils import (
AttentionCGSupport,
AttentionMetadataBuilder,
Expand Down Expand Up @@ -533,15 +534,22 @@ def forward(
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
layer_id = getattr(
layer,
"layer_name",
getattr(layer, "layer_id", layer.__class__.__name__),
)
record_or_write_kv_cache(
writer=reshape_and_cache_flash,
layer_id=layer_id,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)

if self.kv_cache_dtype.startswith("fp8"):
Expand Down
34 changes: 20 additions & 14 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
infer_global_hyperparameters,
split_decodes_and_prefills,
)
from vllm.v1.kv_cache import record_or_write_kv_cache
from vllm.v1.kv_cache_interface import AttentionSpec

FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024
Expand Down Expand Up @@ -922,20 +923,25 @@ def forward(
if self.kv_sharing_target_layer_name is None:
# Reshape the input keys and values and store them in the cache.
# Skip this if sharing KV cache with an earlier attention layer.
# NOTE(woosuk): Here, key and value are padded while slot_mapping is
# not padded. However, we don't need to do key[:num_actual_tokens]
# and value[:num_actual_tokens] because the reshape_and_cache_flash
# op uses the slot_mapping's shape to determine the number of
# actual tokens.
torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
kv_cache[:, 0],
kv_cache[:, 1],
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
# NOTE(woosuk): key/value are padded while slot_mapping is not.
key_cache = kv_cache[:, 0]
value_cache = kv_cache[:, 1]
layer_id = getattr(
layer,
"layer_name",
getattr(layer, "layer_id", layer.__class__.__name__),
)
record_or_write_kv_cache(
writer=torch.ops._C_cache_ops.reshape_and_cache_flash,
layer_id=layer_id,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)

# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
Expand Down
26 changes: 17 additions & 9 deletions vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
AttentionMetadataBuilder,
CommonAttentionMetadata,
)
from vllm.v1.kv_cache import record_or_write_kv_cache
from vllm.v1.kv_cache_interface import AttentionSpec

logger = init_logger(__name__)
Expand Down Expand Up @@ -810,15 +811,22 @@ def forward(
assert self.attn_type == AttentionType.DECODER
key_cache, value_cache = kv_cache.unbind(0)

torch.ops._C_cache_ops.reshape_and_cache_flash(
key,
value,
key_cache,
value_cache,
attn_metadata.slot_mapping,
self.kv_cache_dtype,
layer._k_scale,
layer._v_scale,
layer_id = getattr(
layer,
"layer_name",
getattr(layer, "layer_id", layer.__class__.__name__),
)
record_or_write_kv_cache(
writer=torch.ops._C_cache_ops.reshape_and_cache_flash,
layer_id=layer_id,
key=key,
value=value,
key_cache=key_cache,
value_cache=value_cache,
slot_mapping=attn_metadata.slot_mapping,
kv_cache_dtype=self.kv_cache_dtype,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
)

# View out the block_size dim
Expand Down
Loading