@@ -181,20 +181,6 @@ class AttentionKVCacheObj : public KVStateObj {
181181 virtual void AttentionWithFusedQKV (int64_t layer_id, NDArray qkv_data, Optional<NDArray> mask,
182182 NDArray o_data, double attn_score_scaling_factor) = 0;
183183
184- /* !
185- * \brief Compute attention with Q/K/V data.
186- * \param layer_id The model layer where the attention compute happens.
187- * \param q_data The input Q data, in layout `(total_length, num_qo_heads, head_dim)`
188- * \param k_data The input K data, in layout `(total_length, num_kv_heads, head_dim)`
189- * \param v_data The input V data, in layout `(total_length, num_kv_heads, head_dim)`
190- * \param mask The input mask data, in layout `(total_sqr_length)`.
191- * \param o_data The output O data, in layout `(total_length, num_qo_heads, head_dim)`.
192- * \param attn_score_scaling_factor The additional attention scaling factor.
193- */
194- virtual void AttentionWithSeparateQKV (int64_t layer_id, NDArray q_data, NDArray k_data,
195- NDArray v_data, Optional<NDArray> mask, NDArray o_data,
196- double attn_score_scaling_factor) = 0;
197-
198184 /* !
199185 * \brief Compute multi-head latent attention after applying weight absorption.
200186 * \param layer_id The model layer where the attention compute happens.
@@ -275,6 +261,16 @@ class AttentionKVCacheObj : public KVStateObj {
275261 virtual void DebugGetKV (int64_t seq_id, //
276262 int64_t start_pos, int64_t end_pos, NDArray k_data, NDArray v_data) = 0;
277263
264+ /* !
265+ * \brief Fetch the compact K/V data of the given sequence for MLA cache.
266+ * \param seq_id The sequence whose K/V data is to be fetched.
267+ * \param start_pos The start position (inclusive) of the K/V data to fetch.
268+ * \param end_pos The end position (exclusive) of the K/V data to fetch.
269+ * \param kv_data The output KV data of the given sequence in layout elaborated above.
270+ */
271+ virtual void DebugGetKVMLA (int64_t seq_id, int64_t start_pos, int64_t end_pos,
272+ NDArray kv_data) = 0;
273+
278274 /* !
279275 * \brief Set the K/V data of the given sequence from input K/V data.
280276 * `start_pos` (inclusive) controls starting position of K/V data
0 commit comments