Skip to content

Commit 4d19b7b

Browse files
committed
Subclass attn metadata for cross-decoder layers to propagate logits_indices
Signed-off-by: Yong Hoon Shin <[email protected]>
1 parent 94df2f1 commit 4d19b7b

File tree

9 files changed

+264
-56
lines changed

9 files changed

+264
-56
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import gc
5+
import random
6+
from typing import Optional, Union
7+
8+
import pytest
9+
import torch
10+
11+
from vllm import LLM, SamplingParams
12+
from vllm.config import CompilationConfig, CompilationLevel
13+
from vllm.forward_context import get_forward_context
14+
from vllm.model_executor.models.gemma3n import Gemma3nForConditionalGeneration
15+
from vllm.model_executor.models.registry import ModelRegistry
16+
from vllm.sequence import IntermediateTensors
17+
18+
from ...utils import fork_new_process_for_each_test
19+
20+
21+
class TestGemma3nForConditionalGeneration(Gemma3nForConditionalGeneration):
22+
def forward(
23+
self,
24+
input_ids: torch.Tensor,
25+
positions: torch.Tensor,
26+
intermediate_tensors: Optional[IntermediateTensors] = None,
27+
inputs_embeds: Optional[torch.Tensor] = None,
28+
**kwargs,
29+
) -> Union[torch.Tensor, IntermediateTensors]:
30+
hidden_states = self.model(input_ids, positions, intermediate_tensors,
31+
inputs_embeds, **kwargs)
32+
attn_metadata = get_forward_context().attn_metadata
33+
# attn_metadata is None during dummy runs
34+
if attn_metadata is not None:
35+
assert isinstance(attn_metadata, dict) # true in V1
36+
# Layer 20 is a cross-decoder layer in YOCO
37+
layer_attn_metadata = attn_metadata['model.language_model.layers.20.self_attn.attn']
38+
if hasattr(layer_attn_metadata, 'logits_indices_padded'):
39+
# This field is only set when
40+
# enable_kv_sharing_truncated_prefill is set to True
41+
assert self.cache_config.enable_kv_sharing_truncated_prefill
42+
logits_indices_padded = (
43+
layer_attn_metadata.logits_indices_padded
44+
)
45+
assert logits_indices_padded is not None
46+
num_logits_indices = layer_attn_metadata.num_logits_indices
47+
assert num_logits_indices > 0
48+
49+
logits_hs = hidden_states[logits_indices_padded]
50+
hidden_states = torch.randn_like(hidden_states)
51+
gen_indices = logits_indices_padded[:num_logits_indices]
52+
# Only set logits for logits_indices to valid values
53+
hidden_states[gen_indices] = logits_hs[:num_logits_indices]
54+
55+
return hidden_states
56+
57+
@pytest.fixture
58+
def test_prompts():
59+
"""
60+
Adapted from tests/v1/e2e/test_spec_decode.py
61+
"""
62+
prompt_types = ["repeat", "sentence"]
63+
# Setting higher num prompts increases the chance of numerics mismatch
64+
# due to matrix multiplication numerics depending on batch dimension
65+
num_prompts = 10
66+
prompts = []
67+
68+
random.seed(0)
69+
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
70+
71+
for kind in random_prompt_type_choices:
72+
word_choices = ["test", "temp", "hello", "where"]
73+
word = random.choice(word_choices)
74+
if kind == "repeat":
75+
prompt = f"""please repeat the word '{word}' 10 times."""
76+
elif kind == "sentence":
77+
prompt = f"""please give a ten-word sentence that
78+
uses the word {word} at least once."""
79+
else:
80+
raise ValueError(f"Unknown prompt type: {kind}")
81+
prompts.append(prompt)
82+
83+
return prompts
84+
85+
86+
@fork_new_process_for_each_test
87+
@pytest.mark.parametrize("enforce_eager", [True, False])
88+
def test_kv_sharing_truncated_prefill(
89+
monkeypatch: pytest.MonkeyPatch,
90+
enforce_eager: bool,
91+
test_prompts: list[str],
92+
):
93+
ModelRegistry.register_model("Gemma3nForConditionalGeneration", TestGemma3nForConditionalGeneration)
94+
sampling_params = SamplingParams(temperature=0.0, max_tokens=100)
95+
compilation_config = CompilationConfig(
96+
# This allows vLLM compilation backend to handle allocating and
97+
# managing buffers for cudagraph
98+
cudagraph_copy_inputs=True,
99+
level=CompilationLevel.
100+
PIECEWISE if not enforce_eager else CompilationLevel.NO_COMPILATION)
101+
102+
with monkeypatch.context() as m:
103+
m.setenv("VLLM_USE_V1", "1")
104+
105+
llm = LLM(
106+
model="google/gemma-3n-E2B-it",
107+
enforce_eager=enforce_eager,
108+
compilation_config=compilation_config,
109+
)
110+
ref_responses = llm.generate(test_prompts, sampling_params)
111+
112+
del llm
113+
gc.collect()
114+
torch.cuda.empty_cache()
115+
116+
llm = LLM(model="google/gemma-3n-E2B-it",
117+
enforce_eager=enforce_eager,
118+
compilation_config=compilation_config,
119+
enable_kv_sharing_truncated_prefill=True)
120+
optimized_responses = llm.generate(test_prompts, sampling_params)
121+
122+
misses = 0
123+
124+
for ref_response, optimized_response in zip(ref_responses,
125+
optimized_responses):
126+
if ref_response.outputs[0].text != optimized_response.outputs[
127+
0].text:
128+
misses += 1
129+
130+
assert misses == 0

