Skip to content

Commit 2e71471

Browse files
committed
fix precompiled multi_query_token kernel not having is_fp8_out hash key (NVIDIA#6279)
Signed-off-by: Jhao-Ting Chen <[email protected]>
1 parent 515a229 commit 2e71471

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplCommon.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ XQAKernelRuntimeHashKey getRuntimeHashKeyFromXQAParams(XQAParams const& xqaParam
5252
unsigned int kernel_m_tilesize
5353
= getKernelMTileSize(num_q_heads_over_kv, xqaParams.multi_query_tokens, qSeqLen, isXqaJit);
5454

55+
// precompiled XQA does not use is_fp8_output as hashing key
5556
return {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, kernel_m_tilesize,
5657
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0, xqaParams.paged_kv_cache,
57-
xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
58+
xqaParams.multi_query_tokens, isXqaJit ? xqaParams.is_fp8_output : false};
5859
}
5960

6061
} // namespace tensorrt_llm::kernels

cpp/tensorrt_llm/kernels/decoderMaskedMultiheadAttention/decoderXQAImplPrecompiled.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,11 @@ class XQAKernelList
124124
m_tilesize = num_q_heads_over_kv;
125125
}
126126

127+
// precompiled XQA does not support param is_fp8_output in hash key
127128
XQAKernelRuntimeHashKey hash_key
128129
= {xqaParams.kv_cache_data_type, head_size, beam_width, kernel_num_q_heads_over_kv, m_tilesize,
129130
xqaParams.paged_kv_cache ? static_cast<unsigned int>(xqaParams.tokens_per_block) : 0,
130-
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, xqaParams.is_fp8_output};
131+
xqaParams.paged_kv_cache, xqaParams.multi_query_tokens, 0 /* xqa jit param is_fp8_output */};
131132
auto const findIter = mFunctions.find(hash_key);
132133
return findIter != mFunctions.end();
133134
}

0 commit comments

Comments
 (0)