Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
12f85fb
init dev branch
heheda12345 Sep 19, 2025
1486830
add indexer module
youkaichao Sep 20, 2025
ee3271e
fix fp8 weight loading
youkaichao Sep 20, 2025
3f4154d
fix key
youkaichao Sep 20, 2025
991b94f
basic test
youkaichao Sep 20, 2025
00c455c
add indexer cache (#12)
heheda12345 Sep 20, 2025
ddaf933
setup sparse attention backend
heheda12345 Sep 21, 2025
aff9596
build sparse
LucasWilkinson Sep 21, 2025
22d0fe5
pass in selected index
heheda12345 Sep 21, 2025
3b9df19
make basic.py runable
heheda12345 Sep 21, 2025
f85564f
small fix
heheda12345 Sep 21, 2025
fe45b06
reduce api change
heheda12345 Sep 22, 2025
216c42f
revert
heheda12345 Sep 22, 2025
446c0de
Merge pull request #14 from vllm-model-0920/mla_backend
robertgshaw2-redhat Sep 22, 2025
840f205
format
LucasWilkinson Sep 22, 2025
fa13a8b
Merge pull request #13 from vllm-model-0920/lwilkinson/build-sparse-f…
LucasWilkinson Sep 22, 2025
0f54ca6
deepgemm integration
yewentao256 Sep 22, 2025
93eade0
fix clean logic
yewentao256 Sep 22, 2025
1e304d8
Merge pull request #20 from vllm-model-0920/wye-deepgemm-integration
yewentao256 Sep 22, 2025
0eba9f1
sparse decode and make prefill and decode both use MQA (#16)
LucasWilkinson Sep 23, 2025
c0c0624
Adding pytorch impl for Paged Indexer (#9)
zyongye Sep 23, 2025
6a29a01
support mtp with indexer kv (#21)
luccafong Sep 23, 2025
9ca6434
FlashMLA prefill kernel integration (#17)
heheda12345 Sep 23, 2025
75d382e
fix indexer bs>1 (#23)
heheda12345 Sep 23, 2025
9905f9d
fix build
LucasWilkinson Sep 23, 2025
23e809c
fix import (#24)
heheda12345 Sep 23, 2025
e19d0c9
enable sparse by default
heheda12345 Sep 23, 2025
bff5944
fix mla
NickLucche Sep 24, 2025
d689f18
Merge pull request #28 from vllm-model-0920/fix
mgoin Sep 24, 2025
b3a44bd
fix unify kv cache spec
heheda12345 Sep 25, 2025
08bd3e3
Merge pull request #30 from vllm-model-0920/fix-unify-kv-cache
NickLucche Sep 25, 2025
c81e3f7
Fix paged_mqa_logits clear True
yewentao256 Sep 25, 2025
70ec108
Merge pull request #34 from vllm-model-0920/wentao-fix-deepgemm-integ…
yewentao256 Sep 25, 2025
87104b5
Separate indexer prefill and decode and use different kernel (#26)
zyongye Sep 25, 2025
e2dcd85
update prefill indexer unittest
zyongye Sep 26, 2025
d7f80ed
paged_indexer_unit test
zyongye Sep 26, 2025
b7e4b60
remove unnecessary bias in wq_b and wk layer, accuracy is greatly imp…
zyongye Sep 27, 2025
9bb302e
Support piecewise cuda graph (#42)
heheda12345 Sep 27, 2025
065e9c4
set max buffer size (#45)
zyongye Sep 27, 2025
5fc3571
fix indexer + mtp (#43)
luccafong Sep 27, 2025
a95229b
indexer ref code cleanup (#47)
zyongye Sep 28, 2025
ff5eb40
fix non spec decode error (#48)
luccafong Sep 28, 2025
10e6d47
fix test
yewentao256 Sep 28, 2025
ad28f46
Merge pull request #50 from vllm-model-0920/wentao-fix-test
yewentao256 Sep 28, 2025
6853a0e
Revert "[Bug] Fix test for Blackwell"
yewentao256 Sep 28, 2025
9e4ec68
Merge pull request #51 from vllm-model-0920/revert-50-wentao-fix-test
yewentao256 Sep 28, 2025
ed9e42c
fix test
yewentao256 Sep 28, 2025
7829646
Merge pull request #52 from vllm-model-0920/wentao-fix-test
yewentao256 Sep 28, 2025
656ab3c
fix num sms (#53)
yewentao256 Sep 28, 2025
e744e06
FP8 Cache by using decode kernel only (#40)
LucasWilkinson Sep 29, 2025
53df680
Preliminary blackwell enablement (#54)
mxz297 Sep 29, 2025
fd63ddc
Move the logic of determining padding amount to class __init__
mxz297 Sep 29, 2025
224f1dd
Add indexer_k_quant_and_cache_kernel (#38)
Barry-Delaney Sep 29, 2025
5a07bc5
Merge pull request #56 from vllm-model-0920/mxz297/small-padding-fix
youkaichao Sep 29, 2025
b5ef289
remove tilelang dep (#57)
zyongye Sep 29, 2025
0e12bdb
default to fp8
LucasWilkinson Sep 29, 2025
82f0fa5
fix up prints
LucasWilkinson Sep 29, 2025
ee3edfa
Full-CG Support (#46)
LucasWilkinson Sep 29, 2025
d710dc8
reverse last commit of insert kernel (#60)
zyongye Sep 29, 2025
03fcecf
Merge pull request #59 from vllm-model-0920/lwilkinson/default-fp8
youkaichao Sep 29, 2025
b64779a
Gather cache. (#61)
zyongye Sep 29, 2025
80d834c
fix basic.py (#63)
heheda12345 Sep 29, 2025
aeee929
fix flashmla
heheda12345 Sep 29, 2025
f142654
fix unpack kernel (#64)
luccafong Sep 29, 2025
b215ed8
Merge branch 'dev' of github.com:vllm-model-0920/vllm-dsv32 into dev
heheda12345 Sep 29, 2025
093b0c0
partial configs
zyongye Sep 29, 2025
3e530a5
fix blackwell
heheda12345 Sep 29, 2025
88ef733
update config
youkaichao Sep 29, 2025
e35f98a
Merge remote-tracking branch 'upstream/main' into dsv32-base
zyongye Sep 29, 2025
98e0a0f
small fix
zyongye Sep 29, 2025
3683a69
update to support 12.8
LucasWilkinson Sep 29, 2025
b7de53e
format
zyongye Sep 29, 2025
53a3b94
fixing pre-commit
zyongye Sep 29, 2025
69fcaa2
fixing pre-commit
zyongye Sep 29, 2025
1dfc501
delete envs
zyongye Sep 29, 2025
66ebc85
fix basic.py
zyongye Sep 29, 2025
684658d
fix pre-commit
LucasWilkinson Sep 29, 2025
ae30e22
pre-commit
zyongye Sep 29, 2025
148f43a
fix more pre-commit
zyongye Sep 29, 2025
9033b4e
fix mtp config (#1)
luccafong Sep 29, 2025
1fd2cef
add tilelang kernel and skip if not installed
zyongye Sep 29, 2025
cd77644
[ci fix] DeepseekV2DecoderLayer.topk_indices_buffer
heheda12345 Sep 30, 2025
e1da843
Merge branch 'dsv32-base' of https://github.com/zyongye/vllm into dsv…
heheda12345 Sep 30, 2025
01e46c3
[ci fix] models/test_registry.py::test_registry_imports[DeepseekV32Fo…
heheda12345 Sep 30, 2025
1419ff1
[ci fix] AttentionSpec.use_mla related
heheda12345 Sep 30, 2025
8fbefb4
address review comment
zyongye Sep 30, 2025
abaa8cc
[ci] skip if not on sm90+, add vllm_config on longcat model
zyongye Sep 30, 2025
de7f7cb
[ci fix] test_can_initialize_large_subset[DeepseekV32ForCausalLM]
heheda12345 Sep 30, 2025
24fc3e7
Update vllm/transformers_utils/config.py
youkaichao Sep 30, 2025
9b1b762
Update vllm/v1/attention/backends/mla/flashmla_sparse.py
youkaichao Sep 30, 2025
eb5d331
rm files
youkaichao Sep 30, 2025
d9693e8
rm files
youkaichao Sep 30, 2025
39d9d0e
fix spacing
youkaichao Sep 30, 2025
c80dfd5
add type for return value
youkaichao Sep 30, 2025
07be34b
add type for return value
youkaichao Sep 30, 2025
a0264c7
fix for amd
youkaichao Sep 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 76 additions & 11 deletions cmake/external_projects/flashmla.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
else()
FetchContent_Declare(
flashmla
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
GIT_PROGRESS TRUE
CONFIGURE_COMMAND ""
BUILD_COMMAND ""
Expand All @@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
# Only build FlashMLA kernels if we are building for something compatible with
# sm90a
cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}")
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)

set(SUPPORT_ARCHS)
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
list(APPEND SUPPORT_ARCHS 9.0a)
endif()
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
list(APPEND SUPPORT_ARCHS 10.0a)
endif()


cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}")
if(FLASH_MLA_ARCHS)
set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS})
list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math")

set(FlashMLA_SOURCES
${flashmla_SOURCE_DIR}/csrc/flash_api.cpp
${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
${flashmla_SOURCE_DIR}/csrc/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu
${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu
${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu
)

set(FlashMLA_Extension_SOURCES
${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
)

set(FlashMLA_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)

set(FlashMLA_Extension_INCLUDES
${flashmla_SOURCE_DIR}/csrc
${flashmla_SOURCE_DIR}/csrc/sm90
${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/
${flashmla_SOURCE_DIR}/csrc/cutlass/include
${flashmla_SOURCE_DIR}/csrc)
${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include
)

set_gencode_flags_for_srcs(
SRCS "${FlashMLA_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")

set_gencode_flags_for_srcs(
SRCS "${FlashMLA_Extension_SOURCES}"
CUDA_ARCHS "${FLASH_MLA_ARCHS}")

define_gpu_extension_target(
_flashmla_C
DESTINATION vllm
Expand All @@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
USE_SABI 3
WITH_SOABI)

# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)

define_gpu_extension_target(
_flashmla_extension_C
DESTINATION vllm
LANGUAGE ${VLLM_GPU_LANG}
SOURCES ${FlashMLA_Extension_SOURCES}
COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
ARCHITECTURES ${VLLM_GPU_ARCHES}
INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
USE_SABI 3
WITH_SOABI)

# Keep Stable ABI for the module, but *not* for CUDA/C++ files.
# This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
target_compile_options(_flashmla_extension_C PRIVATE
$<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
$<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
else()
# Create an empty target for setup.py when not targeting sm90a systems
# Create empty targets for setup.py when not targeting sm90a systems
add_custom_target(_flashmla_C)
add_custom_target(_flashmla_extension_C)
endif()

8 changes: 8 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,11 @@ void cp_gather_cache(
torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES]
torch::Tensor const& cu_seq_lens, // [BATCH+1]
int64_t batch_size, std::optional<torch::Tensor> seq_starts = std::nullopt);

// Indexer K quantization and cache function
void indexer_k_quant_and_cache(
torch::Tensor& k, // [num_tokens, head_dim]
torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride]
torch::Tensor& slot_mapping, // [num_tokens]
int64_t quant_block_size, // quantization block size
const std::string& scale_fmt);
Loading