@@ -380,6 +380,7 @@ def from_dict(cls, data: dict):
380380 "Lookahead" : LookaheadDecodingConfig ,
381381 "NGram" : NGramDecodingConfig ,
382382 "DraftTarget" : DraftTargetDecodingConfig ,
383+ "SaveState" : SaveHiddenStatesDecodingConfig ,
383384 "UserProvided" : UserProvidedDecodingConfig ,
384385 "AUTO" : AutoDecodingConfig ,
385386 }
@@ -562,6 +563,52 @@ def num_capture_layers(self) -> int:
562563 return 3
563564
564565
566+ class SaveHiddenStatesDecodingConfig (DecodingBaseConfig ):
567+ output_directory : str
568+ write_interval : int = 20
569+ file_prefix : str = "data"
570+ eagle3_layers_to_capture : Optional [Set [int ]] = None
571+
572+ max_total_draft_tokens : Optional [int ] = Field (default = 1 , init = False )
573+ eagle_choices : Optional [List [List [int ]]] = Field (default = None , init = False )
574+
575+ def model_post_init (self , __context ):
576+ self ._last_hidden_in_save = True
577+ if self .eagle3_layers_to_capture is None :
578+ self ._last_hidden_in_save = False
579+ elif - 1 not in self .eagle3_layers_to_capture :
580+ self ._last_hidden_in_save = False
581+ self .eagle3_layers_to_capture .add (- 1 )
582+
583+ @classmethod
584+ def from_dict (cls , data : dict ):
585+ return cls (** data )
586+
587+ decoding_type : ClassVar [str ] = "SaveState"
588+
589+ def validate (self ) -> None :
590+ if self .output_directory is None or not self .eagle3_layers_to_capture :
591+ raise ValueError (
592+ "Save directory and layers to capture must be provided" )
593+
594+ @functools .cached_property
595+ def spec_dec_mode (self ):
596+ from tensorrt_llm ._torch .speculative .interface import \
597+ SpeculativeDecodingMode as TorchSpeculativeDecodingMode
598+ return TorchSpeculativeDecodingMode .SAVE_HIDDEN_STATES
599+
600+ @functools .cached_property
601+ def num_capture_layers (self ):
602+ """
603+ Returns the number of layers to capture of the target model.
604+ If eagle3_layers_to_capture is not None, return the length of the set.
605+ Otherwise, assume Eagle3 base set and return 3 + 1 (for post norm last hidden state).
606+ """
607+ if self .eagle3_layers_to_capture is None :
608+ return 4
609+ return len (self .eagle3_layers_to_capture )
610+
611+
565612class UserProvidedDecodingConfig (DecodingBaseConfig ):
566613 # Cannot use real type annotations due to circular imports
567614 drafter : object # Type is Drafter
@@ -1050,6 +1097,7 @@ def supports_backend(self, backend: str) -> bool:
10501097 MTPDecodingConfig ,
10511098 NGramDecodingConfig ,
10521099 UserProvidedDecodingConfig ,
1100+ SaveHiddenStatesDecodingConfig ,
10531101 AutoDecodingConfig ,
10541102]]
10551103
@@ -1869,6 +1917,20 @@ def validate_speculative_config(self):
18691917 self .build_config .speculative_decoding_mode = SpeculativeDecodingMode .AUTO
18701918 self .build_config .max_draft_len = self .speculative_config .max_draft_len
18711919
1920+ elif isinstance (self .speculative_config ,
1921+ SaveHiddenStatesDecodingConfig ):
1922+ assert self .backend in ['pytorch' ]
1923+ logger .warning (
1924+ "SaveHiddenStatesDecodingConfig is active, setting max_batch_size to 1, disabling overlap scheduler, and setting cuda_graph_config to None"
1925+ )
1926+ self .build_config .max_batch_size = 1
1927+ self .max_batch_size = 1
1928+ self .disable_overlap_scheduler = True
1929+ self .cuda_graph_config = None
1930+ self .build_config .speculative_decoding_mode = SpeculativeDecodingMode .SAVE_HIDDEN_STATES
1931+ self .build_config .max_draft_len = 1
1932+ self .speculative_config .max_draft_len = 1
1933+
18721934 else :
18731935 raise ValueError (
18741936 f"Unrecognized speculative config type { type (self .speculative_config )} "
0 commit comments