Skip to content

Commit 057d988

Browse files
committed
perf: speed up sampling of 'trivial' request batches
Signed-off-by: ixlmar <[email protected]>
1 parent 259cc66 commit 057d988

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

tensorrt_llm/_torch/pyexecutor/sampler.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -373,6 +373,7 @@ def _group_requests_by_sampling_strategy(
373373
requests: Iterable[LlmRequest],
374374
*,
375375
pin_memory: bool = False) -> dict[Strategy, torch.Tensor]:
376+
# NB: Client code relies on request indices in returned torch.Tensor being sorted.
376377
strategy_dict: dict[Strategy, list[int]] = defaultdict(list)
377378
for req_index, req in enumerate(requests):
378379
strategy_dict[_request_strategy(req)].append(req_index)
@@ -1176,12 +1177,20 @@ def _sample_batched_by_strategy(
11761177
len(speculation_group_indices), dtype=torch.int32)
11771178

11781179
group_logits_cuda_indices = logits_cuda_indexer[group_req_indices]
1179-
if group_logits_cuda_indices.numel() != logits_cuda.size(0):
1180+
# NB: Assuming that group_req_indices are sorted
1181+
group_req_1st_index, group_req_last_index = group_req_indices[
1182+
0], group_req_indices[-1]
1183+
if group_req_last_index - group_req_1st_index + 1 == len(
1184+
group_req_indices):
1185+
# Avoid data movement if indices are contiguous
1186+
group_logits_cuda = logits_cuda[
1187+
req_offsets[group_req_1st_index]:(
1188+
req_offsets[group_req_last_index] +
1189+
req_num_steps[group_req_last_index])]
1190+
else:
11801191
group_logits_cuda_indices_cuda = group_logits_cuda_indices.to(
11811192
device=logits_cuda.device, non_blocking=True)
11821193
group_logits_cuda = logits_cuda[group_logits_cuda_indices_cuda]
1183-
else:
1184-
group_logits_cuda = logits_cuda
11851194

11861195
# Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda')
11871196
# corresponding to the requests in 'group_req_indices'.

0 commit comments

Comments
 (0)