Skip to content

Commit a533a11

Browse files
committed
Fix lint
1 parent ca86ca7 commit a533a11

File tree

3 files changed

+26
-21
lines changed

3 files changed

+26
-21
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2461,10 +2461,14 @@ def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
24612461

24622462
return batch_prefill_ragged_kv
24632463

2464+
24642465
def _attention_prefill_ragged(h_kv, h_q, d, dtype, rope_scaling: Dict[str, Any], target: Target):
24652466
return _attention_prefill_ragged_generic(h_kv, h_q, d, d, dtype, rope_scaling, target)
24662467

2467-
def _attention_prefill_ragged_generic(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target):
2468+
2469+
def _attention_prefill_ragged_generic(
2470+
h_kv, h_q, d_qk, d_v, dtype, rope_scaling: Dict[str, Any], target: Target
2471+
):
24682472
# pylint: disable=line-too-long
24692473
(
24702474
NUM_BLKS,

src/runtime/relax_vm/kv_state.cc

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,12 +90,14 @@ TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_absorbed")
9090
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
9191
});
9292

93-
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
93+
TVM_REGISTER_GLOBAL("vm.builtin.attention_kv_cache_mla_normal")
9494
.set_body_typed([](AttentionKVCache kv_cache, int64_t layer_id,
95-
double attn_score_scaling_factor, NDArray q_data, NDArray k_data, NDArray v_data, NDArray compressed_kv_data,
96-
NDArray k_pe_data, NDArray o_data) {
97-
kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data), std::move(compressed_kv_data),
98-
std::move(k_pe_data), std::move(o_data), attn_score_scaling_factor);
95+
double attn_score_scaling_factor, NDArray q_data, NDArray k_data,
96+
NDArray v_data, NDArray compressed_kv_data, NDArray k_pe_data,
97+
NDArray o_data) {
98+
kv_cache->MLANormal(layer_id, std::move(q_data), std::move(k_data), std::move(v_data),
99+
std::move(compressed_kv_data), std::move(k_pe_data), std::move(o_data),
100+
attn_score_scaling_factor);
99101
});
100102

101103
// RNN State methods

src/runtime/relax_vm/paged_kv_cache.cc

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2288,7 +2288,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
22882288
CHECK_EQ(v_data->shape[2], v_head_dim_);
22892289
CHECK_EQ(o_data->shape[2], v_head_dim_);
22902290

2291-
22922291
// Part 2: Synchronize streams and update auxiliary data.
22932292
ComputeStreamWaitForCopyStream();
22942293
ICHECK(!dirty_aux_data_device_);
@@ -2303,20 +2302,20 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
23032302
// Here, we use f_mla_prefill_ragged_normal_, which is designed to work for both decode
23042303
// and normal prefill cases. Optionally, you could check a flag like `use_decode_kernel_[0]`
23052304
// to adjust parameters; here we assume the kernel internally supports both cases.
2306-
f_mla_prefill_ragged_normal_(q_data,
2307-
cur_append_length_indptr_view_,
2308-
k_data,
2309-
v_data,
2310-
cur_append_length_indptr_view_,
2311-
q_rope_position_map_view_,
2312-
k_ragged_rope_pos_offset_view_,
2313-
o_data, // output tensor
2314-
merged_attn_scores_view_,
2315-
/*causal=*/1,
2316-
static_cast<int>(RoPEMode::kNone), // Rope changes have already been applied before the kernel
2317-
0, // Rope param, not important
2318-
0, // Rope param, not important
2319-
attn_score_scaling_factor);
2305+
f_mla_prefill_ragged_normal_(q_data,
2306+
cur_append_length_indptr_view_,
2307+
k_data,
2308+
v_data,
2309+
cur_append_length_indptr_view_,
2310+
q_rope_position_map_view_,
2311+
k_ragged_rope_pos_offset_view_,
2312+
o_data, // output tensor
2313+
merged_attn_scores_view_,
2314+
/*causal=*/1,
2315+
static_cast<int>(RoPEMode::kNone),
2316+
0, // Rope param, not important
2317+
0, // Rope param, not important
2318+
attn_score_scaling_factor);
23202319

23212320
// Part 5: If appending is to occur after attention, call the append kernel.
23222321
if (!append_before_attn_) {

0 commit comments

Comments
 (0)