diff --git a/.buildkite/check-wheel-size.py b/.buildkite/check-wheel-size.py index 68aff793ae6a..76f6d7aeca0d 100644 --- a/.buildkite/check-wheel-size.py +++ b/.buildkite/check-wheel-size.py @@ -5,11 +5,11 @@ import sys import zipfile -# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB -# Note that we have 400 MiB quota, please use it wisely. -# See https://github.com/pypi/support/issues/3792 . +# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB +# Note that we have 800 MiB quota, please use it wisely. +# See https://github.com/pypi/support/issues/6326 . # Please also sync the value with the one in Dockerfile. -VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400)) +VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450)) def print_top_10_largest_files(zip_file): diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 49defccbb1fa..3d32121f13ac 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -38,7 +38,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f + GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn diff --git a/docker/Dockerfile b/docker/Dockerfile index 75e8fa49f86c..afbf29bac582 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ # Check the size of the wheel if RUN_WHEEL_CHECK is true COPY .buildkite/check-wheel-size.py check-wheel-size.py # sync the default value with .buildkite/check-wheel-size.py -ARG VLLM_MAX_SIZE_MB=400 +ARG VLLM_MAX_SIZE_MB=450 ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB ARG RUN_WHEEL_CHECK=true RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \ diff --git a/tests/kernels/attention/test_attention_selector.py b/tests/kernels/attention/test_attention_selector.py index aea166da3af2..3c2aaabacae8 100644 --- a/tests/kernels/attention/test_attention_selector.py +++ b/tests/kernels/attention/test_attention_selector.py @@ -22,7 +22,7 @@ def clear_cache(): # Define MLA and non-MLA backends separately DEVICE_MLA_BACKENDS = { - "cuda": ["TRITON_MLA", "FLASHMLA"], + "cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"], "hip": ["TRITON_MLA", "ROCM_AITER_MLA"], "cpu": [], } @@ -98,21 +98,14 @@ def test_env( with patch("vllm.attention.selector.current_platform", RocmPlatform()): if use_mla: - # Validate HIP MLA backend-block_size combinations - valid_combination = ( - (name == "TRITON_MLA" and block_size != 1) - or (name == "ROCM_AITER_MLA" and block_size == 1)) - - if valid_combination: - backend = get_attn_backend(16, - torch.float16, - torch.float16, - block_size, - False, - use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name - assert backend.get_name() == expected - else: + # ROCm MLA backend logic: + # - TRITON_MLA: supported when block_size != 1 + # - ROCM_AITER_MLA: supported when block_size == 1 + # If backend is forced but doesn't match block_size, + # should raise ValueError + + if name == "TRITON_MLA" and block_size == 1: + # TRITON_MLA doesn't support block_size == 1 with pytest.raises(ValueError) as exc_info: get_attn_backend(16, torch.float16, @@ -122,6 +115,27 @@ def test_env( use_mla=use_mla) assert f"The selected backend, {name}" in str( exc_info.value) + elif name == "ROCM_AITER_MLA" and block_size != 1: + # ROCM_AITER_MLA only supports block_size == 1 + with pytest.raises(ValueError) as exc_info: + get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + assert f"The selected backend, {name}" in str( + exc_info.value) + else: + # Valid backend-block_size combination + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected else: backend = get_attn_backend(16, torch.float16, @@ -136,16 +150,57 @@ def test_env( with patch("vllm.attention.selector.current_platform", CudaPlatform()): if use_mla: - if name == "FLASHMLA" and block_size == 64: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - - # only on cuda platforms with specific capability. - is_supported, _ = is_flashmla_supported() - - if not is_supported: - # if platform is not supported then skip this case. - pytest.skip() + # CUDA MLA backend logic: + # - CUTLASS_MLA: only supported with block_size == 128 + # and Blackwell GPUs (SM 10.0), V1 only + # - FLASHMLA: only supported with block_size == 64 + # - FLASH_ATTN_MLA: V1 only + # - TRITON_MLA: fallback for other cases + + if name == "CUTLASS_MLA": + if not use_v1: + # CUTLASS_MLA only supported on V1 engine + pytest.skip( + "CUTLASS_MLA only supported on V1 engine") + elif block_size != 128: + # CUTLASS_MLA only supports block_size == 128 + pytest.skip( + "CUTLASS_MLA only supports block_size 128") + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = "CUTLASS_MLA_VLLM_V1" + assert backend.get_name() == expected + elif name == "FLASHMLA": + if block_size != 64: + # FlashMLA only supports block_size == 64 + pytest.skip("FlashMLA only supports block_size 64") + else: + from vllm.attention.backends.flashmla import ( + is_flashmla_supported) + is_supported, _ = is_flashmla_supported() + if not is_supported: + pytest.skip( + "FlashMLA not supported on this platform") + else: + backend = get_attn_backend(16, + torch.float16, + torch.float16, + block_size, + False, + use_mla=use_mla) + expected = f"{name}_VLLM_V1" if use_v1 else name + assert backend.get_name() == expected + elif name == "FLASH_ATTN_MLA": + if not use_v1: + # FlashAttention MLA only supported on V1 engine + pytest.skip( + "FlashAttention MLA only supported on V1 engine" + ) else: backend = get_attn_backend(16, torch.float16, @@ -153,9 +208,10 @@ def test_env( block_size, False, use_mla=use_mla) - expected = f"{name}_VLLM_V1" if use_v1 else name + expected = "FLASH_ATTN_MLA" assert backend.get_name() == expected else: + # TRITON_MLA or other fallback backend = get_attn_backend(16, torch.float16, torch.float16, diff --git a/tests/v1/attention/test_attention_backends.py b/tests/v1/attention/test_attention_backends.py index e4c07aae0ebe..1ae8b91c347a 100644 --- a/tests/v1/attention/test_attention_backends.py +++ b/tests/v1/attention/test_attention_backends.py @@ -70,22 +70,6 @@ def _convert_dtype_to_torch(dtype): } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - 2, # K and V - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, - kv_cache_spec.head_size, - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( k_contexts: list[torch.Tensor], v_contexts: list[torch.Tensor], diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index 24070358799e..e7cd116fdc83 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -15,7 +15,7 @@ from vllm.v1.kv_cache_interface import FullAttentionSpec BACKENDS_TO_TEST = [ - _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, + _Backend.CUTLASS_MLA, _Backend.FLASHMLA_VLLM_V1, _Backend.FLASH_ATTN_MLA, _Backend.TRITON_MLA_VLLM_V1 ] @@ -69,20 +69,6 @@ def _convert_dtype_to_torch(dtype): } -def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec, - device: torch.device, - num_blocks: int = 100) -> torch.Tensor: - """Create a dummy KV cache tensor for testing.""" - kv_cache = torch.randn( - num_blocks, - kv_cache_spec.block_size, - kv_cache_spec.head_size, # latent dimension - dtype=_convert_dtype_to_torch(kv_cache_spec.dtype), - device=device, - ) - return kv_cache - - def create_and_prepopulate_kv_cache( kv_c_contexts: list[torch.Tensor], k_pe_contexts: list[torch.Tensor], @@ -315,7 +301,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): # 2. Generate data and compute SDPA reference output for MLA all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] - all_sdpa_outputs = [] + all_sdpa_outputs: list[list[torch.Tensor]] = [] kv_c_contexts, k_pe_contexts = [], [] # Create shared MLA weight matrices for consistency across all sequences @@ -331,6 +317,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): device=device) kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + for i, backend in enumerate(BACKENDS_TO_TEST): + all_sdpa_outputs.append([]) + for i in range(batch_size): s_len = seq_lens[i] q_len = query_lens[i] @@ -358,85 +347,93 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): dtype=dtype, device=device) - # Determine if this is decode (single token) - # or prefill (multiple tokens) - is_decode = q_len == 1 + # Determine if this is decode or prefill + is_decode = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + builder_cls, _ = get_attention_backend(backend) + is_decode.append(q_len <= builder_cls.reorder_batch_threshold) # Split q into nope and rope components q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) - if is_decode: - # Decode path: MQA-style attention in latent space - # Transform q_nope to latent space: q_nope @ W_UK - # q_nope: [1, num_heads, qk_nope_head_dim] - # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] - ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, - W_UK) # [1, num_heads, kv_lora_rank] - - # Build MQA attention inputs - # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] - q_mqa = torch.cat([ql_nope, q_pe], dim=-1) - # K: [s_len, kv_lora_rank + qk_rope_head_dim] - # (broadcasted to all heads) - k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) - k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) - # V: [s_len, kv_lora_rank] (broadcasted to all heads) - v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) - - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) - - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, k_sdpa_in, v_sdpa_in, is_causal=False, scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze( - 0) # [1, num_heads, kv_lora_rank] - - # Project back to output space: sdpa_out @ W_UV - sdpa_out_i = torch.einsum("qnl,lnv->qnv", sdpa_out_i, W_UV) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - else: - # Prefill path: MHA-style attention with full sequence - # Apply kv_b_proj to the full kv_c tensor - kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, - kv_b_proj_weight) - k_nope_full, v_full = kv_nope_full.split( - [qk_nope_head_dim, v_head_dim], dim=-1) - - # Build attention inputs for full sequence - q_mha = torch.cat([q_nope, q_pe], - dim=-1) # [q_len, num_heads, total_dim] - k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) - k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) - - # Create custom attention mask: - # - Query tokens can attend to all context tokens - # - Query tokens can only attend to query tokens up to their pos - attn_mask = torch.ones(q_len, - s_len, - dtype=torch.bool, - device=device) - # Apply causal mask only to the query portion (context_len onwards) - causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) - attn_mask[:, context_len:] = causal_mask - - # SDPA expects (N, H, L, D) - q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) - k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) - v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) - - # Single attention call with custom mask - sdpa_out_i = torch.nn.functional.scaled_dot_product_attention( - q_sdpa_in, - k_sdpa_in, - v_sdpa_in, - attn_mask=attn_mask, - scale=scale) - sdpa_out_i = sdpa_out_i.transpose(1, 2).squeeze(0) - sdpa_out_i = sdpa_out_i.flatten(start_dim=-2) - - all_sdpa_outputs.append(sdpa_out_i) + ####################################################### + # Decode path: MQA-style attention in latent space + # Transform q_nope to latent space: q_nope @ W_UK + # q_nope: [1, num_heads, qk_nope_head_dim] + # W_UK: [kv_lora_rank, num_heads, qk_nope_head_dim] + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, + W_UK) # [1, num_heads, kv_lora_rank] + + # Build MQA attention inputs + # Q: [1, num_heads, kv_lora_rank + qk_rope_head_dim] + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + # K: [s_len, kv_lora_rank + qk_rope_head_dim] + # (broadcasted to all heads) + k_mqa = torch.cat([kv_c_full, k_pe_full.squeeze(1)], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_q_heads, -1) + # V: [s_len, kv_lora_rank] (broadcasted to all heads) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_q_heads, -1) + + # Create custom attention mask for decode path: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their position + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out_i_decode = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_decode = sdpa_out_i_decode.transpose(1, 2).squeeze( + 0) # [1, num_heads, kv_lora_rank] + + # Project back to output space: sdpa_out @ W_UV + sdpa_out_i_decode = torch.einsum("qnl,lnv->qnv", sdpa_out_i_decode, + W_UV) + sdpa_out_i_decode = sdpa_out_i_decode.flatten(start_dim=-2) + + ####################################################### + # Prefill path: MHA-style attention with full sequence + # Apply kv_b_proj to the full kv_c tensor + kv_nope_full = torch.einsum("sl,lnh->snh", kv_c_full, kv_b_proj_weight) + k_nope_full, v_full = kv_nope_full.split( + [qk_nope_head_dim, v_head_dim], dim=-1) + + # Build attention inputs for full sequence + q_mha = torch.cat([q_nope, q_pe], + dim=-1) # [q_len, num_heads, total_dim] + k_pe_full_expanded = k_pe_full.expand(-1, num_q_heads, -1) + k_full = torch.cat([k_nope_full, k_pe_full_expanded], dim=-1) + + # Create custom attention mask: + # - Query tokens can attend to all context tokens + # - Query tokens can only attend to query tokens up to their pos + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + # Apply causal mask only to the query portion (context_len onwards) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, context_len:] = causal_mask + + # SDPA expects (N, H, L, D) + q_sdpa_in = q_mha.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_full.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_full.unsqueeze(0).transpose(1, 2) + + # Single attention call with custom mask + sdpa_out_i_prefill = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out_i_prefill = sdpa_out_i_prefill.transpose(1, 2).squeeze(0) + sdpa_out_i_prefill = sdpa_out_i_prefill.flatten(start_dim=-2) + + for i, backend in enumerate(BACKENDS_TO_TEST): + if is_decode[i]: + all_sdpa_outputs[i].append(sdpa_out_i_decode) + else: + all_sdpa_outputs[i].append(sdpa_out_i_prefill) # Inputs for vLLM MLA backends are just the new tokens all_q_vllm.append(q_c) @@ -451,7 +448,9 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): query_vllm = torch.cat(all_q_vllm, dim=0) kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) - sdpa_output = torch.cat(all_sdpa_outputs, dim=0) + sdpa_outputs = [] + for i, backend in enumerate(BACKENDS_TO_TEST): + sdpa_outputs.append(torch.cat(all_sdpa_outputs[i], dim=0)) # Create mock kv_b_proj using the same weights as reference implementation from vllm.model_executor.layers.linear import ColumnParallelLinear @@ -486,7 +485,7 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): randomize_blocks=True) # 4. Run vLLM backends and compare - for backend_name in BACKENDS_TO_TEST: + for i, backend_name in enumerate(BACKENDS_TO_TEST): backend_output = run_attention_backend( backend_name, kv_cache_spec, ["placeholder"], vllm_config, device, common_attn_metadata, query_vllm, kv_c_vllm, k_pe_vllm, kv_cache, @@ -494,12 +493,12 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): mock_kv_b_proj) # Check shape and dtype consistency - assert backend_output.shape == sdpa_output.shape, ( + assert backend_output.shape == sdpa_outputs[i].shape, ( f"[{backend_name}] shape {backend_output.shape} != " - f"SDPA shape {sdpa_output.shape}") - assert backend_output.dtype == sdpa_output.dtype, ( + f"SDPA shape {sdpa_outputs[i].shape}") + assert backend_output.dtype == sdpa_outputs[i].dtype, ( f"[{backend_name}] dtype {backend_output.dtype} != " - f"SDPA dtype {sdpa_output.dtype}") + f"SDPA dtype {sdpa_outputs[i].dtype}") assert torch.isfinite(backend_output).all(), ( f"[{backend_name}] produced non-finite values") @@ -508,12 +507,13 @@ def test_backend_correctness(dist_init, batch_spec_name: str, model: str): rtol = 1e-2 atol = 5e-1 - max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item() + max_diff = torch.max(torch.abs(backend_output - + sdpa_outputs[i])).item() max_rel_diff = torch.max( - torch.abs(backend_output - sdpa_output) / - torch.abs(sdpa_output)).item() + torch.abs(backend_output - sdpa_outputs[i]) / + torch.abs(sdpa_outputs[i])).item() all_close = torch.allclose(backend_output, - sdpa_output, + sdpa_outputs[i], rtol=rtol, atol=atol) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index 6a08cdc56f73..5c49566240df 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -139,6 +139,8 @@ def get_attention_backend(backend_name: _Backend): "vllm.v1.attention.backends.mla.cutlass_mla.CutlassMLABackend", _Backend.FLASHMLA_VLLM_V1: "vllm.v1.attention.backends.mla.flashmla.FlashMLABackend", + _Backend.FLASH_ATTN_MLA: + "vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend", _Backend.TRITON_MLA_VLLM_V1: "vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend", } diff --git a/vllm/attention/utils/fa_utils.py b/vllm/attention/utils/fa_utils.py index f8b00565f051..dc0af7e28e3e 100644 --- a/vllm/attention/utils/fa_utils.py +++ b/vllm/attention/utils/fa_utils.py @@ -68,5 +68,18 @@ def flash_attn_supports_fp8() -> bool: current_platform.get_device_capability().major == 9 +def flash_attn_supports_mla(): + from vllm.platforms import current_platform + if current_platform.is_cuda(): + try: + from vllm.vllm_flash_attn.flash_attn_interface import ( + is_fa_version_supported) + return is_fa_version_supported(3) \ + and current_platform.get_device_capability()[0] == 9 + except (ImportError, AssertionError): + pass + return False + + def is_flash_attn_varlen_func_available() -> bool: return current_platform.is_cuda() or current_platform.is_xpu() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index d4dd545dd43a..71ee90040f37 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -1488,6 +1488,8 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool: "TRITON_MLA", "CUTLASS_MLA", "FLASHMLA", + "FLASHMLA_VLLM_V1", + "FLASH_ATTN_MLA", "FLASHINFER", "FLASHINFER_VLLM_V1", "ROCM_AITER_MLA", diff --git a/vllm/envs.py b/vllm/envs.py index 1232bd7bf963..56adb83e8de1 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -463,6 +463,7 @@ def get_vllm_port() -> Optional[int]: # - "ROCM_FLASH": use ROCmFlashAttention # - "FLASHINFER": use flashinfer # - "FLASHMLA": use FlashMLA + # - "FLASH_ATTN_MLA": use FlashAttention for MLA "VLLM_ATTENTION_BACKEND": lambda: os.getenv("VLLM_ATTENTION_BACKEND", None), diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 5cbb7346436e..614ff241fd3d 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -223,9 +223,30 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, if use_mla: # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here - if selected_backend == _Backend.CUTLASS_MLA or ( - cls.is_device_capability(100) and selected_backend is None - and block_size == 128): + + from vllm.attention.ops.flashmla import is_flashmla_supported + from vllm.attention.utils.fa_utils import flash_attn_supports_mla + + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( + selected_backend is None and cls.is_device_capability(100) + and block_size == 128) + use_flashmla = selected_backend in [ + _Backend.FLASHMLA, _Backend.FLASHMLA_VLLM_V1 + ] or (selected_backend is None and is_flashmla_supported()[0]) + use_flashattn = selected_backend == _Backend.FLASH_ATTN_MLA or ( + selected_backend is None and flash_attn_supports_mla()) + use_triton = selected_backend == _Backend.TRITON_MLA or ( + selected_backend is None) + + def _get_version(name, import_suffix) -> str: + if use_v1: + logger.info_once(f"Using {name} backend on V1 engine.") + return f"vllm.v1.attention.backends.mla.{import_suffix}" + else: + logger.info_once(f"Using {name} backend.") + return f"vllm.attention.backends.{import_suffix}" + + if use_cutlassmla: if use_v1: logger.info_once("Using Cutlass MLA backend on V1 engine.") return ("vllm.v1.attention.backends.mla." @@ -233,36 +254,27 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype, else: logger.warning( "Cutlass MLA backend is only supported on V1 engine") - if selected_backend == _Backend.TRITON_MLA or block_size != 64: - if use_v1: - logger.info_once("Using Triton MLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "triton_mla.TritonMLABackend") - else: - logger.info("Using Triton MLA backend.") - return "vllm.attention.backends.triton_mla.TritonMLABackend" - else: - from vllm.attention.backends.flashmla import ( - is_flashmla_supported) - if not is_flashmla_supported()[0]: - logger.warning( - "FlashMLA backend is not supported due to %s", - is_flashmla_supported()[1]) - elif block_size != 64: + if use_flashmla: + if block_size != 64: logger.warning( "FlashMLA backend is not supported for block size %d" " (currently only supports block size 64).", block_size) else: - if use_v1: - logger.info_once( - "Using FlashMLA backend on V1 engine.") - return ("vllm.v1.attention.backends.mla." - "flashmla.FlashMLABackend") - else: - logger.info("Using FlashMLA backend.") - return ("vllm.attention.backends." - "flashmla.FlashMLABackend") + return _get_version("FlashMLA", "flashmla.FlashMLABackend") + if use_flashattn: + if use_v1: + logger.info_once( + "Using FlashAttention MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla." + "flashattn_mla.FlashAttnMLABackend") + else: + logger.warning( + "FlashAttention MLA backend is only supported on V1 " + "engine.") + if use_triton: + return _get_version("Triton MLA", + "triton_mla.TritonMLABackend") if use_v1: FLASHINFER_V1 = "vllm.v1.attention.backends.flashinfer.FlashInferBackend" # noqa: E501 FLEX_ATTENTION_V1 = "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501 diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index ad12f7f788cf..cb620542b89f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -52,9 +52,10 @@ class _Backend(enum.Enum): FLASHINFER_VLLM_V1 = enum.auto() TRITON_MLA = enum.auto() # Supported by V1 TRITON_MLA_VLLM_V1 = enum.auto() - FLASHMLA_VLLM_V1 = enum.auto() - FLASHMLA = enum.auto() # Supported by V1 CUTLASS_MLA = enum.auto() + FLASHMLA = enum.auto() # Supported by V1 + FLASHMLA_VLLM_V1 = enum.auto() + FLASH_ATTN_MLA = enum.auto() # Supported by V1 PALLAS = enum.auto() PALLAS_VLLM_V1 = enum.auto() IPEX = enum.auto() diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index 2f275b8b23b1..fc1738579787 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -317,7 +317,8 @@ def build(self, num_reqs = common_attn_metadata.num_reqs num_actual_tokens = common_attn_metadata.num_actual_tokens num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens =\ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) page_size = self.page_size max_q_len = common_attn_metadata.max_query_len diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index f08b6d7f177c..ac0034b5dcf0 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -52,8 +52,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) attn_metadata = LinearAttentionMetadata( num_prefills=num_prefills, diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 97a1aa86dda0..7cbfa2c2c9a5 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -50,8 +50,9 @@ def build( query_start_loc.device) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None padded_decodes = num_decodes diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index ed30884fdbc9..f3e6cd7430e0 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -115,8 +115,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) # Compute seq_idx, chunk_indices and chunk_offsets for prefill only if num_prefills > 0: diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 9f93b50b075b..b4c9aae254ea 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -578,11 +578,13 @@ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata): prefill.prefill_main = self._fi_prefill_main prefill.prefill_chunks = self._fi_prefill_chunks - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor): + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> MLACommonDecodeMetadata: return MLACommonDecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, ) def build_for_cudagraph_capture( @@ -618,6 +620,7 @@ def build(self, query_start_loc = common_attn_metadata.query_start_loc query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu seq_lens = common_attn_metadata.seq_lens + seq_lens_cpu = common_attn_metadata.seq_lens_cpu query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] @@ -625,7 +628,8 @@ def build(self, query_seq_lens_cpu) num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ - split_decodes_and_prefills(common_attn_metadata) + split_decodes_and_prefills(common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) assert num_decodes + num_prefills == num_reqs assert num_decode_tokens + num_prefill_tokens == num_tokens @@ -725,7 +729,10 @@ def build(self, if num_decodes > 0: decode_metadata = self._build_decode( block_table_tensor=block_table_tensor[:num_decodes, ...], - seq_lens=seq_lens[:num_decodes], + seq_lens_cpu=seq_lens_cpu[:num_decodes], + seq_lens_device=seq_lens[:num_decodes], + query_start_loc_cpu=query_start_loc_cpu[:num_decodes + 1], + query_start_loc_device=query_start_loc[:num_decodes + 1], ) attn_metadata = self.metadata_cls( diff --git a/vllm/v1/attention/backends/mla/flashattn_mla.py b/vllm/v1/attention/backends/mla/flashattn_mla.py new file mode 100644 index 000000000000..0e08307ddf84 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashattn_mla.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionLayer, AttentionType, + is_quantized_kv_cache) +from vllm.attention.utils.fa_utils import (flash_attn_supports_mla, + get_flash_attn_version) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.v1.attention.backends.mla.common import (MLACommonBackend, + MLACommonDecodeMetadata, + MLACommonImpl, + MLACommonMetadata, + MLACommonMetadataBuilder) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.vllm_flash_attn import flash_attn_varlen_func, get_scheduler_metadata + +logger = init_logger(__name__) + + +class FlashAttnMLABackend(MLACommonBackend): + + @staticmethod + def get_name() -> str: + return "FLASH_ATTN_MLA" + + @staticmethod + def get_metadata_cls() -> type["FlashAttnMLAMetadata"]: + return FlashAttnMLAMetadata + + @staticmethod + def get_builder_cls() -> type["FlashAttnMLAMetadataBuilder"]: + return FlashAttnMLAMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashAttnMLAImpl"]: + return FlashAttnMLAImpl + + +@dataclass +class FlashAttnMLADecodeMetadata(MLACommonDecodeMetadata): + query_start_loc: torch.Tensor + max_query_len: int + max_seq_len: int + scheduler_metadata: Optional[torch.Tensor] = None + + +@dataclass +class FlashAttnMLAMetadata(MLACommonMetadata[FlashAttnMLADecodeMetadata]): + pass + + +class FlashAttnMLAMetadataBuilder( + MLACommonMetadataBuilder[FlashAttnMLAMetadata]): + reorder_batch_threshold: ClassVar[int] = 512 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + super().__init__(kv_cache_spec, layer_names, vllm_config, device, + FlashAttnMLAMetadata) + self.fa_aot_schedule = (get_flash_attn_version() == 3) + + def _schedule_decode(self, num_reqs, cu_query_lens, max_query_len, seqlens, + max_seq_len, causal): + if self.fa_aot_schedule: + return get_scheduler_metadata( + batch_size=num_reqs, + max_seqlen_q=max_query_len, + max_seqlen_k=max_seq_len, + num_heads_q=self.num_heads, + num_heads_kv=1, + headdim=self.mla_dims.qk_rope_head_dim, + cache_seqlens=seqlens, + qkv_dtype=self.kv_cache_spec.dtype, + headdim_v=self.mla_dims.kv_lora_rank, + page_size=self.page_size, + cu_seqlens_q=cu_query_lens, + causal=causal, + ) + return None + + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor + ) -> FlashAttnMLADecodeMetadata: + query_lens_cpu = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]) + max_query_len = query_lens_cpu.max().item() + max_seq_len = seq_lens_cpu.max().item() + + scheduler_metadata = self._schedule_decode( + num_reqs=seq_lens_cpu.numel(), + cu_query_lens=query_start_loc_device, + max_query_len=max_query_len, + seqlens=seq_lens_device, + max_seq_len=max_seq_len, + causal=True, + ) + + return FlashAttnMLADecodeMetadata( + block_table=block_table_tensor, + seq_lens=seq_lens_device, + query_start_loc=query_start_loc_device, + max_query_len=max_query_len, + max_seq_len=max_seq_len, + scheduler_metadata=scheduler_metadata, + ) + + +class FlashAttnMLAImpl(MLACommonImpl[FlashAttnMLAMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + + assert flash_attn_supports_mla(), \ + "FlashAttnMLA is not supported on this device" + + unsupported_features = [alibi_slopes, sliding_window, logits_soft_cap] + if any(unsupported_features): + raise NotImplementedError( + "FlashAttnMLAImpl does not support one of the following: " + "alibi_slopes, sliding_window, logits_soft_cap") + + if attn_type != AttentionType.DECODER: + raise NotImplementedError("Encoder self-attention and " + "encoder/decoder cross-attention " + "are not implemented for " + "FlashAttnMLAImpl") + + if is_quantized_kv_cache(self.kv_cache_dtype): + raise NotImplementedError( + "FlashAttnMLA V1 with FP8 KV cache not yet supported") + + def _forward_decode( + self, + q_nope: torch.Tensor, + q_pe: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + attn_metadata: FlashAttnMLAMetadata, + layer: AttentionLayer, + ) -> torch.Tensor: + assert kv_c_and_k_pe_cache.numel() > 0 + assert attn_metadata.decode is not None + + if self.kv_cache_dtype.startswith("fp8"): + raise NotImplementedError( + "FP8 FlashAttention MLA not yet supported") + + kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank] + k_pe_cache = kv_c_and_k_pe_cache[..., self.kv_lora_rank:] + + o = flash_attn_varlen_func( + q=q_pe, + k=k_pe_cache.unsqueeze(-2), # Add head dim of 1 + v=kv_c_cache.unsqueeze(-2), # Add head dim of 1 + q_v=q_nope, + max_seqlen_q=attn_metadata.decode.max_query_len, + cu_seqlens_q=attn_metadata.decode.query_start_loc, + max_seqlen_k=attn_metadata.decode.max_seq_len, + seqused_k=attn_metadata.decode.seq_lens, + block_table=attn_metadata.decode.block_table, + softmax_scale=self.scale, + causal=True, + fa_version=3, # only version 3 is supported + scheduler_metadata=attn_metadata.decode.scheduler_metadata, + ) + + return self._v_up_proj(o) diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 1c50144d4790..df617ab7a8ea 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -85,11 +85,13 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], device=self.device, dtype=torch.int32) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> FlashMLADecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> FlashMLADecodeMetadata: tile_scheduler_metadata, num_splits = \ get_mla_metadata( - seq_lens, + seq_lens_device, self.num_q_heads, 1, # MQA for the decode path ) @@ -123,7 +125,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, return FlashMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, tile_scheduler_metadata=tile_scheduler_metadata, num_splits=num_splits, ) diff --git a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py index 870cc600388e..42670093daa9 100644 --- a/vllm/v1/attention/backends/mla/rocm_aiter_mla.py +++ b/vllm/v1/attention/backends/mla/rocm_aiter_mla.py @@ -104,12 +104,14 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], dtype=torch.int32, device=device) - def _build_decode(self, block_table_tensor: torch.Tensor, - seq_lens: torch.Tensor) -> AiterMLADecodeMetadata: + def _build_decode( + self, block_table_tensor: torch.Tensor, seq_lens_cpu: torch.Tensor, + seq_lens_device: torch.Tensor, query_start_loc_cpu: torch.Tensor, + query_start_loc_device: torch.Tensor) -> AiterMLADecodeMetadata: page_size = self.kv_cache_spec.block_size - block_table_bounds = (seq_lens + page_size - 1) // page_size + block_table_bounds = (seq_lens_device + page_size - 1) // page_size device = self.device - num_reqs = seq_lens.size(0) + num_reqs = seq_lens_device.size(0) mask = (torch.arange(block_table_tensor.size(1), dtype=block_table_tensor.dtype, @@ -117,7 +119,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, < block_table_bounds.unsqueeze(1)) paged_kv_indices = block_table_tensor[mask] - paged_kv_last_page_len = seq_lens % page_size + paged_kv_last_page_len = seq_lens_device % page_size paged_kv_last_page_len = torch.where(paged_kv_last_page_len == 0, page_size, paged_kv_last_page_len) @@ -156,7 +158,7 @@ def _build_decode(self, block_table_tensor: torch.Tensor, attn_metadata = AiterMLADecodeMetadata( block_table=block_table_tensor, - seq_lens=seq_lens, + seq_lens=seq_lens_device, paged_kv_indptr=paged_kv_indptr, paged_kv_indices=paged_kv_indices, paged_kv_last_page_len=paged_kv_last_page_len, diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index d80ced8ec876..fcbf0c7b5356 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -58,8 +58,9 @@ def build(self, state_indices_tensor = common_attn_metadata.block_table_tensor[:, 0] num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) has_initial_states = None if num_prefills > 0: #[batch,] @@ -78,4 +79,4 @@ def build(self, has_initial_states=has_initial_states, state_indices_tensor=state_indices_tensor, ) - return attn_metadata \ No newline at end of file + return attn_metadata diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index 7f888c113574..c59ff32cf7c2 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -3,7 +3,7 @@ """Attention layer with XFormersAttention.""" from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING, ClassVar, Optional import torch @@ -197,6 +197,8 @@ def decode_metadata(self) -> Optional["XFormersAttentionMetadata"]: class XFormersAttentionMetadataBuilder( AttentionMetadataBuilder[XFormersAttentionMetadata]): + reorder_batch_threshold: ClassVar[int] = 1 + def __init__( self, kv_cache_spec: AttentionSpec, @@ -212,9 +214,10 @@ def __init__( def reorder_batch(self, input_batch: "InputBatch", scheduler_output: "SchedulerOutput") -> bool: - return reorder_batch_to_split_decodes_and_prefills(input_batch, - scheduler_output, - decode_threshold=1) + return reorder_batch_to_split_decodes_and_prefills( + input_batch, + scheduler_output, + decode_threshold=self.reorder_batch_threshold) def build( self, @@ -223,8 +226,9 @@ def build( fast_build: bool = False, ) -> XFormersAttentionMetadata: num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = ( - split_decodes_and_prefills(common_attn_metadata, - decode_threshold=1)) + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold)) num_actual_tokens = common_attn_metadata.num_actual_tokens q_start_loc = common_attn_metadata.query_start_loc