Skip to content

Commit 30a19fc

Browse files
authored
[TRTLLM-6291] feat: Add user-provided speculative decoding support (#5204)
Signed-off-by: Robin Kobus <[email protected]>
1 parent 85b4a68 commit 30a19fc

File tree

19 files changed

+221
-60
lines changed

19 files changed

+221
-60
lines changed

tensorrt_llm/_torch/pyexecutor/_util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ def teardown_managers(self, resources: Dict) -> None:
400400

401401

402402
def create_py_executor_instance(
403+
*,
403404
dist,
404405
resources,
405406
mapping,

tensorrt_llm/_torch/pyexecutor/model_engine.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,7 @@ class PyTorchModelEngine(ModelEngine):
343343

344344
def __init__(
345345
self,
346+
*,
346347
model_path: str,
347348
pytorch_backend_config: PyTorchConfig,
348349
batch_size: int = 8,

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,13 +168,13 @@ def __init__(self,
168168
sampler: Sampler,
169169
dist: Distributed,
170170
max_num_sequences: int,
171-
drafter: Drafter = None,
171+
drafter: Optional[Drafter] = None,
172172
disable_overlap_scheduler: bool = False,
173173
max_input_len: int = 2048,
174174
max_batch_size: int = 8,
175175
max_beam_width: int = 1,
176176
max_draft_tokens: int = 0,
177-
kv_cache_transceiver: KvCacheTransceiver = None,
177+
kv_cache_transceiver: Optional[KvCacheTransceiver] = None,
178178
draft_model_engine: Optional[ModelEngine] = None,
179179
garbage_collection_gen0_threshold: Optional[int] = None,
180180
start_worker: bool = True):
@@ -922,8 +922,7 @@ def _executor_loop(self):
922922
self._prepare_draft_tokens(scheduled_batch)
923923

924924
if self.drafter is not None:
925-
self.drafter.prepare_draft_tokens(
926-
scheduled_batch, sample_state)
925+
self.drafter.prepare_draft_tokens(scheduled_batch)
927926

928927
if self.kv_cache_transceiver:
929928
# For generation requests which have completed KV cache transfer

tensorrt_llm/_torch/pyexecutor/py_executor_creator.py

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -227,8 +227,8 @@ def create_py_executor(
227227
with mem_monitor.observe_creation_stage(
228228
_ExecutorCreationStage.MODEL_ENGINE_MAIN):
229229
model_engine = PyTorchModelEngine(
230-
checkpoint_dir,
231-
pytorch_backend_config,
230+
model_path=checkpoint_dir,
231+
pytorch_backend_config=pytorch_backend_config,
232232
batch_size=executor_config.max_batch_size,
233233
max_beam_width=executor_config.max_beam_width,
234234
max_num_tokens=executor_config.max_num_tokens,
@@ -250,8 +250,8 @@ def create_py_executor(
250250
draft_spec_config.max_draft_tokens = 0
251251

252252
draft_model_engine = PyTorchModelEngine(
253-
spec_config.draft_model_path,
254-
pytorch_backend_config,
253+
model_path=spec_config.draft_model_path,
254+
pytorch_backend_config=pytorch_backend_config,
255255
batch_size=executor_config.max_batch_size,
256256
max_beam_width=executor_config.max_beam_width,
257257
max_num_tokens=executor_config.max_num_tokens,
@@ -358,24 +358,36 @@ def create_py_executor(
358358
if estimating_kv_cache else _ExecutorCreationStage.KV_CACHE):
359359
kv_cache_creator.build_managers(resources)
360360

361+
# Drafter for speculative decoding
362+
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
363+
drafter = get_spec_drafter(model_engine)
364+
361365
# Resource managers for speculative decoding
362366
spec_resource_manager = get_spec_resource_manager(model_engine,
363-
draft_model_engine)
367+
draft_model_engine,
368+
drafter)
364369
if spec_resource_manager is not None:
365370
resources[
366371
ResourceManagerType.SPEC_RESOURCE_MANAGER] = spec_resource_manager
367372

368-
# Drafter for speculative decoding
369-
with mem_monitor.observe_creation_stage(_ExecutorCreationStage.DRAFTER):
370-
drafter = get_spec_drafter(model_engine, spec_resource_manager)
371-
372373
with mem_monitor.observe_creation_stage(
373374
_ExecutorCreationStage.INIT_EXTRA_RESOURCES
374375
if estimating_kv_cache else _ExecutorCreationStage.EXTRA_RESOURCES):
375376
py_executor = create_py_executor_instance(
376-
dist, resources, mapping, pytorch_backend_config, executor_config,
377-
ctx_chunk_config, model_engine, draft_model_engine, False, sampler,
378-
drafter, lora_config, garbage_collection_gen0_threshold)
377+
dist=dist,
378+
resources=resources,
379+
mapping=mapping,
380+
pytorch_backend_config=pytorch_backend_config,
381+
executor_config=executor_config,
382+
ctx_chunk_config=ctx_chunk_config,
383+
model_engine=model_engine,
384+
draft_model_engine=draft_model_engine,
385+
start_worker=False,
386+
sampler=sampler,
387+
drafter=drafter,
388+
lora_config=lora_config,
389+
garbage_collection_gen0_threshold=garbage_collection_gen0_threshold,
390+
)
379391

380392
if estimating_kv_cache:
381393
assert kv_cache_creator is not None
@@ -404,10 +416,21 @@ def create_py_executor(
404416
with mem_monitor.observe_creation_stage(
405417
_ExecutorCreationStage.EXTRA_RESOURCES):
406418
py_executor = create_py_executor_instance(
407-
dist, resources, mapping, pytorch_backend_config,
408-
executor_config, ctx_chunk_config, model_engine,
409-
draft_model_engine, False, sampler, drafter, lora_config,
410-
garbage_collection_gen0_threshold)
419+
dist=dist,
420+
resources=resources,
421+
mapping=mapping,
422+
pytorch_backend_config=pytorch_backend_config,
423+
executor_config=executor_config,
424+
ctx_chunk_config=ctx_chunk_config,
425+
model_engine=model_engine,
426+
draft_model_engine=draft_model_engine,
427+
start_worker=False,
428+
sampler=sampler,
429+
drafter=drafter,
430+
lora_config=lora_config,
431+
garbage_collection_gen0_threshold=
432+
garbage_collection_gen0_threshold,
433+
)
411434

412435
py_executor.start_worker()
413436
return py_executor

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .interface import SpecConfig, SpecMetadata
44
from .mtp import MTPConfig, MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramConfig, NGramDrafter, NGramPoolManager
6+
from .user_provided import UserProvidedConfig
67
from .utils import (get_num_spec_layers, get_spec_decoder, get_spec_drafter,
78
get_spec_metadata, get_spec_resource_manager,
89
get_spec_worker)
@@ -20,6 +21,7 @@
2021
"NGramPoolManager",
2122
"SpecConfig",
2223
"SpecMetadata",
24+
"UserProvidedConfig",
2325
"get_num_spec_layers",
2426
"get_spec_decoder",
2527
"get_spec_drafter",

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Optional
33

44
from ..pyexecutor.resource_manager import BaseResourceManager
5-
from ..pyexecutor.sampler import SampleState
65
from ..pyexecutor.scheduler import ScheduledRequests
76

87

@@ -18,7 +17,6 @@ def __init__(
1817
def prepare_draft_tokens(
1918
self,
2019
scheduled_requests: ScheduledRequests,
21-
state: SampleState,
2220
) -> None:
2321
"""
2422
Prepare the drafter tokens for the forward computation this step.

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ class SpeculativeDecodingMode(IntEnum):
1717
EAGLE3_ONE_MODEL = auto()
1818
NGRAM = auto()
1919
DRAFT_TARGET = auto()
20+
USER_PROVIDED = auto()
2021
NONE = auto()
2122

2223
def is_mtp(self):
@@ -37,6 +38,9 @@ def is_eagle3_one_model(self):
3738
def is_ngram(self):
3839
return self == SpeculativeDecodingMode.NGRAM
3940

41+
def is_user_provided(self):
42+
return self == SpeculativeDecodingMode.USER_PROVIDED
43+
4044
def is_none(self):
4145
return self == SpeculativeDecodingMode.NONE
4246

@@ -74,7 +78,7 @@ def has_spec_decoder(self):
7478
return self.is_mtp() or self.is_eagle3() or self.is_eagle3_one_model()
7579

7680
def has_spec_drafter(self):
77-
return self.is_ngram()
81+
return self.is_ngram() or self.is_user_provided()
7882

7983
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
8084
"""
@@ -86,7 +90,8 @@ def extend_ctx(self, attention_backend: Type[AttentionBackend]):
8690
# Fixme: only trtllm attention backend supports eagle3 generation-phase kernels on blackwell.
8791
return ((self.is_eagle3() or self.is_draft_target())
8892
and not (isinstance(attention_backend, TrtllmAttention)
89-
and get_sm_version() == 100)) or self.is_ngram()
93+
and get_sm_version() == 100)
94+
) or self.is_ngram() or self.is_user_provided()
9095

9196
def attention_need_spec_dec_mode(self):
9297
"""
@@ -185,6 +190,9 @@ class SpecMetadata:
185190
# if spec-dec tree wouldn't be changed at all, the mask won't be computed every step.
186191
is_spec_dec_dynamic_tree: bool = False
187192

193+
def __post_init__(self):
194+
pass
195+
188196
def prepare(self):
189197
"""
190198
Hook to be called before the forward step of the model.

tensorrt_llm/_torch/speculative/ngram.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,9 @@
77

88
from ..pyexecutor.llm_request import *
99
from ..pyexecutor.resource_manager import BaseResourceManager
10-
from ..pyexecutor.sampler import SampleState
1110
from ..pyexecutor.scheduler import ScheduledRequests
1211
from .drafter import Drafter
13-
from .interface import SpecConfig, SpecMetadata, SpeculativeDecodingMode
12+
from .interface import SpecConfig, SpeculativeDecodingMode
1413

1514

1615
@dataclass
@@ -40,16 +39,6 @@ def update_from_model_config(self, model_config):
4039
pass
4140

4241

43-
@dataclass
44-
class NGramSpecMetadata(SpecMetadata):
45-
"""
46-
Metadata for NGram.
47-
"""
48-
49-
def __post_init__(self) -> None:
50-
return
51-
52-
5342
class NGramPoolManager(BaseResourceManager):
5443
"""
5544
Drafter for NGram. This class maintains the pattern-matches pairs for NGram drafter.
@@ -212,12 +201,8 @@ def __init__(
212201
def prepare_draft_tokens(
213202
self,
214203
scheduled_requests: ScheduledRequests,
215-
state: SampleState,
216204
) -> None:
217205

218-
if state is None: # Skip the first step
219-
return
220-
221206
for request in sorted(scheduled_requests.generation_requests,
222207
key=lambda r: r.py_batch_idx):
223208
# Add new token to a copy of the generated tokens to find new daft tokens
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
from dataclasses import dataclass
2+
from typing import Optional
3+
4+
from tensorrt_llm._torch.speculative.drafter import Drafter
5+
6+
from .interface import SpecConfig, SpeculativeDecodingMode
7+
8+
9+
@dataclass
10+
class UserProvidedConfig(SpecConfig):
11+
"""
12+
Configuration for user provided speculative decoding.
13+
"""
14+
# The name of speculative decoding.
15+
spec_dec_name = "USER_PROVIDED"
16+
17+
num_extra_kv_tokens: int = 0
18+
max_draft_tokens: int = 0
19+
drafter: Optional[Drafter] = None
20+
21+
def __post_init__(self) -> None:
22+
self.spec_dec_mode = SpeculativeDecodingMode.from_string(
23+
self.spec_dec_name)
24+
25+
def update_from_model_config(self, model_config):
26+
pass

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
from tensorrt_llm._torch.pyexecutor.sampler import TorchSampler
2-
from tensorrt_llm._torch.speculative.interface import SpecConfig
2+
from tensorrt_llm._torch.speculative.interface import SpecConfig, SpecMetadata
33

44
from .draft_target import DraftTargetSpecMetadata
55
from .eagle3 import (Eagle3OneModelSampler, Eagle3OneModelSpecMetadata,
66
Eagle3OneModelWorker, Eagle3ResourceManager,
77
Eagle3SpecMetadata)
88
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
99
MTPSpecMetadata, MTPWorker)
10-
from .ngram import NGramDrafter, NGramPoolManager, NGramSpecMetadata
10+
from .ngram import NGramDrafter, NGramPoolManager
1111

1212

1313
def get_spec_metadata(spec_config,
@@ -50,16 +50,19 @@ def get_spec_metadata(spec_config,
5050
spec_dec_mode=spec_config.spec_dec_mode,
5151
max_num_requests=max_num_requests,
5252
)
53-
if spec_config.spec_dec_mode.is_ngram():
54-
return NGramSpecMetadata(
53+
if spec_config.spec_dec_mode.is_ngram(
54+
) or spec_config.spec_dec_mode.is_user_provided():
55+
return SpecMetadata(
5556
max_draft_tokens=spec_config.max_draft_tokens,
5657
spec_dec_mode=spec_config.spec_dec_mode,
5758
max_num_requests=max_num_requests,
5859
)
5960
return None
6061

6162

62-
def get_spec_resource_manager(model_engine, draft_model_engine=None):
63+
def get_spec_resource_manager(model_engine,
64+
draft_model_engine=None,
65+
drafter=None):
6366
spec_config = model_engine.spec_config
6467
if spec_config is None:
6568
return None
@@ -95,8 +98,9 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
9598
max_seq_len,
9699
max_num_tokens,
97100
)
98-
if spec_dec_mode.is_ngram():
99-
return NGramPoolManager(spec_config, max_num_requests)
101+
if spec_dec_mode.is_ngram() or spec_dec_mode.is_user_provided():
102+
assert drafter is not None, "Drafter is required for ngram or user provided speculative decoding."
103+
return drafter.spec_resource_manager
100104
return None
101105

102106

@@ -113,12 +117,16 @@ def get_spec_decoder(sampler_args: TorchSampler.Args, spec_config: SpecConfig):
113117
f"Unsupported speculative decoding mode: {spec_config.spec_dec_mode}")
114118

115119

116-
def get_spec_drafter(model_engine, spec_resource_manager=None):
120+
def get_spec_drafter(model_engine):
117121
spec_config = model_engine.spec_config
122+
max_num_requests = model_engine.batch_size
118123
if spec_config is None:
119124
return None
120125
if spec_config.spec_dec_mode.is_ngram():
121-
return NGramDrafter(spec_config, spec_resource_manager)
126+
return NGramDrafter(spec_config,
127+
NGramPoolManager(spec_config, max_num_requests))
128+
if spec_config.spec_dec_mode.is_user_provided():
129+
return spec_config.drafter
122130
return None
123131

124132

0 commit comments

Comments
 (0)