Skip to content

Commit 82d14cd

Browse files
committed
[None][refactor] Move draft token padding out of Drafter
Signed-off-by: Mike Iovine <[email protected]>
1 parent 90bfc8c commit 82d14cd

File tree

3 files changed

+14
-22
lines changed

3 files changed

+14
-22
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
from .guided_decoder import GuidedDecoder
4242
from .kv_cache_transceiver import KvCacheTransceiver
4343
from .llm_request import (ExecutorRequest, LlmRequest, LlmRequestState,
44-
LlmResponse)
44+
LlmResponse, get_draft_token_length)
4545
from .model_engine import ModelEngine
4646
from .sampler import Sampler, SampleState, SampleStateTensors
4747
from .scheduler import RequestScheduler, ScheduledRequests
@@ -1001,6 +1001,15 @@ def _executor_loop(self):
10011001
self.drafter.prepare_draft_tokens(
10021002
scheduled_batch, self.resource_manager)
10031003

1004+
# Pad draft tokens to the max draft length. This is for CUDA
1005+
# graph compatibility.
1006+
for req in scheduled_batch.generation_requests:
1007+
max_draft_tokens = self.max_draft_len
1008+
num_draft_tokens = get_draft_token_length(req)
1009+
req.py_draft_tokens.extend(
1010+
0 for _ in range(max_draft_tokens -
1011+
num_draft_tokens))
1012+
10041013
batch_outputs = self._forward_step(scheduled_batch)
10051014
self._execute_guided_decoder(scheduled_batch,
10061015
batch_outputs['logits'])

tensorrt_llm/_torch/speculative/model_drafter.py

Lines changed: 1 addition & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from tensorrt_llm.logger import logger
1010

1111
from ..pyexecutor.guided_decoder import GuidedDecoder
12-
from ..pyexecutor.llm_request import (LlmRequest, LlmRequestState,
13-
get_draft_token_length)
12+
from ..pyexecutor.llm_request import LlmRequest, LlmRequestState
1413
from ..pyexecutor.resource_manager import BaseResourceManager, ResourceManager
1514
from ..pyexecutor.sampler import Sampler, SampleState, TorchSampler
1615
from ..pyexecutor.scheduler import ScheduledRequests
@@ -311,15 +310,6 @@ def _process_decoded_tokens(
311310

312311
return new_requests
313312

314-
def _pad_to_max_draft_tokens(self,
315-
scheduled_requests: ScheduledRequests) -> None:
316-
"""Pad draft tokens to maximum length for all generation requests."""
317-
for req in scheduled_requests.generation_requests:
318-
max_draft_tokens = self.max_draft_tokens
319-
num_draft_tokens = get_draft_token_length(req)
320-
req.py_draft_tokens.extend(
321-
0 for _ in range(max_draft_tokens - num_draft_tokens))
322-
323313
def _execute_guided_decoder(self,
324314
scheduled_batch: ScheduledRequests,
325315
logits: torch.Tensor,
@@ -403,7 +393,6 @@ def prepare_draft_tokens(
403393
self._update_requests(previous_batch)
404394
self._process_decoded_tokens(previous_batch.scheduled_requests,
405395
req_id_to_old_request)
406-
self._pad_to_max_draft_tokens(scheduled_requests)
407396

408397
if self.guided_decoder is not None:
409398
self.guided_decoder.rollback_draft_tokens(scheduled_requests)

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,13 +87,13 @@ def get_draft_tokens(
8787
self,
8888
prefix: list[int],
8989
request_id: int,
90-
padding_id: int,
9190
max_sequence_length: int,
9291
):
9392
prefix_len = len(prefix)
9493
max_draft_token_length_this_step = max_sequence_length - 1 - prefix_len
9594
if max_draft_token_length_this_step <= 0: # No draft token is need if the prefix is long enough
96-
return [padding_id]
95+
return []
96+
9797
if request_id not in self.start_index: # Extend start_index and pool for a new request
9898
self.start_index[request_id] = 0
9999
if not self.is_public_pool:
@@ -125,8 +125,7 @@ def get_draft_tokens(
125125
pool[pattern].remove(match)
126126
pool[pattern].add(new_match)
127127

128-
# Find match
129-
draft_tokens = [padding_id] # fallback value
128+
draft_tokens = []
130129
for size in range(min(self.max_matching_ngram_size, prefix_len - 1), 0,
131130
-1):
132131
pattern = tuple(prefix[-size:])
@@ -194,12 +193,7 @@ def prepare_draft_tokens(
194193
draft_tokens = self.spec_resource_manager.get_draft_tokens(
195194
prefix,
196195
request.request_id,
197-
padding_id=0,
198196
max_sequence_length=request.py_orig_prompt_len +
199197
request.py_max_new_tokens,
200198
)
201-
# Pad length to `self.max_draft_len`
202-
if len(draft_tokens) > 0:
203-
pad_length = self.max_draft_len - len(draft_tokens)
204-
draft_tokens.extend([0] * pad_length)
205199
request.py_draft_tokens = draft_tokens

0 commit comments

Comments
 (0)