Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
5ea3788
sync with amd, support v1
LucasWilkinson Feb 28, 2025
d048631
fix IMA
LucasWilkinson Mar 4, 2025
e09841c
bugfix
LucasWilkinson Mar 4, 2025
52e7234
working
LucasWilkinson Mar 4, 2025
4604201
cleanup
LucasWilkinson Mar 5, 2025
f9f3e3e
fa MLA
LucasWilkinson Mar 5, 2025
b3b060b
commit wip
LucasWilkinson Mar 6, 2025
e286de8
cleanup
LucasWilkinson Apr 18, 2025
27a2cd2
fix
LucasWilkinson Apr 20, 2025
9165af3
move files
LucasWilkinson Apr 20, 2025
d056efd
fix up
LucasWilkinson Apr 20, 2025
ac4c624
v0 support + decode threshold
LucasWilkinson Apr 22, 2025
1f6bb3d
v0 fix
LucasWilkinson Apr 22, 2025
73c8736
fix
LucasWilkinson Apr 22, 2025
0f9ed95
fix logs
LucasWilkinson Apr 22, 2025
f2dc4a3
don't schedule prefills
LucasWilkinson Apr 22, 2025
d695fdc
still default to FlashMLA
LucasWilkinson Apr 24, 2025
82c9393
Remove V0 FlashAttention MLA
MatthewBonanni Aug 21, 2025
dc16bb5
Move back to original location
MatthewBonanni Aug 21, 2025
046af0b
Undo change
MatthewBonanni Aug 21, 2025
8a0fe94
Match main
MatthewBonanni Aug 21, 2025
9c5445d
Use reorder_batch_threshold throughout
MatthewBonanni Aug 21, 2025
6b90fd7
Remove input_positions
MatthewBonanni Aug 21, 2025
790bde6
Match main, remove unused arguments
MatthewBonanni Aug 21, 2025
161f50e
Align _build_decode signature
MatthewBonanni Aug 21, 2025
d87b921
Fix more arguments
MatthewBonanni Aug 21, 2025
5a32eeb
More compatibility fixes
MatthewBonanni Aug 21, 2025
63ec527
Fix backend enum
MatthewBonanni Aug 21, 2025
da96e28
Remove unused helpers
MatthewBonanni Aug 21, 2025
fb09124
Rename
MatthewBonanni Aug 25, 2025
70343e7
Loosen tolerances for FA MLA backend
MatthewBonanni Aug 25, 2025
5daadfe
Fix _forward_decode signature
MatthewBonanni Aug 25, 2025
b8e6e0a
Respect each backend's decode threshold
MatthewBonanni Aug 26, 2025
91f01d4
Fix backend selection logic
MatthewBonanni Aug 26, 2025
4201218
Address pre-commit
MatthewBonanni Aug 26, 2025
fe5ba41
Update GIT_TAG
MatthewBonanni Aug 27, 2025
6455578
Decode threshold tuning
MatthewBonanni Aug 27, 2025
513fdeb
Undo V0 change
MatthewBonanni Aug 27, 2025
fd25615
Pass qkv_dtype
MatthewBonanni Aug 27, 2025
398e55b
increase wheel size
LucasWilkinson Aug 28, 2025
4f29ce1
missing line
LucasWilkinson Aug 28, 2025
8672a7f
Fix backend selector logic and test
MatthewBonanni Aug 28, 2025
ea0f9c4
Merge remote-tracking branch 'origin/main' into lwilkinson/fa-mla
LucasWilkinson Aug 29, 2025
8298b9e
Merge remote-tracking branch 'origin/main' into lwilkinson/fa-mla
LucasWilkinson Aug 29, 2025
15c0fed
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Aug 29, 2025
98f3592
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Aug 29, 2025
6ef55b0
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 2, 2025
84737e7
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 3, 2025
dd2516a
Merge branch 'main' into lwilkinson/fa-mla
MatthewBonanni Sep 3, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .buildkite/check-wheel-size.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
import sys
import zipfile

# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 400 MiB
# Note that we have 400 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/3792 .
# Read the VLLM_MAX_SIZE_MB environment variable, defaulting to 450 MiB
# Note that we have 800 MiB quota, please use it wisely.
# See https://github.com/pypi/support/issues/6326 .
# Please also sync the value with the one in Dockerfile.
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 400))
VLLM_MAX_SIZE_MB = int(os.environ.get("VLLM_MAX_SIZE_MB", 450))


