@@ -396,51 +396,6 @@ __global__ void concat_and_cache_mla_kernel(
396396 copy (k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
397397}
398398
399- template <typename scalar_t , typename cache_t , Fp8KVCacheDataType kv_dt>
400- __global__ void cp_fused_concat_and_cache_mla_kernel (
401- const scalar_t * __restrict__ kv_c, // [num_full_tokens, kv_lora_rank]
402- const scalar_t * __restrict__ k_pe, // [num_full_tokens, pe_dim]
403- const int64_t * __restrict__ cp_local_token_select_indices, // [num_tokens]
404- cache_t * __restrict__ kv_cache, // [num_blocks, block_size, (kv_lora_rank
405- // + pe_dim)]
406- const int64_t * __restrict__ slot_mapping, // [num_tokens]
407- const int block_stride, //
408- const int entry_stride, //
409- const int kv_c_stride, //
410- const int k_pe_stride, //
411- const int kv_lora_rank, //
412- const int pe_dim, //
413- const int block_size, //
414- const float * scale //
415- ) {
416- const int64_t token_idx = cp_local_token_select_indices[blockIdx .x ];
417- const int64_t slot_idx = slot_mapping[blockIdx .x ];
418- // NOTE: slot_idx can be -1 if the token is padded
419- if (slot_idx < 0 ) {
420- return ;
421- }
422- const int64_t block_idx = slot_idx / block_size;
423- const int64_t block_offset = slot_idx % block_size;
424-
425- auto copy = [&](const scalar_t * __restrict__ src, cache_t * __restrict__ dst,
426- int src_stride, int dst_stride, int size, int offset) {
427- for (int i = threadIdx .x ; i < size; i += blockDim .x ) {
428- const int64_t src_idx = token_idx * src_stride + i;
429- const int64_t dst_idx =
430- block_idx * block_stride + block_offset * entry_stride + i + offset;
431- if constexpr (kv_dt == Fp8KVCacheDataType::kAuto ) {
432- dst[dst_idx] = src[src_idx];
433- } else {
434- dst[dst_idx] =
435- fp8::scaled_convert<cache_t , scalar_t , kv_dt>(src[src_idx], *scale);
436- }
437- }
438- };
439-
440- copy (kv_c, kv_cache, kv_c_stride, block_stride, kv_lora_rank, 0 );
441- copy (k_pe, kv_cache, k_pe_stride, block_stride, pe_dim, kv_lora_rank);
442- }
443-
444399} // namespace vllm
445400
446401// KV_T is the data type of key and value tensors.
@@ -554,20 +509,6 @@ void reshape_and_cache_flash(
554509 kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
555510 reinterpret_cast <const float *>(scale.data_ptr()));
556511
557- // KV_T is the data type of key and value tensors.
558- // CACHE_T is the stored data type of kv-cache.
559- // KV_DTYPE is the real data type of kv-cache.
560- #define CALL_CP_FUSED_CONCAT_AND_CACHE_MLA (KV_T, CACHE_T, KV_DTYPE ) \
561- vllm::cp_fused_concat_and_cache_mla_kernel<KV_T, CACHE_T, KV_DTYPE> \
562- <<<grid, block, 0 , stream>>> ( \
563- reinterpret_cast <KV_T*>(kv_c.data_ptr()), \
564- reinterpret_cast <KV_T*>(k_pe.data_ptr()), \
565- cp_local_token_select_indices.data_ptr<int64_t >(), \
566- reinterpret_cast <CACHE_T*>(kv_cache.data_ptr()), \
567- slot_mapping.data_ptr<int64_t >(), block_stride, entry_stride, \
568- kv_c_stride, k_pe_stride, kv_lora_rank, pe_dim, block_size, \
569- reinterpret_cast <const float *>(scale.data_ptr()));
570-
571512void concat_and_cache_mla (
572513 torch::Tensor& kv_c, // [num_tokens, kv_lora_rank]
573514 torch::Tensor& k_pe, // [num_tokens, pe_dim]
@@ -606,50 +547,6 @@ void concat_and_cache_mla(
606547 CALL_CONCAT_AND_CACHE_MLA);
607548}
608549
609- // Note(hc): cp_fused_concat_and_cache_mla fuses the following three kernel
610- // calls into one:
611- // k_c_normed.index_select(0, cp_local_token_select_indices) + \
612- // k_pe.squeeze(1).index_select(0, cp_local_token_select_indices) + \
613- // concat_and_cache_mla.
614- void cp_fused_concat_and_cache_mla (
615- torch::Tensor& kv_c, // [num_total_tokens, kv_lora_rank]
616- torch::Tensor& k_pe, // [num_total_tokens, pe_dim]
617- torch::Tensor& cp_local_token_select_indices, // [num_tokens]
618- torch::Tensor& kv_cache, // [num_blocks, block_size, (kv_lora_rank +
619- // pe_dim)]
620- torch::Tensor& slot_mapping, // [num_tokens] or [num_actual_tokens]
621- const std::string& kv_cache_dtype, torch::Tensor& scale) {
622- // NOTE(woosuk): In vLLM V1, key.size(0) can be different from
623- // slot_mapping.size(0) because of padding for CUDA graphs.
624- // In vLLM V0, key.size(0) is always equal to slot_mapping.size(0) because
625- // both include padding.
626- // In vLLM V1, however, key.size(0) can be larger than slot_mapping.size(0)
627- // since key includes padding for CUDA graphs, while slot_mapping does not.
628- // In this case, slot_mapping.size(0) represents the actual number of tokens
629- // before padding.
630- // For compatibility with both cases, we use slot_mapping.size(0) as the
631- // number of tokens.
632- int num_tokens = slot_mapping.size (0 );
633- int kv_lora_rank = kv_c.size (1 );
634- int pe_dim = k_pe.size (1 );
635- int block_size = kv_cache.size (1 );
636-
637- TORCH_CHECK (kv_cache.size (2 ) == kv_lora_rank + pe_dim);
638-
639- int kv_c_stride = kv_c.stride (0 );
640- int k_pe_stride = k_pe.stride (0 );
641- int block_stride = kv_cache.stride (0 );
642- int entry_stride = kv_cache.stride (1 );
643-
644- dim3 grid (num_tokens);
645- dim3 block (std::min (kv_lora_rank, 512 ));
646- const at::cuda::OptionalCUDAGuard device_guard (device_of (kv_c));
647- const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
648-
649- DISPATCH_BY_KV_CACHE_DTYPE (kv_c.dtype (), kv_cache_dtype,
650- CALL_CP_FUSED_CONCAT_AND_CACHE_MLA);
651- }
652-
653550namespace vllm {
654551
655552template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
0 commit comments