diff --git a/cmake/external_projects/flashmla.cmake b/cmake/external_projects/flashmla.cmake index 02224cfe3ee8..0636f25ec13c 100644 --- a/cmake/external_projects/flashmla.cmake +++ b/cmake/external_projects/flashmla.cmake @@ -18,8 +18,8 @@ if(FLASH_MLA_SRC_DIR) else() FetchContent_Declare( flashmla - GIT_REPOSITORY https://github.com/vllm-project/FlashMLA.git - GIT_TAG a757314c04eedd166e329e846c820eb1bdd702de + GIT_REPOSITORY https://github.com/vllm-project/FlashMLA + GIT_TAG e350c2d2e42f069ced7ceee68804f224553899ac GIT_PROGRESS TRUE CONFIGURE_COMMAND "" BUILD_COMMAND "" @@ -33,23 +33,64 @@ message(STATUS "FlashMLA is available at ${flashmla_SOURCE_DIR}") # The FlashMLA kernels only work on hopper and require CUDA 12.3 or later. # Only build FlashMLA kernels if we are building for something compatible with # sm90a -cuda_archs_loose_intersection(FLASH_MLA_ARCHS "9.0a" "${CUDA_ARCHS}") -if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) + +set(SUPPORT_ARCHS) +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3) + list(APPEND SUPPORT_ARCHS 9.0a) +endif() +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.8) + list(APPEND SUPPORT_ARCHS 10.0a) +endif() + + +cuda_archs_loose_intersection(FLASH_MLA_ARCHS "${SUPPORT_ARCHS}" "${CUDA_ARCHS}") +if(FLASH_MLA_ARCHS) + set(VLLM_FLASHMLA_GPU_FLAGS ${VLLM_GPU_FLAGS}) + list(APPEND VLLM_FLASHMLA_GPU_FLAGS "--expt-relaxed-constexpr" "--expt-extended-lambda" "--use_fast_math") + set(FlashMLA_SOURCES - ${flashmla_SOURCE_DIR}/csrc/flash_api.cpp - ${flashmla_SOURCE_DIR}/csrc/kernels/get_mla_metadata.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/mla_combine.cu - ${flashmla_SOURCE_DIR}/csrc/kernels/splitkv_mla.cu - ${flashmla_SOURCE_DIR}/csrc/kernels_fp8/flash_fwd_mla_fp8_sm90.cu) + ${flashmla_SOURCE_DIR}/csrc/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/smxx/get_mla_metadata.cu + ${flashmla_SOURCE_DIR}/csrc/smxx/mla_combine.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/dense/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm90/prefill/sparse/fwd.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/decode/sparse_fp8/splitkv_mla.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_fwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/dense/fmha_cutlass_bwd_sm100.cu + ${flashmla_SOURCE_DIR}/csrc/sm100/prefill/sparse/fwd.cu + ) + + set(FlashMLA_Extension_SOURCES + ${flashmla_SOURCE_DIR}/csrc/extension/torch_api.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/pybind.cpp + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/flash_fwd_mla_fp8_sm90.cu + ) set(FlashMLA_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 + ${flashmla_SOURCE_DIR}/csrc/cutlass/include + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) + + set(FlashMLA_Extension_INCLUDES + ${flashmla_SOURCE_DIR}/csrc + ${flashmla_SOURCE_DIR}/csrc/sm90 + ${flashmla_SOURCE_DIR}/csrc/extension/sm90/dense_fp8/ ${flashmla_SOURCE_DIR}/csrc/cutlass/include - ${flashmla_SOURCE_DIR}/csrc) + ${flashmla_SOURCE_DIR}/csrc/cutlass/tools/util/include + ) set_gencode_flags_for_srcs( SRCS "${FlashMLA_SOURCES}" CUDA_ARCHS "${FLASH_MLA_ARCHS}") + set_gencode_flags_for_srcs( + SRCS "${FlashMLA_Extension_SOURCES}" + CUDA_ARCHS "${FLASH_MLA_ARCHS}") + define_gpu_extension_target( _flashmla_C DESTINATION vllm @@ -60,8 +101,32 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.3 AND FLASH_MLA_ARCHS) INCLUDE_DIRECTORIES ${FlashMLA_INCLUDES} USE_SABI 3 WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) + + define_gpu_extension_target( + _flashmla_extension_C + DESTINATION vllm + LANGUAGE ${VLLM_GPU_LANG} + SOURCES ${FlashMLA_Extension_SOURCES} + COMPILE_FLAGS ${VLLM_FLASHMLA_GPU_FLAGS} + ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${FlashMLA_Extension_INCLUDES} + USE_SABI 3 + WITH_SOABI) + + # Keep Stable ABI for the module, but *not* for CUDA/C++ files. + # This prevents Py_LIMITED_API from affecting nvcc and C++ compiles. + target_compile_options(_flashmla_extension_C PRIVATE + $<$:-UPy_LIMITED_API> + $<$:-UPy_LIMITED_API>) else() - # Create an empty target for setup.py when not targeting sm90a systems + # Create empty targets for setup.py when not targeting sm90a systems add_custom_target(_flashmla_C) + add_custom_target(_flashmla_extension_C) endif() diff --git a/csrc/cache.h b/csrc/cache.h index fd230bec27fc..3a4fc92a6c25 100644 --- a/csrc/cache.h +++ b/csrc/cache.h @@ -56,3 +56,11 @@ void cp_gather_cache( torch::Tensor const& block_table, // [BATCH, BLOCK_INDICES] torch::Tensor const& cu_seq_lens, // [BATCH+1] int64_t batch_size, std::optional seq_starts = std::nullopt); + +// Indexer K quantization and cache function +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt); diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 80b4c47c5547..b9fb1b680806 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -16,6 +16,7 @@ #include #include +#include // FLT_MIN #include #include @@ -396,6 +397,160 @@ __global__ void concat_and_cache_mla_kernel( copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank); } +template +__global__ void concat_and_cache_ds_mla_kernel( + const scalar_t* __restrict__ kv_c, // [num_tokens, kv_lora_rank] + const scalar_t* __restrict__ k_pe, // [num_tokens, pe_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank + // + pe_dim)] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, // + const int entry_stride, // + const int kv_c_stride, // + const int k_pe_stride, // + const int kv_lora_rank, // + const int pe_dim, // + const int block_size, // + const float* scale // +) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int64_t dst_idx_start = + block_idx * block_stride + block_offset * entry_stride; + + // Create 4 tile scales in shared memory + __shared__ float smem[20]; + float* shard_abs_max = smem; + float* tile_scales = smem + 16; + + // For the NoPE part, each tile of 128 elements is handled by 4 warps + // (128 threads). There are 4 total tiles, so 16 warps (512 threads). + // The first thread of the first warp in each tile writes the scale + // value for the tile. The RoPE part (last 64 elements) is handled + // by another 2 warps (64 threads). + // So in total, we use 18 warps (576 threads) per block. + + // Cast kv_cache to 16_bit for RoPE values + scalar_t* kv_cache_16bit = + reinterpret_cast(&kv_cache[dst_idx_start]); + + // The last 64 threads handle the RoPE part + if (threadIdx.x >= kv_lora_rank) { + const int8_t pe_idx = threadIdx.x - kv_lora_rank; + const int64_t src_idx = token_idx * k_pe_stride + pe_idx; + // RoPE values start after the packed 8-bit NoPE values and the + // 32-bit scales + const int64_t dst_idx = kv_lora_rank / 2 + 8 + pe_idx; + kv_cache_16bit[dst_idx] = k_pe[src_idx]; + return; + } + + // Determine the scale for each chunk of NoPE + const int16_t tile_idx = threadIdx.x >> 7; + const int16_t warp_idx = (threadIdx.x & 127) >> 5; + const int16_t lane_idx = threadIdx.x & 31; + + // Load the NoPE element for this thread into registers + const int64_t src_idx = token_idx * kv_c_stride + threadIdx.x; + const scalar_t src_val = kv_c[src_idx]; + + // Warp-level reduction to find the max absolute value in the warp + float max_abs = fabsf(src_val); +#pragma unroll + for (int offset = 16; offset > 0; offset /= 2) { + max_abs = fmaxf(max_abs, __shfl_down_sync(0xFFFFFFFF, max_abs, offset)); + } + + // The first lane of each warp in each tile writes the max_abs of this part + // of the tile to shared memory + if (lane_idx == 0) { + shard_abs_max[tile_idx * 4 + warp_idx] = max_abs; + } + __syncthreads(); + + // The first lane of the first warp in each tile computes the scale for the + // tile and writes it to shared memory and to kv_cache + if (warp_idx == 0 && lane_idx == 0) { + float4 shard_abs_max_vec = + reinterpret_cast(shard_abs_max)[tile_idx]; + float tile_scale = fmaxf(fmaxf(shard_abs_max_vec.x, shard_abs_max_vec.y), + fmaxf(shard_abs_max_vec.z, shard_abs_max_vec.w)) / + 448.f; + + // Avoid division by zero in `scaled_convert` + tile_scales[tile_idx] = fmaxf(tile_scale, FLT_MIN); + float* kv_cache_32bit = reinterpret_cast(&kv_cache[dst_idx_start]); + const uint64_t dst_idx = kv_lora_rank / 4 + tile_idx; + kv_cache_32bit[dst_idx] = tile_scales[tile_idx]; + } + + __syncthreads(); + + // Now all threads in the block scale and write their element + const float scale_val = tile_scales[tile_idx]; + const int64_t dst_idx = dst_idx_start + threadIdx.x; + kv_cache[dst_idx] = + fp8::scaled_convert( + src_val, scale_val); +} + +template +__global__ void indexer_k_quant_and_cache_kernel( + const scalar_t* __restrict__ k, // [num_tokens, head_dim] + cache_t* __restrict__ kv_cache, // [num_blocks, block_size, cache_stride] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int head_dim, // dimension of each head + const int quant_block_size, // quantization block size + const int cache_block_size, // cache block size + const int cache_stride, // stride for each token in kv_cache + const bool use_ue8m0 // use ue8m0 scale format +) { + constexpr int VEC_SIZE = 4; + const int64_t token_idx = blockIdx.x; + const int64_t head_dim_idx = (blockIdx.y * blockDim.y * blockDim.x + threadIdx.y * blockDim.x + threadIdx.x) * VEC_SIZE; + const int64_t slot_idx = slot_mapping[token_idx]; + const int64_t block_idx = slot_idx / cache_block_size; + const int64_t block_offset = slot_idx % cache_block_size; + + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0 || (head_dim_idx >= head_dim)) { + return; + } + + float2 k_val = (reinterpret_cast(k))[(token_idx * head_dim + head_dim_idx) / VEC_SIZE]; + scalar_t* k_val_ptr = reinterpret_cast(&k_val); + float amax = 0.0f; + for (int i = 0; i < VEC_SIZE; i++) { + amax = fmaxf(amax, fabsf(float(k_val_ptr[i]))); + } + __syncwarp(); + + // Reduced amax + for (int mask = 16; mask > 0; mask /= 2) { + amax = fmaxf(amax, __shfl_xor_sync(unsigned(-1), amax, mask)); + } + __syncwarp(); + float scale = fmaxf(amax, 1e-4) / 448.0f; + if (use_ue8m0) { + scale = exp2f(ceilf(log2f(scale))); + } + + const int64_t dst_offset = block_idx * cache_block_size * cache_stride + block_offset * head_dim + head_dim_idx; + for (int i = 0; i < VEC_SIZE; i++) { + kv_cache[dst_offset + i] = fp8::scaled_convert(k_val_ptr[i], scale); + } + if (threadIdx.x == 0) { + const int64_t dst_scale_idx = block_idx * cache_block_size * cache_stride + cache_block_size * head_dim + (block_offset * head_dim + head_dim_idx) * 4 / quant_block_size; + reinterpret_cast(kv_cache)[dst_scale_idx / 4] = scale; + } +} + } // namespace vllm // KV_T is the data type of key and value tensors. @@ -438,7 +593,7 @@ void reshape_and_cache( const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, - CALL_RESHAPE_AND_CACHE) + CALL_RESHAPE_AND_CACHE); } // KV_T is the data type of key and value tensors. @@ -509,6 +664,18 @@ void reshape_and_cache_flash( kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ reinterpret_cast(scale.data_ptr())); +// KV_T is the data type of key and value tensors. +// CACHE_T is the stored data type of kv-cache. +#define CALL_CONCAT_AND_CACHE_DS_MLA(KV_T, CACHE_T, KV_DTYPE) \ + vllm::concat_and_cache_ds_mla_kernel \ + <<>>( \ + reinterpret_cast(kv_c.data_ptr()), \ + reinterpret_cast(k_pe.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, entry_stride, \ + kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \ + reinterpret_cast(scale.data_ptr())); + void concat_and_cache_mla( torch::Tensor& kv_c, // [num_tokens, kv_lora_rank] torch::Tensor& k_pe, // [num_tokens, pe_dim] @@ -531,20 +698,44 @@ void concat_and_cache_mla( int pe_dim = k_pe.size(1); int block_size = kv_cache.size(1); - TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + if (kv_cache_dtype == "fp8_ds_mla") { + TORCH_CHECK(kv_lora_rank == 512, "kv_lora_rank must be 512 for fp8_ds_mla"); + TORCH_CHECK(pe_dim == 64, "pe_dim must be 64 for fp8_ds_mla"); + TORCH_CHECK(kv_cache.size(2) == 656 / kv_cache.itemsize(), + "kv_cache.size(2) must be 656 bytes for fp8_ds_mla"); + TORCH_CHECK(kv_c.itemsize() == 2, + "kv_c.itemsize() must be 2 for fp8_ds_mla"); + TORCH_CHECK(k_pe.itemsize() == 2, + "k_pe.itemsize() must be 2 for fp8_ds_mla"); + } else { + TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim); + } int kv_c_stride = kv_c.stride(0); int k_pe_stride = k_pe.stride(0); int block_stride = kv_cache.stride(0); int entry_stride = kv_cache.stride(1); - dim3 grid(num_tokens); - dim3 block(std::min(kv_lora_rank, 512)); const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c)); const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, - CALL_CONCAT_AND_CACHE_MLA); + if (kv_cache_dtype == "fp8_ds_mla") { + dim3 grid(num_tokens); + // For the NoPE part, each tile of 128 elements is handled by 4 warps + // (128 threads). There are 4 total tiles, so 16 warps (512 threads). + // The first thread of the first warp in each tile writes the scale + // value for the tile. The RoPE part (last 64 elements) is handled + // by another 2 warps (64 threads). + // So in total, we use 18 warps (576 threads) per block. + dim3 block(576); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_DS_MLA); + } else { + dim3 grid(num_tokens); + dim3 block(std::min(kv_lora_rank, 512)); + DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype, + CALL_CONCAT_AND_CACHE_MLA); + } } namespace vllm { @@ -922,3 +1113,42 @@ void cp_gather_cache( TORCH_CHECK(false, "Unsupported data type width: ", dtype_bits); } } + +// Macro to dispatch the kernel based on the data type. +#define CALL_INDEXER_K_QUANT_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::indexer_k_quant_and_cache_kernel \ + <<>>( \ + reinterpret_cast(k.data_ptr()), \ + reinterpret_cast(kv_cache.data_ptr()), \ + slot_mapping.data_ptr(), head_dim, quant_block_size, \ + cache_block_size, cache_stride, use_ue8m0); + +void indexer_k_quant_and_cache( + torch::Tensor& k, // [num_tokens, head_dim] + torch::Tensor& kv_cache, // [num_blocks, block_size, cache_stride] + torch::Tensor& slot_mapping, // [num_tokens] + int64_t quant_block_size, // quantization block size + const std::string& scale_fmt) { + + int num_tokens = k.size(0); + int head_dim = k.size(1); + int cache_block_size = kv_cache.size(1); + int cache_stride = kv_cache.size(2); + bool use_ue8m0 = scale_fmt == "ue8m0"; + + TORCH_CHECK(k.device() == kv_cache.device(), + "k and kv_cache must be on the same device"); + TORCH_CHECK(k.device() == slot_mapping.device(), + "k and slot_mapping must be on the same device"); + TORCH_CHECK(head_dim % quant_block_size == 0, + "head_dim must be divisible by quant_block_size"); + + constexpr int vec_size = 4; + dim3 grid(num_tokens, (head_dim + quant_block_size * vec_size - 1) / (quant_block_size * vec_size)); + dim3 block(32, vec_size); + const at::cuda::OptionalCUDAGuard device_guard(device_of(k)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(k.dtype(), "fp8_e4m3", + CALL_INDEXER_K_QUANT_AND_CACHE); +} \ No newline at end of file diff --git a/csrc/quantization/fp8/nvidia/quant_utils.cuh b/csrc/quantization/fp8/nvidia/quant_utils.cuh index 5b9c2df8468c..5361a8b1a598 100644 --- a/csrc/quantization/fp8/nvidia/quant_utils.cuh +++ b/csrc/quantization/fp8/nvidia/quant_utils.cuh @@ -576,6 +576,17 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { TORCH_CHECK(false, \ "Unsupported input type of kv cache: ", SRC_DTYPE); \ } \ + } else if (KV_DTYPE == "fp8_ds_mla") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ } else { \ TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ } \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index bc096406c51a..9e7fbeb80bb3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -713,6 +713,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "cp_gather_cache(Tensor src_cache, Tensor! dst, Tensor block_table, " "Tensor cu_seq_lens, int batch_size, Tensor? seq_starts) -> ()"); cache_ops.impl("cp_gather_cache", torch::kCUDA, &cp_gather_cache); + + cache_ops.def( + "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache, Tensor slot_mapping, " + "int quant_block_size, str kv_cache_dtype) -> ()"); + cache_ops.impl("indexer_k_quant_and_cache", torch::kCUDA, &indexer_k_quant_and_cache); } TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cuda_utils), cuda_utils) { diff --git a/examples/offline_inference/basic/basic.py b/examples/offline_inference/basic/basic.py index 78bfda9bcf4e..909fc9e4df66 100644 --- a/examples/offline_inference/basic/basic.py +++ b/examples/offline_inference/basic/basic.py @@ -32,4 +32,4 @@ def main(): if __name__ == "__main__": - main() + main() \ No newline at end of file diff --git a/setup.py b/setup.py index e4c40d22b928..6434bada6898 100644 --- a/setup.py +++ b/setup.py @@ -322,6 +322,8 @@ def extract_precompiled_and_patch_package(wheel_url_or_path: str) -> dict: "vllm/_C.abi3.so", "vllm/_moe_C.abi3.so", "vllm/_flashmla_C.abi3.so", + "vllm/_flashmla_extension_C.abi3.so", + "vllm/_sparse_flashmla_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so", "vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so", "vllm/cumem_allocator.abi3.so", @@ -589,6 +591,8 @@ def _read_requirements(filename: str) -> list[str]: # not targeting a hopper system ext_modules.append( CMakeExtension(name="vllm._flashmla_C", optional=True)) + ext_modules.append( + CMakeExtension(name="vllm._flashmla_extension_C", optional=True)) ext_modules.append(CMakeExtension(name="vllm.cumem_allocator")) if _build_custom_ops(): diff --git a/tests/compile/test_fusion_attn.py b/tests/compile/test_fusion_attn.py index 022f183b3193..76e82bfa8087 100644 --- a/tests/compile/test_fusion_attn.py +++ b/tests/compile/test_fusion_attn.py @@ -191,7 +191,6 @@ def __init__(self, num_qo_heads: int, num_kv_heads: int, head_size: int, num_kv_heads=self.num_kv_heads, head_size=self.head_size, dtype=self.kv_cache_dtype, - use_mla=False, ), layer_names=[self.attn.layer_name], vllm_config=self.vllm_config, diff --git a/tests/kernels/attention/test_cache.py b/tests/kernels/attention/test_cache.py index 69e96dfd2cb1..75bdcb6808b9 100644 --- a/tests/kernels/attention/test_cache.py +++ b/tests/kernels/attention/test_cache.py @@ -578,6 +578,119 @@ def test_concat_and_cache_mla( torch.testing.assert_close(kv_cache, ref_kv_cache) +@pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) +@pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) +@pytest.mark.parametrize("num_tokens", NUM_TOKENS_MLA) +@pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) +@pytest.mark.parametrize("num_blocks", NUM_BLOCKS_MLA) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_concat_and_cache_ds_mla( + kv_lora_rank: int, + qk_rope_head_dim: int, + num_tokens: int, + block_size: int, + num_blocks: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + if dtype.itemsize != 2: + pytest.skip("ds_mla only supports 16-bit input") + kv_cache_dtype = "fp8_ds_mla" + current_platform.seed_everything(seed) + torch.set_default_device(device) + + total_slots = num_blocks * block_size + slot_mapping_lst = random.sample(range(total_slots), num_tokens) + slot_mapping = torch.tensor(slot_mapping_lst, + dtype=torch.long, + device=device) + + kv_c = torch.randn(num_tokens, kv_lora_rank, dtype=dtype, device=device) + k_pe = torch.randn(num_tokens, + qk_rope_head_dim, + dtype=dtype, + device=device) + entry_size = kv_lora_rank + (4 * 4) + (2 * qk_rope_head_dim) + + scale = torch.tensor(1.0, dtype=torch.float32, device=device) + kv_cache = _create_mla_cache(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + kv_cache_dtype=kv_cache_dtype, + device=device) + + ref_cache = torch.zeros_like(kv_cache, dtype=kv_cache.dtype) + tile_data = torch.zeros(128, dtype=dtype, device=device) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + + ref_cache_slice = ref_cache[block_idx, block_offset] + ref_cache_16bit = ref_cache_slice.view(dtype) + ref_cache_32bit = ref_cache_slice.view(torch.float32) + + kv_c_data = kv_c[i] + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = (tile_idx + 1) * 128 + tile_data[:] = kv_c_data[tile_start:tile_end] + + # tile_scale = tile_data.amax().to(torch.float32) / 448. + # NOTE: Using torch's amax() gives different results, + # so this must be manually computed. + tile_data_float = tile_data.to(torch.float32) + manual_max = abs(tile_data_float[0]) + for j in range(1, 128): + manual_max = max(manual_max, abs(tile_data_float[j])) + tile_scale = manual_max / 448. + + ref_cache_32bit[kv_lora_rank // 4 + tile_idx] = tile_scale + + ops.convert_fp8(ref_cache_slice[tile_start:tile_end], + tile_data, + tile_scale.item(), + kv_dtype="fp8") + + for j in range(qk_rope_head_dim): + ref_cache_16bit[kv_lora_rank // 2 + 8 + j] = k_pe[i, j] + + opcheck( + torch.ops._C_cache_ops.concat_and_cache_mla, + (kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale), + test_utils=DEFAULT_OPCHECK_TEST_UTILS, + ) + + ops.concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, + kv_cache_dtype, scale) + + for i in range(num_tokens): + slot = slot_mapping[i].item() + block_idx = slot // block_size + block_offset = slot % block_size + kv_cache_slice = kv_cache[block_idx, block_offset] + ref_cache_slice = ref_cache[block_idx, block_offset] + + kv_nope = kv_cache_slice[:kv_lora_rank] + ref_nope = ref_cache_slice[:kv_lora_rank] + kv_scales = kv_cache_slice.view(torch.float32)[kv_lora_rank // + 4:kv_lora_rank // 4 + 4] + ref_scales = ref_cache_slice.view( + torch.float32)[kv_lora_rank // 4:kv_lora_rank // 4 + 4] + kv_rope = kv_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] + ref_rope = ref_cache_slice.view(dtype)[kv_lora_rank // 2 + 8:] + + torch.testing.assert_close(kv_nope, ref_nope, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_scales, ref_scales, atol=0.001, rtol=0.1) + torch.testing.assert_close(kv_rope, ref_rope, atol=0.001, rtol=0.1) + + @pytest.mark.parametrize("kv_lora_rank", KV_LORA_RANKS) @pytest.mark.parametrize("qk_rope_head_dim", QK_ROPE_HEAD_DIMS) @pytest.mark.parametrize("block_size", BLOCK_SIZES_MLA) diff --git a/tests/kernels/attention/test_deepgemm_attention.py b/tests/kernels/attention/test_deepgemm_attention.py new file mode 100644 index 000000000000..50c547b84be6 --- /dev/null +++ b/tests/kernels/attention/test_deepgemm_attention.py @@ -0,0 +1,299 @@ +import random +import pytest +import torch + +from vllm.platforms import current_platform +from vllm.utils import has_deep_gemm, cdiv +from vllm.utils.deep_gemm import ( + _ceil_to_ue8m0, + fp8_mqa_logits, + calc_diff, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, + get_num_sms, +) + + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + # x: (num_blocks, block_size, 1, head_dim) + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty( + (num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8, + ) + x_fp8[:, : block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim + ).view(dtype=torch.uint8) + x_fp8[:, block_size * head_dim :] = sf.view(num_blocks, block_size).view( + dtype=torch.uint8 + ) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + + +def per_custom_dims_cast_to_fp8( + x: torch.Tensor, dims: tuple, use_ue8m0: bool +) -> tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = _ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def _generate_cp_test_data(seq_len: int, seq_len_kv: int): + assert seq_len_kv % seq_len == 0 and seq_len % 2 == 0 + chunk_size = seq_len // 2 + cp_size = seq_len_kv // seq_len + cp_id = cp_size // 3 + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.zeros(seq_len, dtype=torch.int, device="cuda") + for i in range(chunk_size): + ke[i] = cp_id * chunk_size + i + ke[i + chunk_size] = (cp_size * 2 - 1 - cp_id) * chunk_size + i + return ks, ke + + +def _ref_fp8_mqa_logits( + q: torch.Tensor, + kv: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + seq_len_kv = kv.shape[0] + + k = kv + q = q.float() + k = k.float() + + mask_lo = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + >= cu_seqlen_ks[:, None] + ) + mask_hi = ( + torch.arange(0, seq_len_kv, device="cuda")[None, :] + < cu_seqlen_ke[:, None] + ) + mask = mask_lo & mask_hi + + score = torch.einsum("mhd,nd->hmn", q, k) + logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0) + logits = logits.masked_fill(~mask, float("-inf")) + + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +def test_deepgemm_fp8_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + num_heads, head_dim = 32, 128 + for seq_len in (512,): + for seq_len_kv in (1024,): + for disable_cp in (False, True): + q = torch.randn( + seq_len, + num_heads, + head_dim, + device="cuda", + dtype=torch.bfloat16, + ) + kv = torch.randn( + seq_len_kv, head_dim, device="cuda", dtype=torch.bfloat16 + ) + weights = torch.randn( + seq_len, num_heads, device="cuda", dtype=torch.float32 + ) + + if disable_cp: + ks = torch.zeros(seq_len, dtype=torch.int, device="cuda") + ke = torch.arange( + seq_len, dtype=torch.int, device="cuda" + ) + (seq_len_kv - seq_len) + else: + ks, ke = _generate_cp_test_data(seq_len, seq_len_kv) + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_fp8 = per_custom_dims_cast_to_fp8(kv, (0,), False) + logits = fp8_mqa_logits(q_fp8, kv_fp8, weights, ks, ke) + + ref_logits = _ref_fp8_mqa_logits( + q=q, + kv=kv, + weights=weights, + cu_seqlen_ks=ks, + cu_seqlen_ke=ke, + ) + + ref_neginf_mask = ref_logits == float("-inf") + neginf_mask = logits == float("-inf") + assert torch.equal(neginf_mask, ref_neginf_mask) + + ref_logits = ref_logits.masked_fill(ref_neginf_mask, 0) + logits = logits.masked_fill(neginf_mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + + +def _ref_fp8_paged_mqa_logits( + q: torch.Tensor, + kv_cache: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + max_model_len: int, +): + batch_size, next_n, _, _ = q.size() + _, block_size, _, _ = kv_cache.size() + logits = torch.full( + [batch_size * next_n, max_model_len], + float("-inf"), + device=q.device, + dtype=torch.float32, + ) + context_lens_list = context_lens.tolist() + for i in range(batch_size): + context_len = context_lens_list[i] + q_offsets = torch.arange( + context_len - next_n, context_len, device="cuda" + ) + weight_slice = ( + weights[i * next_n : (i + 1) * next_n, :] + .transpose(0, 1) + .contiguous() + ) + for block_rk in range(cdiv(context_len, block_size)): + block_idx = block_tables[i][block_rk] + qx, kx = q[i], kv_cache[block_idx] + k_offsets = torch.arange( + block_rk * block_size, + (block_rk + 1) * block_size, + device="cuda", + ) + mask = (k_offsets[None, :] < context_len) & ( + k_offsets[None, :] <= q_offsets[:, None] + ) + s = torch.where( + mask[None, :, :], + (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to( + logits.dtype + ), + float("-inf"), + ) + s = torch.relu(s) * weight_slice[..., None] + s = s.sum(dim=0) + logits[ + i * next_n : (i + 1) * next_n, + block_rk * block_size : (block_rk + 1) * block_size, + ] = torch.where( + k_offsets[None, :] <= q_offsets[:, None], s, float("-inf") + ) + return logits + + +@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA only") +@pytest.mark.skipif(not has_deep_gemm(), reason="DeepGEMM not available") +def test_deepgemm_fp8_paged_mqa_logits(): + torch.manual_seed(0) + random.seed(0) + + max_model_len = 4096 + for batch_size, next_n in [(4, 1), (2, 2)]: + for heads, index_dim in [(32, 128)]: + for avg_kv in (2048,): + num_blocks, blocksize = max_model_len * 2, 64 + + q = torch.randn( + (batch_size, next_n, heads, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + kv_cache = torch.randn( + (num_blocks, blocksize, 1, index_dim), + device="cuda", + dtype=torch.bfloat16, + ) + weights = torch.randn( + (batch_size * next_n, heads), + device="cuda", + dtype=torch.float32, + ) + + context_lens = ( + torch.randint( + int(0.8 * avg_kv), int(1.2 * avg_kv), (batch_size,) + ) + .cuda() + .to(torch.int32) + ) + max_block_len = ( + (context_lens.max().item() + blocksize - 1) + // blocksize + * blocksize + ) + block_tables = torch.zeros( + (batch_size, max_block_len), + device="cuda", + dtype=torch.int32, + ) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(batch_size): + ctx_len = int(context_lens[i].item()) + for j in range((ctx_len + blocksize - 1) // blocksize): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + q_fp8 = q.to(torch.float8_e4m3fn) + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + context_lens, blocksize, get_num_sms() + ) + logits = fp8_paged_mqa_logits( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + ) + + ref_logits = _ref_fp8_paged_mqa_logits( + q, + kv_cache, + weights, + context_lens, + block_tables, + max_model_len, + ) + + positions = ( + torch.arange(max_model_len, device="cuda") + .unsqueeze(0) + .expand(batch_size * next_n, -1) + ) + row_indices = ( + torch.arange(batch_size * next_n, device="cuda") // next_n + ) + next_n_offset = ( + torch.arange(batch_size * next_n, device="cuda") % next_n + ) + mask = positions <= ( + context_lens[row_indices] - next_n + next_n_offset + ).unsqueeze(1) + + logits = logits.masked_fill(~mask, 0) + ref_logits = ref_logits.masked_fill(~mask, 0) + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" diff --git a/tests/kernels/attention/test_flashmla.py b/tests/kernels/attention/test_flashmla.py index abcfe828d5ac..bddd7e5c50ed 100644 --- a/tests/kernels/attention/test_flashmla.py +++ b/tests/kernels/attention/test_flashmla.py @@ -97,18 +97,16 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, block_size, causal, descale_k = None def flash_mla(): - return flash_mla_with_kvcache( - q, - blocked_k, - block_table, - cache_seqlens, - dv, - tile_scheduler_metadata, - num_splits, - causal=causal, - descale_q=descale_q, - descale_k=descale_k, - ) + return flash_mla_with_kvcache(q, + blocked_k, + block_table, + cache_seqlens, + dv, + tile_scheduler_metadata, + num_splits, + causal=causal, + descale_q=descale_q, + descale_k=descale_k) def scaled_dot_product_attention(query, key, value, is_causal=False): query = query.float() diff --git a/tests/kernels/attention/test_flashmla_sparse.py b/tests/kernels/attention/test_flashmla_sparse.py new file mode 100644 index 000000000000..62ff7f65a0a2 --- /dev/null +++ b/tests/kernels/attention/test_flashmla_sparse.py @@ -0,0 +1,110 @@ +import pytest +import torch + + +def _cuda_sm90_available() -> bool: + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major == 9 + + +def test_sparse_flashmla_metadata_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 128 + num_heads_k = 1 + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + topk = 128 + + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + + tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True) + assert tile_md.dtype == torch.int32 + assert num_splits.dtype == torch.int32 + + +def test_sparse_flashmla_decode_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + batch_size = 1 + seqlen_q = 1 + num_heads_q = 1 + head_dim_k = 576 + head_dim_v = 512 + num_heads_k = 1 + page_block_size = 64 + bytes_per_token = 656 + topk = 128 + + # Metadata + q_seq_per_hk = seqlen_q * num_heads_q // num_heads_k + q_heads_per_hk = num_heads_q // num_heads_k + cache_seqlens = torch.zeros(batch_size, dtype=torch.int32, device=device) + tile_md, num_splits = fm.get_mla_metadata(cache_seqlens, + + q_seq_per_hk, + num_heads_k, + num_heads_q=num_heads_q, + topk=topk, + is_fp8_kvcache=True) + + # Inputs + q = torch.zeros((batch_size, seqlen_q, num_heads_q, head_dim_k), + dtype=torch.bfloat16, + device=device) + k_cache = torch.zeros((1, page_block_size, num_heads_k, bytes_per_token), + dtype=torch.uint8, + device=device) + indices = torch.zeros((batch_size, seqlen_q, topk), + dtype=torch.int32, + device=device) + + block_table = torch.zeros((batch_size, 128), dtype=torch.int32, device=device) + out, lse = fm.flash_mla_with_kvcache(q, k_cache, block_table, cache_seqlens, + head_dim_v, tile_md, + num_splits, indices=indices, is_fp8_kvcache=True) + assert out.shape[0] == batch_size + assert out.shape[-1] == head_dim_v + assert lse.shape[0] == batch_size + + +def test_sparse_flashmla_prefill_smoke(): + import vllm.attention.ops.flashmla as fm + ok, reason = fm.is_flashmla_supported() + if not ok or not _cuda_sm90_available(): + pytest.skip(reason or "SM90 not available") + + device = torch.device("cuda") + s_q = 1 + s_kv = 1 + h_q = 64 # kernel expects multiple of 64 + h_kv = 1 + d_qk = 576 + d_v = 512 + topk = 128 + + q = torch.zeros((s_q, h_q, d_qk), dtype=torch.bfloat16, device=device) + kv = torch.zeros((s_kv, h_kv, d_qk), dtype=torch.bfloat16, device=device) + indices = torch.zeros((s_q, h_kv, topk), dtype=torch.int32, device=device) + + out, max_logits, lse = fm.flash_mla_sparse_prefill(q, kv, indices, 1.0, d_v) + assert out.shape == (s_q, h_q, d_v) + assert max_logits.shape == (s_q, h_q) + assert lse.shape == (s_q, h_q) + diff --git a/tests/kernels/attention/test_indexer.py b/tests/kernels/attention/test_indexer.py new file mode 100644 index 000000000000..5ed6c212e528 --- /dev/null +++ b/tests/kernels/attention/test_indexer.py @@ -0,0 +1,224 @@ +import random + +import torch + +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.indexer import kv_spans_from_batches +from vllm.utils.tile_lang_kernels import act_quant, fp8_index +from vllm import _custom_ops as ops +from vllm.model_executor.models.deepseek_v2 import indexer_k_quant_and_cache +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + calc_diff, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, + get_num_sms, +) + +def kv_cache_cast_to_fp8(x: torch.Tensor) -> torch.Tensor: + num_blocks, block_size, num_heads, head_dim = x.shape + assert num_heads == 1 + x_amax = x.abs().float().amax(dim=3, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + x_fp8 = torch.empty((num_blocks, block_size * (head_dim + 4)), + device=x.device, + dtype=torch.uint8) + x_fp8[:, :block_size * head_dim] = x_scaled.view( + num_blocks, block_size * head_dim).view(dtype=torch.uint8) + x_fp8[:, + block_size * head_dim:] = sf.view(num_blocks, + block_size).view(dtype=torch.uint8) + return x_fp8.view(num_blocks, block_size, num_heads, head_dim + 4) + +def ref_compute_logits_fp8(q, kv, weights, mask, block_size): + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + k_fp8, k_scale = act_quant(kv, block_size, "ue8m0") + + weights = weights.unsqueeze(-1) * q_scale + weights = weights * (128**(-0.5)) + index_score = fp8_index( + q_fp8.contiguous(), weights, + k_fp8.contiguous(), + k_scale.contiguous()) + if mask is not None: + index_score += mask + return index_score + +def ref_indexer(seq_len, q, kv, weights, block_size, topk): + B = seq_len.shape[0] + total_seqlen = torch.sum(seq_len) + varlen_logits = torch.full((total_seqlen, total_seqlen), float("-inf"), device="cuda") + + current_context_ptr = 0 + for i in range(B): + S = seq_len[i] + q_s = q[i][:S].contiguous().unsqueeze(0) + kv_s = kv[i][:S].contiguous().unsqueeze(0) + weights_s = weights[i][:S].contiguous().unsqueeze(0) + mask = torch.full( + (S, S), float("-inf"), + device="cuda").triu_(1) + logits = ref_compute_logits_fp8(q_s, kv_s, weights_s, mask, block_size) + logits = logits.squeeze(0) + + varlen_logits[current_context_ptr:current_context_ptr + S, current_context_ptr: current_context_ptr + S] = logits + current_context_ptr += S + return varlen_logits + +def deepgemm_mqa_indexer(seq_len, query_seq_len, q, kv, weights, block_size, topk, is_kv_batched=True): + B = seq_len.shape[0] + concat_q = [] + concat_kv = [] + concat_weights = [] + + for i in range(B): + S = seq_len[i] + q_s = q[i][:S].contiguous() + if is_kv_batched: + kv_s = kv[i][:S].contiguous() + weight_s = weights[i][:S].contiguous() + concat_q.append(q_s) + if is_kv_batched: + concat_kv.append(kv_s) + concat_weights.append(weight_s) + + q = torch.cat(concat_q, dim=0) + if is_kv_batched: + kv = torch.cat(concat_kv, dim=0) + weights = torch.cat(concat_weights, dim=0) + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + kv_fp8, kv_scale = act_quant(kv, block_size, "ue8m0") + + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale + weights = weights.squeeze(-1) + query_start_loc = torch.empty((B + 1), device="cuda") + query_start_loc[0] = 0 + query_start_loc[1:] = query_seq_len.cumsum(dim=0).to(dtype=torch.int32) + + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(query_start_loc, seq_len) + + logits = fp8_mqa_logits( + q_fp8, + (kv_fp8, kv_scale), + weights, + cu_seqlen_ks, + cu_seqlen_ke + ) + topk_indices = logits.topk(topk, dim=-1)[1] + mask_lo = topk_indices >= cu_seqlen_ks[:, None] + mask_hi = topk_indices < cu_seqlen_ke[:, None] + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + return logits + +def test_prefill_indexer(): + B = 3 + S = 128 + SKV = S + H = 64 + HKV = 1 + D = 128 + block_size = 128 + topk = 64 + device = "cuda" + seq_len = torch.randint(low=64, high=S, size=(B,)) + + q = torch.randn(B, S, H, D, device="cuda", + dtype=torch.bfloat16) + kv = torch.randn(B, SKV, D, device="cuda", + dtype=torch.bfloat16) + weights = torch.randn(B, S, H, device=device, dtype=torch.float32) * H**-0.5 + + ref_logits = ref_indexer(seq_len, q, kv, weights, block_size, topk) + deepgemm_logits = deepgemm_mqa_indexer(seq_len, seq_len, q, kv, weights, block_size, topk) + torch.testing.assert_close(ref_logits, deepgemm_logits) + + +def test_decode_paged_indexer(): + num_blocks, blocksize = 111 * 3000, 64 + B = 3 + S = 128 + SKV = S + H = 64 + HKV = 1 + D = 128 + block_size = 128 + topk = 64 + device = "cuda" + seq_len = torch.randint(low=64, high=S, size=(B,), device="cuda") + + query_seq_len = torch.ones(B, device="cuda") + + q = torch.randn((B, 1, H, D), + device='cuda', + dtype=torch.bfloat16) + kv_cache = torch.randn((num_blocks, blocksize, 1, D), + device='cuda', + dtype=torch.bfloat16) + weights = torch.randn((B * 1, H), + device='cuda', + dtype=torch.float32) * H**-0.5 + max_block_len = (seq_len.max().item() + blocksize - + 1) // blocksize * blocksize + + block_tables = torch.zeros((B, max_block_len), + device='cuda', + dtype=torch.int32) + + counter = 0 + block_idx_pool = list(range(num_blocks)) + random.shuffle(block_idx_pool) + for i in range(B): + ctx_len = seq_len[i].item() + for j in range(cdiv(ctx_len, blocksize)): + block_tables[i][j] = block_idx_pool[counter] + counter += 1 + + flatten_kv = torch.empty( + [seq_len.sum(), D], device="cuda", dtype=torch.bfloat16 + ) + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32, device=device), + seq_len.cumsum(dim=0) + ]).to(torch.int32).cuda() + + ops.cp_gather_cache( + kv_cache, + flatten_kv, + block_tables, + cu_seq_lens, + B, + ) + + ref_logits = deepgemm_mqa_indexer(seq_len, query_seq_len, q, flatten_kv, weights, block_size, topk, is_kv_batched=False) + + q_fp8, q_scale = act_quant(q, block_size, "ue8m0") + kv_cache_fp8 = kv_cache_cast_to_fp8(kv_cache) + + schedule_metadata = get_paged_mqa_logits_metadata( + seq_len.int(), blocksize, get_num_sms()) + + weights = weights.unsqueeze(-1) * (128**(-0.5)) * q_scale.squeeze(1) + weights = weights.squeeze(-1) + + logits = fp8_paged_mqa_logits( + q_fp8, kv_cache_fp8, weights, seq_len.int(), block_tables, + schedule_metadata, 4096) + + concat_logit = [] + context = 0 + for i in range(B): + per_seq_logits = torch.zeros(4096, device="cuda") + S = seq_len[i] + per_seq_logits[:S] = ref_logits[i][context: context + S] + concat_logit.append(per_seq_logits) + context += S + ref_logits = torch.stack(concat_logit, dim=0) + logits[logits == float("-inf")] = 0 + diff = calc_diff(logits, ref_logits) + assert diff < 1e-3, f"{diff=}" + +if __name__ == "__main__": + test_prefill_indexer() + test_decode_paged_indexer() \ No newline at end of file diff --git a/tests/kernels/attention/test_pack_unpack_triton.py b/tests/kernels/attention/test_pack_unpack_triton.py new file mode 100644 index 000000000000..59a9b64eebff --- /dev/null +++ b/tests/kernels/attention/test_pack_unpack_triton.py @@ -0,0 +1,230 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from torch.testing import assert_close + +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton + + +def test_pack_seq_basic_fp8(): + """Test basic functionality of pack_seq_triton with fp8 and 3D tensors.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors (N, H, D) + test_cases = [ + (6, 8, 4, 2, [3, 3]), # (6, 8, 4) -> (2, 3, 8, 4) + (10, 4, 8, 3, [2, 4, 4]), # (10, 4, 8) -> (3, 4, 4, 8) + (20, 16, 32, 4, [5, 5, 5, 5]), # (20, 16, 32) -> (4, 5, 16, 32) + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Check output shape and properties + expected_shape = (B, max(lengths_list), H, D) + assert packed.shape == expected_shape + assert packed.dtype == dtype + assert packed.device == x.device + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = sum(lengths_list[:b]) + seq_len = lengths_list[b] + + expected_data = x[start_idx:start_idx + seq_len].to(torch.float32) + actual_data = packed[b, :seq_len].to(torch.float32) + + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_seq_custom_padding_fp8(): + """Test pack_seq_triton with custom padding values for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test with different padding values + for pad_value in [-100.0, -10.0, 0.0, 10.0, 100.0]: + result = pack_seq_triton(x, lengths, pad_value=pad_value) + + # Check valid data + for b in range(B): + start_idx = b * 10 + expected_data = x[start_idx:start_idx + 10].to(torch.float32) + actual_data = result[b, :10].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + # Check padding (fp8 has limited range, so check for large values) + padded_data = result[:, 10:].to(torch.float32) + if pad_value < 0: + assert torch.all(padded_data < -50) # Large negative values + elif pad_value > 0: + assert torch.all(padded_data > 50) # Large positive values + else: + assert torch.allclose(padded_data, torch.zeros_like(padded_data), atol=1e-2) + + +def test_pack_seq_default_negative_inf_padding_fp8(): + """Test that pack_seq_triton uses -inf padding by default for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + result = pack_seq_triton(x, lengths) + + # Check that padding is large negative values (fp8 representation of -inf) + padded_data = result[:, 10:].to(torch.float32) + assert torch.all(padded_data < -100) # fp8 -inf is represented as large negative number + + +def test_pack_seq_edge_cases_fp8(): + """Test pack_seq_triton with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (1, 10, 8, 16) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 1, 4, 8) + + # Test with different sequence lengths + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + result = pack_seq_triton(x, lengths) + assert result.shape == (3, 7, 8, 16) + + +def test_pack_seq_different_block_sizes_fp8(): + """Test pack_seq_triton with different block sizes for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 100, 16, 32, 4 + lengths = torch.tensor([25, 25, 25, 25], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + # Test different block sizes + for block_t, block_d in [(32, 32), (64, 64), (128, 128)]: + result = pack_seq_triton(x, lengths, block_t=block_t, block_d=block_d) + + assert result.shape == (B, 25, H, D) + + # Check that valid data is preserved (within fp8 precision) + for b in range(B): + start_idx = b * 25 + expected_data = x[start_idx:start_idx + 25].to(torch.float32) + actual_data = result[b, :25].to(torch.float32) + assert_close(actual_data, expected_data, rtol=1e-1, atol=1e-2) + + +def test_pack_seq_shape_consistency(): + """Test that pack_seq_triton maintains shape consistency.""" + device = "cuda" + dtype = torch.float8_e4m3fn + N, H, D, B = 20, 8, 16, 2 + lengths = torch.tensor([10, 10], device=device) + + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + + result = pack_seq_triton(x, lengths) + + # Check shape consistency + assert result.shape[0] == B # Batch dimension + assert result.shape[1] == lengths.max().item() # Max sequence length + assert result.shape[2:] == x.shape[1:] # Feature dimensions preserved + + +def test_pack_unpack_roundtrip_fp8(): + """Test that pack -> unpack gives us back the original data for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test cases with 3D tensors + test_cases = [ + (6, 8, 4, 2, [3, 3]), + (10, 4, 8, 3, [2, 4, 4]), + (20, 16, 32, 4, [5, 5, 5, 5]), + (15, 8, 16, 3, [7, 5, 3]), + ] + + for N, H, D, B, lengths_list in test_cases: + # Create input tensor with small values for fp8 + x = torch.randn(N, H, D, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor(lengths_list, device=device) + + # Pack the data + packed = pack_seq_triton(x, lengths) + + # Unpack the data + unpacked = unpack_seq_triton(packed, lengths) + + # Check that we get back the original data (within fp8 precision) + assert unpacked.shape == x.shape + x_f32 = x.to(torch.float32) + unpacked_f32 = unpacked.to(torch.float32) + assert_close(x_f32, unpacked_f32, rtol=1e-3, atol=1e-3) + + # Unpack without explicit start locations (computed in kernel) + unpacked_with_loc = unpack_seq_triton(packed, lengths) + assert_close(x_f32, unpacked_with_loc.to(torch.float32), rtol=1e-3, atol=1e-2) + + +def test_unpack_seq_triton_edge_cases_fp8(): + """Test unpack function with edge cases for fp8.""" + device = "cuda" + dtype = torch.float8_e4m3fn + + # Test with single batch element + x = torch.randn(10, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([10], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + # Test with very short sequences + x = torch.randn(20, 4, 8, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([1, 1, 1], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + # Only compare the first 3 elements that were actually packed + assert_close(x[:3].to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) + + x = torch.randn(15, 8, 16, dtype=torch.float32, device=device) * 0.1 + x = x.to(dtype=dtype) + lengths = torch.tensor([5, 7, 3], device=device) + packed = pack_seq_triton(x, lengths) + unpacked = unpack_seq_triton(packed, lengths) + assert unpacked.shape == x.shape + assert_close(x.to(torch.float32), unpacked.to(torch.float32), rtol=1e-1, atol=1e-2) \ No newline at end of file diff --git a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py index 8d0687b49bb4..30d721304b5c 100644 --- a/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py +++ b/tests/plugins/vllm_add_dummy_platform/vllm_add_dummy_platform/dummy_platform.py @@ -26,5 +26,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: def get_attn_backend_cls(self, backend_name, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink): + has_sink, use_sparse): return "vllm_add_dummy_platform.dummy_attention_backend.DummyAttentionBackend" # noqa E501 diff --git a/tests/v1/attention/test_mla_backends.py b/tests/v1/attention/test_mla_backends.py index a62993950aff..a5a7af7c0250 100644 --- a/tests/v1/attention/test_mla_backends.py +++ b/tests/v1/attention/test_mla_backends.py @@ -5,6 +5,8 @@ import pytest import torch +from vllm import _custom_ops as ops + from tests.v1.attention.utils import (BatchSpec, _Backend, create_common_attn_metadata, create_standard_kv_cache_spec, @@ -78,7 +80,9 @@ def create_and_prepopulate_kv_cache( device: torch.device, num_blocks: int, common_attn_metadata: CommonAttentionMetadata, - randomize_blocks: bool = True) -> torch.Tensor: + randomize_blocks: bool = True, + kv_cache_dtype: str | None = None, + scale: float | torch.Tensor = 1.0) -> torch.Tensor: """Create and prepopulate an MLA KV cache with context data. Args: @@ -93,6 +97,11 @@ def create_and_prepopulate_kv_cache( common_attn_metadata: Common attention metadata randomize_blocks: Whether to randomly permute blocks or use sequential order + kv_cache_dtype: Optional kv cache dtype string. When set to + "fp8_ds_mla" the cache is populated using the + fp8 DeepSeek MLA layout via concat_and_cache_mla. + scale: Scaling factor forwarded to concat_and_cache_mla when the + fp8 cache layout is requested. Returns: MLA KV cache tensor @@ -105,23 +114,62 @@ def create_and_prepopulate_kv_cache( block_table = common_attn_metadata.block_table_tensor slot_mapping = common_attn_metadata.slot_mapping - # Create MLA KV cache: (num_blocks, block_size, head_size) - kv_cache = torch.empty(num_blocks, - block_size, - head_size, - dtype=dtype, - device=device) - kv_cache_flat = kv_cache.view(-1, head_size) + use_fp8_ds_mla = kv_cache_dtype == "fp8_ds_mla" + + if use_fp8_ds_mla: + if not kv_c_contexts: + raise ValueError("kv_c_contexts cannot be empty when using" + " fp8_ds_mla cache dtype") + kv_lora_rank = kv_c_contexts[0].shape[-1] + rope_dim = k_pe_contexts[0].shape[-1] + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + kv_cache = torch.zeros(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + device=device) + scale_tensor = (scale if isinstance(scale, torch.Tensor) else + torch.tensor(scale, dtype=torch.float32, + device=device)) + scale_tensor = scale_tensor.to(device=device, dtype=torch.float32) + else: + # Create MLA KV cache: (num_blocks, block_size, head_size) + kv_cache = torch.empty(num_blocks, + block_size, + head_size, + dtype=dtype, + device=device) + kv_cache_flat = kv_cache.view(-1, head_size) # Populate the cache with the context tokens # Start from block_id=1 since block_id=0 is considered the null block start_block_idx = 1 for i in range(batch_size): kv_c_context, k_pe_context = kv_c_contexts[i], k_pe_contexts[i] - kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], dim=-1) + context_len = kv_c_context.shape[0] + if context_len == 0: + start_block_idx += cdiv(int(seq_lens[i]), block_size) + continue + start = start_block_idx * block_size - end = start + kv_context.shape[0] - kv_cache_flat[start:end, ...] = kv_context + + if use_fp8_ds_mla: + slots = torch.arange(context_len, + device=device, + dtype=torch.long) + start + ops.concat_and_cache_mla( + kv_c_context, + k_pe_context.squeeze(1), + kv_cache, + slots, + kv_cache_dtype="fp8_ds_mla", + scale=scale_tensor, + ) + else: + kv_context = torch.cat([kv_c_context, k_pe_context.squeeze(1)], + dim=-1) + end = start + kv_context.shape[0] + kv_cache_flat[start:end, ...] = kv_context # Stay block aligned and allocate enough blocks for the new tokens start_block_idx += cdiv(int(seq_lens[i]), block_size) diff --git a/tests/v1/attention/test_sparse_mla_backends.py b/tests/v1/attention/test_sparse_mla_backends.py new file mode 100644 index 000000000000..74eea6f716fe --- /dev/null +++ b/tests/v1/attention/test_sparse_mla_backends.py @@ -0,0 +1,426 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for the FlashMLA sparse backend utilities.""" + +import math +from types import MethodType, SimpleNamespace + +import numpy as np +import pytest +import torch + +from tests.v1.attention.test_mla_backends import ( + BATCH_SPECS, BatchSpec, MockAttentionLayer, + create_and_prepopulate_kv_cache) +from tests.v1.attention.utils import (create_common_attn_metadata, + create_standard_kv_cache_spec, + create_vllm_config) +from vllm import _custom_ops as ops +from vllm.attention.ops import flashmla +from vllm.model_executor.layers.linear import ColumnParallelLinear +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.flashmla_sparse import ( + FlashMLASparseBackend, FlashMLASparseDecodeAndContextMetadata, + FlashMLASparseImpl, FlashMLASparseMetadata) + +SPARSE_BACKEND_BATCH_SPECS = { + name: BATCH_SPECS[name] + for name in [ + "mixed_small", + "mixed_medium", + "small_prefill", + "medium_prefill", + "single_prefill", + ] +} + +SPARSE_BACKEND_BATCH_SPECS["large_q_prefill"] = BatchSpec(seq_lens=[1024] * 2, + query_lens=[256] * 2) +SPARSE_BACKEND_BATCH_SPECS["large_q_pure_prefill"] = BatchSpec( + seq_lens=[256] * 2, query_lens=[256] * 2) + + +def _dequantize_fp8_ds_mla_entry( + cache_slice: torch.Tensor, kv_lora_rank: int, rope_dim: int, + dtype: torch.dtype) -> tuple[torch.Tensor, torch.Tensor]: + """Dequantize a single fp8_ds_mla cache entry back to latent + rope.""" + + # The first kv_lora_rank bytes store FP8 latent values with one scale per + # 128 element tile written as float32 right after the latent payload. + scales = cache_slice.view(torch.float32)[kv_lora_rank // + 4:kv_lora_rank // 4 + 4] + latent = torch.empty(kv_lora_rank, + dtype=torch.float16, + device=cache_slice.device) + for tile_idx in range(4): + tile_start = tile_idx * 128 + tile_end = tile_start + 128 + ops.convert_fp8(latent[tile_start:tile_end], + cache_slice[tile_start:tile_end], + float(scales[tile_idx].item()), + kv_dtype="fp8") + latent = latent.to(dtype) + + rope_offset = kv_lora_rank // 2 + 8 + rope_vals = cache_slice.view(dtype)[rope_offset:rope_offset + rope_dim] + return latent, rope_vals.clone() + + +def _quantize_dequantize_fp8_ds_mla( + kv_c: torch.Tensor, k_pe: torch.Tensor, block_size: int, + scale: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Round-trip kv_c/k_pe though the fp8_ds_mla cache layout.""" + + if kv_c.numel() == 0: + return kv_c.clone(), k_pe.clone() + + kv_lora_rank = kv_c.shape[-1] + rope_dim = k_pe.shape[-1] + num_tokens = kv_c.shape[0] + num_blocks = max(1, math.ceil(num_tokens / block_size)) + entry_size = kv_lora_rank + 4 * 4 + 2 * rope_dim + + tmp_cache = torch.zeros(num_blocks, + block_size, + entry_size, + dtype=torch.uint8, + device=kv_c.device) + slot_mapping = torch.arange(num_tokens, + dtype=torch.long, + device=kv_c.device) + + ops.concat_and_cache_mla(kv_c, + k_pe, + tmp_cache, + slot_mapping, + kv_cache_dtype="fp8_ds_mla", + scale=scale) + + dequant_kv_c = torch.empty_like(kv_c) + dequant_k_pe = torch.empty_like(k_pe) + + for token_idx in range(num_tokens): + slot = slot_mapping[token_idx].item() + block_idx = slot // block_size + block_offset = slot % block_size + cache_slice = tmp_cache[block_idx, block_offset] + latent, rope_vals = _dequantize_fp8_ds_mla_entry( + cache_slice, kv_lora_rank, rope_dim, kv_c.dtype) + dequant_kv_c[token_idx] = latent + dequant_k_pe[token_idx] = rope_vals + + return dequant_kv_c, dequant_k_pe + + +def test_sparse_backend_metadata_registration(): + backend = FlashMLASparseBackend + + assert backend.get_name() == "FLASHMLA_SPARSE_VLLM_V1" + assert backend.get_metadata_cls() is FlashMLASparseMetadata + assert backend.get_impl_cls() is FlashMLASparseImpl + + dtype_list = backend.get_supported_dtypes() + assert torch.bfloat16 in dtype_list + + shape = backend.get_kv_cache_shape(num_blocks=2, + block_size=64, + num_kv_heads=1, + head_size=576) + assert shape == (2, 64, 576) + + +def test_sparse_decode_metadata_filters_prefill_indices(): + prefill_context_lengths = torch.tensor([4, 2], dtype=torch.int32) + metadata = FlashMLASparseDecodeAndContextMetadata( + scheduler_metadata=torch.tensor([[0]], dtype=torch.int32), + num_splits=torch.tensor([1, 1], dtype=torch.int32), + cache_lens=torch.tensor([10, 12], dtype=torch.int32), + prefill_context_lengths=prefill_context_lengths, + ) + + indices = torch.tensor([[0, 3, 5], [1, 2, 4]], dtype=torch.int32) + + context_indices, new_token_indices = metadata.filter_prefill_indices( + indices) + + expected_context = torch.tensor([[-1, -1, 5], [-1, -1, 4]], + dtype=torch.int32) + expected_new_tokens = torch.tensor([[-1, -1, 1], [-1, 0, 2]], + dtype=torch.int32) + + assert torch.equal(context_indices, expected_context) + assert torch.equal(new_token_indices, expected_new_tokens) + + +def test_sparse_impl_zero_fills_when_metadata_missing(): + impl = FlashMLASparseImpl.__new__(FlashMLASparseImpl) + dummy_layer = object() + q = torch.zeros((2, 1, 3)) + k_c = torch.zeros((2, 3)) + k_pe = torch.zeros((2, 1, 1)) + kv_cache = torch.zeros((1, 1, 1)) + output = torch.ones((2, 4)) + + result = FlashMLASparseImpl.forward(impl, + dummy_layer, + q, + k_c, + k_pe, + kv_cache, + attn_metadata=None, + output=output) + + assert result is output + assert torch.all(result == 0) + + +@pytest.mark.parametrize("batch_name", list(SPARSE_BACKEND_BATCH_SPECS.keys())) +@pytest.mark.parametrize("kv_cache_dtype", ["fp8_ds_mla", "auto"]) +def test_sparse_backend_decode_correctness(dist_init, batch_name, + kv_cache_dtype): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required for sparse MLA decode test") + + device = torch.device("cuda") + dtype = torch.bfloat16 + + batch_spec = SPARSE_BACKEND_BATCH_SPECS[batch_name] + + # Model hyper-parameters (kept intentionally small for the unit test) + num_heads = 128 + kv_lora_rank = 512 + qk_nope_head_dim = 128 + qk_rope_head_dim = 64 + v_head_dim = 128 + head_size = kv_lora_rank + qk_rope_head_dim + topk_tokens = 2048 + + max_seqlen = max(batch_spec.seq_lens) + total_cache_tokens = sum(batch_spec.seq_lens) + block_size = 64 + + vllm_config = create_vllm_config( + model_name="deepseek-ai/DeepSeek-V2-Lite-Chat", + max_model_len=max_seqlen, + num_gpu_blocks=max(2048, + cdiv(total_cache_tokens, block_size) + 1), + block_size=block_size) + model_config = vllm_config.model_config + model_config.hf_config = SimpleNamespace( + attn_module_list_cfg=[{ + "topk_tokens": topk_tokens + }]) + model_config.hf_text_config = SimpleNamespace( + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + v_head_dim=v_head_dim, + model_type="deepseek_v2", + ) + model_config.dtype = dtype + model_config.get_num_attention_heads = MethodType( + lambda self, parallel_config: num_heads, model_config) + model_config.get_num_kv_heads = MethodType(lambda self, parallel_config: 1, + model_config) + model_config.get_head_size = MethodType(lambda self: head_size, + model_config) + model_config.get_sliding_window = MethodType(lambda self: None, + model_config) + + kv_cache_spec = create_standard_kv_cache_spec(vllm_config) + + torch.manual_seed(0) + + scale = 1.0 / math.sqrt(head_size) + + # Shared MLA projection weights to keep reference and backend in sync + W_UK = torch.randn(kv_lora_rank, + num_heads, + qk_nope_head_dim, + dtype=dtype, + device=device) + W_UV = torch.randn(kv_lora_rank, + num_heads, + v_head_dim, + dtype=dtype, + device=device) + + # Build synthetic decode-only workload + seq_lens = batch_spec.seq_lens + query_lens = batch_spec.query_lens + + all_q_vllm, all_kv_c_vllm, all_k_pe_vllm = [], [], [] + kv_c_contexts, k_pe_contexts = [], [] + reference_outputs = [] + + kv_cache_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + for i in range(batch_spec.batch_size): + s_len = seq_lens[i] + q_len = query_lens[i] + ctx_len = s_len - q_len + + q_c = torch.rand(q_len, + num_heads, + qk_nope_head_dim + qk_rope_head_dim, + dtype=dtype, + device=device) + kv_c_full = torch.rand(s_len, kv_lora_rank, dtype=dtype, device=device) + k_pe_full = torch.rand(s_len, + 1, + qk_rope_head_dim, + dtype=dtype, + device=device) + + kv_c_full, k_pe_full = _quantize_dequantize_fp8_ds_mla( + kv_c_full, + k_pe_full.squeeze(1), + block_size=vllm_config.cache_config.block_size, + scale=kv_cache_scale, + ) + + q_nope, q_pe = q_c.split([qk_nope_head_dim, qk_rope_head_dim], dim=-1) + ql_nope = torch.einsum("qnh,lnh->qnl", q_nope, W_UK) + q_mqa = torch.cat([ql_nope, q_pe], dim=-1) + + k_mqa = torch.cat([kv_c_full, k_pe_full], dim=-1) + k_mqa = k_mqa.unsqueeze(1).expand(-1, num_heads, -1) + v_mqa = kv_c_full.unsqueeze(1).expand(-1, num_heads, -1) + + attn_mask = torch.ones(q_len, s_len, dtype=torch.bool, device=device) + causal_mask = torch.tril(torch.ones(q_len, q_len, device=device)) + attn_mask[:, ctx_len:] = causal_mask + + q_sdpa_in = q_mqa.unsqueeze(0).transpose(1, 2) + k_sdpa_in = k_mqa.unsqueeze(0).transpose(1, 2) + v_sdpa_in = v_mqa.unsqueeze(0).transpose(1, 2) + + sdpa_out = torch.nn.functional.scaled_dot_product_attention( + q_sdpa_in, k_sdpa_in, v_sdpa_in, attn_mask=attn_mask, scale=scale) + sdpa_out = sdpa_out.transpose(1, 2).squeeze(0) + + sdpa_out = torch.einsum("qnl,lnv->qnv", sdpa_out, W_UV) + reference_outputs.append(sdpa_out.flatten(start_dim=-2)) + + all_q_vllm.append(q_c) + all_kv_c_vllm.append(kv_c_full[ctx_len:]) + all_k_pe_vllm.append(k_pe_full[ctx_len:]) + kv_c_contexts.append(kv_c_full[:ctx_len + 1]) + k_pe_contexts.append(k_pe_full[:ctx_len + 1]) + + query_vllm = torch.cat(all_q_vllm, dim=0) + kv_c_vllm = torch.cat(all_kv_c_vllm, dim=0) + k_pe_vllm = torch.cat(all_k_pe_vllm, dim=0) + sdpa_reference = torch.cat(reference_outputs, dim=0) + + vllm_config.cache_config.cache_dtype = kv_cache_dtype + + common_attn_metadata = create_common_attn_metadata( + batch_spec, + vllm_config.cache_config.block_size, + device, + arange_block_indices=True) + + kv_cache = create_and_prepopulate_kv_cache( + kv_c_contexts=kv_c_contexts, + k_pe_contexts=k_pe_contexts, + block_size=vllm_config.cache_config.block_size, + head_size=head_size, + dtype=dtype, + device=device, + num_blocks=vllm_config.cache_config.num_gpu_blocks, + common_attn_metadata=common_attn_metadata, + randomize_blocks=False, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + scale=kv_cache_scale, + ) + + builder_cls = FlashMLASparseBackend.get_builder_cls() + builder = builder_cls(kv_cache_spec, ["placeholder"], vllm_config, device) + metadata = builder.build(common_prefix_len=0, + common_attn_metadata=common_attn_metadata) + + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, + dtype=np.int32) + seg_lengths = np.diff(starts) + positions = np.arange(starts[-1], dtype=np.int32) - np.repeat( + starts[:-1], seg_lengths) + seq_lengths = np.asarray(common_attn_metadata.seq_lens_cpu, dtype=np.int32) + prefix_lengths = seq_lengths - seg_lengths + positions += np.repeat(prefix_lengths, seg_lengths) + + pos_gpu = torch.as_tensor(positions, device=device, dtype=torch.int32) + topk = metadata.topk_tokens + debug_indices = torch.arange(topk, device=device, + dtype=torch.int32).unsqueeze(0) + token_positions = pos_gpu.unsqueeze(1) + causal_mask = (debug_indices <= token_positions) + debug_indices = torch.where(causal_mask, debug_indices, + torch.full_like(debug_indices, -1)) + + # FlashMLASparseImpl now reads top-k indices from the indexer-provided + # buffer, so emulate that contract with a simple namespace mock. + debug_indices = debug_indices.expand(metadata.num_actual_tokens, + -1).clone() + mock_indexer = SimpleNamespace(topk_indices_buffer=debug_indices) + + ok, reason = flashmla.is_flashmla_supported() + if not ok: + pytest.skip(reason) + + kv_b_proj_weight = torch.cat([W_UK, W_UV], dim=-1) + kv_b_proj_weight = kv_b_proj_weight.view( + kv_lora_rank, num_heads * (qk_nope_head_dim + v_head_dim)) + + mock_kv_b_proj = ColumnParallelLinear(input_size=kv_lora_rank, + output_size=num_heads * + (qk_nope_head_dim + v_head_dim), + bias=False).to(device=device, + dtype=dtype) + mock_kv_b_proj.weight = torch.nn.Parameter(kv_b_proj_weight.T.contiguous()) + + impl_cls = FlashMLASparseBackend.get_impl_cls() + impl = impl_cls(num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=1, + alibi_slopes=None, + sliding_window=None, + kv_cache_dtype=vllm_config.cache_config.cache_dtype, + logits_soft_cap=None, + attn_type="decoder", + kv_sharing_target_layer_name=None, + q_lora_rank=None, + kv_lora_rank=kv_lora_rank, + qk_nope_head_dim=qk_nope_head_dim, + qk_rope_head_dim=qk_rope_head_dim, + qk_head_dim=qk_nope_head_dim + qk_rope_head_dim, + v_head_dim=v_head_dim, + kv_b_proj=mock_kv_b_proj, + indexer=mock_indexer) + + impl.process_weights_after_loading(dtype) + + layer = MockAttentionLayer(device) + out_buffer = torch.empty(metadata.num_actual_tokens, + num_heads * v_head_dim, + dtype=dtype, + device=device) + + backend_output = impl.forward(layer, + query_vllm, + kv_c_vllm, + k_pe_vllm, + kv_cache, + metadata, + output=out_buffer) + + assert backend_output.shape == sdpa_reference.shape + assert backend_output.dtype == sdpa_reference.dtype + assert torch.isfinite(backend_output).all() + + torch.testing.assert_close(backend_output, + sdpa_reference, + rtol=0.5, + atol=0.5) diff --git a/tests/v1/attention/utils.py b/tests/v1/attention/utils.py index f07c6eb0ea4d..41b71e33e0c4 100644 --- a/tests/v1/attention/utils.py +++ b/tests/v1/attention/utils.py @@ -168,7 +168,6 @@ def create_standard_kv_cache_spec( vllm_config.parallel_config), head_size=vllm_config.model_config.get_head_size(), dtype=vllm_config.model_config.dtype, - use_mla=vllm_config.model_config.use_mla, sliding_window=vllm_config.model_config.get_sliding_window(), ) diff --git a/tests/v1/core/test_kv_cache_utils.py b/tests/v1/core/test_kv_cache_utils.py index 4cb7ed6ce382..452b16ef4a91 100644 --- a/tests/v1/core/test_kv_cache_utils.py +++ b/tests/v1/core/test_kv_cache_utils.py @@ -24,7 +24,8 @@ make_block_hash_with_group_id) from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - KVCacheTensor, SlidingWindowSpec, + KVCacheTensor, MLAAttentionSpec, + SlidingWindowSpec, UniformTypeKVCacheSpecs) from vllm.v1.metrics.stats import PrefixCacheStats from vllm.v1.request import Request @@ -77,13 +78,11 @@ def new_kv_cache_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, sliding_window=None): return FullAttentionSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, sliding_window=sliding_window) @@ -91,13 +90,11 @@ def new_sliding_window_spec(block_size=16, num_kv_heads=2, head_size=64, dtype=torch.float32, - use_mla=False, sliding_window=1): return SlidingWindowSpec(block_size=block_size, num_kv_heads=num_kv_heads, head_size=head_size, dtype=dtype, - use_mla=use_mla, sliding_window=sliding_window) @@ -894,7 +891,6 @@ def test_merge_kv_cache_spec(): num_kv_heads=full_spec.num_kv_heads, head_size=full_spec.head_size, dtype=full_spec.dtype, - use_mla=full_spec.use_mla, sliding_window=1, ), ] @@ -991,7 +987,6 @@ def test_estimate_max_model_len(model_id, max_model_len, num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) # Estimate the maximum model length, 16384 model_len need 8GB estimated_max_len = estimate_max_model_len(vllm_config, kv_cache_spec, @@ -1022,7 +1017,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, ) sliding_window_spec = SlidingWindowSpec( @@ -1030,7 +1024,6 @@ def test_get_max_concurrency_for_kv_cache_config(): num_kv_heads=32, head_size=128, dtype=torch.float16, - use_mla=False, sliding_window=1024, ) @@ -1412,3 +1405,48 @@ def test_generate_scheduler_kv_cache_config(): KVCacheGroupSpec(['layer_1', 'layer_2'], new_kv_cache_spec()) ], ) + + +def new_mla_spec(cache_dtype_str=None): + return MLAAttentionSpec(block_size=16, + num_kv_heads=16, + head_size=64, + dtype=torch.float32, + cache_dtype_str=cache_dtype_str) + + +def test_merge_mla_spec(): + kv_cache_specs = [ + new_mla_spec(), + new_mla_spec(), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec() + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + ] + mla_spec = kv_cache_specs[0].merge(kv_cache_specs) + assert mla_spec == new_mla_spec(cache_dtype_str="fp8_ds_mla") + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_mla_spec(cache_dtype_str=None), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_kv_cache_spec(), + new_mla_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) + + kv_cache_specs = [ + new_mla_spec(cache_dtype_str="fp8_ds_mla"), + new_kv_cache_spec(), + ] + with pytest.raises(AssertionError): + kv_cache_specs[0].merge(kv_cache_specs) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index 3cf9d9369676..3ddfaf71a1ca 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1337,7 +1337,6 @@ def test_eagle_with_sliding_window(): head_size=1, dtype=torch.float32, sliding_window=block_size, - use_mla=False, ) manager = KVCacheManager( KVCacheConfig( diff --git a/tests/v1/core/test_single_type_kv_cache_manager.py b/tests/v1/core/test_single_type_kv_cache_manager.py index b70850a9bcff..e1a26cfd898f 100644 --- a/tests/v1/core/test_single_type_kv_cache_manager.py +++ b/tests/v1/core/test_single_type_kv_cache_manager.py @@ -35,7 +35,6 @@ def test_chunked_local_attention_possible_cached_prefix(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -101,7 +100,6 @@ def test_sliding_window_possible_cached_prefix(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -167,7 +165,6 @@ def test_chunked_local_attention_remove_skipped_blocks(): head_size=1, dtype=torch.float32, attention_chunk_size=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -219,7 +216,6 @@ def test_sliding_window_remove_skipped_blocks(): head_size=1, dtype=torch.float32, sliding_window=4, - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=2000, enable_caching=True) @@ -287,7 +283,6 @@ def test_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, sliding_window=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) @@ -310,7 +305,6 @@ def test_chunked_local_attention_get_num_blocks_to_allocate(): head_size=1, dtype=torch.float32, attention_chunk_size=4, # Placeholder value, not related to test result - use_mla=False, ) block_pool = BlockPool(num_gpu_blocks=100, enable_caching=True) diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index 992c4e01386e..10adac9bab5f 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -836,8 +836,7 @@ def create_mock_executor(vllm_config): mock_spec = FullAttentionSpec(block_size=16, num_kv_heads=1, head_size=64, - dtype=torch.float16, - use_mla=False) + dtype=torch.float16) mock_executor.get_kv_cache_specs.return_value = [{ "default": mock_spec diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 8b571f95c5ec..49a7a61e1889 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -39,7 +39,6 @@ def initialize_kv_cache(runner: GPUModelRunner): runner.parallel_config), head_size=runner.model_config.get_head_size(), dtype=runner.kv_cache_dtype, - use_mla=False, ) tensor_size = attn_spec.page_size_bytes * NUM_BLOCKS kv_cache_config = KVCacheConfig( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 712295aa9288..3c06cce130f7 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1671,6 +1671,15 @@ def cp_gather_cache(src_cache: torch.Tensor, cu_seq_lens, batch_size, seq_starts) +def indexer_k_quant_and_cache(k: torch.Tensor, + kv_cache: torch.Tensor, + slot_mapping: torch.Tensor, + quant_block_size: int, + kv_cache_dtype: str) -> None: + torch.ops._C_cache_ops.indexer_k_quant_and_cache(k, kv_cache, slot_mapping, + quant_block_size, kv_cache_dtype) + + def get_device_attribute(attribute: int, device: int) -> int: return torch.ops._C_cuda_utils.get_device_attribute(attribute, device) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index dfde67e1713c..754545e6f2d6 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -76,6 +76,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: raise NotImplementedError diff --git a/vllm/attention/backends/differential_flash_attn.py b/vllm/attention/backends/differential_flash_attn.py index a7d0e3afb517..7dce44489a21 100644 --- a/vllm/attention/backends/differential_flash_attn.py +++ b/vllm/attention/backends/differential_flash_attn.py @@ -53,6 +53,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 78c768f92d3c..25f05dac28c2 100755 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -71,6 +71,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/attention/backends/mla/common.py b/vllm/attention/backends/mla/common.py index 789393eb39a7..3feaee438523 100644 --- a/vllm/attention/backends/mla/common.py +++ b/vllm/attention/backends/mla/common.py @@ -263,6 +263,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (num_blocks, block_size, head_size) diff --git a/vllm/attention/backends/placeholder_attn.py b/vllm/attention/backends/placeholder_attn.py index e630a6c6de8c..aaa12da3c67b 100644 --- a/vllm/attention/backends/placeholder_attn.py +++ b/vllm/attention/backends/placeholder_attn.py @@ -51,6 +51,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (1, 1, 1, 1, 1) diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 9262144e37b5..5dc7790bacf9 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -82,6 +82,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: paged_attn = _get_paged_attn_module() return paged_attn.get_kv_cache_shape(num_blocks, block_size, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 302d3d7ea903..495225127fe2 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -53,6 +53,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return PagedAttention.get_kv_cache_shape(num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 544a72052442..2a89db21fb47 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -92,6 +92,7 @@ def __init__( logits_soft_cap: Optional[float] = None, per_layer_sliding_window: Optional[int] = None, use_mla: bool = False, + use_sparse: bool = False, prefix: str = "", attn_type: str = AttentionType.DECODER, kv_sharing_target_layer_name: Optional[str] = None, @@ -154,6 +155,7 @@ def __init__( self._o_scale_float: Optional[float] = None self.use_mla = use_mla + self.use_sparse = use_sparse self.num_heads = num_heads self.head_size = head_size self.num_kv_heads = num_kv_heads @@ -187,7 +189,8 @@ def __init__( block_size, is_attention_free, use_mla=use_mla, - has_sink=self.has_sink) + has_sink=self.has_sink, + use_sparse=use_sparse) else: self.attn_backend = attn_backend diff --git a/vllm/attention/ops/common.py b/vllm/attention/ops/common.py index 189b57e8e8b8..6f4a695637d3 100644 --- a/vllm/attention/ops/common.py +++ b/vllm/attention/ops/common.py @@ -137,3 +137,183 @@ def cp_lse_ag_out_rs(cp_attn_out: torch.Tensor, assert out.is_contiguous() out = cp_group.reduce_scatter(out, dim=1) return out + +@triton.jit +def _pack_seq_kernel( + x_ptr, # [N, D] + out_ptr, # [B, Lmax, D] + lengths_ptr, # *i32, [B] + N: tl.constexpr, D: tl.constexpr, Lmax: tl.constexpr, + PAD_VALUE: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # Compute start index and sequence length from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) + seq_len = tl.load(lengths_ptr + pid_b) + + # valid time positions for this block + t_mask = off_t < Lmax + + # compute input row indices for valid (b, t) + in_row = in_start + off_t + valid_row = (off_t < seq_len) & t_mask + + # Pointers + # x_ptr: row-major [N, D] + x_row_ptr = x_ptr + in_row[:, None] * D + off_d[None, :] + + # out_ptr: row-major [B, Lmax, D] + out_row_ptr = out_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # Initialize with PAD (cast will occur as needed based on out_ptr dtype) + d_mask = off_d[None, :] < D + pad_vals = tl.full([BLOCK_T, BLOCK_D], PAD_VALUE, tl.float32) + tl.store(out_row_ptr, pad_vals, mask=t_mask[:, None] & d_mask) + + # Load & write only where within seq_len + x_vals = tl.load(x_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, x_vals, mask=valid_row[:, None] & d_mask) + +def pack_seq_triton(x, lengths, pad_value=-float('inf'), block_t=64, block_d=64): + """ + Pack sequences of different lengths into a batched tensor. + + Args: + x: [N, ...] - input tensor where N is total number of tokens + lengths: [B] - sequence lengths for each batch + pad_value: value to use for padding + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + packed: [B, Lmax, ...] - packed tensor + """ + + # Handle multi-dimensional input by reshaping to (N, -1) + original_shape = x.shape + if len(original_shape) > 2: + N = original_shape[0] + x_reshaped = x.reshape(N, -1) + D = x_reshaped.shape[1] + else: + N, D = x.shape + x_reshaped = x + + B = lengths.numel() + Lmax = int(lengths.max().item()) + + # Starts are computed inside the kernel from lengths + + out = torch.empty((B, Lmax, D), device=x.device, dtype=x.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _pack_seq_kernel[grid]( + x_reshaped, out, lengths.int(), + N, D, Lmax, + PAD_VALUE=float(pad_value), + BLOCK_T=block_t, BLOCK_D=block_d, + num_warps=4, num_stages=2 + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 2: + output_shape = (B, Lmax) + original_shape[1:] + out = out.reshape(output_shape) + + return out + + +@triton.jit +def _unpack_seq_triton_kernel( + packed_ptr, # [B, Lmax, D] + out_ptr, # [N, D] + lengths_ptr, # *i32, [B] + B: tl.constexpr, Lmax: tl.constexpr, D: tl.constexpr, + BLOCK_T: tl.constexpr, # timesteps per program + BLOCK_D: tl.constexpr # features per program +): + pid_b = tl.program_id(0) # batch id + pid_t = tl.program_id(1) # block over time dimension + pid_d = tl.program_id(2) # block over feature dimension + off_t = pid_t * BLOCK_T + tl.arange(0, BLOCK_T) # [BLOCK_T] + off_d = pid_d * BLOCK_D + tl.arange(0, BLOCK_D) # [BLOCK_D] + + # bounds: compute start from cumulative lengths + in_start = 0 + for i in range(pid_b): + in_start += tl.load(lengths_ptr + i) + seq_len = tl.load(lengths_ptr + pid_b) + + # valid time positions for this block + t_mask = off_t < Lmax + valid_row = (off_t < seq_len) & t_mask + + # compute output row indices for valid (b, t) + out_row = in_start + off_t + + # Pointers + # packed_ptr: row-major [B, Lmax, D] + packed_row_ptr = packed_ptr + (pid_b * Lmax + off_t)[:, None] * D + off_d[None, :] + + # out_ptr: row-major [N, D] + out_row_ptr = out_ptr + out_row[:, None] * D + off_d[None, :] + + # Load from packed tensor and store to output + d_mask = off_d[None, :] < D + packed_vals = tl.load(packed_row_ptr, mask=valid_row[:, None] & d_mask) + tl.store(out_row_ptr, packed_vals, mask=valid_row[:, None] & d_mask) + + +def unpack_seq_triton(packed_tensor, lengths, block_t=64, block_d=64): + """ + Unpack a packed decode query tensor back to the original format. + Efficient Triton implementation. + + Args: + packed_tensor: [B, Lmax, ...] - packed tensor from pack_seq_triton + lengths: [B] - sequence lengths for each batch + block_t: block size for time dimension + block_d: block size for feature dimension + + Returns: + unpacked_tensor: [N, ...] where N = sum(lengths) + """ + + # Handle multi-dimensional input by reshaping to (B, Lmax, -1) + original_shape = packed_tensor.shape + if len(original_shape) > 3: + B, Lmax = original_shape[:2] + packed_reshaped = packed_tensor.reshape(B, Lmax, -1) + D = packed_reshaped.shape[2] + else: + B, Lmax, D = packed_tensor.shape + packed_reshaped = packed_tensor + + # Calculate total number of elements + N = int(lengths.sum().item()) + + out = torch.empty((N, D), device=packed_tensor.device, dtype=packed_tensor.dtype) + + grid = (B, triton.cdiv(Lmax, block_t), triton.cdiv(D, block_d)) + _unpack_seq_triton_kernel[grid]( + packed_reshaped, out, lengths.int(), + B, Lmax, D, + BLOCK_T=block_t, BLOCK_D=block_d, + num_warps=4, num_stages=2 + ) + + # Reshape output back to original dimensions (except first dimension) + if len(original_shape) > 3: + output_shape = (N,) + original_shape[2:] + out = out.reshape(output_shape) + + return out \ No newline at end of file diff --git a/vllm/attention/ops/flashmla.py b/vllm/attention/ops/flashmla.py index 2c3e8c42400c..9c9eee24ebeb 100644 --- a/vllm/attention/ops/flashmla.py +++ b/vllm/attention/ops/flashmla.py @@ -19,6 +19,15 @@ else: _flashmla_C_AVAILABLE = False +if current_platform.is_cuda(): + try: + import vllm._flashmla_extension_C # noqa: F401 + _flashmla_extension_C_AVAILABLE = True + except ImportError: + _flashmla_extension_C_AVAILABLE = False +else: + _flashmla_extension_C_AVAILABLE = False + def is_flashmla_supported() -> Tuple[bool, Optional[str]]: """ @@ -37,24 +46,28 @@ def is_flashmla_supported() -> Tuple[bool, Optional[str]]: def get_mla_metadata( - cache_seqlens: torch.Tensor, - num_heads_per_head_k: int, - num_heads_k: int, -) -> Tuple[torch.Tensor, torch.Tensor]: + cache_seqlens: torch.Tensor, + num_q_tokens_per_head_k: int, + num_heads_k: int, + num_heads_q: Optional[int] = None, + is_fp8_kvcache: bool = False, + topk: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: cache_seqlens: (batch_size), dtype torch.int32. - num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k. - num_heads_k: num_heads_k. + num_q_tokens_per_head_k: Equals to num_q_tokens_per_q_seq * num_heads_q // num_heads_k. + num_heads_k: The number of k heads. + num_heads_q: The number of q heads. This argument is optional when sparse attention is not enabled + is_fp8_kvcache: Whether the k_cache and v_cache are in fp8 format. + topk: If not None, sparse attention will be enabled, and only tokens in the `indices` array passed to `flash_mla_with_kvcache_sm90` will be attended to. - Return: - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - dtype torch.int32. + Returns: + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32. num_splits: (batch_size + 1), dtype torch.int32. """ - return torch.ops._flashmla_C.get_mla_metadata(cache_seqlens, - num_heads_per_head_k, - num_heads_k) + return torch.ops._flashmla_C.get_mla_decoding_metadata( + cache_seqlens, num_q_tokens_per_head_k, num_heads_k, num_heads_q, + is_fp8_kvcache, topk) def flash_mla_with_kvcache( @@ -69,6 +82,8 @@ def flash_mla_with_kvcache( causal: bool = False, descale_q: Optional[torch.Tensor] = None, descale_k: Optional[torch.Tensor] = None, + is_fp8_kvcache: bool = False, + indices: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Arguments: @@ -76,38 +91,67 @@ def flash_mla_with_kvcache( k_cache: (num_blocks, page_block_size, num_heads_k, head_dim). block_table: (batch_size, max_num_blocks_per_seq), torch.int32. cache_seqlens: (batch_size), torch.int32. - head_dim_v: Head_dim of v. - tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), - torch.int32, return by get_mla_metadata. - num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(head_dim). + head_dim_v: Head dimension of v. + tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, returned by get_mla_metadata. + num_splits: (batch_size + 1), torch.int32, returned by get_mla_metadata. + softmax_scale: float. The scale of QK^T before applying softmax. Default to 1 / sqrt(head_dim). causal: bool. Whether to apply causal attention mask. - descale_q: (batch_size), torch.float32. Descaling factors for Q. - descale_k: (batch_size), torch.float32. Descaling factors for K. + descale_q: (batch_size), torch.float32. Descaling factors for Q, used for fp8 quantization. + descale_k: (batch_size), torch.float32. Descaling factors for K, used for fp8 quantization. + is_fp8_kvcache: bool. Whether the k_cache and v_cache are in fp8 format. For the format of FP8 KV cache, please refer to README.md + indices: (batch_size, seq_len_q, topk), torch.int32. If not None, sparse attention will be enabled, and only tokens in the `indices` array will be attended to. Invalid indices should be set to -1 or numbers >= total_seq_len_kv. For details about how to set up `indices`, please refer to README.md. - Return: + Returns: out: (batch_size, seq_len_q, num_heads_q, head_dim_v). softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32. """ if softmax_scale is None: softmax_scale = q.shape[-1]**(-0.5) - out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( - q, - k_cache, - head_dim_v, - cache_seqlens, - block_table, - softmax_scale, - causal, - tile_scheduler_metadata, - num_splits, - descale_q, - descale_k, - ) - - # Note(hc): need revisit when we support DCP with decode query_len > 1. - return out.squeeze(1), softmax_lse.squeeze(-1) + if indices is not None: + assert causal == False, "causal must be `false` if sparse attention is enabled." + assert (descale_q is None) == ( + descale_k is None + ), "descale_q and descale_k should be both None or both not None" + + if (descale_q is not None) and (descale_k is not None): + out, softmax_lse = torch.ops._flashmla_extension_C.fwd_kvcache_mla_fp8( + q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, + causal, tile_scheduler_metadata, num_splits, descale_q, descale_k) + else: + out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( + q, k_cache, head_dim_v, cache_seqlens, block_table, softmax_scale, + causal, tile_scheduler_metadata, num_splits, is_fp8_kvcache, + indices) + return out, softmax_lse + + +def flash_mla_sparse_prefill( + q: torch.Tensor, + kv: torch.Tensor, + indices: torch.Tensor, + sm_scale: float, + d_v: int = 512, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Sparse attention prefill kernel + + Args: + q: [s_q, h_q, d_qk], bfloat16 + kv: [s_kv, h_kv, d_qk], bfloat16 + indices: [s_q, h_kv, topk], int32. Invalid indices should be set to -1 or numbers >= s_kv + sm_scale: float + d_v: The dimension of value vectors. Can only be 512 + + Returns: + (output, max_logits, lse) + About the definition of output, max_logits and lse, please refer to README.md + - output: [s_q, h_q, d_v], bfloat16 + - max_logits: [s_q, h_q], float + - lse: [s_q, h_q], float, 2-based log-sum-exp + """ + results = torch.ops._flashmla_C.sparse_prefill_fwd(q, kv, indices, + sm_scale, d_v) + return results # diff --git a/vllm/attention/ops/paged_attn.py b/vllm/attention/ops/paged_attn.py index 4d870a45e580..539b57e41de7 100644 --- a/vllm/attention/ops/paged_attn.py +++ b/vllm/attention/ops/paged_attn.py @@ -50,6 +50,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) diff --git a/vllm/attention/selector.py b/vllm/attention/selector.py index 3a235ba6e0b4..e53674494a12 100644 --- a/vllm/attention/selector.py +++ b/vllm/attention/selector.py @@ -145,6 +145,7 @@ def get_attn_backend( is_attention_free: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: """Selects which attention backend to use and lazily imports it.""" # Accessing envs.* behind an @lru_cache decorator can cause the wrong @@ -160,6 +161,7 @@ def get_attn_backend( use_v1=envs.VLLM_USE_V1, use_mla=use_mla, has_sink=has_sink, + use_sparse=use_sparse, ) @@ -173,6 +175,7 @@ def _cached_get_attn_backend( use_v1: bool = False, use_mla: bool = False, has_sink: bool = False, + use_sparse: bool = False, ) -> type[AttentionBackend]: # If there are no attention layers (e.g. we are running Mamba), # use the placeholder NO_ATTENTION @@ -204,7 +207,7 @@ def _cached_get_attn_backend( # get device-specific attn_backend attention_cls = current_platform.get_attn_backend_cls( selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, - use_mla, has_sink) + use_mla, has_sink, use_sparse) if not attention_cls: raise ValueError( f"Invalid attention backend for {current_platform.device_name}") diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 4c4e39c37ee5..bf13a18e0e0c 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -22,7 +22,8 @@ logger = init_logger(__name__) BlockSize = Literal[1, 8, 16, 32, 64, 128] -CacheDType = Literal["auto", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] +CacheDType = Literal[ + "auto", "bfloat16", "fp8", "fp8_e4m3", "fp8_e5m2", "fp8_inc"] MambaDType = Literal["auto", "float32"] PrefixCachingHashAlgo = Literal["sha256", "sha256_cbor"] @@ -52,7 +53,11 @@ class CacheConfig: cache_dtype: CacheDType = "auto" """Data type for kv cache storage. If "auto", will use model data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. ROCm (AMD GPU) supports - fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc).""" + fp8 (=fp8_e4m3). Intel Gaudi (HPU) supports fp8 (using fp8_inc). + Some models (namely DeepSeekV3.2) default to fp8, set to bfloat16 to use + bfloat16 instead, this is an invalid option for models that do not default + to fp8. + """ is_attention_free: bool = False """Whether the model is attention-free. This is primarily set in `ModelConfig` and that value should be manually duplicated here.""" @@ -171,11 +176,12 @@ def _verify_cache_dtype(self) -> None: if self.cache_dtype == "auto": pass elif self.cache_dtype in get_args(CacheDType): - logger.info( - "Using fp8 data type to store kv cache. It reduces the GPU " - "memory footprint and boosts the performance. " - "Meanwhile, it may cause accuracy drop without a proper " - "scaling factor.") + if self.cache_dtype.startswith("fp8"): + logger.info( + "Using fp8 data type to store kv cache. It reduces the GPU " + "memory footprint and boosts the performance. " + "Meanwhile, it may cause accuracy drop without a proper " + "scaling factor.") else: raise ValueError(f"Unknown kv cache dtype: {self.cache_dtype}") diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 22b38daf46c3..4dad3b668285 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -362,6 +362,7 @@ class CompilationConfig: "vllm.linear_attention", "vllm.plamo2_mamba_mixer", "vllm.gdn_attention", + "vllm.sparse_attn_indexer", ] def compute_hash(self) -> str: diff --git a/vllm/config/model.py b/vllm/config/model.py index 921322bb475c..baf577b49ef5 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -1198,13 +1198,13 @@ def is_deepseek_mla(self) -> bool: if not hasattr(self.hf_text_config, "model_type"): return False elif self.hf_text_config.model_type in \ - ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2'): + ('deepseek_v2', 'deepseek_v3', 'deepseek_mtp', 'kimi_k2', 'deepseek_v32'): return self.hf_text_config.kv_lora_rank is not None elif self.hf_text_config.model_type == 'eagle': # if the model is an EAGLE module, check for the # underlying architecture return self.hf_text_config.model.model_type in \ - ('deepseek_v2', 'deepseek_v3') \ + ('deepseek_v2', 'deepseek_v3', 'deepseek_v32') \ and self.hf_text_config.kv_lora_rank is not None return False diff --git a/vllm/config/speculative.py b/vllm/config/speculative.py index 2c861723c396..c3dfbaba6426 100644 --- a/vllm/config/speculative.py +++ b/vllm/config/speculative.py @@ -142,7 +142,7 @@ def compute_hash(self) -> str: @staticmethod def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: - if hf_config.model_type == "deepseek_v3": + if hf_config.model_type in ("deepseek_v3", "deepseek_v32"): hf_config.model_type = "deepseek_mtp" if hf_config.model_type == "deepseek_mtp": n_predict = getattr(hf_config, "num_nextn_predict_layers", None) @@ -204,9 +204,8 @@ def __post_init__(self): # mtp acceleration for more models besides deepseek_v3 if self.target_model_config and \ (self.target_model_config.hf_text_config.model_type \ - == "deepseek_v3" or - self.target_model_config.hf_text_config.model_type in - ("mimo","ernie4_5_moe", "qwen3_next")): + in ("deepseek_v3", "deepseek_v32", + "mimo","ernie4_5_moe", "qwen3_next")): # use the draft model from the same model: self.model = self.target_model_config.model # Align the quantization of draft model for cases such as diff --git a/vllm/model_executor/layers/layernorm.py b/vllm/model_executor/layers/layernorm.py index f875f712ba9c..a44ca5c8939e 100644 --- a/vllm/model_executor/layers/layernorm.py +++ b/vllm/model_executor/layers/layernorm.py @@ -5,6 +5,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp @@ -379,3 +380,20 @@ def forward_cuda( x: torch.Tensor, ) -> torch.Tensor: return poly_norm(x, self.weight, self.bias, self.variance_epsilon) + + +class LayerNorm(nn.Module): + """ + Layer Normalization. + """ + + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.dim = dim + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim, dtype=torch.float32)) + self.bias = nn.Parameter(torch.zeros(dim, dtype=torch.float32)) + + def forward(self, x: torch.Tensor): + return F.layer_norm(x.float(), (self.dim, ), self.weight, self.bias, + self.eps).type_as(x) diff --git a/vllm/model_executor/layers/mla.py b/vllm/model_executor/layers/mla.py index a05716190365..d23232ab09b5 100644 --- a/vllm/model_executor/layers/mla.py +++ b/vllm/model_executor/layers/mla.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os from dataclasses import dataclass from typing import Optional @@ -24,6 +25,9 @@ class MLAModules: q_a_layernorm: Optional[torch.nn.Module] q_b_proj: Optional[torch.nn.Module] q_proj: Optional[torch.nn.Module] + indexer: Optional[torch.nn.Module] + is_sparse: bool + topk_indices_buffer: Optional[torch.Tensor] @CustomOp.register("multi_head_latent_attention") @@ -76,6 +80,15 @@ def __init__( self.kv_b_proj = mla_modules.kv_b_proj self.rotary_emb = mla_modules.rotary_emb self.o_proj = mla_modules.o_proj + self.indexer = mla_modules.indexer + self.use_sparse = mla_modules.is_sparse and os.getenv( + "VLLM_MLA_SPARSE_DISABLED") != "1" + + if self.indexer is not None: + assert hasattr(self.indexer, "topk_tokens") + self.topk_tokens = self.indexer.topk_tokens \ + if self.indexer else None + self.topk_indices_buffer = mla_modules.topk_indices_buffer # In the MLA backend, kv_cache includes both k_c and # pe (i.e. decoupled position embeddings). In particular, @@ -92,6 +105,7 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.attn", use_mla=True, + use_sparse=mla_modules.is_sparse, # MLA Args q_lora_rank=self.q_lora_rank, kv_lora_rank=self.kv_lora_rank, @@ -100,6 +114,7 @@ def __init__( qk_head_dim=self.qk_head_dim, v_head_dim=self.v_head_dim, kv_b_proj=self.kv_b_proj, + indexer=self.indexer, ) self.prefix = prefix @@ -146,6 +161,10 @@ def forward_native( q[..., self.qk_nope_head_dim:], k_pe = self.rotary_emb( positions, q[..., self.qk_nope_head_dim:], k_pe) + if self.indexer and self.use_sparse: + _topk_indices = self.indexer(hidden_states, q_c, positions, + self.rotary_emb) + attn_out = self.mla_attn( q, kv_c_normed, diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index ce3d23763ed6..494403492c2a 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -414,6 +414,32 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "exactly equal.", mamba_padding_pct) +class DeepseekV3ForCausalLM(VerifyAndUpdateConfig): + + @classmethod + def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: + """ + Updated fp8 cache to custom "fp8_ds_mla" format for DeepSeekV32 + """ + hf_config = vllm_config.model_config.hf_config + + is_v32 = hasattr( + hf_config, "index_topk" + ) + + if is_v32: + # For DeepSeekV3.2, we use a custom fp8 format as default (i.e. + # "auto") + cache_config = vllm_config.cache_config + if cache_config.cache_dtype == "auto" or \ + cache_config.cache_dtype.startswith("fp8"): + cache_config.cache_dtype = "fp8_ds_mla" + logger.info("Using custom fp8 kv-cache format for DeepSeekV3.2") + if cache_config.cache_dtype == "bfloat16": + cache_config.cache_dtype = "auto" + logger.info("Using bfloat16 kv-cache for DeepSeekV3.2") + + MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = { "GteModel": SnowflakeGteNewModelConfig, "GteNewModel": GteNewModelConfig, @@ -431,4 +457,5 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: "MambaForCausalLM": MambaModelConfig, "Mamba2ForCausalLM": MambaModelConfig, "FalconMambaForCausalLM": MambaModelConfig, + "DeepseekV3ForCausalLM": DeepseekV3ForCausalLM, } diff --git a/vllm/model_executor/models/deepseek_mtp.py b/vllm/model_executor/models/deepseek_mtp.py index 8fbf16d206a8..04f15ec53f77 100644 --- a/vllm/model_executor/models/deepseek_mtp.py +++ b/vllm/model_executor/models/deepseek_mtp.py @@ -54,8 +54,20 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: self.eh_proj = nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False) + + self.is_v32 = hasattr( + config, "index_topk" + ) + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None self.shared_head = SharedHead(config=config, quant_config=quant_config) - self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix) + self.mtp_block = DeepseekV2DecoderLayer(vllm_config, prefix, topk_indices_buffer) def forward( self, diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 636554bd648f..c92ab16c9a8c 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -31,18 +31,24 @@ import torch from torch import nn from transformers import DeepseekV2Config, DeepseekV3Config +import torch.distributed as dist +from vllm.attention.backends.abstract import AttentionBackend +from vllm.logger import init_logger +from vllm.config.compilation import CompilationConfig import vllm.envs as envs from vllm.attention import Attention from vllm.compilation.decorators import support_torch_compile -from vllm.config import CacheConfig, ParallelConfig, VllmConfig +from vllm.config import CacheConfig, ParallelConfig, VllmConfig, get_current_vllm_config from vllm.distributed import (get_ep_group, get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather) +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase from vllm.model_executor.layers.fused_moe import FusedMoE -from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, ReplicatedLinear, @@ -60,11 +66,28 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.utils import cdiv, direct_register_custom_op +from vllm.v1.attention.backends.mla.indexer import DeepseekV32IndexerBackend, DeepseekV32IndexerMetadata +from vllm.attention.ops.common import pack_seq_triton, unpack_seq_triton from .interfaces import MixtureOfExperts, SupportsLoRA, SupportsPP -from .utils import (PPMissingLayer, is_pp_missing_parameter, +from .utils import (PPMissingLayer, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) +from vllm.v1.kv_cache_interface import MLAAttentionSpec, KVCacheSpec +from vllm.utils.deep_gemm import ( + fp8_mqa_logits, + get_paged_mqa_logits_metadata, + fp8_paged_mqa_logits, +) +from vllm.model_executor.layers.quantization.utils.fp8_utils import per_token_group_quant_fp8 + +if current_platform.is_cuda_alike(): + from vllm import _custom_ops as ops +elif current_platform.is_xpu(): + from vllm._ipex_ops import ipex_ops as ops + +logger = init_logger(__name__) class DeepseekV2MLP(nn.Module): @@ -332,6 +355,7 @@ class DeepseekV2Attention(nn.Module): def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -474,6 +498,363 @@ def forward( return output +class DeepseekV32IndexerCache(torch.nn.Module, AttentionLayerBase): + + def __init__(self, head_dim: int, dtype: torch.dtype, prefix: str, + cache_config: CacheConfig): + super().__init__() + self.kv_cache = [torch.tensor([])] + self.head_dim = head_dim + self.prefix = prefix + self.cache_config = cache_config + self.dtype = dtype + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + + def get_kv_cache_spec(self) -> KVCacheSpec: + return MLAAttentionSpec( # Only has one vector instead of K + V + block_size=self.cache_config.block_size, + num_kv_heads=1, + head_size=self.head_dim, + dtype=self.dtype, + ) + + def forward(self): + attn_metadata = get_forward_context().attn_metadata + if isinstance(attn_metadata, dict): + attn_metadata = attn_metadata[self.prefix] + logger.info(f"attn_metadata {attn_metadata}") + + def get_attn_backend(self) -> AttentionBackend: + return DeepseekV32IndexerBackend + +@torch.inference_mode() +def cp_gather_indexer_k_quant_cache( + kv_cache, # [num_blocks, block_size, head_dim + 1] + dst_value, # [cu_seq_lens[-1], head_dim] + dst_scale, # [cu_seq_lens[-1], 4] + block_table, # [batch_size, num_blocks] + cu_seq_lens, # [batch_size + 1, ] + batch_size, +): + num_blocks, block_size, _ = kv_cache.shape + head_dim = dst_value.shape[-1] + kv_cache = kv_cache.view(num_blocks, -1) + + expected_value = [] + expected_scale = [] + for b in range(batch_size): + s = cu_seq_lens[b + 1] - cu_seq_lens[b] + if s == 0: + continue + tot = cdiv(s, block_size) + blocks = block_table[b, :tot] + + value = [] + scale = [] + full_block = torch.arange(tot - 1, device=kv_cache.device, dtype=torch.int32) + # print(f"full_blocks: {blocks[full_block]}") + non_remaining_value = kv_cache[blocks[full_block], : block_size * head_dim].view(-1, head_dim) + non_remaining_scale = kv_cache[blocks[full_block], block_size * head_dim:].view(-1, 4) + + # for i in range(tot - 1): + # value.append(kv_cache[blocks[i], :block_size * head_dim]) + # scale.append(kv_cache[blocks[i], block_size * head_dim:]) + + remaining = s - (tot - 1) * block_size + # value.append(kv_cache[blocks[-1], :remaining * head_dim]) + # scale.append(kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4]) + + value = torch.cat([non_remaining_value, kv_cache[blocks[-1], :remaining * head_dim].view(-1, head_dim)], dim=0) + scale = torch.cat([non_remaining_scale, kv_cache[blocks[-1], block_size * head_dim: block_size * head_dim + remaining * 4].view(-1, 4)], dim=0) + + expected_value.append(value) + expected_scale.append(scale) + + gather_value = torch.cat(expected_value, dim=0).view(-1, head_dim) + gather_scale = torch.cat(expected_scale, dim=0).view(-1, 4) + gather_value = gather_value.view(torch.float8_e4m3fn) + gather_scale = gather_scale.view(torch.float32) + dst_value.copy_(gather_value) + dst_scale.copy_(gather_scale) + + +def sparse_attn_indexer( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + + # careful! this will be None in dummy run + attn_metadata = get_forward_context().attn_metadata + # assert isinstance(attn_metadata, dict) + if not isinstance(attn_metadata, dict): + return sparse_attn_indexer_fake( + hidden_states, + k_cache_prefix, + kv_cache, + q_fp8, + k, + weights, + quant_block_size, + scale_fmt, + topk_tokens, + head_dim, + max_model_len, + total_seq_lens, + topk_indices_buffer, + ) + attn_metadata = attn_metadata[k_cache_prefix] + assert isinstance(attn_metadata, DeepseekV32IndexerMetadata) + slot_mapping = attn_metadata.slot_mapping + has_decode = attn_metadata.num_decodes > 0 + has_prefill = attn_metadata.num_prefills > 0 + num_decode_tokens = attn_metadata.num_decode_tokens + + ops.indexer_k_quant_and_cache( + k, + kv_cache, + slot_mapping, + quant_block_size, + scale_fmt, + ) + + topk_indices_buffer[:hidden_states.shape[0]] = -1 + if has_prefill: + prefill_metadata = attn_metadata.prefill + num_prefills = attn_metadata.num_prefills + k_fp8 = torch.empty( + [prefill_metadata.total_seq_lens, head_dim], + device=k.device, + dtype=torch.float8_e4m3fn) + k_scale = torch.empty( + [prefill_metadata.total_seq_lens, 1], + device=k.device, + dtype=torch.float32) + cp_gather_indexer_k_quant_cache( + kv_cache, + k_fp8, + k_scale, + prefill_metadata.block_table, + prefill_metadata.cu_seq_lens, + num_prefills, + ) + cu_seqlen_ks = prefill_metadata.cu_seqlen_ks + cu_seqlen_ke = prefill_metadata.cu_seqlen_ke + num_tokens = attn_metadata.num_actual_tokens + logits = fp8_mqa_logits( + q_fp8[num_decode_tokens:num_tokens], + (k_fp8, k_scale), + weights[num_decode_tokens:num_tokens], + cu_seqlen_ks, + cu_seqlen_ke, + ) + topk_indices = logits.topk(min(topk_tokens, logits.shape[-1]), + dim=-1)[1] + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 + mask = torch.full_like(topk_indices, + False, + dtype=torch.bool, + device=topk_indices.device) + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + topk_indices_buffer[num_decode_tokens:num_tokens, :topk_indices. + shape[-1]] = topk_indices.to(dtype=torch.int32) + + if has_decode: + decode_metadata = attn_metadata.decode + # kv_cache size requirement [num_block, block_size, n_head, head_dim], + # we only have [num_block, block_size, head_dim], + kv_cache = kv_cache.unsqueeze(-2) + decode_lens = decode_metadata.decode_lens + if decode_metadata.requires_padding: + # pad in edge case where we have short chunked prefill length < + # decode_threshold since we unstrictly split + # prefill and decode by decode_threshold (currently set to 1 + speculative tokens) + padded_q_fp8_decode_tokens = pack_seq_triton(q_fp8[:num_decode_tokens], decode_lens) + else: + padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(decode_lens.shape[0], -1, *q_fp8.shape[1:]) + # TODO: move and optimize below logic with triton kernels + batch_size = padded_q_fp8_decode_tokens.shape[0] + next_n = padded_q_fp8_decode_tokens.shape[1] + assert batch_size == decode_metadata.seq_lens.shape[0] + num_padded_tokens = batch_size * next_n + logits = fp8_paged_mqa_logits( + padded_q_fp8_decode_tokens, + kv_cache, + weights[:num_padded_tokens], + decode_metadata.seq_lens, + decode_metadata.block_table, + decode_metadata.schedule_metadata, + max_model_len=max_model_len, + ) + # padded query len + current_device = padded_q_fp8_decode_tokens.device + padded_num_tokens = batch_size * next_n + positions = torch.arange(max_model_len, device=current_device).unsqueeze(0).expand( + batch_size * next_n, -1) + row_indices = torch.arange(padded_num_tokens, device=current_device) // next_n + next_n_offset = torch.arange(padded_num_tokens, device=padded_q_fp8_decode_tokens.device) % next_n + index_end_pos = (decode_metadata.seq_lens[row_indices] - next_n + next_n_offset).unsqueeze(1) + # index_end_pos: [B * N, 1] + mask = positions <= index_end_pos + # mask: [B * N, L] + logits = logits.masked_fill(~mask, float('-inf')) + topk_indices = logits.topk(topk_tokens, dim=-1)[1].to( + torch.int32) # [B * N, K] + # ensure we don't set indices for the top k that out of range(masked already) + # this will happen if context length is shorter than K + topk_indices[topk_indices > index_end_pos] = -1 + if decode_metadata.requires_padding: + # if padded, we need to unpack the topk indices removing padded tokens + topk_indices = unpack_seq_triton(topk_indices.reshape(batch_size, -1, topk_indices.shape[-1]), decode_lens) + topk_indices_buffer[:num_decode_tokens, :topk_indices. + shape[-1]] = topk_indices.to( + dtype=torch.int32) + + return topk_indices_buffer + + +def sparse_attn_indexer_fake( + hidden_states: torch.Tensor, + k_cache_prefix: str, + kv_cache: torch.Tensor, + q_fp8: torch.Tensor, + k: torch.Tensor, + weights: torch.Tensor, + quant_block_size: int, + scale_fmt: Optional[str], + topk_tokens: int, + head_dim: int, + max_model_len: int, + total_seq_lens: int, + topk_indices_buffer: Optional[torch.Tensor], +) -> torch.Tensor: + # profile run + # NOTE(Chen): create the max possible flattened_kv. So that + # profile_run can get correct memory usage. + _flattened_kv = torch.empty([total_seq_lens, head_dim + 4], + device=k.device, + dtype=torch.uint8) + _k_fp8 = _flattened_kv[..., :head_dim].view( + torch.float8_e4m3fn).contiguous() + _k_scale = _flattened_kv[..., head_dim:].view(torch.float32).contiguous() + return topk_indices_buffer + + +direct_register_custom_op( + op_name="sparse_attn_indexer", + op_func=sparse_attn_indexer, + mutates_args=["topk_indices_buffer"], + fake_impl=sparse_attn_indexer_fake, + dispatch_key=current_platform.dispatch_key, +) + + +class Indexer(nn.Module): + + def __init__(self, + vllm_config: VllmConfig, + config: Union[DeepseekV2Config, DeepseekV3Config], + hidden_size: int, + q_lora_rank: int, + quant_config: Optional[QuantizationConfig], + cache_config: Optional[CacheConfig], + topk_indices_buffer: Optional[torch.Tensor], + prefix: str = ""): + super().__init__() + self.vllm_config = vllm_config + self.config = config + # self.indexer_cfg = config.attn_module_list_cfg[0]["attn_index"] + self.topk_tokens = config.index_topk + self.n_head = config.index_n_heads # 64 + self.head_dim = config.index_head_dim # 128 + self.rope_dim = config.qk_rope_head_dim # 64 + self.q_lora_rank = q_lora_rank # 1536 + # no tensor parallel, just replicated + self.wq_b = ReplicatedLinear(self.q_lora_rank, + self.head_dim * self.n_head, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wq_b") + self.wk = ReplicatedLinear(hidden_size, + self.head_dim, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wk") + self.k_norm = LayerNorm(self.head_dim, eps=1e-6) + self.weights_proj = ReplicatedLinear(hidden_size, + self.n_head, + quant_config=None, + prefix=f"{prefix}.weights_proj") + self.softmax_scale = self.head_dim**-0.5 + + self.scale_fmt = "ue8m0" + self.quant_block_size = 128 # TODO: get from config + self.topk_indices_buffer = topk_indices_buffer + + #TODO (zyongye) change dim to fp8 later to (self.head_dim + 4) + self.k_cache = DeepseekV32IndexerCache(head_dim=self.head_dim + 4, + dtype=torch.uint8, + prefix=f"{prefix}.k_cache", + cache_config=cache_config) + self.max_model_len = vllm_config.model_config.max_model_len + self.prefix = prefix + from vllm.v1.attention.backends.mla.indexer import get_max_prefill_buffer_size + self.max_total_seq_len = get_max_prefill_buffer_size(vllm_config) + + def forward(self, hidden_states: torch.Tensor, qr: torch.Tensor, positions, + rotary_emb) -> torch.Tensor: + q, _ = self.wq_b(qr) + q = q.view(-1, self.n_head, self.head_dim) + q_pe, q_nope = torch.split( + q, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + + k, _ = self.wk(hidden_states) + k = self.k_norm(k) + k_pe, k_nope = torch.split( + k, [self.rope_dim, self.head_dim - self.rope_dim], dim=-1) + + #FIXME (zyongye) this will cause OOM when using full sequence forward on 8xH200 + q_pe, k_pe = rotary_emb(positions, q_pe, k_pe.unsqueeze(1)) + q = torch.cat([q_pe, q_nope], dim=-1) + k = torch.cat([k_pe.squeeze(1), k_nope], dim=-1) + + # we only quant q here since k quant is fused with cache insertion + q = q.view(-1, self.head_dim) + q_fp8, q_scale = per_token_group_quant_fp8(q, + self.quant_block_size, + column_major_scales=False, + use_ue8m0=self.scale_fmt is not None) + q_fp8 = q_fp8.view(-1, self.n_head, self.head_dim) + q_scale = q_scale.view(-1, self.n_head, 1) + + weights, _ = self.weights_proj(hidden_states) + weights = weights.unsqueeze( + -1) * q_scale * self.softmax_scale * self.n_head**-0.5 + weights = weights.squeeze(-1) + + return torch.ops.vllm.sparse_attn_indexer( + hidden_states, self.k_cache.prefix, self.k_cache.kv_cache[0], + q_fp8, k, weights, self.quant_block_size, self.scale_fmt, + self.topk_tokens, self.head_dim, self.max_model_len, + self.max_total_seq_len, self.topk_indices_buffer) + + class DeepseekV2MLAAttention(nn.Module): """ Main reference: DeepseekV2 paper, and FlashInfer Implementation @@ -484,6 +865,7 @@ class DeepseekV2MLAAttention(nn.Module): def __init__( self, + vllm_config: VllmConfig, config: Union[DeepseekV2Config, DeepseekV3Config], hidden_size: int, num_heads: int, @@ -498,6 +880,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + topk_indices_buffer: Optional[torch.Tensor] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -578,6 +961,17 @@ def __init__( mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) self.scaling = self.scaling * mscale * mscale + self.is_v32 = hasattr( + config, "index_topk" + ) + + if self.is_v32: + self.indexer = Indexer(vllm_config, config, hidden_size, + q_lora_rank, quant_config, cache_config, + topk_indices_buffer, f"{prefix}.indexer") + else: + self.indexer = None + mla_modules = MLAModules( kv_a_layernorm=self.kv_a_layernorm, kv_b_proj=self.kv_b_proj, @@ -591,7 +985,11 @@ def __init__( if self.q_lora_rank is not None else None, q_b_proj=self.q_b_proj if self.q_lora_rank is not None else None, q_proj=self.q_proj if self.q_lora_rank is None else None, + indexer=self.indexer, + is_sparse=self.is_v32, + topk_indices_buffer=topk_indices_buffer, ) + self.mla_attn = MultiHeadLatentAttention( self.hidden_size, self.num_local_heads, @@ -612,12 +1010,14 @@ def forward( positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: + # self.indexer(torch.tensor([]), torch.tensor([])) return self.mla_attn(positions, hidden_states) class DeepseekV2DecoderLayer(nn.Module): - def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: + def __init__(self, vllm_config: VllmConfig, prefix: str, + topk_indices_buffer: Optional[torch.Tensor]) -> None: super().__init__() config = vllm_config.model_config.hf_config @@ -640,6 +1040,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: else: attn_cls = DeepseekV2Attention self.self_attn = attn_cls( + vllm_config=vllm_config, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -655,6 +1056,7 @@ def __init__(self, vllm_config: VllmConfig, prefix: str) -> None: cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", + topk_indices_buffer=topk_indices_buffer, ) if (config.n_routed_experts is not None @@ -738,6 +1140,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.vocab_size = config.vocab_size + self.is_v32 = hasattr( + config, "index_topk" + ) + if self.is_v32: + topk_tokens = config.index_topk + topk_indices_buffer = torch.empty(vllm_config.scheduler_config.max_num_batched_tokens, + topk_tokens, + dtype=torch.int32, + device="cuda") + else: + topk_indices_buffer = None if get_pp_group().is_first_rank: self.embed_tokens = VocabParallelEmbedding( @@ -750,7 +1163,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix), + lambda prefix: DeepseekV2DecoderLayer(vllm_config, prefix, + topk_indices_buffer), prefix=f"{prefix}.layers") if get_pp_group().is_last_rank: diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 5dc5d545bb9c..371222af9e62 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -69,6 +69,7 @@ "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), "DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), + "DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"), "Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"), "Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"), "Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"), diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index 544e091491bf..78d2b0f37328 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -93,11 +93,14 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse: bool) -> str: if selected_backend and selected_backend != _Backend.TORCH_SDPA: logger.info("Cannot use %s backend on CPU.", selected_backend) if use_mla: raise NotImplementedError("MLA is not supported on CPU.") + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on CPU.") logger.info("Using Torch SDPA backend.") if not use_v1: raise ValueError("CPU backend only supports V1.") diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 87d8f2b7481b..1575b78d1a06 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -139,6 +139,7 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: # TODO(lucas): handle this more gracefully # Note: model_config may be None during testing if model_config is not None and model_config.use_mla: + use_sparse = os.getenv("VLLM_MLA_SPARSE_DISABLED") != "1" # If `VLLM_ATTENTION_BACKEND` is not set and we are using MLA, # then we default to FlashMLA backend for non-blackwell GPUs, # else we default to CutlassMLA. For each case, we force the @@ -185,6 +186,12 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: "Forcing kv cache block size to 64 for FlashInferMLA " "backend.") + # TODO(Chen): remove this hacky code + if use_sparse and cache_config.block_size != 64: + cache_config.block_size = 64 + logger.info( + "Forcing kv cache block size to 64 for FlashMLASparse " + "backend.") # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -234,14 +241,21 @@ def get_vit_attn_backend(cls, head_size: int, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + has_sink, use_sparse) -> str: if use_mla: + use_sparse = os.getenv( + "VLLM_MLA_SPARSE_DISABLED") != "1" and use_sparse # TODO(lucas): refactor to be more concise # we should probably consider factoring out V1 here from vllm.attention.ops.flashmla import is_flashmla_supported from vllm.attention.utils.fa_utils import flash_attn_supports_mla + if use_sparse: + logger.info_once("Using Sparse MLA backend on V1 engine.") + return ("vllm.v1.attention.backends.mla.flashmla_sparse." + "FlashMLASparseBackend") + use_cutlassmla = selected_backend == _Backend.CUTLASS_MLA or ( selected_backend is None and cls.is_device_capability(100) and block_size == 128) diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 53fc762dce54..0abc536d8a6f 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -200,7 +200,7 @@ def get_vit_attn_backend(cls, head_size: int, def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse: bool) -> str: """Get the attention backend class of a device.""" return "" diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 4f540fe965e2..ba97df02e8d2 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -189,7 +189,10 @@ def get_vit_attn_backend(cls, head_size: int, @classmethod def get_attn_backend_cls(cls, selected_backend, head_size, dtype, kv_cache_dtype, block_size, use_v1, use_mla, - has_sink) -> str: + has_sink, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on ROCm.") if use_mla: from vllm.v1.attention.backends.mla.rocm_aiter_mla import ( is_aiter_mla_enabled) diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 4e4db116abca..d846eebac136 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -49,7 +49,10 @@ class TpuPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink) -> str: + has_sink, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on TPU.") if (selected_backend != _Backend.PALLAS and selected_backend != _Backend.PALLAS_VLLM_V1): logger.info("Cannot use %s backend on TPU.", selected_backend) diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index 67ef058df10f..574576f3e9ed 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -36,7 +36,10 @@ class XPUPlatform(Platform): def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int, dtype: torch.dtype, kv_cache_dtype: Optional[str], block_size: int, use_v1: bool, use_mla: bool, - has_sink: bool) -> str: + has_sink: bool, use_sparse) -> str: + if use_sparse: + raise NotImplementedError( + "Sparse Attention is not supported on XPU.") use_v1 = envs.VLLM_USE_V1 if not use_v1: raise ValueError("XPU backend only supports V1.") diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index cafc43f6b767..6bbfaad5fa1a 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -64,6 +64,7 @@ def __getitem__(self, key): _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict( chatglm="ChatGLMConfig", deepseek_vl_v2="DeepseekVLV2Config", + deepseek_v32="DeepseekV3Config", kimi_vl="KimiVLConfig", Llama_Nemotron_Nano_VL="Nemotron_Nano_VL_Config", RefinedWeb="RWConfig", # For tiiuae/falcon-40b(-instruct) diff --git a/vllm/transformers_utils/configs/__init__.py b/vllm/transformers_utils/configs/__init__.py index 91bfeb8c55ee..efb249da2e87 100644 --- a/vllm/transformers_utils/configs/__init__.py +++ b/vllm/transformers_utils/configs/__init__.py @@ -32,10 +32,12 @@ Step3VisionEncoderConfig, Step3VLConfig) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from vllm.transformers_utils.configs.deepseek_v3 import DeepseekV3Config __all__ = [ "ChatGLMConfig", "DeepseekVLV2Config", + "DeepseekV3Config", "EAGLEConfig", "RWConfig", "JAISConfig", diff --git a/vllm/transformers_utils/configs/deepseek_v3.py b/vllm/transformers_utils/configs/deepseek_v3.py new file mode 100644 index 000000000000..235b7b0fd33c --- /dev/null +++ b/vllm/transformers_utils/configs/deepseek_v3.py @@ -0,0 +1,193 @@ +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + +DEEPSEEK_PRETRAINED_CONFIG_ARCHIVE_MAP = {} +class DeepseekV3Config(PretrainedConfig): + r""" + This is the configuration class to store the configuration of a [`DeepseekV3Model`]. It is used to instantiate an DeepSeek + model according to the specified arguments, defining the model architecture. Instantiating a configuration with the + defaults will yield a similar configuration to that of the DeepSeek-V3. + Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the + documentation from [`PretrainedConfig`] for more information. + Args: + vocab_size (`int`, *optional*, defaults to 129280): + Vocabulary size of the Deep model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`DeepseekV3Model`] + hidden_size (`int`, *optional*, defaults to 4096): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 11008): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1407): + Dimension of the MoE representations. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + num_nextn_predict_layers (`int`, *optional*, defaults to 1): + Number of nextn predict layers in the DeepSeekV3 Model. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + n_shared_experts (`int`, *optional*, defaults to None): + Number of shared experts, None means dense model. + n_routed_experts (`int`, *optional*, defaults to None): + Number of routed experts, None means dense model. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor or routed experts. + topk_method (`str`, *optional*, defaults to `gready`): + Topk method used in routed gate. + n_group (`int`, *optional*, defaults to None): + Number of groups for routed experts. + topk_group (`int`, *optional*, defaults to None): + Number of selected groups for each token(for each token, ensuring the selected experts is only within `topk_group` groups). + num_experts_per_tok (`int`, *optional*, defaults to None): + Number of selected experts, None means dense model. + moe_layer_freq (`int`, *optional*, defaults to 1): + The frequency of the MoE layer: one expert layer for every `moe_layer_freq - 1` dense layers. + first_k_dense_replace (`int`, *optional*, defaults to 0): + Number of dense layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + \--k dense layers--/ + norm_topk_prob (`bool`, *optional*, defaults to False): + Whether to normalize the weights of the routed experts. + scoring_func (`str`, *optional*, defaults to 'softmax'): + Method of computing expert weights. + aux_loss_alpha (`float`, *optional*, defaults to 0.001): + Auxiliary loss weight coefficient. + seq_aux = (`bool`, *optional*, defaults to True): + Whether to compute the auxiliary loss for each individual sample. + num_key_value_heads (`int`, *optional*): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details checkout [this + paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to + `num_attention_heads`. + hidden_act (`str` or `function`, *optional*, defaults to `"silu"`): + The non-linear activation function (function or string) in the decoder. + max_position_embeddings (`int`, *optional*, defaults to 2048): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + rms_norm_eps (`float`, *optional*, defaults to 1e-06): + The epsilon used by the rms normalization layers. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + pad_token_id (`int`, *optional*): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `False`): + Whether to tie weight embeddings + rope_theta (`float`, *optional*, defaults to 10000.0): + The base period of the RoPE embeddings. + rope_scaling (`Dict`, *optional*): + Dictionary containing the scaling configuration for the RoPE embeddings. Currently supports two scaling + strategies: linear and dynamic. Their scaling factor must be a float greater than 1. The expected format is + `{"type": strategy name, "factor": scaling factor}`. When using this flag, don't update + `max_position_embeddings` to the expected new maximum. + attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`): + Whether to use a bias in the query, key, value and output projection layers during self-attention. + attention_dropout (`float`, *optional*, defaults to 0.0): + The dropout ratio for the attention probabilities. + ```python + >>> from transformers import DeepseekV3Model, DeepseekV3Config + >>> # Initializing a Deepseek-V3 style configuration + >>> configuration = DeepseekV3Config() + >>> # Accessing the model configuration + >>> configuration = model.config + ```""" + + model_type = "deepseek_v3" + keys_to_ignore_at_inference = ["past_key_values"] + + def __init__( + self, + vocab_size=129280, + hidden_size=7168, + intermediate_size=18432, + moe_intermediate_size = 2048, + num_hidden_layers=61, + num_nextn_predict_layers=1, + num_attention_heads=128, + num_key_value_heads=128, + n_shared_experts = 1, + n_routed_experts = 256, + ep_size = 1, + routed_scaling_factor = 2.5, + kv_lora_rank = 512, + q_lora_rank = 1536, + qk_rope_head_dim = 64, + v_head_dim = 128, + qk_nope_head_dim = 128, + topk_method = 'noaux_tc', + n_group = 8, + topk_group = 4, + num_experts_per_tok = 8, + moe_layer_freq = 1, + first_k_dense_replace = 3, + norm_topk_prob = True, + scoring_func = 'sigmoid', + hidden_act="silu", + max_position_embeddings=4096, + initializer_range=0.02, + rms_norm_eps=1e-6, + use_cache=True, + pad_token_id=None, + bos_token_id=0, + eos_token_id=1, + tie_word_embeddings=False, + rope_theta=10000.0, + rope_scaling=None, + attention_bias=False, + attention_dropout=0.0, + **kwargs, + ): + self.vocab_size = vocab_size + self.max_position_embeddings = max_position_embeddings + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.moe_intermediate_size = moe_intermediate_size + self.num_hidden_layers = num_hidden_layers + self.num_nextn_predict_layers = num_nextn_predict_layers + self.num_attention_heads = num_attention_heads + self.n_shared_experts = n_shared_experts + self.n_routed_experts = n_routed_experts + self.ep_size = ep_size + self.routed_scaling_factor = routed_scaling_factor + self.kv_lora_rank = kv_lora_rank + self.q_lora_rank = q_lora_rank + self.qk_rope_head_dim = qk_rope_head_dim + self.v_head_dim = v_head_dim + self.qk_nope_head_dim = qk_nope_head_dim + self.topk_method = topk_method + self.n_group = n_group + self.topk_group = topk_group + self.num_experts_per_tok = num_experts_per_tok + self.moe_layer_freq = moe_layer_freq + self.first_k_dense_replace = first_k_dense_replace + self.norm_topk_prob = norm_topk_prob + self.scoring_func = scoring_func + # for backward compatibility + if num_key_value_heads is None: + num_key_value_heads = num_attention_heads + + self.num_key_value_heads = num_key_value_heads + self.hidden_act = hidden_act + self.initializer_range = initializer_range + self.rms_norm_eps = rms_norm_eps + self.use_cache = use_cache + self.rope_theta = rope_theta + self.rope_scaling = rope_scaling + self.attention_bias = attention_bias + self.attention_dropout = attention_dropout + + super().__init__( + pad_token_id=pad_token_id, + bos_token_id=bos_token_id, + eos_token_id=eos_token_id, + tie_word_embeddings=tie_word_embeddings, + **kwargs, + ) \ No newline at end of file diff --git a/vllm/utils/__init__.py b/vllm/utils/__init__.py index 968bba664f0a..a38d0d58e4ec 100644 --- a/vllm/utils/__init__.py +++ b/vllm/utils/__init__.py @@ -189,6 +189,7 @@ "fp8_e5m2": torch.uint8, "int8": torch.int8, "fp8_inc": torch.float8_e4m3fn, + "fp8_ds_mla": torch.uint8, } TORCH_DTYPE_TO_NUMPY_DTYPE = { diff --git a/vllm/utils/deep_gemm.py b/vllm/utils/deep_gemm.py index 38d92f01192b..9b95de373fe9 100644 --- a/vllm/utils/deep_gemm.py +++ b/vllm/utils/deep_gemm.py @@ -70,15 +70,26 @@ def _missing(*_: Any, **__: Any) -> NoReturn: _fp8_gemm_nt_impl: Callable[..., Any] | None = None _grouped_impl: Callable[..., Any] | None = None _grouped_masked_impl: Callable[..., Any] | None = None +_fp8_mqa_logits_impl: Callable[..., Any] | None = None +_fp8_paged_mqa_logits_impl: Callable[..., Any] | None = None +_get_paged_mqa_logits_metadata_impl: Callable[..., Any] | None = None def _lazy_init() -> None: """Import deep_gemm and resolve symbols on first use.""" global _fp8_gemm_nt_impl, _grouped_impl, _grouped_masked_impl + global _fp8_mqa_logits_impl, _fp8_paged_mqa_logits_impl + global _get_paged_mqa_logits_metadata_impl # fast path - if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None - or _grouped_masked_impl is not None): + if ( + _fp8_gemm_nt_impl is not None + or _grouped_impl is not None + or _grouped_masked_impl is not None + or _fp8_mqa_logits_impl is not None + or _fp8_paged_mqa_logits_impl is not None + or _get_paged_mqa_logits_metadata_impl is not None + ): return if not has_deep_gemm(): @@ -95,6 +106,17 @@ def _lazy_init() -> None: _fp8_gemm_nt_impl = getattr(_dg, "fp8_gemm_nt", None) _grouped_impl = getattr(_dg, "m_grouped_fp8_gemm_nt_contiguous", None) _grouped_masked_impl = getattr(_dg, "fp8_m_grouped_gemm_nt_masked", None) + _fp8_mqa_logits_impl = getattr(_dg, "fp8_mqa_logits", None) + _fp8_paged_mqa_logits_impl = getattr(_dg, "fp8_paged_mqa_logits", None) + _get_paged_mqa_logits_metadata_impl = getattr( + _dg, "get_paged_mqa_logits_metadata", None + ) + + +def get_num_sms() -> int: + _lazy_init() + _dg = importlib.import_module("deep_gemm") + return int(_dg.get_num_sms()) def fp8_gemm_nt(*args, **kwargs): @@ -123,6 +145,106 @@ def fp8_m_grouped_gemm_nt_masked(*args, **kwargs): *args, disable_ue8m0_cast=not is_deep_gemm_e8m0_used(), **kwargs) +def fp8_mqa_logits( + q: torch.Tensor, + kv: tuple[torch.Tensor, torch.Tensor], + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Compute FP8 MQA logits for a single sequence without KV paging. + + Args: + q: Query tensor of shape [M, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv: Tuple `(k_fp8, k_scales)` where `k_fp8` has shape [N, D] with + dtype `torch.float8_e4m3fn` and `k_scales` has shape [N] (or + [N, 1]) with dtype `torch.float32`. + weights: weights of shape [M, H], dtype `torch.float32`. + cu_seqlen_ks: Start indices (inclusive) for valid K per query position, + shape [M], dtype int32. + cu_seqlen_ke: End indices (exclusive) for valid K per query position, + shape [M], dtype int32. + + Returns: + Logits tensor of shape [M, N], dtype `torch.float32`. + """ + _lazy_init() + if _fp8_mqa_logits_impl is None: + return _missing() + return _fp8_mqa_logits_impl(q, kv, weights, cu_seqlen_ks, cu_seqlen_ke) + + + +def get_paged_mqa_logits_metadata( + context_lens: torch.Tensor, block_size: int, num_sms: int +) -> torch.Tensor: + """Build scheduling metadata for paged MQA logits. + + Args: + context_lens: Tensor of shape [B], dtype int32; effective context length + per batch element. + block_size: KV-cache block size in tokens (e.g., 64). + num_sms: Number of SMs available. 132 for Hopper + + Returns: + Backend-specific tensor consumed by `fp8_paged_mqa_logits` to + schedule work across SMs. + """ + _lazy_init() + if _get_paged_mqa_logits_metadata_impl is None: + return _missing() + return _get_paged_mqa_logits_metadata_impl( + context_lens, block_size, num_sms + ) + + +def fp8_paged_mqa_logits( + q_fp8: torch.Tensor, + kv_cache_fp8: torch.Tensor, + weights: torch.Tensor, + context_lens: torch.Tensor, + block_tables: torch.Tensor, + schedule_metadata: torch.Tensor, + max_model_len: int, +) -> torch.Tensor: + """Compute FP8 MQA logits using paged KV-cache. + + Args: + q_fp8: Query tensor of shape [B, next_n, H, D]. Casted to + `torch.float8_e4m3fn` by caller. + kv_cache_fp8: Paged KV-cache in packed FP8+scale layout with shape + [num_blocks, block_size, 1, D+4], dtype `torch.uint8`. The last + 4 bytes per (block,pos) store the `float` dequant scale. + weights: Tensor of shape [B * next_n, H], dtype `torch.float32`. + context_lens: Tensor of shape [B], dtype int32; effective context length + for each batch element. + block_tables: Tensor of shape [B, max_blocks], dtype int32; maps logical + block indices to physical blocks in the paged cache. + schedule_metadata: Returned by `get_paged_mqa_logits_metadata`; + used to distribute work across SMs. + max_model_len: Maximum sequence length used to size the logits output. + + Returns: + Logits tensor of shape [B * next_n, max_model_len], dtype + `torch.float32`. + """ + _lazy_init() + if _fp8_paged_mqa_logits_impl is None: + return _missing() + return _fp8_paged_mqa_logits_impl( + q_fp8, + kv_cache_fp8, + weights, + context_lens, + block_tables, + schedule_metadata, + max_model_len, + clean_logits=True + ) + + + def _ceil_to_ue8m0(x: torch.Tensor): return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) @@ -183,8 +305,12 @@ def should_use_deepgemm_for_fp8_linear(output_dtype: torch.dtype, "fp8_gemm_nt", "m_grouped_fp8_gemm_nt_contiguous", "fp8_m_grouped_gemm_nt_masked", + "fp8_mqa_logits", + "fp8_paged_mqa_logits", + "get_paged_mqa_logits_metadata", "per_block_cast_to_fp8", "is_deep_gemm_e8m0_used", "is_deep_gemm_supported", + "get_num_sms", "should_use_deepgemm_for_fp8_linear", ] diff --git a/vllm/v1/attention/backends/cpu_attn.py b/vllm/v1/attention/backends/cpu_attn.py index 6627164c9879..466e6320c591 100644 --- a/vllm/v1/attention/backends/cpu_attn.py +++ b/vllm/v1/attention/backends/cpu_attn.py @@ -79,6 +79,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return _get_paged_attn_impl().get_kv_cache_shape( num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 20f1904b3be6..5d407fca1ad9 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -82,6 +82,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/flashinfer.py b/vllm/v1/attention/backends/flashinfer.py index cb092aa74e7f..d5a00a3afebe 100755 --- a/vllm/v1/attention/backends/flashinfer.py +++ b/vllm/v1/attention/backends/flashinfer.py @@ -176,6 +176,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, 2, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/flex_attention.py b/vllm/v1/attention/backends/flex_attention.py index 662d3984554a..c0e5acdd245a 100644 --- a/vllm/v1/attention/backends/flex_attention.py +++ b/vllm/v1/attention/backends/flex_attention.py @@ -88,6 +88,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (2, num_blocks, block_size, num_kv_heads, head_size) diff --git a/vllm/v1/attention/backends/mla/common.py b/vllm/v1/attention/backends/mla/common.py index 5b307810de93..d8749aaab930 100755 --- a/vllm/v1/attention/backends/mla/common.py +++ b/vllm/v1/attention/backends/mla/common.py @@ -286,6 +286,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, # assumed to be 1 for MLA head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: return (num_blocks, block_size, head_size) @@ -407,6 +408,7 @@ def __post_init__(self): M = TypeVar("M", bound=MLACommonMetadata) +A = TypeVar("A") def use_flashinfer_prefill() -> bool: @@ -916,7 +918,9 @@ def reorg_kvcache( return reorganized_kv_c_normed, reorganized_k_pe -class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): +# TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl, +# and MLACommonImpl -> MLACommonDenseImpl or somthing like that +class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]): """ NOTE: Please read the comment at the top of the file before trying to understand this class @@ -942,6 +946,7 @@ def __init__( qk_head_dim: int, v_head_dim: int, kv_b_proj: ColumnParallelLinear, + indexer=None, ) -> None: if kv_sharing_target_layer_name is not None: raise NotImplementedError("KV sharing is not supported for MLA") @@ -959,6 +964,127 @@ def __init__( self.qk_head_dim = qk_head_dim self.v_head_dim = v_head_dim self.kv_b_proj = kv_b_proj + self.indexer = indexer + + def process_weights_after_loading(self, act_dtype: torch.dtype): + + def get_layer_weight(layer): + WEIGHT_NAMES = ("weight", "qweight", "weight_packed") + for attr in WEIGHT_NAMES: + if hasattr(layer, attr): + return getattr(layer, attr) + raise AttributeError( + f"Layer '{layer}' has no recognized weight attribute:" + f" {WEIGHT_NAMES}.") + + def get_and_maybe_dequant_weights(layer: LinearBase): + if not isinstance(layer.quant_method, UnquantizedLinearMethod): + # NOTE: This should only be used offline, since it's O(N^3) + eye = torch.eye(layer.input_size_per_partition, + dtype=act_dtype, + device=get_layer_weight(layer).device) + dequant_weights = layer.quant_method.apply(layer, + eye, + bias=None) + del eye + # standardize to (output, input) + return dequant_weights.T + return layer.weight + + # we currently do not have quantized bmm's which are needed for + # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform + # the bmm's in 16-bit, the extra memory overhead of this is fairly low + kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T + assert kv_b_proj_weight.shape == ( + self.kv_lora_rank, + self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( + f"{kv_b_proj_weight.shape=}, " + f"{self.kv_lora_rank=}, " + f"{self.num_heads=}, " + f"{self.qk_nope_head_dim=}, " + f"{self.v_head_dim=}") + kv_b_proj_weight = kv_b_proj_weight.view( + self.kv_lora_rank, + self.num_heads, + self.qk_nope_head_dim + self.v_head_dim, + ) + + W_UK, W_UV = kv_b_proj_weight.split( + [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + + if is_rocm_aiter_fp8bmm_enabled(): + W_K = W_UK.transpose(0, 1) # 16 512 128 + W_V = W_UV.permute(1, 2, 0) # 16 128 512 + self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( + W_K, dtype=current_platform.fp8_dtype()) + self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( + W_V, dtype=current_platform.fp8_dtype()) + + # The kernel operates on non-padded inputs. Hence, pre-compiling + # triton kernel to avoid runtime compilation for unseen batch sizes + # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. + # On DS-R1, this step adds roughly 50s to the model loading time. + max_batch_size = 1024 # [ToDo] Find the optimal upper limit + pre_compilation_list = list(range(1, max_batch_size + 1)) + if is_global_first_rank(): + pre_compilation_list = tqdm( + pre_compilation_list, + desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", + total=max_batch_size, + ) + + for m in pre_compilation_list: + x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), + dtype=torch.bfloat16, + device=self.W_K.device) + aiter_triton_fp8_bmm(x, + self.W_K, + self.W_K_scale, + group_size=128, + transpose_bm=True) + + x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), + dtype=torch.bfloat16, + device=self.W_V.device) + aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + else: + # Convert from (L, N, V) to (N, L, V) + self.W_UV = W_UV.transpose(0, 1) + # Convert from (L, N, P) to (N, P, L) + self.W_UK_T = W_UK.permute(1, 2, 0) + + def _v_up_proj(self, x): + # Convert from (B, N, L) to (N, B, L) + x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) + if is_rocm_aiter_fp8bmm_enabled(): + # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) + x = aiter_triton_fp8_bmm(x, + self.W_V, + self.W_V_scale, + group_size=128, + transpose_bm=True) + # Convert from (B, N, V) to (B, N * V) + x = x.reshape(-1, self.num_heads * self.v_head_dim) + else: + # Multiply (N, B, L) x (N, L, V) -> (N, B, V) + x = torch.bmm(x, self.W_UV) + # Convert from (N, B, V) to (B, N * V) + x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) + return x + + +class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]): + """ + NOTE: Please read the comment at the top of the file before trying to + understand this class + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) if use_flashinfer_prefill(): logger.debug_once("Using FlashInfer prefill for MLA") @@ -1134,116 +1260,6 @@ def _run_prefill_context_chunk_cudnn(self, True, #Indicates actual_seq_lens are on GPU or CPU. ) - def _v_up_proj(self, x): - # Convert from (B, N, L) to (N, B, L) - x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1) - if is_rocm_aiter_fp8bmm_enabled(): - # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V) - x = aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - # Convert from (B, N, V) to (B, N * V) - x = x.reshape(-1, self.num_heads * self.v_head_dim) - else: - # Multiply (N, B, L) x (N, L, V) -> (N, B, V) - x = torch.bmm(x, self.W_UV) - # Convert from (N, B, V) to (B, N * V) - x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim) - return x - - def process_weights_after_loading(self, act_dtype: torch.dtype): - - def get_layer_weight(layer): - WEIGHT_NAMES = ("weight", "qweight", "weight_packed") - for attr in WEIGHT_NAMES: - if hasattr(layer, attr): - return getattr(layer, attr) - raise AttributeError( - f"Layer '{layer}' has no recognized weight attribute:" - f" {WEIGHT_NAMES}.") - - def get_and_maybe_dequant_weights(layer: LinearBase): - if not isinstance(layer.quant_method, UnquantizedLinearMethod): - # NOTE: This should only be used offline, since it's O(N^3) - eye = torch.eye(layer.input_size_per_partition, - dtype=act_dtype, - device=get_layer_weight(layer).device) - dequant_weights = layer.quant_method.apply(layer, - eye, - bias=None) - del eye - # standardize to (output, input) - return dequant_weights.T - return layer.weight - - # we currently do not have quantized bmm's which are needed for - # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform - # the bmm's in 16-bit, the extra memory overhead of this is fairly low - kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T - assert kv_b_proj_weight.shape == ( - self.kv_lora_rank, - self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), ( - f"{kv_b_proj_weight.shape=}, " - f"{self.kv_lora_rank=}, " - f"{self.num_heads=}, " - f"{self.qk_nope_head_dim=}, " - f"{self.v_head_dim=}") - kv_b_proj_weight = kv_b_proj_weight.view( - self.kv_lora_rank, - self.num_heads, - self.qk_nope_head_dim + self.v_head_dim, - ) - - W_UK, W_UV = kv_b_proj_weight.split( - [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - - if is_rocm_aiter_fp8bmm_enabled(): - W_K = W_UK.transpose(0, 1) # 16 512 128 - W_V = W_UV.permute(1, 2, 0) # 16 128 512 - self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant( - W_K, dtype=current_platform.fp8_dtype()) - self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant( - W_V, dtype=current_platform.fp8_dtype()) - - # The kernel operates on non-padded inputs. Hence, pre-compiling - # triton kernel to avoid runtime compilation for unseen batch sizes - # Pre-compile for batch sizes 1 to 1024 to cover most use-cases. - # On DS-R1, this step adds roughly 50s to the model loading time. - max_batch_size = 1024 # [ToDo] Find the optimal upper limit - pre_compilation_list = list(range(1, max_batch_size + 1)) - if is_global_first_rank(): - pre_compilation_list = tqdm( - pre_compilation_list, - desc="[Aiter Triton] Pre-compiling fp8 BMM kernel", - total=max_batch_size, - ) - - for m in pre_compilation_list: - x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]), - dtype=torch.bfloat16, - device=self.W_K.device) - aiter_triton_fp8_bmm(x, - self.W_K, - self.W_K_scale, - group_size=128, - transpose_bm=True) - - x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]), - dtype=torch.bfloat16, - device=self.W_V.device) - aiter_triton_fp8_bmm(x, - self.W_V, - self.W_V_scale, - group_size=128, - transpose_bm=True) - else: - # Convert from (L, N, V) to (N, L, V) - self.W_UV = W_UV.transpose(0, 1) - # Convert from (L, N, P) to (N, P, L) - self.W_UK_T = W_UK.permute(1, 2, 0) - def _compute_prefill_context( self, q: torch.Tensor, @@ -1424,6 +1440,7 @@ def _forward_prefill( attn_metadata: MLACommonMetadata, k_scale: torch.Tensor, ) -> torch.Tensor: + # TODO (zyongye): Prefill function here assert attn_metadata.prefill is not None assert self.dcp_world_size is not None diff --git a/vllm/v1/attention/backends/mla/flashmla.py b/vllm/v1/attention/backends/mla/flashmla.py index 150e38553e4b..bb145e1f4a29 100644 --- a/vllm/v1/attention/backends/mla/flashmla.py +++ b/vllm/v1/attention/backends/mla/flashmla.py @@ -177,6 +177,7 @@ def _forward_decode( attn_metadata: FlashMLAMetadata, layer: AttentionLayer, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: + # TODO: (zyongye) decode function for mla here assert kv_c_and_k_pe_cache.numel() > 0 assert attn_metadata.decode is not None diff --git a/vllm/v1/attention/backends/mla/flashmla_sparse.py b/vllm/v1/attention/backends/mla/flashmla_sparse.py new file mode 100644 index 000000000000..93e40ac44658 --- /dev/null +++ b/vllm/v1/attention/backends/mla/flashmla_sparse.py @@ -0,0 +1,552 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import math +from dataclasses import dataclass +from typing import TYPE_CHECKING, ClassVar, Optional + +import numpy as np +import torch + +from vllm import _custom_ops as ops +from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, + AttentionMetadata) +from vllm.attention.backends.utils import get_mla_dims +from vllm.attention.ops.flashmla import (flash_mla_sparse_prefill, + flash_mla_with_kvcache, + get_mla_metadata) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.triton_utils import tl, triton +from vllm.utils import cdiv +from vllm.v1.attention.backends.mla.common import MLACommonBaseImpl +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata) +from vllm.v1.kv_cache_interface import AttentionSpec +from vllm.platforms import current_platform + +if TYPE_CHECKING: + from vllm.model_executor.models.deepseek_v2 import Indexer + +logger = init_logger(__name__) +""" +NOTE: FlashMLA Sparse uses an fp8 cache with the following format + +In the "FP8 with scale" format, each token's KV cache is 656 Bytes, +structured as: +- **First 512 bytes:** The "quantized NoPE" part, containing 512 + `float8_e4m3` values. +- **Next 16 bytes:** Scale factors, containing 4 `float32` values. + The first `float32` is the scale for the first 128 `float8_e4m3` values, + the second for the next 128, and so on. +- **Last 128 bytes:** The "RoPE" part, containing 64 `bfloat16` values. This + part is not quantized for accuracy. +""" + + +def _lse2_to_lse(lse_base2: torch.Tensor) -> torch.Tensor: + # Convert base-2 LSE to natural-log LSE + # Keep FP32 for numerical stability during the merge. + return (lse_base2.to(torch.float32) * math.log(2.0)) + + +class FlashMLASparseBackend(AttentionBackend): + + accept_output_buffer: bool = True + + @staticmethod + def get_name() -> str: + return "FLASHMLA_SPARSE_VLLM_V1" + + @staticmethod + def get_metadata_cls() -> type[AttentionMetadata]: + return FlashMLASparseMetadata + + @staticmethod + def get_builder_cls() -> type["FlashMLASparseMetadataBuilder"]: + return FlashMLASparseMetadataBuilder + + @staticmethod + def get_impl_cls() -> type["FlashMLASparseImpl"]: + return FlashMLASparseImpl + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, # assumed to be 1 for MLA + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + if cache_dtype_str == "fp8_ds_mla": + # custom storage fromat is 656 bytes + # see FlashMLA readme.md for details + return (num_blocks, block_size, 656) + else: + return (num_blocks, block_size, head_size) + + @classmethod + def get_supported_dtypes(cls) -> list[torch.dtype]: + return [torch.bfloat16] + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [576] + + +@dataclass +class MLASparsePrefillMetadata: + # NOTE(Chen): not call it "FlashMLASparsePrefillMetadata" because + # the kernel is not from flashmla + block_table: torch.Tensor + has_context: bool = False + context_lens: Optional[torch.Tensor] = None + + +@dataclass +class FlashMLASparseDecodeAndContextMetadata: + scheduler_metadata: torch.Tensor = None + num_splits: torch.Tensor = None + cache_lens: torch.Tensor = None + prefill_context_lengths: Optional[torch.Tensor] = None + prefill_new_k_start_locs: Optional[torch.Tensor] = None + dummy_block_table: torch.Tensor = None + + def filter_prefill_indices( + self, indices: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + assert self.prefill_context_lengths is not None + prefill_context_lengths = self.prefill_context_lengths.unsqueeze(-1) + context_indices = torch.where(indices < prefill_context_lengths, + indices, -1) + new_token_indices = torch.where(indices >= prefill_context_lengths, + indices - prefill_context_lengths, -1) + return context_indices, new_token_indices + + +@dataclass +class FlashMLASparseMetadata: + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + + block_table: torch.Tensor + req_id_per_token: torch.Tensor + block_size: int = 64 + topk_tokens: int = 2048 + + @dataclass + class FP8KernelMetadata: + scheduler_metadata: Optional[torch.Tensor] + num_splits: torch.Tensor + dummy_block_table: torch.Tensor + cache_lens: torch.Tensor + + fp8_extra_metadata: Optional[FP8KernelMetadata] = None + + +@triton.jit +def _convert_req_index_to_global_index_kernel( + req_id_ptr, # int32 [num_tokens] + block_table_ptr, # int32 [num_requests, max_num_blocks_per_req] + token_indices_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + out_ptr, # int32 [num_tokens, NUM_TOPK_TOKENS] + # shapes (compile-time where possible) + max_num_blocks_per_req: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_N: tl.constexpr, # tile width along columns + # strides (in elements) + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, +): + # program_id(0) -> token_id (row) + # program_id(1) -> tile index along columns + token_id = tl.program_id(0) + tile_id = tl.program_id(1) + + # Each program covers BLOCK_N consecutive columns + indice_id = tile_id * BLOCK_N + tl.arange(0, BLOCK_N) + + # Load request id for this token (no mask: grid is exact) + req = tl.load(req_id_ptr + token_id) + + # Load token indices for this tile + ti_ptr = token_indices_ptr + token_id * ti_stride0 + indice_id * ti_stride1 + tok = tl.load(ti_ptr) # int32 + + # Only token == -1 should propagate as -1 + is_invalid_tok = tok < 0 + + # Compute block id and in-block offset + block_id = tok // BLOCK_SIZE + inblock_off = tok % BLOCK_SIZE + + # Guard block_table access + valid_block = block_id < max_num_blocks_per_req + bt_ptr = block_table_ptr + req * bt_stride0 + block_id * bt_stride1 + base = tl.load(bt_ptr, mask=valid_block, other=0) + + # If token == -1 OR block_id OOB, output -1; else base * BLOCK_SIZE + offset + out_val = tl.where(is_invalid_tok | (~valid_block), -1, + base * BLOCK_SIZE + inblock_off) + + # Store results + out_ptr_ij = out_ptr + token_id * out_stride0 + indice_id * out_stride1 + tl.store(out_ptr_ij, out_val) + + +def triton_convert_req_index_to_global_index( + req_id: torch.Tensor, # int32 [num_tokens] + block_table: torch. + Tensor, # int32 [num_requests, max_num_blocks_per_req] + token_indices: torch.Tensor, # int32 [num_tokens, NUM_TOPK_TOKENS] + BLOCK_SIZE: int = 64, + NUM_TOPK_TOKENS: int = 2048, + BLOCK_N: int = 128, # tile width along columns +): + """ + out[token_id, indice_id] = + block_table[req_id[token_id], + token_indices[token_id, indice_id] // BLOCK_SIZE] * BLOCK_SIZE + + token_indices[token_id, indice_id] % BLOCK_SIZE + + Only when token_indices[token_id, indice_id] == -1 do we output -1. + For safety, we also output -1 if the derived block_id would be + out-of-bounds. + """ + assert req_id.dtype == torch.int32 + assert block_table.dtype == torch.int32 + assert token_indices.dtype == torch.int32 + assert token_indices.shape[1] == NUM_TOPK_TOKENS + assert NUM_TOPK_TOKENS % BLOCK_N == 0, \ + f"NUM_TOPK_TOKENS ({NUM_TOPK_TOKENS}) must be divisible by" \ + f"BLOCK_N ({BLOCK_N})" + + num_tokens = req_id.shape[0] + num_requests, max_num_blocks_per_req = block_table.shape + tiles_per_row = NUM_TOPK_TOKENS // BLOCK_N + + # Ensure contiguous tensors on the same device + req_id_c = req_id.contiguous() + block_table_c = block_table.contiguous() + token_indices_c = token_indices.contiguous() + out = torch.empty_like(token_indices_c) + + # Strides in elements + bt_stride0, bt_stride1 = block_table_c.stride() + ti_stride0, ti_stride1 = token_indices_c.stride() + out_stride0, out_stride1 = out.stride() + + # Exact 2D grid: tokens × column tiles + grid = (num_tokens, tiles_per_row) + + _convert_req_index_to_global_index_kernel[grid]( + req_id_c, + block_table_c, + token_indices_c, + out, + # shapes / constexprs + max_num_blocks_per_req, + BLOCK_SIZE, + BLOCK_N, + # strides + bt_stride0, + bt_stride1, + ti_stride0, + ti_stride1, + out_stride0, + out_stride1, + ) + return out + + +@dataclass +class FlashMLASparseMetadataBuilder( + AttentionMetadataBuilder[FlashMLASparseMetadata]): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 1 + + def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str], + vllm_config: VllmConfig, device: torch.device): + + cache_config = vllm_config.cache_config + self.kv_cache_spec = kv_cache_spec + self.model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.device = device + + props = torch.cuda.get_device_properties(device) + sm_count = props.multi_processor_count + + self.num_heads = self.model_config.get_num_attention_heads( + parallel_config) + self.mla_dims = get_mla_dims(self.model_config) + self.topk_tokens = vllm_config.model_config.hf_config.index_topk + self.use_fp8_kv_cache = cache_config.cache_dtype == "fp8_ds_mla" + self.topk_tokens_tensor = torch.tensor([self.topk_tokens], + device=device, + dtype=torch.int32) + self.max_model_len_tensor = torch.tensor( + [self.model_config.max_model_len], + device=device, + dtype=torch.int32) + # this is ignored by `flash_mla_with_kvcache` if indices not None + self.dummy_block_table = torch.empty((1, 1), + dtype=torch.int32, + device=self.device) + self.num_speculative_tokens = ( + vllm_config.speculative_config.num_speculative_tokens + if vllm_config.speculative_config else 0 + ) + # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 + self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) + + # Equation taken from FlashMLA/csrc/pybind.cpp + h_q, h_k = self.num_heads, 1 + s_q = 1 # inversely proportional to s_q, so s_q = 1 is the largest + max_num_sm_parts = int( + max((sm_count // 2) / h_k // (cdiv(h_q // h_k, 2 * 64) * s_q), 1)) + if current_platform.is_device_capability(100): + max_num_sm_parts *= 2 + self.tile_scheduler_metadata_buffer = torch.empty( + # TileSchedulerMetaDataSize = 8 + # see: FlashMLA/csrc/params.h + (max_num_sm_parts, 8), + dtype=torch.int32, + device=device) + self.num_splits_buffer = torch.empty( + # We pack all the tokens into one batch for sparse attention. + # Otherwise, we can exceed the sm of `get_mla_metadata`. + (2, ), + dtype=torch.int32, + device=device) + self.req_id_per_token_buffer = torch.empty( + (vllm_config.scheduler_config.max_num_batched_tokens, ), + dtype=torch.int32, + device=device) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> FlashMLASparseMetadata: + + num_tokens = common_attn_metadata.num_actual_tokens + starts = np.asarray(common_attn_metadata.query_start_loc_cpu, + dtype=np.int32) + seg_lengths = np.diff(starts) + req_id_per_token = np.repeat( + np.arange(seg_lengths.shape[0], dtype=np.int32), seg_lengths) + # Zero-fill for cudagraphs + self.req_id_per_token_buffer.fill_(0) + self.req_id_per_token_buffer[:req_id_per_token.shape[0]]\ + .copy_(torch.from_numpy(req_id_per_token), non_blocking=True) + req_id_per_token = self.req_id_per_token_buffer[:num_tokens] + + fp8_extra_metadata = None + if self.use_fp8_kv_cache: + tile_scheduler_metadata, num_splits = get_mla_metadata( + cache_seqlens=self.topk_tokens_tensor, + num_q_tokens_per_head_k=num_tokens * self.num_heads, + topk=self.topk_tokens, + num_heads_q=self.num_heads, + num_heads_k=1, + is_fp8_kvcache=True, + ) + + num_sm_parts = tile_scheduler_metadata.size(0) + # Copy to persistent buffer for full-CG support + tile_scheduler_metadata_buffer = \ + self.tile_scheduler_metadata_buffer[:num_sm_parts] + tile_scheduler_metadata_buffer.copy_(tile_scheduler_metadata) + self.num_splits_buffer.copy_(num_splits) + + fp8_extra_metadata = FlashMLASparseMetadata.FP8KernelMetadata( + scheduler_metadata=tile_scheduler_metadata_buffer, + num_splits=self.num_splits_buffer, + # cache_lens and block_table are basically unused in sparse case + # but the decode kernel will treat -1 and indices >= cache_lens + # as invalid so we make sure cache_lens is large enough to not + # accidentally mark indices invalid, we will use -1 exclusively + # to mark invalid indices + cache_lens=self.max_model_len_tensor, + dummy_block_table=self.dummy_block_table) + + metadata = FlashMLASparseMetadata( + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + block_table=common_attn_metadata.block_table_tensor, + req_id_per_token=req_id_per_token, + block_size=self.kv_cache_spec.block_size, + topk_tokens=self.topk_tokens, + fp8_extra_metadata=fp8_extra_metadata, + ) + return metadata + + +class FlashMLASparseImpl(MLACommonBaseImpl[FlashMLASparseMetadata]): + + def __init__( + self, + num_heads: int, + head_size: int, + scale: float, + num_kv_heads: int, + alibi_slopes: Optional[list[float]], + sliding_window: Optional[int], + kv_cache_dtype: str, + logits_soft_cap: Optional[float], + attn_type: str, + kv_sharing_target_layer_name: Optional[str], + # MLA Specific Arguments + topk_indice_buffer: Optional[torch.Tensor] = None, + indexer: Optional["Indexer"] = None, + **mla_args) -> None: + super().__init__(num_heads, head_size, scale, num_kv_heads, + alibi_slopes, sliding_window, kv_cache_dtype, + logits_soft_cap, attn_type, + kv_sharing_target_layer_name, **mla_args) + self.softmax_scale = scale + assert indexer is not None + self.topk_indices_buffer = indexer.topk_indices_buffer + self.padding = 128 if current_platform.is_device_capability(100) else 64 + + def _forward_bf16_kv( + self, q: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + num_tokens = q.shape[0] + kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.view( + -1, 1, kv_c_and_k_pe_cache.shape[-1]) + + # NOTE(Chen): kernel requires num_local_head to be a multiple of + # 64 on hopper and 128 on blackwell + if self.num_heads % self.padding != 0: + assert self.padding % self.num_heads == 0 + logger.warning_once( + f"padding num_heads to {self.padding} due to sparse attn kernel requirement" + ) + q_padded = q.new_empty((q.shape[0], self.padding, q.shape[2])) + q_padded[:, :self.num_heads, :] = q + q = q_padded + + topk_indices = topk_indices.view(num_tokens, 1, -1) + output = flash_mla_sparse_prefill(q, kv_c_and_k_pe_cache, topk_indices, + self.softmax_scale)[0] + output = output[:, :self.num_heads, :] + return output + + def _forward_fp8_kv(self, q: torch.Tensor, + kv_c_and_k_pe_cache: torch.Tensor, + topk_indices: torch.Tensor, + attn_metadata: FlashMLASparseMetadata) -> torch.Tensor: + + assert attn_metadata.fp8_extra_metadata is not None + extra_metadata = attn_metadata.fp8_extra_metadata + + _attn_out, _ = flash_mla_with_kvcache( + q=q.unsqueeze(0), # unsqueeze to add batch_dim + k_cache=kv_c_and_k_pe_cache.view(torch.uint8).unsqueeze(-2), + block_table=extra_metadata.dummy_block_table, + head_dim_v=512, + cache_seqlens=extra_metadata.cache_lens, + tile_scheduler_metadata=extra_metadata.scheduler_metadata, + num_splits=extra_metadata.num_splits, + is_fp8_kvcache=True, + indices=topk_indices.unsqueeze(0), # unsqueeze to add batch_dim + softmax_scale=self.softmax_scale, + ) + + return _attn_out + + def forward( + self, + layer: AttentionLayer, + q: torch.Tensor, + k_c_normed: torch.Tensor, # key in unified attn + k_pe: torch.Tensor, # value in unified attn + kv_cache: torch.Tensor, + attn_metadata: FlashMLASparseMetadata, + output: Optional[torch.Tensor] = None, + output_scale: Optional[torch.Tensor] = None, + output_block_scale: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use + # MQA 576/512 approach for both prefill and decode (see: + # https://vllm-dev.slack.com/archives/C09GKA1D4LR/p1758506094148479) + + assert output is not None, "Output tensor must be provided." + + if output_scale is not None or output_block_scale is not None: + raise NotImplementedError( + "fused output quantization is not yet supported" + " for MLACommonImpl") + + if attn_metadata is None: + # The zero fill is required when used with DP + EP + # to ensure all ranks within a DP group compute the + # same expert outputs. + return output.fill_(0) + + num_actual_toks = attn_metadata.num_actual_tokens + + # Inputs and outputs may be padded for CUDA graphs + + q = q[:num_actual_toks, ...] + k_c_normed = k_c_normed[:num_actual_toks, ...] + k_pe = k_pe[:num_actual_toks, ...] + + q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], + dim=-1) + # Convert from (B, N, P) to (N, B, P) + q_nope = q_nope.transpose(0, 1) + # Multiply (N, B, P) x (N, P, L) -> (N, B, L) + ql_nope = torch.bmm(q_nope, self.W_UK_T) + # Convert from (N, B, L) to (B, N, L) + ql_nope = ql_nope.transpose(0, 1) + + topk_indices = self.topk_indices_buffer[:num_actual_toks] + + # TODO: handle index / kv_cache correctly + topk_indices_global = triton_convert_req_index_to_global_index( + attn_metadata.req_id_per_token, + attn_metadata.block_table, + topk_indices, + BLOCK_SIZE=attn_metadata.block_size, + NUM_TOPK_TOKENS=attn_metadata.topk_tokens, + ) + + q = torch.cat([ql_nope, q_pe], dim=-1) + + # write the latent and rope to kv cache + if kv_cache.numel() > 0: + ops.concat_and_cache_mla( + k_c_normed, + k_pe.squeeze(1), + kv_cache, + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype=self.kv_cache_dtype, + scale=layer._k_scale, + ) + + if self.kv_cache_dtype != "fp8_ds_mla": + attn_out = self._forward_bf16_kv(q, kv_cache, topk_indices_global, + attn_metadata) + else: + attn_out = self._forward_fp8_kv(q, kv_cache, topk_indices_global, + attn_metadata) + + output[:num_actual_toks] = self._v_up_proj(attn_out) + return output diff --git a/vllm/v1/attention/backends/mla/indexer.py b/vllm/v1/attention/backends/mla/indexer.py new file mode 100644 index 000000000000..eda23c701a50 --- /dev/null +++ b/vllm/v1/attention/backends/mla/indexer.py @@ -0,0 +1,283 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from dataclasses import dataclass +from typing import ClassVar, Optional + +import torch + +from vllm.attention.backends.abstract import (AttentionBackend, + AttentionMetadata) +from vllm.config import VllmConfig +from vllm.logger import init_logger +from vllm.utils.deep_gemm import get_paged_mqa_logits_metadata +from vllm.v1.attention.backends.utils import (AttentionCGSupport, + AttentionMetadataBuilder, + CommonAttentionMetadata, + split_decodes_and_prefills) + +logger = init_logger(__name__) + + +class DeepseekV32IndexerBackend(AttentionBackend): + + @staticmethod + def get_metadata_cls() -> type["AttentionMetadata"]: + return DeepseekV32IndexerMetadata + + @classmethod + def get_supported_head_sizes(cls) -> list[int]: + return [32, 64, 128] + + @staticmethod + def get_builder_cls() -> type["DeepseekV32IndexerMetadataBuilder"]: + return DeepseekV32IndexerMetadataBuilder + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + cache_dtype_str: str = "auto", + ) -> tuple[int, ...]: + assert num_kv_heads == 1 + return (num_blocks, block_size, head_size) + + @staticmethod + def get_kv_cache_stride_order() -> tuple[int, ...]: + return (0, 1, 2) + + +@dataclass +class DeepseekV32IndexerPrefillMetadata: + block_table: torch.Tensor + query_start_loc: torch.Tensor + max_query_len: int + cu_seqlen_ks: torch.Tensor + cu_seqlen_ke: torch.Tensor + cu_seq_lens: torch.Tensor + total_seq_lens: int + + +@dataclass +class DeepSeekV32IndexerDecodeMetadata: + block_table: torch.Tensor + seq_lens: torch.Tensor + decode_lens: torch.Tensor + requires_padding: bool + schedule_metadata: torch.Tensor + + +@dataclass +class DeepseekV32IndexerMetadata: + + #FIXME (zyongye) hacky way to access the data now, need to be in chunked meta + seq_lens: torch.Tensor + + num_reqs: int + max_query_len: int + max_seq_len: int + + num_actual_tokens: int # Number of tokens excluding padding. + query_start_loc: torch.Tensor + slot_mapping: torch.Tensor + # The dimension of the attention heads + head_dim: int + + # New for MLA (compared to FlashAttention) + # For handling prefill decode split + num_decodes: int + num_decode_tokens: int + num_prefills: int + num_prefill_tokens: int + + decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None + prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None + + +# TODO (zyongye) optimize this, this is now vibe coded +def kv_spans_from_batches(start_seq_loc: torch.Tensor, + seq_len_per_batch: torch.Tensor): + """ + Args: + start_seq_loc: 1D long tensor [B+1], cumulative counts of selected tokens per batch. + Example: [0, 2, 4, 7] -> batch sizes (selected) [2, 2, 3], N=7 tokens total. + seq_len_per_batch: 1D long tensor [B], full sequence length (KV length) of each batch. + Example: [5, 9, 4]. + + Returns: + start_tensor: 1D long tensor [N], start offset in the concatenated KV cache for each token's batch. + end_location: 1D long tensor [N], **exclusive** end = start + token's local position. + (So the attended KV slice is kv[start:end].) + + Assumes each batch contributes its full `seq_len_per_batch[i]` keys to the KV cache, and + the selected tokens within a batch are the **last** `counts[i]` positions of that sequence. + """ + q = start_seq_loc.to(dtype=torch.long) + L = seq_len_per_batch.to(dtype=torch.long, device=q.device) + assert q.dim() == 1 and L.dim() == 1 + assert q.numel() == L.numel() + 1, "start_seq_loc must have length B+1" + + # Selected tokens per batch and totals + counts = q[1:] - q[:-1] # [B] + N = int(q[-1].item()) # total selected tokens + B = L.numel() + device = L.device + + if N == 0: + return (torch.empty(0, dtype=torch.long, device=device), + torch.empty(0, dtype=torch.long, device=device)) + + # KV start offsets per batch in the concatenated KV cache + kv_starts_per_batch = torch.cumsum(L, dim=0) - L # [B] + + # For each selected token, which batch does it belong to? + batch_id = torch.repeat_interleave(torch.arange(B, device=device), + counts) # [N] + + # Map batch KV start to each token + start_tensor = kv_starts_per_batch[batch_id] # [N] + + # End-align local positions inside each batch: + # local_pos = L[b] - counts[b] + (1..counts[b]) for each batch b + L_expand = torch.repeat_interleave(L, counts) # [N] + m_expand = torch.repeat_interleave(counts, counts) # [N] + # position within the selected block: 1..counts[b] + pos_within = (torch.arange(N, device=device, dtype=torch.long) - + torch.repeat_interleave(q[:-1], counts) + 1) + + local_pos = L_expand - m_expand + pos_within # [N], 1-based + end_location = start_tensor + local_pos # exclusive end + + return start_tensor.int(), end_location.int() + + +def get_max_prefill_buffer_size(vllm_config: VllmConfig): + max_model_len = vllm_config.model_config.max_model_len + max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens + max_num_seq = vllm_config.scheduler_config.max_num_seqs + # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. + return max_model_len * max_num_seq + + +class DeepseekV32IndexerMetadataBuilder(AttentionMetadataBuilder): + cudagraph_support: ClassVar[AttentionCGSupport] = \ + AttentionCGSupport.UNIFORM_BATCH + + reorder_batch_threshold: ClassVar[int] = 1 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + scheduler_config = self.vllm_config.scheduler_config + # NOTE(Chen): an estimated max size of flattened_kv. Need to double check. + self.max_prefill_buffer_size = get_max_prefill_buffer_size( + self.vllm_config) + self.num_speculative_tokens = ( + self.vllm_config.speculative_config.num_speculative_tokens + if self.vllm_config.speculative_config else 0 + ) + # Now deepgemm fp8_paged_mqa_logits does not support next_n > 2 + self.reorder_batch_threshold += min(self.num_speculative_tokens, 1) + + props = torch.cuda.get_device_properties(self.device) + sm_count = props.multi_processor_count + self.num_sms = sm_count + + self.decode_lens_buffer = torch.empty( + (scheduler_config.max_num_seqs, ), + dtype=torch.int32, + device=self.device) + + # See: DeepGMM/csrc/apis/attention.hpp + self.scheduler_metadata_buffer = torch.empty( + (self.num_sms + 1, 2), dtype=torch.int32, device=self.device + ) + + def build(self, + common_prefix_len: int, + common_attn_metadata: CommonAttentionMetadata, + fast_build: bool = False) -> DeepseekV32IndexerMetadata: + + num_reqs = common_attn_metadata.num_reqs + num_tokens = common_attn_metadata.num_actual_tokens + + device = self.device + block_table_tensor = common_attn_metadata.block_table_tensor + + query_start_loc = common_attn_metadata.query_start_loc + + num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \ + split_decodes_and_prefills( + common_attn_metadata, + decode_threshold=self.reorder_batch_threshold) + + assert num_decodes + num_prefills == num_reqs + assert num_decode_tokens + num_prefill_tokens == num_tokens + + prefill_metadata = None + if num_prefills > 0: + reqs_start = num_decodes + prefill_query_start_loc = query_start_loc[ + reqs_start:] - query_start_loc[reqs_start] + cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches( + prefill_query_start_loc, + common_attn_metadata.seq_lens[reqs_start:]) + total_seq_lens = common_attn_metadata.seq_lens[reqs_start:].sum() + assert total_seq_lens < self.max_prefill_buffer_size + cu_seq_lens = torch.cat([ + torch.zeros(1, dtype=torch.int32, device=device), + common_attn_metadata.seq_lens[reqs_start:].cumsum(dim=0) + ]).to(torch.int32).cuda() + prefill_metadata = DeepseekV32IndexerPrefillMetadata( + block_table=block_table_tensor[reqs_start:, ...], + query_start_loc=prefill_query_start_loc, + max_query_len=common_attn_metadata.max_query_len, + cu_seqlen_ks=cu_seqlen_ks, + cu_seqlen_ke=cu_seqlen_ke, + cu_seq_lens=cu_seq_lens, + total_seq_lens=total_seq_lens, + ) + + decode_metadata = None + if num_decodes > 0: + torch.diff(common_attn_metadata.query_start_loc[:num_decodes+1], + out=self.decode_lens_buffer[:num_decodes]) + decode_lens = self.decode_lens_buffer[:num_decodes] + decode_lens_cpu = torch.diff( + common_attn_metadata.query_start_loc_cpu[:num_decodes+1]) + + # Use CPU to avoid GPU sync; breaking async scheduling + requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item() + + seq_lens = common_attn_metadata.seq_lens[:num_decodes] + + self.scheduler_metadata_buffer[:] = get_paged_mqa_logits_metadata( + seq_lens, self.kv_cache_spec.block_size, self.num_sms) + decode_metadata = DeepSeekV32IndexerDecodeMetadata( + block_table=common_attn_metadata. + block_table_tensor[:num_decodes, ...], + seq_lens=common_attn_metadata.seq_lens[:num_decodes], + decode_lens=decode_lens, + requires_padding=requires_padding, + schedule_metadata=self.scheduler_metadata_buffer, + ) + + attn_metadata = DeepseekV32IndexerMetadata( + seq_lens=common_attn_metadata.seq_lens, + num_reqs=common_attn_metadata.num_reqs, + max_query_len=common_attn_metadata.max_query_len, + max_seq_len=common_attn_metadata.max_seq_len, + num_actual_tokens=common_attn_metadata.num_actual_tokens, + query_start_loc=common_attn_metadata.query_start_loc, + slot_mapping=common_attn_metadata.slot_mapping, + head_dim=128, + num_decodes=num_decodes, + num_decode_tokens=num_decode_tokens, + num_prefills=num_prefills, + num_prefill_tokens=num_prefill_tokens, + prefill=prefill_metadata, + decode=decode_metadata, + ) + + # if get_tensor_model_parallel_rank() == 0: + # logger.info(f"attn_metadata: {attn_metadata}") + return attn_metadata diff --git a/vllm/v1/attention/backends/pallas.py b/vllm/v1/attention/backends/pallas.py index 26f9abf13d0e..f05c3a7e93a9 100644 --- a/vllm/v1/attention/backends/pallas.py +++ b/vllm/v1/attention/backends/pallas.py @@ -107,6 +107,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: padded_head_size = cdiv( head_size, TPU_HEAD_SIZE_ALIGNMENT) * TPU_HEAD_SIZE_ALIGNMENT diff --git a/vllm/v1/attention/backends/rocm_aiter_fa.py b/vllm/v1/attention/backends/rocm_aiter_fa.py index afb2283c44d3..1fffe4a6b191 100644 --- a/vllm/v1/attention/backends/rocm_aiter_fa.py +++ b/vllm/v1/attention/backends/rocm_aiter_fa.py @@ -360,6 +360,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/tree_attn.py b/vllm/v1/attention/backends/tree_attn.py index 10238f36455d..14e3c57a8683 100644 --- a/vllm/v1/attention/backends/tree_attn.py +++ b/vllm/v1/attention/backends/tree_attn.py @@ -68,6 +68,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 784912a122f6..2dbccb1d284d 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -179,6 +179,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 63326d19194f..6336ba1d2629 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -694,7 +694,6 @@ def split_decodes_and_prefills( return num_reqs, 0, num_tokens, 0 first_prefill = is_prefill.int().argmax(dim=-1).item() - assert torch.all(query_lens[first_prefill:] > decode_threshold) assert torch.all(query_lens[:first_prefill] <= decode_threshold) num_decodes = first_prefill num_prefills = num_reqs - num_decodes diff --git a/vllm/v1/attention/backends/xformers.py b/vllm/v1/attention/backends/xformers.py index a6ca33491235..88ecdfcd00f5 100644 --- a/vllm/v1/attention/backends/xformers.py +++ b/vllm/v1/attention/backends/xformers.py @@ -106,6 +106,7 @@ def get_kv_cache_shape( block_size: int, num_kv_heads: int, head_size: int, + cache_dtype_str: str = "auto", ) -> tuple[int, ...]: if block_size % 16 != 0: raise ValueError("Block size must be a multiple of 16.") diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 47a41322c423..2ff1bb681d80 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1103,7 +1103,9 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): kv_cache_spec: The kv cache spec of each attention layer in the model """ - if is_kv_cache_spec_uniform(kv_cache_spec): + if is_kv_cache_spec_uniform( + kv_cache_spec) or UniformTypeKVCacheSpecs.is_uniform_type( + kv_cache_spec): return logger.warning( @@ -1128,7 +1130,6 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, sliding_window=spec.sliding_window, ) elif isinstance(spec, ChunkedLocalAttentionSpec): @@ -1137,11 +1138,11 @@ def unify_hybrid_kv_cache_specs(kv_cache_spec: dict[str, KVCacheSpec]): num_kv_heads=spec.num_kv_heads, head_size=spec.head_size, dtype=spec.dtype, - use_mla=spec.use_mla, attention_chunk_size=spec.attention_chunk_size, ) - if not is_kv_cache_spec_uniform(kv_cache_spec): + if not (is_kv_cache_spec_uniform(kv_cache_spec) + or UniformTypeKVCacheSpecs.is_uniform_type(kv_cache_spec)): raise ValueError("Hybrid KV cache manager is disabled but failed to " "convert the KV cache specs to one unified type.") diff --git a/vllm/v1/core/single_type_kv_cache_manager.py b/vllm/v1/core/single_type_kv_cache_manager.py index d27239164b0d..58fe12aef0a9 100644 --- a/vllm/v1/core/single_type_kv_cache_manager.py +++ b/vllm/v1/core/single_type_kv_cache_manager.py @@ -9,6 +9,7 @@ from vllm.v1.core.kv_cache_utils import BlockHash, KVCacheBlock from vllm.v1.kv_cache_interface import (ChunkedLocalAttentionSpec, CrossAttentionSpec, FullAttentionSpec, + MLAAttentionSpec, KVCacheSpec, MambaSpec, SlidingWindowSpec) from vllm.v1.request import Request @@ -656,6 +657,7 @@ def remove_skipped_blocks(self, request_id: str, spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { FullAttentionSpec: FullAttentionManager, + MLAAttentionSpec: FullAttentionManager, SlidingWindowSpec: SlidingWindowManager, ChunkedLocalAttentionSpec: ChunkedLocalAttentionManager, MambaSpec: MambaManager, diff --git a/vllm/v1/kv_cache_interface.py b/vllm/v1/kv_cache_interface.py index f72cc8f93a6c..281816653540 100644 --- a/vllm/v1/kv_cache_interface.py +++ b/vllm/v1/kv_cache_interface.py @@ -59,13 +59,10 @@ class AttentionSpec(KVCacheSpec): num_kv_heads: int head_size: int dtype: torch.dtype - use_mla: bool @property def page_size_bytes(self) -> int: - # For MLA we only store a single latent vector - coef = 1 if self.use_mla else 2 - return coef * self.block_size * self.num_kv_heads * self.head_size \ + return 2 * self.block_size * self.num_kv_heads * self.head_size \ * get_dtype_size(self.dtype) @@ -118,12 +115,13 @@ def merge(cls, specs: list[Self]) -> Self: if spec.sliding_window is not None) attention_chunk_size = set(spec.attention_chunk_size for spec in specs if spec.attention_chunk_size is not None) + assert not any(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "MLAAttentionSpec should be merged in MLAAttentionSpec.merge") merged_spec = cls( block_size=specs[0].block_size, num_kv_heads=specs[0].num_kv_heads, head_size=specs[0].head_size, dtype=specs[0].dtype, - use_mla=specs[0].use_mla, sliding_window=cls.merge_window_sizes(sliding_window), attention_chunk_size=cls.merge_window_sizes(attention_chunk_size), ) @@ -140,6 +138,38 @@ def merge(cls, specs: list[Self]) -> Self: return merged_spec +@dataclass(frozen=True) +class MLAAttentionSpec(FullAttentionSpec): + # TODO(Lucas/Chen): less hacky way to do this + cache_dtype_str: Optional[str] = None + + @property + def page_size_bytes(self) -> int: + if self.cache_dtype_str == "fp8_ds_mla": + # See `vllm/v1/attention/backends/mla/flashmla_sparse.py` + # for details. + return self.block_size * 656 + return self.block_size * self.num_kv_heads * self.head_size \ + * get_dtype_size(self.dtype) + + @classmethod + def merge(cls, specs: list[Self]) -> Self: + assert all(isinstance(spec, MLAAttentionSpec) for spec in specs), ( + "All attention layers in the same KV cache group must be " + "MLAAttentionSpec.") + cache_dtype_str_set = set(spec.cache_dtype_str for spec in specs) + assert len(cache_dtype_str_set) == 1, ( + "All attention layers in the same KV cache group must use the same " + "quantization method.") + return cls( + block_size=specs[0].block_size, + num_kv_heads=specs[0].num_kv_heads, + head_size=specs[0].head_size, + dtype=specs[0].dtype, + cache_dtype_str=cache_dtype_str_set.pop(), + ) + + @dataclass(frozen=True) class ChunkedLocalAttentionSpec(AttentionSpec): attention_chunk_size: int @@ -163,9 +193,6 @@ def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: class SlidingWindowSpec(AttentionSpec): sliding_window: int - def __post_init__(self): - assert not self.use_mla, "MLA is not supported for sliding window" - def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: assert vllm_config.parallel_config.decode_context_parallel_size == 1, \ "DCP not support sliding window." @@ -266,9 +293,13 @@ def is_uniform_type(cls, kv_cache_specs: dict[str, KVCacheSpec]) -> bool: # Different block sizes, not uniform. return False one_spec = next(iter(kv_cache_specs.values())) - if isinstance(one_spec, (FullAttentionSpec, CrossAttentionSpec)): + if isinstance(one_spec, FullAttentionSpec): + return all( + isinstance(spec, FullAttentionSpec) + for spec in kv_cache_specs.values()) + elif isinstance(one_spec, CrossAttentionSpec): return all( - isinstance(spec, type(one_spec)) + isinstance(spec, CrossAttentionSpec) for spec in kv_cache_specs.values()) elif isinstance(one_spec, SlidingWindowSpec): return all( diff --git a/vllm/v1/spec_decode/eagle.py b/vllm/v1/spec_decode/eagle.py index 2a178ddf4877..f7a5dd20df97 100644 --- a/vllm/v1/spec_decode/eagle.py +++ b/vllm/v1/spec_decode/eagle.py @@ -16,6 +16,7 @@ from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.platforms import current_platform @@ -62,6 +63,7 @@ def __init__( self.method = self.speculative_config.method self.runner = runner + self.device = device self.dtype = vllm_config.model_config.dtype self.max_model_len = vllm_config.model_config.max_model_len self.block_size = vllm_config.cache_config.block_size @@ -197,12 +199,26 @@ def propose( self.runner.attn_groups[0][0].metadata_builders[ubatch_id] attn_metadata = attn_metadata_builder.build_for_drafting( common_attn_metadata=common_attn_metadata, draft_index=0) - + # FIXME: support hybrid kv for draft model (remove separate indexer) + if self.draft_indexer_metadata_builder: + draft_indexer_metadata = ( + self.draft_indexer_metadata_builder + .build_for_drafting( + common_attn_metadata=common_attn_metadata, + draft_index=0, + ) + ) + else: + draft_indexer_metadata = None # At this moment, we assume all eagle layers belong to the same KV # cache group, thus using the same attention metadata. per_layer_attn_metadata = {} for layer_name in self.attn_layer_names: per_layer_attn_metadata[layer_name] = attn_metadata + for layer_name in self.indexer_layer_names: + assert draft_indexer_metadata is not None + per_layer_attn_metadata[layer_name] = draft_indexer_metadata + if self.use_cuda_graph and \ num_tokens <= self.cudagraph_batch_sizes[-1]: num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens) @@ -806,6 +822,10 @@ def load_model(self, target_model: nn.Module) -> None: self.vllm_config.speculative_config.draft_model_config target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, Attention).keys()) + # FIXME: support hybrid kv for draft model + target_indexer_layer_names = set( + get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache).keys()) from vllm.compilation.backends import set_model_tag with set_model_tag("eagle_head"): @@ -815,8 +835,25 @@ def load_model(self, target_model: nn.Module) -> None: draft_attn_layer_names = ( get_layers_from_vllm_config(self.vllm_config, Attention).keys() - target_attn_layer_names) - + indexer_layers = get_layers_from_vllm_config(self.vllm_config, DeepseekV32IndexerCache) + draft_indexer_layer_names = (indexer_layers.keys() - target_indexer_layer_names) self.attn_layer_names = list(draft_attn_layer_names) + self.indexer_layer_names = list(draft_indexer_layer_names) + + if self.indexer_layer_names: + first_layer = self.indexer_layer_names[0] + self.draft_indexer_metadata_builder = ( + indexer_layers[first_layer] + .get_attn_backend() + .get_builder_cls()( + indexer_layers[first_layer].get_kv_cache_spec(), + self.indexer_layer_names, + self.vllm_config, + self.device, + ) + ) + else: + self.draft_indexer_metadata_builder = None if supports_multimodal(target_model): # handle multimodality diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d0946e8c5d7d..ef60866c074b 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -40,6 +40,7 @@ from vllm.model_executor.layers.mamba.abstract import MambaBase from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding from vllm.model_executor.model_loader import TensorizerLoader, get_model_loader +from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.interfaces import (is_mixture_of_experts, supports_eagle3, supports_mrope, @@ -74,7 +75,8 @@ EncoderOnlyAttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, KVCacheSpec, - MambaSpec, SlidingWindowSpec, + MambaSpec, MLAAttentionSpec, + SlidingWindowSpec, UniformTypeKVCacheSpecs) # yapf: enable from vllm.v1.outputs import (EMPTY_MODEL_RUNNER_OUTPUT, AsyncModelRunnerOutput, @@ -2995,7 +2997,7 @@ def _dummy_run( attn_metadata_i = (attn_group\ .get_metadata_builder(ubatch_id=ubid)\ .build_for_cudagraph_capture(common_attn_metadata)) - for layer_name in kv_cache_group_spec.layer_names: + for layer_name in attn_group.layer_names: assert type(attn_metadata) is list attn_metadata[ubid][ layer_name] = attn_metadata_i @@ -3003,7 +3005,7 @@ def _dummy_run( assert type(attn_metadata) is dict attn_metadata_i = attn_group.get_metadata_builder()\ .build_for_cudagraph_capture(common_attn_metadata) - for layer_name in kv_cache_group_spec.layer_names: + for layer_name in attn_group.layer_names: attn_metadata[layer_name] = attn_metadata_i with self.maybe_dummy_run_with_lora(self.lora_config, @@ -3725,8 +3727,11 @@ def _reshape_kv_cache_tensors( if isinstance(kv_cache_spec, AttentionSpec): has_attn = True kv_cache_shape = attn_backend.get_kv_cache_shape( - num_blocks, kv_cache_spec.block_size, - kv_cache_spec.num_kv_heads, kv_cache_spec.head_size) + num_blocks, + kv_cache_spec.block_size, + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype) dtype = kv_cache_spec.dtype try: kv_cache_stride_order = \ @@ -3910,7 +3915,6 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: Add encoder-only layers to the KV cache config. """ block_size = self.vllm_config.cache_config.block_size - use_mla = self.vllm_config.model_config.use_mla encoder_only_attn_specs: dict[AttentionSpec, list[str]] = defaultdict(list) attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) @@ -3920,8 +3924,7 @@ def may_add_encoder_only_layers_to_kv_cache_config(self) -> None: block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) encoder_only_attn_specs[attn_spec].append(layer_name) self.runner_only_attn_layers.add(layer_name) if len(encoder_only_attn_specs) > 0: @@ -3943,6 +3946,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: block_size = self.vllm_config.cache_config.block_size use_mla = self.vllm_config.model_config.use_mla + cache_dtype_str = self.vllm_config.cache_config.cache_dtype kv_cache_spec: dict[str, KVCacheSpec] = {} attn_layers = get_layers_from_vllm_config(self.vllm_config, Attention) for layer_name, attn_module in attn_layers.items(): @@ -3962,13 +3966,21 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: # the attention backends if attn_module.attn_type == AttentionType.DECODER: if attn_module.sliding_window is not None: + assert not use_mla, "MLA is not supported for sliding" \ + "window" kv_cache_spec[layer_name] = SlidingWindowSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - sliding_window=attn_module.sliding_window, - use_mla=use_mla) + sliding_window=attn_module.sliding_window) + elif use_mla: + kv_cache_spec[layer_name] = MLAAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=self.kv_cache_dtype, + cache_dtype_str=cache_dtype_str) elif self.attention_chunk_size is not None \ and isinstance(attn_module, ChunkedLocalAttention): kv_cache_spec[layer_name] = ChunkedLocalAttentionSpec( @@ -3976,22 +3988,19 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - attention_chunk_size=self.attention_chunk_size, - use_mla=use_mla) + attention_chunk_size=self.attention_chunk_size) else: kv_cache_spec[layer_name] = FullAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) elif attn_module.attn_type == AttentionType.ENCODER_DECODER: kv_cache_spec[layer_name] = CrossAttentionSpec( block_size=block_size, num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, - dtype=self.kv_cache_dtype, - use_mla=use_mla) + dtype=self.kv_cache_dtype) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): # encoder-only attention does not need KV cache. @@ -4028,6 +4037,10 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: self.speculative_config.num_speculative_tokens if self.speculative_config else 0), ) + ds_indexer_layers = get_layers_from_vllm_config( + self.vllm_config, DeepseekV32IndexerCache) + for layer_name, ds_indexer_module in ds_indexer_layers.items(): + kv_cache_spec[layer_name] = ds_indexer_module.get_kv_cache_spec() return kv_cache_spec diff --git a/vllm/v1/worker/tpu_model_runner.py b/vllm/v1/worker/tpu_model_runner.py index 48070c1e3e7c..c7d6dcd77b2c 100644 --- a/vllm/v1/worker/tpu_model_runner.py +++ b/vllm/v1/worker/tpu_model_runner.py @@ -528,7 +528,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: head_size=attn_module.head_size, dtype=self.kv_cache_dtype, sliding_window=attn_module.sliding_window, - use_mla=False, ) else: kv_cache_spec[layer_name] = FullAttentionSpec( @@ -536,7 +535,6 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]: num_kv_heads=attn_module.num_kv_heads, head_size=attn_module.head_size, dtype=self.kv_cache_dtype, - use_mla=False, ) elif attn_module.attn_type in (AttentionType.ENCODER, AttentionType.ENCODER_ONLY): diff --git a/vllm/worker/cache_engine.py b/vllm/worker/cache_engine.py index 530907012f70..6349f0e97592 100644 --- a/vllm/worker/cache_engine.py +++ b/vllm/worker/cache_engine.py @@ -73,7 +73,8 @@ def _allocate_kv_cache( ) -> List[torch.Tensor]: """Allocates KV cache on the specified device.""" kv_cache_generic_shape = self.attn_backend.get_kv_cache_shape( - num_blocks, self.block_size, self.num_kv_heads, self.head_size) + num_blocks, self.block_size, self.num_kv_heads, self.head_size, + self.cache_config.cache_dtype) pin_memory = is_pin_memory_available() if device == "cpu" else False kv_cache: List[torch.Tensor] = [] try: