Skip to content

Commit 1ad7bc4

Browse files
[None][feat] Draft: Save state first pass (#7012)
Signed-off-by: Izzy Putterman <[email protected]>
1 parent e107749 commit 1ad7bc4

File tree

12 files changed

+367
-4
lines changed

12 files changed

+367
-4
lines changed

tensorrt_llm/_torch/pyexecutor/py_executor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,10 @@ def _executor_loop(self):
11101110

11111111
sample_state = self._sample_async(scheduled_batch,
11121112
batch_outputs)
1113+
if self.drafter is not None:
1114+
self.drafter.run_drafter_post(scheduled_batch,
1115+
self.resource_manager,
1116+
self.is_warmup)
11131117

11141118
self._update_request_states(scheduled_batch)
11151119
self._update_requests(sample_state, self.resource_manager)

tensorrt_llm/_torch/speculative/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from .interface import SpecMetadata
44
from .mtp import MTPEagleWorker, MTPSpecMetadata, MTPWorker
55
from .ngram import NGramDrafter, NGramPoolManager
6+
from .save_hidden_state import SaveHiddenStatesDrafter
67
from .spec_tree_manager import SpecTreeManager
78
from .utils import (get_num_extra_kv_tokens, get_num_spec_layers,
89
get_spec_decoder, get_spec_drafter, get_spec_metadata,
@@ -16,6 +17,7 @@
1617
"MTPWorker",
1718
"NGramDrafter",
1819
"NGramPoolManager",
20+
"SaveHiddenStatesDrafter",
1921
"SpecMetadata",
2022
"get_num_extra_kv_tokens",
2123
"get_num_spec_layers",

tensorrt_llm/_torch/speculative/drafter.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,3 +67,15 @@ def pad_draft_tokens_for_cuda_graph(
6767
num_draft_tokens = get_draft_token_length(req)
6868
req.py_draft_tokens.extend(
6969
0 for _ in range(max_draft_tokens - num_draft_tokens))
70+
71+
def run_drafter_post(
72+
self,
73+
scheduled_requests: ScheduledRequests,
74+
resource_manager: Optional[ResourceManager] = None,
75+
is_warmup: bool = False,
76+
) -> None:
77+
"""
78+
If draft forward needs to be run directly after the target model forward,
79+
this method can be overridden to do that.
80+
Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
81+
"""

tensorrt_llm/_torch/speculative/eagle3.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def __post_init__(self):
126126
self.num_layers - 4)
127127
else:
128128
self.layers_to_capture = sorted(list(self.layers_to_capture))
129+
if self.layers_to_capture[0] == -1:
130+
self.layers_to_capture = self.layers_to_capture[1:] + [
131+
self.layers_to_capture.pop(0)
132+
]
129133
self.num_capture_layers = len(self.layers_to_capture)
130134

131135
# Initialize to 0 to avoid reading uninitialized memory during warmup

tensorrt_llm/_torch/speculative/interface.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class SpeculativeDecodingMode(IntEnum):
1919
NGRAM = auto()
2020
DRAFT_TARGET = auto()
2121
USER_PROVIDED = auto()
22+
SAVE_HIDDEN_STATES = auto()
2223
NONE = auto()
2324
AUTO = auto()
2425

@@ -55,6 +56,9 @@ def is_none(self):
5556
def is_draft_target(self):
5657
return self == SpeculativeDecodingMode.DRAFT_TARGET
5758

59+
def is_save_hidden_states(self):
60+
return self == SpeculativeDecodingMode.SAVE_HIDDEN_STATES
61+
5862
def without_logits(self):
5963
return self.is_mtp_one_model() or self.is_eagle3_one_model()
6064

@@ -95,8 +99,9 @@ def has_spec_decoder(self):
9599
) or self.is_eagle3_one_model()
96100

97101
def has_spec_drafter(self):
98-
return self.is_eagle3() or self.is_draft_target() or self.is_ngram(
99-
) or self.is_user_provided() or self.is_mtp_eagle()
102+
return self.is_eagle3(
103+
) or self.is_draft_target() or self.is_ngram() or self.is_user_provided(
104+
) or self.is_mtp_eagle() or self.is_save_hidden_states()
100105

101106
def extend_ctx(self, attention_backend: Type[AttentionBackend]):
102107
"""
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import os
2+
from typing import Optional
3+
4+
import torch
5+
6+
from tensorrt_llm._utils import local_mpi_rank
7+
8+
from ..pyexecutor.llm_request import LlmRequest
9+
from ..pyexecutor.resource_manager import ResourceManager
10+
from ..pyexecutor.scheduler import ScheduledRequests
11+
from .drafter import Drafter
12+
13+
14+
class SaveHiddenStatesDrafter(Drafter):
15+
16+
def __init__(
17+
self,
18+
spec_config: "SaveHiddenStatesDecodingConfig",
19+
spec_resource_manager,
20+
):
21+
super().__init__(spec_config.max_concurrency)
22+
self.spec_config = spec_config
23+
self.max_draft_len = spec_config.max_draft_len
24+
self._iter = 1
25+
self._output_directory = spec_config.output_directory
26+
self._file_prefix = spec_config.file_prefix
27+
self._write_interval = spec_config.write_interval
28+
self._saved_state = []
29+
self.spec_resource_manager = spec_resource_manager
30+
os.makedirs(self._output_directory, exist_ok=True)
31+
32+
def _process_request(self, request: LlmRequest, resource_manager) -> None:
33+
out_dict = {}
34+
if local_mpi_rank() == 0:
35+
input_ids = torch.tensor(list(request.get_tokens(0)),
36+
dtype=torch.long,
37+
device='cpu')
38+
hidden_size = resource_manager.hidden_size
39+
num_tokens = input_ids.shape[0]
40+
hidden_states = resource_manager.hidden_states[:num_tokens,
41+
-hidden_size:].cpu(
42+
).clone()
43+
44+
out_dict = {
45+
"id": self._iter,
46+
"input_ids": input_ids,
47+
"hidden_state": hidden_states,
48+
}
49+
if len(self.spec_config.eagle3_layers_to_capture) > 1:
50+
if self.spec_config._last_hidden_in_save:
51+
out_dict[
52+
"aux_hidden_states"] = resource_manager.hidden_states[:num_tokens, :].cpu(
53+
).clone()
54+
else:
55+
out_dict[
56+
"aux_hidden_states"] = resource_manager.hidden_states[:
57+
num_tokens, :
58+
-hidden_size].cpu(
59+
).clone(
60+
)
61+
62+
self._saved_state.append(out_dict)
63+
64+
def _write_to_file(self) -> None:
65+
if local_mpi_rank() == 0:
66+
output_path = os.path.join(self._output_directory,
67+
f"{self._file_prefix}_{self._iter}.pt")
68+
torch.save(self._saved_state, output_path)
69+
self._saved_state = []
70+
71+
def prepare_draft_tokens(
72+
self,
73+
scheduled_requests: ScheduledRequests,
74+
resource_manager: Optional[ResourceManager] = None,
75+
) -> None:
76+
for request in sorted(
77+
scheduled_requests.context_requests,
78+
key=lambda r:
79+
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
80+
):
81+
request.py_max_new_tokens = 1
82+
83+
def run_drafter_post(
84+
self,
85+
scheduled_requests: ScheduledRequests,
86+
resource_manager: Optional[ResourceManager] = None,
87+
is_warmup: bool = False,
88+
) -> None:
89+
if is_warmup:
90+
return
91+
for request in sorted(
92+
scheduled_requests.context_requests,
93+
key=lambda r:
94+
(r.py_batch_idx is None, r.py_batch_idx or r.request_id),
95+
):
96+
self._process_request(request, self.spec_resource_manager)
97+
if self._iter % self._write_interval == 0:
98+
self._write_to_file()
99+
self._iter += 1

tensorrt_llm/_torch/speculative/utils.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .mtp import (MTPEagleWorker, MTPHiddenStatesManager, MTPSampler,
1212
MTPSpecMetadata, MTPWorker)
1313
from .ngram import NGramDrafter, NGramPoolManager
14+
from .save_hidden_state import SaveHiddenStatesDrafter
1415

1516

1617
def get_spec_metadata(spec_config,
@@ -55,6 +56,25 @@ def get_spec_metadata(spec_config,
5556
max_num_tokens=max_num_tokens,
5657
layers_to_capture=spec_config.eagle3_layers_to_capture,
5758
)
59+
if spec_config.spec_dec_mode.is_save_hidden_states():
60+
if spec_config.eagle3_layers_to_capture is None:
61+
spec_config.eagle3_layers_to_capture = {
62+
1, model_config.num_hidden_layers // 2 - 1,
63+
model_config.num_hidden_layers - 4, -1
64+
}
65+
return Eagle3SpecMetadata(
66+
max_draft_len=spec_config.max_draft_len,
67+
spec_dec_mode=spec_config.spec_dec_mode,
68+
max_num_requests=max_num_requests,
69+
num_layers=model_config.num_hidden_layers,
70+
hidden_size=model_config.hidden_size,
71+
max_num_tokens=max_num_tokens,
72+
dtype=model_config.torch_dtype,
73+
is_draft_model=is_draft_model,
74+
eagle3_resource_manager=spec_resource_manager,
75+
layers_to_capture=spec_config.eagle3_layers_to_capture,
76+
max_total_draft_tokens=1,
77+
)
5878
if spec_config.spec_dec_mode.is_draft_target() or \
5979
spec_config.spec_dec_mode.is_ngram() or \
6080
spec_config.spec_dec_mode.is_user_provided():
@@ -102,6 +122,15 @@ def get_spec_resource_manager(model_engine, draft_model_engine=None):
102122
max_seq_len,
103123
max_num_tokens,
104124
)
125+
if spec_dec_mode.is_save_hidden_states():
126+
return Eagle3ResourceManager(
127+
spec_config,
128+
model_engine.model.config.torch_dtype,
129+
model_config.hidden_size,
130+
max_num_requests,
131+
max_seq_len,
132+
max_num_tokens,
133+
)
105134
if spec_dec_mode.is_ngram():
106135
return NGramPoolManager(spec_config, max_num_requests)
107136
if spec_dec_mode.is_user_provided():
@@ -151,6 +180,9 @@ def get_spec_drafter(model_engine,
151180
if spec_config.spec_dec_mode.is_ngram():
152181
return NGramDrafter(spec_config, spec_resource_manager)
153182

183+
if spec_config.spec_dec_mode.is_save_hidden_states():
184+
return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
185+
154186
return None
155187

156188

tensorrt_llm/llmapi/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
DynamicBatchConfig, EagleDecodingConfig,
1212
ExtendedRuntimePerfKnobConfig, KvCacheConfig, LlmArgs,
1313
LookaheadDecodingConfig, MedusaDecodingConfig, MoeConfig,
14-
MTPDecodingConfig, NGramDecodingConfig, SchedulerConfig,
14+
MTPDecodingConfig, NGramDecodingConfig,
15+
SaveHiddenStatesDecodingConfig, SchedulerConfig,
1516
TorchCompileConfig, TorchLlmArgs, TrtLlmArgs,
1617
UserProvidedDecodingConfig)
1718
from .llm_utils import (BuildConfig, KvCacheRetentionConfig, QuantAlgo,
@@ -59,4 +60,5 @@
5960
'AutoDecodingConfig',
6061
'AttentionDpConfig',
6162
'LoRARequest',
63+
'SaveHiddenStatesDecodingConfig',
6264
]

tensorrt_llm/llmapi/llm_args.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,7 @@ def from_dict(cls, data: dict):
380380
"Lookahead": LookaheadDecodingConfig,
381381
"NGram": NGramDecodingConfig,
382382
"DraftTarget": DraftTargetDecodingConfig,
383+
"SaveState": SaveHiddenStatesDecodingConfig,
383384
"UserProvided": UserProvidedDecodingConfig,
384385
"AUTO": AutoDecodingConfig,
385386
}
@@ -562,6 +563,52 @@ def num_capture_layers(self) -> int:
562563
return 3
563564

564565

566+
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
567+
output_directory: str
568+
write_interval: int = 20
569+
file_prefix: str = "data"
570+
eagle3_layers_to_capture: Optional[Set[int]] = None
571+
572+
max_total_draft_tokens: Optional[int] = Field(default=1, init=False)
573+
eagle_choices: Optional[List[List[int]]] = Field(default=None, init=False)
574+
575+
def model_post_init(self, __context):
576+
self._last_hidden_in_save = True
577+
if self.eagle3_layers_to_capture is None:
578+
self._last_hidden_in_save = False
579+
elif -1 not in self.eagle3_layers_to_capture:
580+
self._last_hidden_in_save = False
581+
self.eagle3_layers_to_capture.add(-1)
582+
583+
@classmethod
584+
def from_dict(cls, data: dict):
585+
return cls(**data)
586+
587+
decoding_type: ClassVar[str] = "SaveState"
588+
589+
def validate(self) -> None:
590+
if self.output_directory is None or not self.eagle3_layers_to_capture:
591+
raise ValueError(
592+
"Save directory and layers to capture must be provided")
593+
594+
@functools.cached_property
595+
def spec_dec_mode(self):
596+
from tensorrt_llm._torch.speculative.interface import \
597+
SpeculativeDecodingMode as TorchSpeculativeDecodingMode
598+
return TorchSpeculativeDecodingMode.SAVE_HIDDEN_STATES
599+
600+
@functools.cached_property
601+
def num_capture_layers(self):
602+
"""
603+
Returns the number of layers to capture of the target model.
604+
If eagle3_layers_to_capture is not None, return the length of the set.
605+
Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
606+
"""
607+
if self.eagle3_layers_to_capture is None:
608+
return 4
609+
return len(self.eagle3_layers_to_capture)
610+
611+
565612
class UserProvidedDecodingConfig(DecodingBaseConfig):
566613
# Cannot use real type annotations due to circular imports
567614
drafter: object # Type is Drafter
@@ -1050,6 +1097,7 @@ def supports_backend(self, backend: str) -> bool:
10501097
MTPDecodingConfig,
10511098
NGramDecodingConfig,
10521099
UserProvidedDecodingConfig,
1100+
SaveHiddenStatesDecodingConfig,
10531101
AutoDecodingConfig,
10541102
]]
10551103

@@ -1869,6 +1917,20 @@ def validate_speculative_config(self):
18691917
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.AUTO
18701918
self.build_config.max_draft_len = self.speculative_config.max_draft_len
18711919

1920+
elif isinstance(self.speculative_config,
1921+
SaveHiddenStatesDecodingConfig):
1922+
assert self.backend in ['pytorch']
1923+
logger.warning(
1924+
"SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
1925+
)
1926+
self.build_config.max_batch_size = 1
1927+
self.max_batch_size = 1
1928+
self.disable_overlap_scheduler = True
1929+
self.cuda_graph_config = None
1930+
self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
1931+
self.build_config.max_draft_len = 1
1932+
self.speculative_config.max_draft_len = 1
1933+
18721934
else:
18731935
raise ValueError(
18741936
f"Unrecognized speculative config type {type(self.speculative_config)}"

tensorrt_llm/models/modeling_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ class SpeculativeDecodingMode(IntFlag):
9898
EAGLE = auto()
9999
NGRAM = auto()
100100
USER_PROVIDED = auto()
101+
SAVE_HIDDEN_STATES = auto()
101102
AUTO = auto()
102103

103104
@staticmethod
@@ -120,6 +121,8 @@ def from_arguments(args: argparse.Namespace):
120121
return SpeculativeDecodingMode.USER_PROVIDED
121122
elif args.speculative_decoding_mode == "auto":
122123
return SpeculativeDecodingMode.AUTO
124+
elif args.speculative_decoding_mode == "save_hidden_states":
125+
return SpeculativeDecodingMode.SAVE_HIDDEN_STATES
123126
else:
124127
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode
125128

0 commit comments

Comments
 (0)