Skip to content

Commit e40b7f8

Browse files
committed
fix: Update beam search workspace estimation to always use standalone_stable_radix_topk_ as an better upper bound
Signed-off-by: Stefan Niebler <[email protected]>
1 parent 519a211 commit e40b7f8

File tree

1 file changed

+13
-3
lines changed

1 file changed

+13
-3
lines changed

cpp/tensorrt_llm/kernels/topkLastDim.cu

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1459,13 +1459,23 @@ template <typename T>
14591459
size_t invokeComputeTopkLastDimWorkspaceSize(
14601460
SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
14611461
{
1462+
using idxT = SizeType32;
1463+
14621464
size_t buf_size = 0;
14631465
void* workspace = nullptr;
14641466
T const* in = nullptr;
14651467
T* out_val = nullptr;
1466-
SizeType32* out_idx = nullptr;
1467-
standalone_stable_radix_11bits<T, SizeType32, true>(
1468-
workspace, buf_size, in, batchSize, inputLength, k, out_val, out_idx, is_largest, 0);
1468+
idxT* out_idx = nullptr;
1469+
1470+
constexpr int block_dim = 512;
1471+
constexpr bool fused_last_filter = false;
1472+
constexpr bool sorted = true;
1473+
1474+
int sm_cnt = tensorrt_llm::common::getMultiProcessorCount();
1475+
unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11, block_dim>(batchSize, inputLength, sm_cnt);
1476+
1477+
standalone_stable_radix_topk_<T, idxT, 11, block_dim>(workspace, buf_size, in, static_cast<idxT*>(nullptr),
1478+
batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0, sorted);
14691479
return buf_size;
14701480
}
14711481

0 commit comments

Comments
 (0)