Skip to content

Commit 35a7704

Browse files
committed
init
Signed-off-by: NickLucche <[email protected]>
1 parent 322cb02 commit 35a7704

File tree

17 files changed

+74
-30
lines changed

17 files changed

+74
-30
lines changed

tests/v1/attention/test_mla_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161

6262
BACKEND_BLOCK_SIZES = {}
6363
for backend in BACKENDS_TO_TEST:
64-
supported_sizes = backend.get_class().supported_kernel_block_sizes
64+
supported_sizes = backend.get_class().get_supported_kernel_block_size()
6565
if supported_sizes:
6666
default_size = supported_sizes[0]
6767
block_size = (

tests/v1/worker/test_gpu_model_runner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,9 @@ def _make_mock_backend_for_kernel_block_size(
185185
supported_sizes: list[int | MultipleOf],
186186
):
187187
class _MockBackend:
188-
supported_kernel_block_sizes = supported_sizes
188+
@staticmethod
189+
def get_supported_kernel_block_size():
190+
return supported_sizes
189191

190192
return _MockBackend()
191193

vllm/attention/backends/abstract.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,12 @@ class AttentionBackend(ABC):
4646
# makes sure the output tensor is allocated inside the cudagraph.
4747
accept_output_buffer: bool = False
4848
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
49-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(1)]
5049
supported_kv_cache_dtypes: ClassVar[list["CacheDType"]] = ["auto"]
5150

51+
@staticmethod
52+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
53+
return [MultipleOf(1)]
54+
5255
@staticmethod
5356
@abstractmethod
5457
def get_name() -> str:
@@ -115,10 +118,11 @@ def supports_block_size(cls, block_size: int | None) -> bool:
115118
if block_size not in valid_sizes:
116119
return False
117120

118-
if not cls.supported_kernel_block_sizes:
121+
supported_kernel_block_sizes = cls.get_supported_kernel_block_size()
122+
if not supported_kernel_block_sizes:
119123
return True
120124

121-
for supported_size in cls.supported_kernel_block_sizes:
125+
for supported_size in supported_kernel_block_sizes:
122126
if isinstance(supported_size, MultipleOf):
123127
supported_size = supported_size.base
124128
# With hybrid_blocks feature, the framework-level block size

vllm/v1/attention/backends/flash_attn.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
get_scheduler_metadata,
3333
reshape_and_cache_flash,
3434
)
35-
from vllm.config import VllmConfig, get_layers_from_vllm_config
35+
from vllm.config import VllmConfig, get_current_vllm_config, get_layers_from_vllm_config
3636
from vllm.config.cache import CacheDType
3737
from vllm.distributed.parallel_state import get_dcp_group
3838
from vllm.logger import init_logger
@@ -56,11 +56,18 @@
5656
class FlashAttentionBackend(AttentionBackend):
5757
accept_output_buffer: bool = True
5858
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
59-
# NOTE(tdoublep): while in principle, FA supports
60-
# MultipleOf(16), these are the block sizes that do not
61-
# suffer from the NaN propagation problem described here:
62-
# https://github.com/Dao-AILab/flash-attention/issues/1974
63-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
59+
60+
@staticmethod
61+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
62+
vllm_config = get_current_vllm_config()
63+
model_config = vllm_config.model_config
64+
if model_config and model_config.is_hybrid:
65+
# NOTE(tdoublep): while in principle, FA supports
66+
# MultipleOf(16), these are the block sizes that do not
67+
# suffer from the NaN propagation problem described here:
68+
# https://github.com/Dao-AILab/flash-attention/issues/1974
69+
return [16, 32, 64]
70+
return [MultipleOf(16)]
6471

6572
@staticmethod
6673
def get_name() -> str:

vllm/v1/attention/backends/flashinfer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from flashinfer.decode import _get_range_buf, trtllm_batch_decode_with_kv_cache
1717
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
1818
from flashinfer.utils import FP4Tensor
19-
from typing_extensions import override
2019

