Skip to content

Commit f5c9163

Browse files
youzhedianhongchaoyoukaichao
authored andcommitted
[Feature] Support Decode Context Parallel (DCP) for MLA (vllm-project#23734)
Signed-off-by: hongchao <[email protected]> Signed-off-by: youkaichao <[email protected]> Co-authored-by: hongchao <[email protected]> Co-authored-by: youkaichao <[email protected]> Signed-off-by: xuebwang-amd <[email protected]>
1 parent d53ceee commit f5c9163

27 files changed

+1000
-231
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ steps:
837837
- pytest -v -s models/test_oot_registration.py # it needs a clean process
838838
- pytest -v -s plugins/lora_resolvers # unit tests for in-tree lora resolver plugins
839839

840-
- label: Pipeline Parallelism Test # 45min
840+
- label: Pipeline + Context Parallelism Test # 45min
841841
timeout_in_minutes: 60
842842
mirror_hardwares: [amdexperimental]
843843
working_dir: "/vllm-workspace/tests"
@@ -851,6 +851,7 @@ steps:
851851
commands:
852852
- pytest -v -s distributed/test_pp_cudagraph.py
853853
- pytest -v -s distributed/test_pipeline_parallel.py
854+
# - pytest -v -s distributed/test_context_parallel.py # TODO: enable it on Hopper runners or add triton MLA support
854855

855856
- label: LoRA TP Test (Distributed) # 17 min
856857
timeout_in_minutes: 30

csrc/cache.h

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@ void concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
3636
const std::string& kv_cache_dtype,
3737
torch::Tensor& scale);
3838

39-
void cp_fused_concat_and_cache_mla(torch::Tensor& kv_c, torch::Tensor& k_pe,
40-
torch::Tensor& cp_local_token_select_indices,
41-
torch::Tensor& kv_cache,
42-
torch::Tensor& slot_mapping,
43-
const std::string& kv_cache_dtype,
44-
torch::Tensor& scale);
45-
4639
// Just for unittest
4740
void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache,
4841
const double scale, const std::string& kv_cache_dtype);

csrc/cache_kernels.cu

Lines changed: 0 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
396396
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
397397
}
398398

399-
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
400-
__global__ void cp_fused_concat_and_cache_mla_kernel(
401-
const scalar_t* __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
402-
const scalar_t* __restrict__ k_pe, // [num_full_tokens, pe_dim]
403-
const int64_t* __restrict__ cp_local_token_select_indices, // [num_tokens]
404-
cache_t* __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
405-
// + pe_dim)]
406-
const int64_t* __restrict__ slot_mapping, // [num_tokens]
407-
const int block_stride, //
408-
const int entry_stride, //
409-
const int kv_c_stride, //
410-
const int k_pe_stride, //
411-
const int kv_lora_rank, //
412-
const int pe_dim, //
413-
const int block_size, //
414-
const float* scale //
415-
) {
416-
const int64_t token_idx = cp_local_token_select_indices[blockIdx.x];
417-
const int64_t slot_idx = slot_mapping[blockIdx.x];
418-
// NOTE: slot_idx can be -1 if the token is padded
419-
if (slot_idx < 0) {
420-
return;
421-
}
422-
const int64_t block_idx = slot_idx / block_size;
423-
const int64_t block_offset = slot_idx % block_size;
424-
425-
auto copy = [&](const scalar_t* __restrict__ src, cache_t* __restrict__ dst,
426-
int src_stride, int dst_stride, int size, int offset) {
427-
for (int i = threadIdx.x; i < size; i += blockDim.x) {
428-
const int64_t src_idx = token_idx * src_stride + i;
429-
const int64_t dst_idx =
430-
block_idx * block_stride + block_offset * entry_stride + i + offset;
431-
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
432-
dst[dst_idx] = src[src_idx];
433-
} else {
434-
dst[dst_idx] =
435-
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(src[src_idx], *scale);
436-
}
437-
}
438-
};
439-
440-
copy(kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0);
441-
copy(k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
442-
}
443-
444399
} // namespace vllm
445400

