Skip to content

Commit 29c8870

Browse files
committed
fix illegal smem access with chunked attention
Signed-off-by: Perkz Zheng <[email protected]>
1 parent eb157ac commit 29c8870

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderMaskedMultiheadAttentionTemplate.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1336,6 +1336,8 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
13361336
// Note max_attention_window_size is maximum of cyclic_attention_window_size among all layers.
13371337
// By default, you can assume that they are the same.
13381338
auto const cyclic_kv_cache_len = static_cast<unsigned>(params.cyclic_attention_window_size);
1339+
// The chunked attention size.
1340+
auto const chunked_attention_size = static_cast<unsigned>(params.chunked_attention_size);
13391341
// The number of sink tokens in kv cache to support streamingllm
13401342
auto const sink_token_len = static_cast<unsigned>(params.sink_token_length);
13411343
// The current timestep (including paddings).
@@ -1361,7 +1363,7 @@ __global__ void __launch_bounds__(MAX_THEADS_PER_BLOCK, MIN_BLOCKS_PER_SM) maske
13611363
#ifndef MMHA_USE_FP32_ACCUM_FOR_LOGITS
13621364
if (sizeof(Tk) != 4)
13631365
{
1364-
auto const max_timesteps = min(timestep, cyclic_kv_cache_len);
1366+
auto const max_timesteps = min(timestep, min(cyclic_kv_cache_len, chunked_attention_size));
13651367
logits_smem_ += divUp(max_timesteps + 1, 4u) * 16;
13661368
}
13671369
Tk* logits_smem = reinterpret_cast<Tk*>(logits_smem_);

0 commit comments

Comments
 (0)