def print_top_10_largest_files(zip_file):
Expand Down
2 changes: 1 addition & 1 deletion cmake/external_projects/vllm_flash_attn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ else()
FetchContent_Declare(
vllm-flash-attn
GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git
GIT_TAG 57b4e68b9f9d94750b46de8f8dbd2bfcc86edd4f
GIT_TAG ee4d25bd84e0cbc7e0b9b9685085fd5db2dcb62a
GIT_PROGRESS TRUE
# Don't share the vllm-flash-attn build between build types
BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn
Expand Down
2 changes: 1 addition & 1 deletion docker/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \
# Check the size of the wheel if RUN_WHEEL_CHECK is true
COPY .buildkite/check-wheel-size.py check-wheel-size.py
# sync the default value with .buildkite/check-wheel-size.py
ARG VLLM_MAX_SIZE_MB=400
ARG VLLM_MAX_SIZE_MB=450
ENV VLLM_MAX_SIZE_MB=$VLLM_MAX_SIZE_MB
ARG RUN_WHEEL_CHECK=true
RUN if [ "$RUN_WHEEL_CHECK" = "true" ]; then \
Expand Down
110 changes: 83 additions & 27 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def clear_cache():

# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA"],
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
Expand Down Expand Up @@ -98,21 +98,14 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
RocmPlatform()):
if use_mla:
# Validate HIP MLA backend-block_size combinations
valid_combination = (
(name == "TRITON_MLA" and block_size != 1)
or (name == "ROCM_AITER_MLA" and block_size == 1))

if valid_combination:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
# ROCm MLA backend logic:
# - TRITON_MLA: supported when block_size != 1
# - ROCM_AITER_MLA: supported when block_size == 1
# If backend is forced but doesn't match block_size,
# should raise ValueError

if name == "TRITON_MLA" and block_size == 1:
# TRITON_MLA doesn't support block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
Expand All @@ -122,6 +115,27 @@ def test_env(
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
elif name == "ROCM_AITER_MLA" and block_size != 1:
# ROCM_AITER_MLA only supports block_size == 1
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
assert f"The selected backend, {name}" in str(
exc_info.value)
else:
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
else:
backend = get_attn_backend(16,
torch.float16,
Expand All @@ -136,26 +150,68 @@ def test_env(
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
if use_mla:
if name == "FLASHMLA" and block_size == 64:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)

# only on cuda platforms with specific capability.
is_supported, _ = is_flashmla_supported()

if not is_supported:
# if platform is not supported then skip this case.
pytest.skip()
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases

if name == "CUTLASS_MLA":
if not use_v1:
# CUTLASS_MLA only supported on V1 engine
pytest.skip(
"CUTLASS_MLA only supported on V1 engine")
elif block_size != 128:
# CUTLASS_MLA only supports block_size == 128
pytest.skip(
"CUTLASS_MLA only supports block_size 128")
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
pytest.skip("FlashMLA only supports block_size 64")
else:
from vllm.attention.backends.flashmla import (
is_flashmla_supported)
is_supported, _ = is_flashmla_supported()
if not is_supported:
pytest.skip(
"FlashMLA not supported on this platform")
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
assert backend.get_name() == expected
elif name == "FLASH_ATTN_MLA":
if not use_v1:
# FlashAttention MLA only supported on V1 engine
pytest.skip(
"FlashAttention MLA only supported on V1 engine"
)
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
block_size,
False,
use_mla=use_mla)
expected = f"{name}_VLLM_V1" if use_v1 else name
expected = "FLASH_ATTN_MLA"
assert backend.get_name() == expected
else:
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
torch.float16,
Expand Down
16 changes: 0 additions & 16 deletions tests/v1/attention/test_attention_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,6 @@ def _convert_dtype_to_torch(dtype):
}


def create_dummy_kv_cache(kv_cache_spec: FullAttentionSpec,
device: torch.device,
num_blocks: int = 100) -> torch.Tensor:
"""Create a dummy KV cache tensor for testing."""
kv_cache = torch.randn(
2, # K and V
num_blocks,
kv_cache_spec.block_size,
kv_cache_spec.num_kv_heads,
kv_cache_spec.head_size,
dtype=_convert_dtype_to_torch(kv_cache_spec.dtype),
device=device,
)
return kv_cache


def create_and_prepopulate_kv_cache(
k_contexts: list[torch.Tensor],
v_contexts: list[torch.Tensor],
Expand Down
Loading