446401
// KV_T is the data type of key and value tensors.
@@ -554,20 +509,6 @@ void reshape_and_cache_flash(
554509
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
555510
reinterpret_cast<const float*>(scale.data_ptr()));
556511

557-
// KV_T is the data type of key and value tensors.
558-
// CACHE_T is the stored data type of kv-cache.
559-
// KV_DTYPE is the real data type of kv-cache.
560-
#define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA(KV_T, CACHE_T, KV_DTYPE) \
561-
vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
562-
<<<grid, block, 0, stream>>>( \
563-
reinterpret_cast<KV_T*>(kv_c.data_ptr()), \
564-
reinterpret_cast<KV_T*>(k_pe.data_ptr()), \
565-
cp_local_token_select_indices.data_ptr<int64_t>(), \
566-
reinterpret_cast<CACHE_T*>(kv_cache.data_ptr()), \
567-
slot_mapping.data_ptr<int64_t>(), block_stride, entry_stride, \
568-
kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
569-
reinterpret_cast<const float*>(scale.data_ptr()));
570-
571512
void concat_and_cache_mla(
572513
torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
573514
torch::Tensor& k_pe, // [num_tokens, pe_dim]
@@ -606,50 +547,6 @@ void concat_and_cache_mla(
606547
CALL_CONCAT_AND_CACHE_MLA);
607548
}
608549

609-
// Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
610-
// calls into one:
611-
// k_c_normed.index_select(0, cp_local_token_select_indices) + \
612-
// k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
613-
// concat_and_cache_mla.
614-
void cp_fused_concat_and_cache_mla(
615-
torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
616-
torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
617-
torch::Tensor& cp_local_token_select_indices, // [num_tokens]
618-
torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
619-
// pe_dim)]
620-
torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
621-
const std::string& kv_cache_dtype, torch::Tensor& scale) {
622-
// NOTE(woosuk): In vLLM V1, key.size(0) can be different from
623-
// slot_mapping.size(0) because of padding for CUDA graphs.
624-
// In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
625-
// both include padding.
626-
// In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
627-
// since key includes padding for CUDA graphs, while slot_mapping does not.
628-
// In this case, slot_mapping.size(0) represents the actual number of tokens
629-
// before padding.
630-
// For compatibility with both cases, we use slot_mapping.size(0) as the
631-
// number of tokens.
632-
int num_tokens = slot_mapping.size(0);
633-
int kv_lora_rank = kv_c.size(1);
634-
int pe_dim = k_pe.size(1);
635-
int block_size = kv_cache.size(1);
636-
637-
TORCH_CHECK(kv_cache.size(2) == kv_lora_rank + pe_dim);
638-
639-
int kv_c_stride = kv_c.stride(0);
640-
int k_pe_stride = k_pe.stride(0);
641-
int block_stride = kv_cache.stride(0);
642-
int entry_stride = kv_cache.stride(1);
643-
644-
dim3 grid(num_tokens);
645-
dim3 block(std::min(kv_lora_rank, 512));
646-
const at::cuda::OptionalCUDAGuard device_guard(device_of(kv_c));
647-
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
648-
649-
DISPATCH_BY_KV_CACHE_DTYPE(kv_c.dtype(), kv_cache_dtype,
650-
CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
651-
}
652-
653550
namespace vllm {
654551

655552
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>

csrc/torch_bindings.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -693,16 +693,6 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
693693
" Tensor scale) -> ()");
694694
cache_ops.impl("concat_and_cache_mla", torch::kCUDA, &concat_and_cache_mla);
695695

696-
cache_ops.def(
697-
"cp_fused_concat_and_cache_mla(Tensor kv_c, Tensor k_pe,"
698-
" Tensor cp_local_token_select_indices,"
699-
" Tensor! kv_cache,"
700-
" Tensor slot_mapping,"
701-
" str kv_cache_dtype,"
702-
" Tensor scale) -> ()");
703-
cache_ops.impl("cp_fused_concat_and_cache_mla", torch::kCUDA,
704-
&cp_fused_concat_and_cache_mla);
705-
706696
// Convert the key and value cache to fp8 data type.
707697
cache_ops.def(
708698
"convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, "

0 commit comments

Comments
 (0)