2120
from vllm import envs
2221
from vllm.attention.backends.abstract import (
@@ -275,17 +274,19 @@ def run(
275274
class FlashInferBackend(AttentionBackend):
276275
accept_output_buffer: bool = True
277276
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
278-
# Note: Not sure for all platforms,
279-
# but on Blackwell, only support a page size of
280-
# 16, 32, 64
281-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [16, 32, 64]
282277
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
283278
"auto",
284279
"fp8",
285280
"fp8_e4m3",
286281
"fp8_e5m2",
287282
]
288283

284+
@staticmethod
285+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
286+
# Note: Not sure for all platforms, but on Blackwell,
287+
# only support a page size of 16, 32, 64.
288+
return [16, 32, 64]
289+
289290
@staticmethod
290291
def get_name() -> str:
291292
return "FLASHINFER"
@@ -558,7 +559,6 @@ def __init__(
558559
)
559560

560561
@classmethod
561-
@override
562562
def get_cudagraph_support(
563563
cls: type["FlashInferMetadataBuilder"],
564564
vllm_config: VllmConfig,

vllm/v1/attention/backends/mla/cutlass_mla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,16 @@ class CutlassMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
3636

3737
class CutlassMLABackend(MLACommonBackend):
3838
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
39-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [128]
4039
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
4140
"auto",
4241
"fp8",
4342
"fp8_e4m3",
4443
]
4544

45+
@staticmethod
46+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
47+
return [128]
48+
4649
@staticmethod
4750
def get_name() -> str:
4851
return "CUTLASS_MLA"

vllm/v1/attention/backends/mla/flashattn_mla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,12 @@
4141

4242
class FlashAttnMLABackend(MLACommonBackend):
4343
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
44-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [MultipleOf(16)]
4544
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto"]
4645

46+
@staticmethod
47+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
48+
return [MultipleOf(16)]
49+
4750
@staticmethod
4851
def get_name() -> str:
4952
return "FLASH_ATTN_MLA"

vllm/v1/attention/backends/mla/flashinfer_mla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@ class FlashInferMLAMetadataBuilder(MLACommonMetadataBuilder[MLACommonMetadata]):
3535

3636
class FlashInferMLABackend(MLACommonBackend):
3737
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
38-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [32, 64]
3938
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
4039
"auto",
4140
"fp8",
4241
"fp8_e4m3",
4342
]
4443

44+
@staticmethod
45+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
46+
return [32, 64]
47+
4548
@staticmethod
4649
def get_name() -> str:
4750
return "FLASHINFER_MLA"

vllm/v1/attention/backends/mla/flashmla.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,16 @@
3939

4040
class FlashMLABackend(MLACommonBackend):
4141
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.float16, torch.bfloat16]
42-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
4342
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = [
4443
"auto",
4544
"fp8",
4645
"fp8_e4m3",
4746
]
4847

48+
@staticmethod
49+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
50+
return [64]
51+
4952
@staticmethod
5053
def get_name() -> str:
5154
return "FLASHMLA"

vllm/v1/attention/backends/mla/flashmla_sparse.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,12 @@
5555
class FlashMLASparseBackend(AttentionBackend):
5656
accept_output_buffer: bool = True
5757
supported_dtypes: ClassVar[list[torch.dtype]] = [torch.bfloat16]
58-
supported_kernel_block_sizes: ClassVar[list[int | MultipleOf]] = [64]
5958
supported_kv_cache_dtypes: ClassVar[list[CacheDType]] = ["auto", "fp8_ds_mla"]
6059

60+
@staticmethod
61+
def get_supported_kernel_block_size() -> list[int | MultipleOf]:
62+
return [64]
63+
6164
@staticmethod
6265
def get_name() -> str:
6366
return "FLASHMLA_SPARSE"

0 commit comments

Comments
 (0)