Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,12 @@
from vllm.platforms import current_platform
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata

from ..conftest import VllmRunner
from ..models.utils import check_outputs_equal
from ..utils import multi_gpu_test

MODELS = [
"facebook/opt-125m",
"google/gemma-2-2b-it",
"meta-llama/Llama-3.2-1B",
]

Expand All @@ -42,8 +43,6 @@ def test_vllm_gc_ed():
@pytest.mark.parametrize("enforce_eager", [False, True])
def test_models(
hf_runner,
vllm_runner,
example_prompts,
model: str,
backend: str,
dtype: str,
Expand All @@ -54,15 +53,27 @@ def test_models(
if backend == "FLASHINFER" and current_platform.is_rocm():
pytest.skip("Flashinfer does not support ROCm/HIP.")

if backend == "XFORMERS" and model == "google/gemma-2-2b-it":
pytest.skip(
"XFORMERS does not support gemma2 with full context length.")

os.environ["VLLM_ATTENTION_BACKEND"] = backend

# 5042 tokens for gemma2
# gemma2 has alternating sliding window size of 4096
# we need a prompt with more than 4096 tokens to test the sliding window
prompt = "The following numbers of the sequence " + ", ".join(
str(i) for i in range(1024)) + " are:"
example_prompts = [prompt]

with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)

with vllm_runner(model,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
with VllmRunner(model,
max_model_len=8192,
dtype=dtype,
enforce_eager=enforce_eager,
gpu_memory_utilization=0.7) as vllm_model:
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

check_outputs_equal(
Expand Down
12 changes: 10 additions & 2 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,26 @@ def __init__(
quant_config: Optional[QuantizationConfig] = None,
blocksparse_params: Optional[Dict[str, Any]] = None,
logits_soft_cap: Optional[float] = None,
per_layer_sliding_window: Optional[int] = None,
prefix: str = "",
) -> None:
super().__init__()
if per_layer_sliding_window is not None:
# per-layer sliding window
sliding_window = per_layer_sliding_window
elif cache_config is not None:
# model-level sliding window
sliding_window = cache_config.sliding_window
else:
sliding_window = None

if cache_config is not None:
kv_cache_dtype = cache_config.cache_dtype
block_size = cache_config.block_size
sliding_window = cache_config.sliding_window
is_attention_free = cache_config.is_attention_free
else:
kv_cache_dtype = "auto"
block_size = 16
sliding_window = None
is_attention_free = False
if num_kv_heads is None:
num_kv_heads = num_heads
Expand Down
29 changes: 20 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,15 +233,26 @@ def __init__(
(self.hf_text_config.model_type in ["gemma2"]))

if (not self.disable_sliding_window and has_interleaved_attention):
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
f"{self.hf_text_config.model_type} has interleaved attention, "
"which is currently not supported by vLLM. Disabling sliding "
"window and capping the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
if envs.VLLM_ATTENTION_BACKEND == "XFORMERS":
sliding_window_len_min = get_min_sliding_window(
self.hf_text_config.sliding_window)

print_warning_once(
f"{self.hf_text_config.model_type} has interleaved "
"attention, which is currently not supported by the "
"XFORMERS backend. Disabling sliding window and capping "
"the max length to the sliding window size "
f"({sliding_window_len_min}).")
self.disable_sliding_window = True
else:
# for a model with interleaved attention,
# the scheduler and the model treat it as full attention
# (i.e., not dropping any tokens outside the window).
# only the attention layer itself is aware of the sliding
# window, and use the window size to compute the attention.
self.hf_text_config.interleaved_sliding_window = sliding_window
delattr(self.hf_text_config, "sliding_window")
sliding_window = None

self.max_model_len = _get_and_verify_max_len(
hf_config=self.hf_text_config,
Expand Down
13 changes: 7 additions & 6 deletions vllm/model_executor/models/gemma2.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,19 +143,20 @@ def __init__(self,
is_neox_style=True,
)

# FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window = (layer_idx % 2 == 1
and config.sliding_window is not None)
del use_sliding_window # Unused.
# reference:
# https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa
use_sliding_window = (layer_idx % 2 == 0 and
config.interleaved_sliding_window is not None)
sliding_window = config.interleaved_sliding_window if \
use_sliding_window else None
self.attn = Attention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
cache_config=cache_config,
quant_config=quant_config,
logits_soft_cap=attn_logits_soft_cap,
per_layer_sliding_window=sliding_window,
prefix=f"{prefix}.attn")

def forward(
Expand Down