@@ -373,6 +373,7 @@ def _group_requests_by_sampling_strategy(
373373 requests : Iterable [LlmRequest ],
374374 * ,
375375 pin_memory : bool = False ) -> dict [Strategy , torch .Tensor ]:
376+ # NB: Client code relies on request indices in returned torch.Tensor being sorted.
376377 strategy_dict : dict [Strategy , list [int ]] = defaultdict (list )
377378 for req_index , req in enumerate (requests ):
378379 strategy_dict [_request_strategy (req )].append (req_index )
@@ -1176,12 +1177,20 @@ def _sample_batched_by_strategy(
11761177 len (speculation_group_indices ), dtype = torch .int32 )
11771178
11781179 group_logits_cuda_indices = logits_cuda_indexer [group_req_indices ]
1179- if group_logits_cuda_indices .numel () != logits_cuda .size (0 ):
1180+ # NB: Assuming that group_req_indices are sorted
1181+ group_req_1st_index , group_req_last_index = group_req_indices [
1182+ 0 ], group_req_indices [- 1 ]
1183+ if group_req_last_index - group_req_1st_index + 1 == len (
1184+ group_req_indices ):
1185+ # Avoid data movement if indices are contiguous
1186+ group_logits_cuda = logits_cuda [
1187+ req_offsets [group_req_1st_index ]:(
1188+ req_offsets [group_req_last_index ] +
1189+ req_num_steps [group_req_last_index ])]
1190+ else :
11801191 group_logits_cuda_indices_cuda = group_logits_cuda_indices .to (
11811192 device = logits_cuda .device , non_blocking = True )
11821193 group_logits_cuda = logits_cuda [group_logits_cuda_indices_cuda ]
1183- else :
1184- group_logits_cuda = logits_cuda
11851194
11861195 # Indexer for accessing tokens in 'group_logits_cuda' (and 'group_next_tokens_cuda')
11871196 # corresponding to the requests in 'group_req_indices'.
0 commit comments