vllm/config.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1684,6 +1684,10 @@ class CacheConfig:
16841684
num_cpu_blocks: Optional[int] = field(default=None, init=False)
16851685
"""The number of blocks to allocate for CPU memory."""
16861686

1687+
enable_kv_sharing_truncated_prefill: bool = False
1688+
"""Skip prefill for tokens where applicable in YOCO-like KV-sharing
1689+
setups (e.g. Gemma3n)"""
1690+
16871691
def compute_hash(self) -> str:
16881692
"""
16891693
WARNING: Whenever a new field is added to this config,

vllm/engine/arg_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,9 @@ class EngineArgs:
438438
# DEPRECATED
439439
enable_prompt_adapter: bool = False
440440

441+
enable_kv_sharing_truncated_prefill: bool = \
442+
CacheConfig.enable_kv_sharing_truncated_prefill
443+
441444
def __post_init__(self):
442445
# support `EngineArgs(compilation_config={...})`
443446
# without having to manually construct a
@@ -686,6 +689,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
686689
**cache_kwargs["cpu_offload_gb"])
687690
cache_group.add_argument("--calculate-kv-scales",
688691
**cache_kwargs["calculate_kv_scales"])
692+
cache_group.add_argument(
693+
"--enable-kv-sharing-truncated-prefill",
694+
**cache_kwargs["enable_kv_sharing_truncated_prefill"])
689695

690696
# Multimodal related configs
691697
multimodal_kwargs = get_kwargs(MultiModalConfig)
@@ -1056,6 +1062,8 @@ def create_engine_config(
10561062
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
10571063
cpu_offload_gb=self.cpu_offload_gb,
10581064
calculate_kv_scales=self.calculate_kv_scales,
1065+
enable_kv_sharing_truncated_prefill=self.
1066+
enable_kv_sharing_truncated_prefill,
10591067
)
10601068

10611069
# Get the current placement group if Ray is initialized and

vllm/entrypoints/llm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,7 @@ def __init__(
193193
override_pooler_config: Optional[PoolerConfig] = None,
194194
compilation_config: Optional[Union[int, dict[str, Any],
195195
CompilationConfig]] = None,
196+
enable_kv_sharing_truncated_prefill: bool = False,
196197
**kwargs,
197198
) -> None:
198199
"""LLM constructor."""
@@ -266,6 +267,8 @@ def __init__(
266267
mm_processor_kwargs=mm_processor_kwargs,
267268
override_pooler_config=override_pooler_config,
268269
compilation_config=compilation_config_instance,
270+
enable_kv_sharing_truncated_prefill=\
271+
enable_kv_sharing_truncated_prefill,
269272
**kwargs,
270273
)
271274

vllm/envs.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,6 @@
143143
VLLM_USE_CUDNN_PREFILL: bool = False
144144
VLLM_ENABLE_CUDAGRAPH_GC: bool = False
145145
VLLM_LOOPBACK_IP: str = ""
146-
VLLM_COMPUTE_PADDED_LOGITS_INDICES: bool = False
147146

148147

149148
def get_default_cache_root():
@@ -992,10 +991,6 @@ def get_vllm_port() -> Optional[int]:
992991
# The default value is "VLLM".
993992
"VLLM_PROCESS_NAME_PREFIX":
994993
lambda: os.getenv("VLLM_PROCESS_NAME_PREFIX", "VLLM"),
995-
996-
# Enable computing and propagating cudagraph padded logits indices
997-
"VLLM_COMPUTE_PADDED_LOGITS_INDICES":
998-
lambda: bool(int(os.getenv("VLLM_COMPUTE_PADDED_LOGITS_INDICES", "0"))),
999994
}
1000995

1001996
# --8<-- [end:env-vars-definition]

vllm/forward_context.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ class ForwardContext:
9595
# set dynamically for each forward pass
9696
dp_metadata: Optional[DPMetadata] = None
9797
skip_cuda_graphs: bool = False
98-
logits_indices_padded: Optional[torch.Tensor] = None
9998

10099

101100
_forward_context: Optional[ForwardContext] = None
@@ -117,7 +116,6 @@ def set_forward_context(
117116
num_tokens: Optional[int] = None,
118117
num_tokens_across_dp: Optional[torch.Tensor] = None,
119118
skip_cuda_graphs: bool = False,
120-
logits_indices_padded: Optional[torch.Tensor] = None,
121119
):
122120
"""A context manager that stores the current forward context,
123121
can be attention metadata, etc.
@@ -143,7 +141,6 @@ def set_forward_context(
143141
attn_metadata=attn_metadata,
144142
dp_metadata=dp_metadata,
145143
skip_cuda_graphs=skip_cuda_graphs,
146-
logits_indices_padded=logits_indices_padded,
147144
)
148145

