diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7f16baa65a64..fcba253d159f 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -14,11 +14,12 @@ from vllm.platforms import current_platform from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from ..conftest import VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test MODELS = [ - "facebook/opt-125m", + "google/gemma-2-2b-it", "meta-llama/Llama-3.2-1B", ] @@ -42,8 +43,6 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, - vllm_runner, - example_prompts, model: str, backend: str, dtype: str, @@ -54,15 +53,27 @@ def test_models( if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": + pytest.skip( + "XFORMERS does not support gemma2 with full context length.") + os.environ["VLLM_ATTENTION_BACKEND"] = backend + # 5042 tokens for gemma2 + # gemma2 has alternating sliding window size of 4096 + # we need a prompt with more than 4096 tokens to test the sliding window + prompt = "The following numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner(model, + max_model_len=8192, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 8acbeaf12b0c..cb4dedf481c7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -40,18 +40,26 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - sliding_window = cache_config.sliding_window is_attention_free = cache_config.is_attention_free else: kv_cache_dtype = "auto" block_size = 16 - sliding_window = None is_attention_free = False if num_kv_heads is None: num_kv_heads = num_heads diff --git a/vllm/config.py b/vllm/config.py index bb02c2ad4c7d..730b069e076f 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -233,15 +233,26 @@ def __init__( (self.hf_text_config.model_type in ["gemma2"])) if (not self.disable_sliding_window and has_interleaved_attention): - sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) - - print_warning_once( - f"{self.hf_text_config.model_type} has interleaved attention, " - "which is currently not supported by vLLM. Disabling sliding " - "window and capping the max length to the sliding window size " - f"({sliding_window_len_min}).") - self.disable_sliding_window = True + if envs.VLLM_ATTENTION_BACKEND == "XFORMERS": + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + print_warning_once( + f"{self.hf_text_config.model_type} has interleaved " + "attention, which is currently not supported by the " + "XFORMERS backend. Disabling sliding window and capping " + "the max length to the sliding window size " + f"({sliding_window_len_min}).") + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + delattr(self.hf_text_config, "sliding_window") + sliding_window = None self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 839130364ef4..9309cced61bb 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -143,12 +143,12 @@ def __init__(self, is_neox_style=True, ) - # FIXME(woosuk): While Gemma 2 uses sliding window attention for every - # odd layer, vLLM currently ignores it and uses global attention for - # all layers. - use_sliding_window = (layer_idx % 2 == 1 - and config.sliding_window is not None) - del use_sliding_window # Unused. + # reference: + # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + use_sliding_window = (layer_idx % 2 == 0 and + config.interleaved_sliding_window is not None) + sliding_window = config.interleaved_sliding_window if \ + use_sliding_window else None self.attn = Attention(self.num_heads, self.head_dim, self.scaling, @@ -156,6 +156,7 @@ def __init__(self, cache_config=cache_config, quant_config=quant_config, logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, prefix=f"{prefix}.attn") def forward(