Skip to content

Commit 2c4afbb

Browse files
authored
[Relax][KV Cache] Refactor _attention_sequence_prefill function to … (#17362)
This PR removes batch_size from the function signature, instead of mapping it within the function body.
1 parent 72b75fe commit 2c4afbb

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

python/tvm/relax/frontend/nn/llm/kv_cache.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1237,7 +1237,7 @@ def merge_state_inplace(
12371237

12381238

12391239
def _attention_sequence_prefill(
1240-
batch_size, h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0
1240+
h_kv, h_q, d, dtype, target: Target, causal=0, attn_score_scaling_factor=1.0
12411241
): # pylint: disable=line-too-long
12421242
LOAD_VEC = 8 // ((DataType(dtype).bits + 7) // 8) # 8 bytes
12431243
group_size = h_q // h_kv
@@ -1264,6 +1264,7 @@ def batch_sequence_prefill_kv( # pylint: disable=too-many-branches
12641264
var_output: T.handle, # [total_len, h_q, d]
12651265
var_lse: T.handle # [total_len, h_q]
12661266
):
1267+
batch_size = T.int32(is_size_var=True)
12671268
qo_len = T.int32(is_size_var=True)
12681269
kv_len = T.int32(is_size_var=True)
12691270
q = T.match_buffer(var_q, (batch_size, qo_len, h_q, d), dtype)

0 commit comments

Comments
 (0)