Skip to content

Commit 70bf47d

Browse files
committed
Make changes suggested in PR.
Signed-off-by: Daniel Campora <[email protected]>
1 parent 5f4c562 commit 70bf47d

File tree

3 files changed

+7
-13
lines changed

3 files changed

+7
-13
lines changed

tensorrt_llm/_torch/pyexecutor/handle_logits.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,6 @@ def __call__(
1616
context_requests: List[LlmRequest],
1717
generation_requests: List[LlmRequest],
1818
logits: torch.Tensor,
19-
num_context_logits_prefix_sum: List[int],
20-
max_num_sequences: int,
2119
beam_width: int,
2220
):
2321
"""Handles context and generation logits for a batch of requests.
@@ -26,10 +24,14 @@ def __call__(
2624
context_requests: List of context requests to process
2725
generation_requests: List of generation requests to process
2826
logits: Input logits tensor
29-
num_context_logits_prefix_sum: Prefix sum of context logits for each request
30-
max_num_sequences: Maximum number of sequences to process
3127
beam_width: Beam width for the generation requests
3228
"""
29+
num_context_logits_prefix_sum = [0]
30+
prefix_sum = 0
31+
for request in context_requests:
32+
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
33+
num_context_logits_prefix_sum.append(prefix_sum)
34+
3335
# Copy logits into decoderBuffers.logits
3436
for batch_index, llm_req in enumerate(context_requests):
3537
logits_begin = num_context_logits_prefix_sum[batch_index]

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,6 @@ def __init__(self,
176176
self.guided_decoder = guided_decoder
177177
self.dist = dist
178178
self.disable_overlap_scheduler = disable_overlap_scheduler
179-
self.max_num_sequences = max_num_sequences
180179

181180
# enqueue and _fetch_new_requests used data
182181
self.active = True
@@ -1489,16 +1488,9 @@ def _sample_async(self, scheduled_batch,
14891488
def _handle_logits(self, scheduled_batch, batch_outputs):
14901489
if any(r.py_return_context_logits or r.py_return_generation_logits
14911490
for r in scheduled_batch.all_requests()):
1492-
num_context_logits_prefix_sum = [0]
1493-
prefix_sum = 0
1494-
for request in scheduled_batch.context_requests:
1495-
prefix_sum += request.context_chunk_size if request.py_return_context_logits else 1
1496-
num_context_logits_prefix_sum.append(prefix_sum)
1497-
14981491
HandleLogits()(
14991492
scheduled_batch.context_requests,
15001493
scheduled_batch.generation_requests, batch_outputs["logits"],
1501-
num_context_logits_prefix_sum, self.max_num_sequences,
15021494
self.sampler.beam_width(scheduled_batch.all_requests()))
15031495

15041496
@nvtx_range("_setup_sampler_step")

tests/unittest/_torch/test_return_logits.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
global_kvcache_config = KvCacheConfig(max_tokens=10000)
1313

1414

15-
# @force_ampere # Save H100 resource
15+
@force_ampere # Save H100 resource
1616
@pytest.mark.parametrize("return_log_probs", [False, True])
1717
@pytest.mark.parametrize("gather_generation_logits", [False, True])
1818
@pytest.mark.parametrize("gather_context_logits", [False, True])

0 commit comments

Comments
 (0)