Skip to content

Commit d90aebf

Browse files
committed
Update heuristics for choosing kernels
Add rules based on tileN and tileM. Signed-off-by: Shiyang Chen <[email protected]>
1 parent 8569794 commit d90aebf

File tree

1 file changed

+17
-2
lines changed

1 file changed

+17
-2
lines changed

cpp/tensorrt_llm/kernels/trtllmGenKernels/gemm/KernelRunner.cpp

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,20 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
143143

144144
std::vector<int32_t> sortedIndices = mPassingConfigIndices;
145145
std::sort(sortedIndices.begin(), sortedIndices.end(),
146-
[&configs](int32_t idx0, int32_t idx1)
146+
[&configs, &gemmData](int32_t idx0, int32_t idx1)
147147
{
148148
auto const& optionsA = configs[idx0].mOptions;
149149
auto const& optionsB = configs[idx1].mOptions;
150150

151-
// Sort by tileK sizes first
151+
// Choose the tileN that is closest to the problem N
152+
// This is the batch size dimension for low latency (transposeMmaOutput) case;
153+
if (optionsA.mTileN != optionsB.mTileN)
154+
{
155+
return abs(gemmData.mProblemDimensions.mN - optionsA.mTileN)
156+
< abs(gemmData.mProblemDimensions.mN - optionsB.mTileN);
157+
}
158+
159+
// Sort by tileK sizes
152160
if (optionsA.mTileK != optionsB.mTileK)
153161
{
154162
return optionsA.mTileK > optionsB.mTileK;
@@ -160,6 +168,13 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
160168
return optionsA.mUseUnrollLoop2xForMma;
161169
}
162170

171+
// Sort by tileM sizes
172+
// This is the batch size dimension for throughput (non-transposeMmaOutput) case;
173+
if (optionsA.mTileM != optionsB.mTileM)
174+
{
175+
return optionsA.mTileM > optionsB.mTileM;
176+
}
177+
163178
// Then by splitK sizes
164179
if (optionsA.mNumSlicesForSplitK != optionsB.mNumSlicesForSplitK)
165180
{

0 commit comments

Comments
 (0)