@@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR)
1818else ()
1919 FetchContent_Declare(
2020 flashmla
21- GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git
22- GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de
21+ GIT_REPOSITORY https://github.com/vllm-project/FlashMLA
22+ GIT_TAG 5f65b85703c7ed75fda01e06495077caad207c3f
2323 GIT_PROGRESS TRUE
2424 CONFIGURE_COMMAND ""
2525 BUILD_COMMAND ""
@@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}")
3333# The FlashMLA kernels only work on hopper and require CUDA 12.3 or later.
3434# Only build FlashMLA kernels if we are building for something compatible with
3535# sm90a
36- cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS} " )
37- if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
36+
37+ set (SUPPORT_ARCHS)
38+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3)
39+ list (APPEND SUPPORT_ARCHS 9.0a)
40+ endif ()
41+ if (${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8)
42+ list (APPEND SUPPORT_ARCHS 10.0a)
43+ endif ()
44+
45+
46+ cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS} " "${CUDA_ARCHS} " )
47+ if (FLASH_MLA_ARCHS)
48+ set (VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS} )
49+ list (APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math" )
50+
3851 set (FlashMLA_SOURCES
39- ${flashmla_SOURCE_DIR} /csrc/flash_api.cpp
40- ${flashmla_SOURCE_DIR} /csrc/kernels/get_mla_metadata.cu
41- ${flashmla_SOURCE_DIR} /csrc/kernels/mla_combine.cu
42- ${flashmla_SOURCE_DIR} /csrc/kernels/splitkv_mla.cu
43- ${flashmla_SOURCE_DIR} /csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu)
52+ ${flashmla_SOURCE_DIR} /csrc/torch_api.cpp
53+ ${flashmla_SOURCE_DIR} /csrc/pybind.cpp
54+ ${flashmla_SOURCE_DIR} /csrc/smxx/get_mla_metadata.cu
55+ ${flashmla_SOURCE_DIR} /csrc/smxx/mla_combine.cu
56+ ${flashmla_SOURCE_DIR} /csrc/sm90/decode/dense/splitkv_mla.cu
57+ ${flashmla_SOURCE_DIR} /csrc/sm90/decode/sparse_fp8/splitkv_mla.cu
58+ ${flashmla_SOURCE_DIR} /csrc/sm90/prefill/sparse/fwd.cu
59+ ${flashmla_SOURCE_DIR} /csrc/sm100/decode/sparse_fp8/splitkv_mla.cu
60+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu
61+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu
62+ ${flashmla_SOURCE_DIR} /csrc/sm100/prefill/sparse/fwd.cu
63+ )
64+
65+ set (FlashMLA_Extension_SOURCES
66+ ${flashmla_SOURCE_DIR} /csrc/extension/torch_api.cpp
67+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/pybind.cpp
68+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu
69+ )
4470
4571 set (FlashMLA_INCLUDES
72+ ${flashmla_SOURCE_DIR} /csrc
73+ ${flashmla_SOURCE_DIR} /csrc/sm90
74+ ${flashmla_SOURCE_DIR} /csrc/cutlass/include
75+ ${flashmla_SOURCE_DIR} /csrc/cutlass/tools/util/include
76+ )
77+
78+ set (FlashMLA_Extension_INCLUDES
79+ ${flashmla_SOURCE_DIR} /csrc
80+ ${flashmla_SOURCE_DIR} /csrc/sm90
81+ ${flashmla_SOURCE_DIR} /csrc/extension/sm90/dense_fp8/
4682 ${flashmla_SOURCE_DIR} /csrc/cutlass/include
47- ${flashmla_SOURCE_DIR} /csrc)
83+ ${flashmla_SOURCE_DIR} /csrc/cutlass/tools/util/include
84+ )
4885
4986 set_gencode_flags_for_srcs(
5087 SRCS "${FlashMLA_SOURCES} "
5188 CUDA_ARCHS "${FLASH_MLA_ARCHS} " )
5289
90+ set_gencode_flags_for_srcs(
91+ SRCS "${FlashMLA_Extension_SOURCES} "
92+ CUDA_ARCHS "${FLASH_MLA_ARCHS} " )
93+
5394 define_gpu_extension_target(
5495 _flashmla_C
5596 DESTINATION vllm
@@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS)
60101 INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES}
61102 USE_SABI 3
62103 WITH_SOABI)
104+
105+ # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
106+ # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
107+ target_compile_options (_flashmla_C PRIVATE
108+ $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
109+ $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
110+
111+ define_gpu_extension_target(
112+ _flashmla_extension_C
113+ DESTINATION vllm
114+ LANGUAGE ${VLLM_GPU_LANG}
115+ SOURCES ${FlashMLA_Extension_SOURCES}
116+ COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS}
117+ ARCHITECTURES ${VLLM_GPU_ARCHES}
118+ INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES}
119+ USE_SABI 3
120+ WITH_SOABI)
121+
122+ # Keep Stable ABI for the module, but *not* for CUDA/C++ files.
123+ # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles.
124+ target_compile_options (_flashmla_extension_C PRIVATE
125+ $<$<COMPILE_LANGUAGE:CUDA>:-UPy_LIMITED_API>
126+ $<$<COMPILE_LANGUAGE:CXX>:-UPy_LIMITED_API>)
63127else ()
64- # Create an empty target for setup.py when not targeting sm90a systems
128+ # Create empty targets for setup.py when not targeting sm90a systems
65129 add_custom_target (_flashmla_C)
130+ add_custom_target (_flashmla_extension_C)
66131endif ()
67132
0 commit comments