149146
try:

vllm/model_executor/models/gemma3n.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,15 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
581581
lambda prefix: Gemma3nDecoderLayer(
582582
config, cache_config, quant_config, prefix=prefix),
583583
prefix=f"{prefix}.layers")
584+
585+
first_kv_shared_layer_idx = (config.num_hidden_layers -
586+
config.num_kv_shared_layers)
587+
# Layer idx 0-19 are self-decoder layers in You Only Cache Once (YOCO)
588+
self.self_decoder_layers = self.layers[:first_kv_shared_layer_idx]
589+
# Layer idx 20-34 are cross-decoder layers in YOCO
590+
# Refer to YOCO paper https://arxiv.org/abs/2405.05254
591+
self.cross_decoder_layers = self.layers[first_kv_shared_layer_idx:]
592+
584593
self.norm = RMSNorm(
585594
config.hidden_size,
586595
eps=config.rms_norm_eps,
@@ -646,7 +655,17 @@ def forward(
646655
hidden_states = torch.stack(hidden_states, dim=0)
647656

648657
# Transformer blocks.
649-
for layer_idx, layer in enumerate(self.layers):
658+
for layer_idx, layer in enumerate(self.self_decoder_layers):
659+
# [altup_num_inputs, num_tokens, hidden_size]
660+
hidden_states = layer(
661+
positions=positions,
662+
hidden_states=hidden_states,
663+
per_layer_input=per_layer_inputs[:, layer_idx, :],
664+
**kwargs,
665+
)
666+
667+
for layer_idx, layer in enumerate(self.cross_decoder_layers,
668+
start=len(self.self_decoder_layers)):
650669
# [altup_num_inputs, num_tokens, hidden_size]
651670
hidden_states = layer(
652671
positions=positions,
@@ -771,6 +790,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
771790
del lora_config # Unused.
772791
super().__init__()
773792
self.config = config
793+
self.cache_config = vllm_config.cache_config
774794
self.model = Gemma3nModel(vllm_config=vllm_config,
775795
prefix=maybe_prefix(prefix, "model"))
776796
self.logits_processor = LogitsProcessor(

vllm/v1/attention/backends/utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
import abc
44
import functools
55
from abc import abstractmethod
6-
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, ClassVar, Generic, Optional, TypeVar
6+
from dataclasses import dataclass, make_dataclass
7+
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Optional, TypeVar
88

99
import numpy as np
1010
import torch
@@ -501,3 +501,16 @@ def reorder_batch_to_split_decodes_and_prefills(
501501
modified_batch = True
502502

503503
return modified_batch
504+
505+
506+
def subclass_attention_metadata(
507+
name_prefix: str,
508+
metadata_cls: Any,
509+
fields: list[tuple[str, Any, Any]],
510+
) -> Any:
511+
"""
512+
Return a new subclass of `metadata_cls` with additional fields
513+
"""
514+
name: str = name_prefix + metadata_cls.__name__ # type: ignore
515+
Wrapped = make_dataclass(name, fields, bases=(metadata_cls, ))
516+
return Wrapped

0 commit comments

Comments
 (0)