diff --git a/python/tvm/relax/frontend/nn/llm/kv_cache.py b/python/tvm/relax/frontend/nn/llm/kv_cache.py index f5ff0105d0f2..ea6f15331654 100644 --- a/python/tvm/relax/frontend/nn/llm/kv_cache.py +++ b/python/tvm/relax/frontend/nn/llm/kv_cache.py @@ -180,6 +180,49 @@ def mla_absorbed( ) ).reshape(b, s, h_qo, kv_lora_rank) + def mla_normal( + self, + layer_id: int, + q: Tensor, + k: Tensor, + v: Tensor, + compressed_kv: Tensor, + k_pe: Tensor, + attn_score_scaling_factor: float = 1.0, + ) -> Tensor: + """Compute multi-head latent attention with the given data + on the specified layer using the normal flow(WITHOUT weight absorption). + """ + # pylint: disable=protected-access + b, s, h_qo, d_qk = q._expr.struct_info.shape + d_v = v._expr.struct_info.shape[3] + kv_lora_rank = compressed_kv._expr.struct_info.shape[3] + qk_rope_head_dim = k_pe._expr.struct_info.shape[3] + q = q.reshape(b * s, h_qo, d_qk) + k = k.reshape(b * s, h_qo, d_qk) + v = v.reshape(b * s, h_qo, d_v) + compressed_kv = compressed_kv.reshape(b * s, kv_lora_rank) + k_pe = k_pe.reshape(b * s, qk_rope_head_dim) + + return Tensor( + _expr=rx.BlockBuilder.current().emit( + rx.call_dps_packed( + "vm.builtin.attention_kv_cache_mla_normal", + [ + self._expr, + rx.PrimValue(layer_id), # type: ignore[arg-type] + rx.PrimValue(attn_score_scaling_factor), + q._expr, + k._expr, + v._expr, + compressed_kv._expr, + k_pe._expr, + ], + out_sinfo=rx.TensorStructInfo((b * s, h_qo, d_v), q.dtype), + ) + ) + ).reshape(b, s, h_qo, d_v) + def get_query_positions(self, total_length: tir.PrimExpr) -> Tensor: """Get the in-sequence positions of each slot in the query, which are needed for applying positional embeddings in some models. @@ -591,7 +634,7 @@ def create_mla_kv_cache( # pylint: disable=too-many-locals rx.PrimValue(0), bb.add_func(_attention_prefill_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_prefill_mla"), bb.add_func(_attention_decode_mla(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, False, target), "tir_attention_decode_mla"), - bb.add_func(_attention_prefill_ragged(num_key_value_heads, num_attention_heads, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"), + bb.add_func(_attention_prefill_ragged_generic(num_key_value_heads, num_attention_heads, qk_rope_head_dim, v_head_dim, dtype, {}, target), "tir_attention_prefill_ragged_mla_normal"), bb.add_func(_attention_prefill_ragged_mla_absorbed(num_attention_heads, kv_lora_rank, qk_rope_head_dim, dtype, target), "tir_attention_prefill_ragged_mla_absorbed"), bb.add_func(_merge_state_inplace(num_attention_heads, kv_lora_rank, dtype, target), "tir_attention_merge_state"), bb.add_func(llama_rope_with_position_map(10000, 1, qk_rope_head_dim, num_attention_heads, num_key_value_heads, dtype, {}, None), "tir_split_rotary"), @@ -2420,6 +2463,12 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target): + return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype, rope_scaling, target) + + +def _attention_prefill_ragged_generic( + h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target +): # pylint: disable=line-too-long ( NUM_BLKS, @@ -2431,7 +2480,7 @@ def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], tile_x, tile_y, tile_z, - ) = _get_prefill_kernel_config(h_kv, h_q, d, dtype, target) + ) = _get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target) # fmt: off @T.prim_func @@ -2459,14 +2508,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches q_rope_position_elem_offset = T.int32(is_size_var=True) k_rope_pos_offset_elem_offset = T.int32(is_size_var=True) - q = T.match_buffer(var_q, (qo_len, h_q, d), dtype) + q = T.match_buffer(var_q, (qo_len, h_q, d_qk), dtype) q_indptr = T.match_buffer(var_q_indptr, (batch_size + 1,), "int32", elem_offset=q_indptr_elem_offset) - k = T.match_buffer(var_k, (kv_len, h_kv, d), dtype) - v = T.match_buffer(var_v, (kv_len, h_kv, d), dtype) + k = T.match_buffer(var_k, (kv_len, h_kv, d_qk), dtype) + v = T.match_buffer(var_v, (kv_len, h_kv, d_v), dtype) kv_indptr = T.match_buffer(var_kv_indptr, (batch_size + 1,), "int32", elem_offset=kv_indptr_elem_offset) q_rope_position = T.match_buffer(var_q_rope_position, (qo_len,), "int32", elem_offset=q_rope_position_elem_offset) k_rope_pos_offset = T.match_buffer(var_k_rope_pos_offset, (batch_size,), "int32", elem_offset=k_rope_pos_offset_elem_offset) - output = T.match_buffer(var_output, (qo_len, h_q, d), dtype) + output = T.match_buffer(var_output, (qo_len, h_q, d_v), dtype) lse = T.match_buffer(var_lse, (qo_len, h_q), "float32") # pylint: disable=unused-variable # kernel code @@ -2485,13 +2534,13 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches iterator = _var("int32") kv_chunk_len = _var("int32") - Q_smem = T.alloc_buffer((tile_x, d), dtype, scope="shared") - K_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") - V_smem = T.alloc_buffer((tile_z, d), dtype, scope="shared") + Q_smem = T.alloc_buffer((tile_x, d_qk), dtype, scope="shared") + K_smem = T.alloc_buffer((tile_z, d_qk), dtype, scope="shared") + V_smem = T.alloc_buffer((tile_z, d_v), dtype, scope="shared") S_smem = T.alloc_buffer((tile_x, tile_z), "float32", scope="shared") S_local = T.alloc_buffer((tile_x, tile_z), "float32", scope="local") - O_local = T.alloc_buffer((tile_x, d), "float32", scope="local") + O_local = T.alloc_buffer((tile_x, d_v), "float32", scope="local") m_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") m_prev_smem = T.alloc_buffer((tile_x, ), "float32", scope="shared") @@ -2548,7 +2597,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches if cur_L < q_indptr[b_idx + 1]: Q_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(q, q_rope_position[cur_L], d, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling), + _rope(q, q_rope_position[cur_L], d_qk, rope_theta, rope_scale, (cur_L, cur_H_qo, j), dtype, rope_scaling), q[cur_L, cur_H_qo, j] ) else: @@ -2565,7 +2614,7 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches if cur_L < kv_chunk_len[0]: K_smem[i, j] = T.if_then_else( rotary_mode == 1, - _rope(k, k_rope_pos_offset[b_idx] + cur_L, d, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype, rope_scaling), + _rope(k, k_rope_pos_offset[b_idx] + cur_L, d_qk, rope_theta, rope_scale, (L_kv_base + cur_L, by, j), dtype, rope_scaling), k[L_kv_base + cur_L, by, j] ) else: diff --git a/src/runtime/relax_vm/kv_state.cc b/src/runtime/relax_vm/kv_state.cc index c78ada58e6d6..1b1867f06093 100644 --- a/src/runtime/relax_vm/kv_state.cc +++ b/src/runtime/relax_vm/kv_state.cc @@ -90,6 +90,16 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed") std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor); }); +TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal") + .set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id, + double attn_score_scaling_factor, NDArray q_data, NDArray k_data, + NDArray v_data, NDArray compressed_kv_data, NDArray k_pe_data, + NDArray o_data) { + kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), + std::move(compressed_kv_data), std::move(k_pe_data), std::move(o_data), + attn_score_scaling_factor); + }); + // RNN State methods TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_get").set_body_method(&RNNStateObj::Get); TVM_REGISTER_GLOBAL("vm.builtin.rnn_state_set") diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 075ff0b94471..a936f429eeec 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -2241,7 +2241,82 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void MLANormal(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data, NDArray compressed_kv_data, NDArray k_pe_data, NDArray o_data, double attn_score_scaling_factor) { - // Todo(ruihang): implement it + // Part 1: Basic Checks and Setup. + int64_t local_layer_id = layer_id - layer_id_begin_offset_; + CHECK_GE(local_layer_id, 0); + CHECK_LT(local_layer_id, num_layers_); + NDArray pages = pages_[local_layer_id]; + CHECK(q_data.DataType() == pages.DataType()); + CHECK(k_data.DataType() == pages.DataType()); + CHECK(v_data.DataType() == pages.DataType()); + CHECK(compressed_kv_data.DataType() == pages.DataType()); + CHECK(k_pe_data.DataType() == pages.DataType()); + CHECK(o_data.DataType() == pages.DataType()); + CHECK(attn_kinds_[layer_id] == AttnKind::kMLA); + + // Expected shapes: + // q_data: (num_total_length, num_qo_heads, qk_head_dim) + // k_data: (num_total_length, num_qo_heads, qk_head_dim) + // v_data: (num_total_length, num_qo_heads, v_head_dim) + // compressed_kv_data: (num_total_length, qk_head_dim - qk_rope_head_dim) + // k_pe_data: (num_total_length, qk_rope_head_dim) + // o_data: (num_total_length, num_qo_heads, v_head_dim) + CHECK_EQ(q_data->ndim, 3); + CHECK_EQ(k_data->ndim, 3); + CHECK_EQ(v_data->ndim, 3); + CHECK_EQ(compressed_kv_data->ndim, 2); + CHECK_EQ(k_pe_data->ndim, 2); + CHECK_EQ(o_data->ndim, 3); + + int64_t total_seq_length = 0; + for (int64_t i = 0; i < cur_batch_size_; ++i) { + total_seq_length += cur_append_lengths_[i]; + } + CHECK_LE(q_data->shape[0], total_seq_length); + CHECK_LE(k_data->shape[0], total_seq_length); + CHECK_LE(v_data->shape[0], total_seq_length); + CHECK_LE(compressed_kv_data->shape[0], total_seq_length); + CHECK_LE(k_pe_data->shape[0], total_seq_length); + CHECK_EQ(k_pe_data->shape[1], qk_rope_head_dim_); + CHECK_LE(o_data->shape[0], total_seq_length); + CHECK_EQ(q_data->shape[1], num_qo_heads_); + CHECK_EQ(o_data->shape[1], num_qo_heads_); + CHECK_EQ(k_data->shape[1], num_qo_heads_); + CHECK_EQ(v_data->shape[1], num_qo_heads_); + CHECK_EQ(q_data->shape[2], qk_head_dim_); + CHECK_EQ(k_data->shape[2], qk_head_dim_); + CHECK_EQ(v_data->shape[2], v_head_dim_); + CHECK_EQ(o_data->shape[2], v_head_dim_); + + // Part 2: Synchronize streams and update auxiliary data. + ComputeStreamWaitForCopyStream(); + ICHECK(!dirty_aux_data_device_); + + // Append k/v data to kv-cache if flag "append_before_attn" is set. + if (append_before_attn_) { + f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, + append_position_map_view_); + } + + // Part 4: Call the ragged kernel. + // Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode + // and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]` + // to adjust parameters; here we assume the kernel internally supports both cases. + f_mla_prefill_ragged_normal_(q_data, cur_append_length_indptr_view_, k_data, v_data, + cur_append_length_indptr_view_, q_rope_position_map_view_, + k_ragged_rope_pos_offset_view_, + o_data, // output tensor + merged_attn_scores_view_, + /*causal=*/1, static_cast(RoPEMode::kNone), + 0, // Rope param, not important + 0, // Rope param, not important + attn_score_scaling_factor); + + // Part 5: If appending is to occur after attention, call the append kernel. + if (!append_before_attn_) { + f_transpose_append_mla_(pages_[local_layer_id], compressed_kv_data, k_pe_data, + append_position_map_view_); + } } void LinearAttention(int64_t layer_id, NDArray q_data, NDArray k_data, NDArray v_data,