@@ -1459,13 +1459,23 @@ template <typename T>
14591459size_t invokeComputeTopkLastDimWorkspaceSize (
14601460 SizeType32 batchSize, SizeType32 inputLength, SizeType32 k, bool is_largest)
14611461{
1462+ using idxT = SizeType32;
1463+
14621464 size_t buf_size = 0 ;
14631465 void * workspace = nullptr ;
14641466 T const * in = nullptr ;
14651467 T* out_val = nullptr ;
1466- SizeType32* out_idx = nullptr ;
1467- standalone_stable_radix_11bits<T, SizeType32, true >(
1468- workspace, buf_size, in, batchSize, inputLength, k, out_val, out_idx, is_largest, 0 );
1468+ idxT* out_idx = nullptr ;
1469+
1470+ constexpr int block_dim = 512 ;
1471+ constexpr bool fused_last_filter = false ;
1472+ constexpr bool sorted = true ;
1473+
1474+ int sm_cnt = tensorrt_llm::common::getMultiProcessorCount ();
1475+ unsigned grid_dim = air_topk_stable::calc_grid_dim<T, idxT, 11 , block_dim>(batchSize, inputLength, sm_cnt);
1476+
1477+ standalone_stable_radix_topk_<T, idxT, 11 , block_dim>(workspace, buf_size, in, static_cast <idxT*>(nullptr ),
1478+ batchSize, inputLength, k, out_val, out_idx, !is_largest, fused_last_filter, grid_dim, 0 , sorted);
14691479 return buf_size;
14701480}
14711481
0 commit comments