@@ -143,12 +143,20 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
143143
144144 std::vector<int32_t > sortedIndices = mPassingConfigIndices ;
145145 std::sort (sortedIndices.begin (), sortedIndices.end (),
146- [&configs](int32_t idx0, int32_t idx1)
146+ [&configs, &gemmData ](int32_t idx0, int32_t idx1)
147147 {
148148 auto const & optionsA = configs[idx0].mOptions ;
149149 auto const & optionsB = configs[idx1].mOptions ;
150150
151- // Sort by tileK sizes first
151+ // Choose the tileN that is closest to the problem N
152+ // This is the batch size dimension for low latency (transposeMmaOutput) case;
153+ if (optionsA.mTileN != optionsB.mTileN )
154+ {
155+ return abs (gemmData.mProblemDimensions .mN - optionsA.mTileN )
156+ < abs (gemmData.mProblemDimensions .mN - optionsB.mTileN );
157+ }
158+
159+ // Sort by tileK sizes
152160 if (optionsA.mTileK != optionsB.mTileK )
153161 {
154162 return optionsA.mTileK > optionsB.mTileK ;
@@ -160,6 +168,13 @@ void TrtllmGenGemmRunner::selectGemmConfig(int32_t m, int32_t n, int32_t k)
160168 return optionsA.mUseUnrollLoop2xForMma ;
161169 }
162170
171+ // Sort by tileM sizes
172+ // This is the batch size dimension for throughput (non-transposeMmaOutput) case;
173+ if (optionsA.mTileM != optionsB.mTileM )
174+ {
175+ return optionsA.mTileM > optionsB.mTileM ;
176+ }
177+
163178 // Then by splitK sizes
164179 if (optionsA.mNumSlicesForSplitK != optionsB.mNumSlicesForSplitK )
165180 {
0 commit comments