Skip to content

Conversation

IzzyPutterman
Copy link
Collaborator

@IzzyPutterman IzzyPutterman commented Aug 18, 2025

Summary by CodeRabbit

  • New Features

    • Adds a speculative decoding mode that saves final and selectable-layer hidden states with periodic per-request exports and config options for output directory, file prefix, write interval, layers to capture, and last-layer post-norm capture.
    • Model now conditionally captures final hidden states when that mode is active.
  • Behavior Changes

    • Selecting this mode requires the PyTorch backend, forces max batch size = 1, fixes draft length = 1, and disables the overlap scheduler.

Description

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Contributor

coderabbitai bot commented Aug 18, 2025

Caution

Review failed

The head commit changed during the review from 373f82f to 75a598f.

📝 Walkthrough

Walkthrough

Adds a SAVE_HIDDEN_STATES speculative decoding mode with config, SpecMetadata/resource-manager/drafter implementations to persist final hidden states, wires a conditional final-hidden-state capture hook into the model forward, and updates utilities and public exports to construct and validate the new mode.

Changes

Cohort / File(s) Summary
Model forward hook
tensorrt_llm/_torch/models/modeling_speculative.py
Call spec_metadata.maybe_capture_final_hidden_states(hidden_states) when spec_metadata is present and spec_metadata.is_final_output_capture() is true; no change to return path.
Speculative interface extensions
tensorrt_llm/_torch/speculative/interface.py
Add SpeculativeDecodingMode.SAVE_HIDDEN_STATES and is_save_hidden_states(); add SpecMetadata.is_final_output_capture() and no-op hook maybe_capture_final_hidden_states(...); treat SAVE_HIDDEN_STATES as a valid drafter mode.
SaveHiddenStates implementation
tensorrt_llm/_torch/speculative/save_hidden_state.py
New classes: SaveHiddenStatesSpecMetadata, SaveHiddenStatesResourceManager (optional CUDA last_hidden_states allocation), and SaveHiddenStatesDrafter (collects per-request states, pads draft tokens, periodic root-rank writes to files).
Utils wiring for new mode
tensorrt_llm/_torch/speculative/utils.py
Import new SaveHiddenStates types and add branches in get_spec_metadata, get_spec_resource_manager, and get_spec_drafter to construct SaveHiddenStates variants when mode is SAVE_HIDDEN_STATES.
LLM API config and validation
tensorrt_llm/llmapi/llm_args.py
Add SaveHiddenStatesDecodingConfig (output_directory, write_interval, file_prefix, eagle3_layers_to_capture, save_last_layer_post_norm), include in dispatch/union, and validation forcing PyTorch backend, batch size 1, disable overlap scheduler/CUDA graphs, set mode to SAVE_HIDDEN_STATES, and set max_draft_len=1.
Public exports
tensorrt_llm/_torch/speculative/__init__.py, tensorrt_llm/llmapi/__init__.py
Export SaveHiddenStatesDrafter, SaveHiddenStatesSpecMetadata, and SaveHiddenStatesDecodingConfig via module __all__.
Modeling utils enum
tensorrt_llm/models/modeling_utils.py
Add SAVE_HIDDEN_STATES to SpeculativeDecodingMode and parsing support in from_arguments.
Drafter hooks and executor
tensorrt_llm/_torch/speculative/drafter.py, tensorrt_llm/_torch/speculative/model_drafter.py, tensorrt_llm/_torch/pyexecutor/py_executor.py
Add Drafter.needs_draft_forward_post() (default False) and ModelDrafter.prepare_draft_tokens_post(...) (no-op hook). PyExecutor tracks draft_forward_post_needed and invokes drafter.prepare_draft_tokens_post(...) after sampling when needed.

Sequence Diagram(s)

sequenceDiagram
    autonumber
    participant Client
    participant Config as LLMArgs/Config
    participant Utils
    participant ResourceMgr
    participant Model
    participant SpecMeta
    participant Drafter
    participant FS as Filesystem

    Client->>Config: submit request with SaveState config
    Config->>Utils: get_spec_metadata(), get_spec_resource_manager(), get_spec_drafter()
    Utils->>ResourceMgr: construct SaveHiddenStatesResourceManager
    Client->>Model: forward(input_ids, spec_metadata)
    Model->>SpecMeta: is_final_output_capture()?
    alt final capture requested
        SpecMeta->>SpecMeta: maybe_capture_final_hidden_states(hidden_states)
        SpecMeta-->>ResourceMgr: store last_hidden_states (if allocated)
    end
    Model-->>Client: logits/draft output
    Client->>Drafter: prepare_draft_tokens(requests, resource_manager)
    Drafter->>Drafter: collect per-request saved states
    alt root rank & interval reached
        Drafter->>FS: write .pt file(s)
    end
    Drafter-->>Client: draft tokens
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

Community want to contribute

Suggested reviewers

  • Superjomn
  • ziyixiong-nv
  • syuoni
  • nv-guomingz
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@IzzyPutterman
Copy link
Collaborator Author

TODO: actually test, handle initial warmup configs (dummy data for sizing shouldnt be saved), handle chunking, make tests.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 13

🔭 Outside diff range comments (2)
tensorrt_llm/_torch/speculative/utils.py (1)

167-171: Add drafter branch for SAVE_HIDDEN_STATES to match has_spec_drafter().

Apply this diff:

     if spec_config.spec_dec_mode.is_ngram():
         return NGramDrafter(spec_config, spec_resource_manager)
 
+    if spec_config.spec_dec_mode.is_save_hidden_states():
+        return SaveHiddenStatesDrafter(spec_config)
+
     return None
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

92-112: Fix call site, draft tokens, and write cadence.

  • Pass resource_manager into _process_request.
  • Define draft_tokens deterministically to max_draft_len.
  • Increment _iter before write modulo check to avoid writing at iteration 0.

Apply this diff:

     def prepare_draft_tokens(
         self,
         scheduled_requests: ScheduledRequests,
         resource_manager: Optional[ResourceManager] = None,
     ) -> None:
         for request in sorted(
                 scheduled_requests.context_requests,
                 key=lambda r:
             (r.py_batch_idx is None, r.py_batch_idx or r.request_id),
         ):
             request.py_max_new_tokens = 1
-            self._process_request(request, resource_manager)
-            if self._iter % self._write_interval == 0:
-                self._write_to_file()
-            self._iter += 1
-            # Pad length to `self.max_draft_len`
-            if len(draft_tokens) > 0:
-                pad_length = self.max_draft_len - len(draft_tokens)
-                draft_tokens.extend([0] * pad_length)
-            request.py_draft_tokens = draft_tokens
+            self._process_request(request, resource_manager)
+            self._iter += 1
+            if self._iter % self._write_interval == 0:
+                self._write_to_file()
+            # Always provide a fixed-length draft token list
+            draft_tokens = [0] * self.max_draft_len
+            request.py_draft_tokens = draft_tokens
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/interface.py (1)

192-209: No-op hooks added appropriately; minor doc/typing polish suggested.

The hooks are fine as defaults. Consider adding a return type for is_final_output_capture -> bool for consistency.

-    def is_final_output_capture(self):
+    def is_final_output_capture(self) -> bool:
         return False
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)

12-18: Add return typing; keep API explicit.

-    def is_final_output_capture(self):
+    def is_final_output_capture(self) -> bool:
         return self.save_last_layer_post_norm

29-40: Resource manager init OK; ensure torch is imported (fixed above).

Also consider documenting last_hidden_states semantics in a short docstring.

tensorrt_llm/llmapi/llm_args.py (1)

449-472: New config LGTM; consider ensuring the output directory exists in validate().

Apply this diff:

     def validate(self) -> None:
-        if self.output_directory is None or not self.eagle3_layers_to_capture:
+        if self.output_directory is None or not self.eagle3_layers_to_capture:
             raise ValueError(
                 "Save directory and layers to capture must be provided")
+        os.makedirs(self.output_directory, exist_ok=True)
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d1d17db and 16b9e78.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (5 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (3 hunks)
  • tensorrt_llm/llmapi/llm_args.py (5 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/llmapi/llm_args.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

12-12: Undefined name dataclass

(F821)


20-20: Undefined name torch

(F821)


29-29: Undefined name SaveHiddenStatesDecodingConfig

(F821)


30-30: Undefined name torch

(F821)


36-36: Undefined name torch

(F821)


46-46: Undefined name SaveHiddenStatesDecodingConfig

(F821)


60-60: Undefined name torch

(F821)


61-61: Undefined name torch

(F821)


63-63: Undefined name resource_manager

(F821)


65-65: Undefined name resource_manager

(F821)


68-68: Undefined name resource_manager

(F821)


78-78: Undefined name resource_manager

(F821)


87-87: Undefined name os

(F821)


89-89: Undefined name torch

(F821)


108-108: Undefined name draft_tokens

(F821)


109-109: Undefined name draft_tokens

(F821)


110-110: Undefined name draft_tokens

(F821)


111-111: Undefined name draft_tokens

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/interface.py (1)

54-56: Predicate addition LGTM.

tensorrt_llm/_torch/models/modeling_speculative.py (1)

421-424: Hook to capture final hidden states LGTM.

Gated behind metadata and a predicate, with no overhead when unused. Correct placement before draft-branching.

tensorrt_llm/llmapi/llm_args.py (1)

12-13: Import of Set is fine.

No issues.

@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from 16b9e78 to dc2eed8 Compare August 20, 2025 16:47
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/utils.py (1)

144-170: Add get_spec_drafter branch for SAVE_HIDDEN_STATES.

Without this branch, the mode won't construct a drafter and will silently return None.

Apply this diff:

     if spec_config.spec_dec_mode.is_ngram():
         return NGramDrafter(spec_config, spec_resource_manager)

+    if spec_config.spec_dec_mode.is_save_hidden_states():
+        return SaveHiddenStatesDrafter(spec_config)
+
     return None
♻️ Duplicate comments (10)
tensorrt_llm/_torch/speculative/utils.py (3)

14-15: Import the drafter to complete the SAVE_HIDDEN_STATES wiring.

The new mode defines a Drafter, but it isn’t imported here, which prevents get_spec_drafter from constructing it.

Apply this diff:

-from .save_hidden_state import (SaveHiddenStatesResourceManager,
-                                SaveHiddenStatesSpecMetadata)
+from .save_hidden_state import (SaveHiddenStatesResourceManager,
+                                SaveHiddenStatesSpecMetadata,
+                                SaveHiddenStatesDrafter)

53-66: Fix incorrect config attribute: use eagle3_layers_to_capture to derive capture count.

spec_config has eagle3_layers_to_capture, not num_capture_layers. Referencing spec_config.num_capture_layers will raise at runtime.

Apply this diff:

         return SaveHiddenStatesSpecMetadata(
             max_draft_len=spec_config.max_draft_len,
             spec_dec_mode=spec_config.spec_dec_mode,
             max_num_requests=max_num_requests,
             num_layers=model_config.num_hidden_layers,
             hidden_size=model_config.hidden_size,
             max_num_tokens=max_num_tokens,
             dtype=model_config.torch_dtype,
             is_draft_model=is_draft_model,
             eagle3_resource_manager=spec_resource_manager,
-            num_capture_layers=spec_config.num_capture_layers,
+            num_capture_layers=(
+                len(spec_config.eagle3_layers_to_capture)
+                if spec_config.eagle3_layers_to_capture is not None
+                else Eagle3SpecMetadata.num_capture_layers
+            ),
             save_last_layer_post_norm=spec_config.save_last_layer_post_norm,
         )

114-122: Avoid possible None dereference: use model dtype instead of draft model dtype.

draft_model_engine can be None for this mode; using draft_model_engine.model.config.torch_dtype risks an AttributeError. Use model_config.torch_dtype.

Apply this diff:

     if spec_dec_mode.is_save_hidden_states():
         return SaveHiddenStatesResourceManager(
             spec_config,
-            draft_model_engine.model.config.torch_dtype,
+            model_config.torch_dtype,
             model_config.hidden_size,
             max_num_requests,
             max_seq_len,
             max_num_tokens,
         )
tensorrt_llm/_torch/speculative/save_hidden_state.py (5)

1-11: Add required header and missing imports; fix forward refs to avoid import-time NameErrors.

dataclass/torch/os are missing; type SaveHiddenStatesDecodingConfig is referenced but not imported or quoted. Use TYPE_CHECKING and string annotations to avoid cycles.

Apply this diff:

-from typing import Optional
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+from dataclasses import dataclass
+import os
+from typing import Optional, TYPE_CHECKING
+import torch
@@
-from tensorrt_llm._utils import local_mpi_rank
+from tensorrt_llm._utils import local_mpi_rank
+
+if TYPE_CHECKING:
+    from tensorrt_llm.llmapi.llm_args import SaveHiddenStatesDecodingConfig

19-25: Slice copy to avoid shape mismatch when capturing final hidden states.

The preallocated buffer is sized to max_num_tokens; hidden_states holds only the active tokens. Copy only the active rows.

Apply this diff:

     def maybe_capture_final_hidden_states(self,
                                           hidden_states: torch.Tensor) -> None:
         if self.save_last_layer_post_norm:
-            # Assume no chunking, BS=1
-            eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states
-            eagle3_hidden_states.copy_(hidden_states)
+            # Assume no chunking, BS=1; copy only active tokens
+            buf = self.eagle3_resource_manager.last_hidden_states
+            n = hidden_states.shape[0]
+            buf[:n].copy_(hidden_states)

44-56: Ensure output directory exists; fix type annotation to avoid runtime import.

Create the output directory up-front. Quote the type annotation to avoid importing at module import time.

Apply this diff:

     def __init__(
         self,
-        spec_config: SaveHiddenStatesDecodingConfig,
+        spec_config: "SaveHiddenStatesDecodingConfig",
     ):
         super().__init__(spec_config.max_concurrency)
         self.spec_config = spec_config
         self.max_draft_len = spec_config.max_draft_len
         self._iter = 0
         self._output_directory = spec_config.output_directory
         self._file_prefix = spec_config.file_prefix
         self._write_interval = spec_config.write_interval
         self._saved_state = []
+        os.makedirs(self._output_directory, exist_ok=True)

57-84: Fix signature, rank logic, undefined names, and iteration field.

  • _process_request must accept resource_manager.
  • Accumulate on root (rank 0) to match write-on-root.
  • Use self._iter.
  • Don’t reference undefined variables.

Apply this diff:

-    def _process_request(self, request: LlmRequest) -> None:
-        out_dict = {}
-        if local_mpi_rank() != 0:
-            input_ids = torch.tensor(list(request.get_tokens(0)),
-                                     dtype=torch.long,
-                                     device='cpu')
-            hidden_size = resource_manager.hidden_size
-            if self.spec_config.save_last_layer_post_norm:
-                hidden_states = resource_manager.last_hidden_states.cpu().clone(
-                )
-            else:
-                hidden_states = resource_manager.hidden_states[:,
-                                                               -hidden_size:].cpu(
-                                                               ).clone()
-
-            out_dict = {
-                "id":
-                self.iteration,
-                "input_ids":
-                input_ids,
-                "hidden_state_features":
-                resource_manager.hidden_states.cpu().clone(),
-                "hidden_state":
-                hidden_states,
-            }
-
-            self._saved_state.append(out_dict)
+    def _process_request(self, request: LlmRequest,
+                         resource_manager: ResourceManager) -> None:
+        if local_mpi_rank() == 0:
+            input_ids = torch.tensor(
+                list(request.get_tokens(0)), dtype=torch.long, device='cpu')
+            hidden_size = resource_manager.hidden_size
+            if getattr(self.spec_config, "save_last_layer_post_norm", False):
+                hidden_state = resource_manager.last_hidden_states.cpu().clone()
+            else:
+                # Fallback: slice last layer/state from captured features.
+                hidden_state = resource_manager.hidden_states[:, -hidden_size:
+                                                             ].cpu().clone()
+            out_dict = {
+                "id": self._iter,
+                "input_ids": input_ids,
+                "hidden_state_features": resource_manager.hidden_states.cpu().
+                clone(),
+                "hidden_state": hidden_state,
+            }
+            self._saved_state.append(out_dict)

85-91: Write only when there’s data; fix undefined fields.

Use self._iter and check self._saved_state; clear after write.

Apply this diff:

     def _write_to_file(self) -> None:
-        if local_mpi_rank() == 0 and self.iteration != self.start_iteration:
-            output_path = os.path.join(self._output_directory,
-                                       f"{self._file_prefix}_{self._iter}.pt")
-            torch.save(self._saved_state, output_path)
-        self._saved_state = []
+        if local_mpi_rank() == 0 and self._saved_state:
+            output_path = os.path.join(self._output_directory,
+                                       f"{self._file_prefix}_{self._iter}.pt")
+            torch.save(self._saved_state, output_path)
+            self._saved_state = []
tensorrt_llm/llmapi/llm_args.py (2)

943-953: Fix broken type alias: use SaveHiddenStatesDecodingConfig.

The singular name is undefined and will break imports.

Apply this diff:

 SpeculativeConfig: TypeAlias = Optional[Union[
     DraftTargetDecodingConfig,
     EagleDecodingConfig,
     LookaheadDecodingConfig,
     MedusaDecodingConfig,
     MTPDecodingConfig,
     NGramDecodingConfig,
     UserProvidedDecodingConfig,
-    SaveHiddenStateDecodingConfig,
+    SaveHiddenStatesDecodingConfig,
     AutoDecodingConfig,
 ]]

364-371: Remove duplicate "SaveState" dispatch entry.

Duplicate keys lead to confusing behavior and may shadow earlier entries.

Apply this diff:

         config_classes = {
             "MTP": MTPDecodingConfig,
             "Medusa": MedusaDecodingConfig,
             "Eagle": EagleDecodingConfig,
-            "SaveState": SaveHiddenStatesDecodingConfig,
             "Lookahead": LookaheadDecodingConfig,
             "NGram": NGramDecodingConfig,
             "DraftTarget": DraftTargetDecodingConfig,
-            "SaveState": SaveHiddenStatesDecodingConfig,
+            "SaveState": SaveHiddenStatesDecodingConfig,
             "UserProvided": UserProvidedDecodingConfig,
             "AUTO": AutoDecodingConfig,
         }
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/utils.py (1)

1-1: Prepend NVIDIA copyright header.

Per coding guidelines, add the NVIDIA copyright header to all source files.

Apply this diff at the top of the file:

+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/llmapi/llm_args.py (1)

449-472: New SaveHiddenStatesDecodingConfig: basic wiring looks correct.

Definition, spec_dec_mode override, and validation are coherent. Consider adding a class docstring and field docstrings to match repo standards, but not blocking.

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 16b9e78 and dc2eed8.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (5 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (3 hunks)
  • tensorrt_llm/llmapi/llm_args.py (5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

12-12: Undefined name dataclass

(F821)


20-20: Undefined name torch

(F821)


29-29: Undefined name SaveHiddenStatesDecodingConfig

(F821)


30-30: Undefined name torch

(F821)


36-36: Undefined name torch

(F821)


46-46: Undefined name SaveHiddenStatesDecodingConfig

(F821)


60-60: Undefined name torch

(F821)


61-61: Undefined name torch

(F821)


63-63: Undefined name resource_manager

(F821)


65-65: Undefined name resource_manager

(F821)


68-68: Undefined name resource_manager

(F821)


78-78: Undefined name resource_manager

(F821)


87-87: Undefined name os

(F821)


89-89: Undefined name torch

(F821)


108-108: Undefined name draft_tokens

(F821)


109-109: Undefined name draft_tokens

(F821)


110-110: Undefined name draft_tokens

(F821)


111-111: Undefined name draft_tokens

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

29-40: Ensure last_hidden_states buffer allocated only when configured; torch dtype import fixed above.

Constructor is fine; annotation for config already addressed via TYPE_CHECKING. No further changes required here.

tensorrt_llm/llmapi/llm_args.py (1)

12-12: LGTM: typing.Set import is correct.

The added Set import aligns with new config usage.

@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from dc2eed8 to 8eb53f3 Compare August 22, 2025 19:01
@IzzyPutterman IzzyPutterman changed the title Draft: Save state first pass [None][feat] Draft: Save state first pass Aug 22, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (9)
tensorrt_llm/_torch/speculative/utils.py (3)

14-16: Resolved: Drafter/resource manager/metadata are now imported.

This addresses the previous import omission for the drafter.


117-125: Avoid possible None dereference; use model dtype, not draft model dtype.

In SAVE_HIDDEN_STATES there may be no draft model. Use model_config.torch_dtype.

Apply this diff:

     if spec_dec_mode.is_save_hidden_states():
         return SaveHiddenStatesResourceManager(
             spec_config,
-            draft_model_engine.model.config.torch_dtype,
+            model_config.torch_dtype,
             model_config.hidden_size,
             max_num_requests,
             max_seq_len,
             max_num_tokens,
         )

56-69: Fix SaveHiddenStatesSpecMetadata constructor kwargs

The SaveHiddenStatesSpecMetadata class inherits its initializer from Eagle3SpecMetadata, which defines a field named layers_to_capture—there is no num_capture_layers parameter. Passing num_capture_layers will cause a runtime error.

• Location to update:

  • File: tensorrt_llm/_torch/speculative/utils.py
  • In the if spec_config.spec_dec_mode.is_save_hidden_states(): branch (around lines 56–69)

Apply this diff:

     if spec_config.spec_dec_mode.is_save_hidden_states():
         return SaveHiddenStatesSpecMetadata(
             max_draft_len=spec_config.max_draft_len,
             spec_dec_mode=spec_config.spec_dec_mode,
             max_num_requests=max_num_requests,
             num_layers=model_config.num_hidden_layers,
             hidden_size=model_config.hidden_size,
             max_num_tokens=max_num_tokens,
             dtype=model_config.torch_dtype,
             is_draft_model=is_draft_model,
             eagle3_resource_manager=spec_resource_manager,
-            num_capture_layers=spec_config.num_capture_layers,
+            layers_to_capture=spec_config.eagle3_layers_to_capture,
             save_last_layer_post_norm=spec_config.save_last_layer_post_norm,
         )
tensorrt_llm/_torch/speculative/save_hidden_state.py (6)

1-4: Missing required header and imports cause import-time failures.

The module uses dataclass/torch/os but does not import them, and the file lacks the required NVIDIA copyright header.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+from dataclasses import dataclass
+import os
+import torch
 from typing import Optional
 
 from tensorrt_llm._utils import local_mpi_rank

16-25: Fix shape mismatch risk when copying final hidden states.

last_hidden_states is allocated to (max_num_tokens, hidden_size), while incoming hidden_states is (#active_tokens, hidden_size). copy_ requires matching shapes.

Apply this diff:

     def maybe_capture_final_hidden_states(self,
                                           hidden_states: torch.Tensor) -> None:
         if self.save_last_layer_post_norm:
-            # Assume no chunking, BS=1
-            eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states
-            eagle3_hidden_states.copy_(hidden_states)
+            # Assume no chunking, BS=1; copy only active token rows.
+            buf = self.eagle3_resource_manager.last_hidden_states
+            n = hidden_states.shape[0]
+            buf[:n].copy_(hidden_states)

44-56: Initialize output directory and fix forward reference typing.

Ensure the output directory exists and avoid importing the config type at runtime by using a string annotation.

Apply this diff:

-    def __init__(
-        self,
-        spec_config: SaveHiddenStatesDecodingConfig,
-    ):
+    def __init__(
+        self,
+        spec_config: "SaveHiddenStatesDecodingConfig",
+    ):
         super().__init__(spec_config.max_concurrency)
         self.spec_config = spec_config
         self.max_draft_len = spec_config.max_draft_len
         self._iter = 0
         self._output_directory = spec_config.output_directory
         self._file_prefix = spec_config.file_prefix
         self._write_interval = spec_config.write_interval
         self._saved_state = []
+        os.makedirs(self._output_directory, exist_ok=True)

57-84: Fix undefined names, wrong rank condition, and signature; use resource_manager param.

  • _process_request must accept resource_manager.
  • Accumulate on root (rank 0) if only root writes.
  • Replace undefined iteration with self._iter.
  • Use getattr for optional flag.

Apply this diff:

-    def _process_request(self, request: LlmRequest) -> None:
-        out_dict = {}
-        if local_mpi_rank() != 0:
-            input_ids = torch.tensor(list(request.get_tokens(0)),
-                                     dtype=torch.long,
-                                     device='cpu')
-            hidden_size = resource_manager.hidden_size
-            if self.spec_config.save_last_layer_post_norm:
-                hidden_states = resource_manager.last_hidden_states.cpu().clone(
-                )
-            else:
-                hidden_states = resource_manager.hidden_states[:,
-                                                               -hidden_size:].cpu(
-                                                               ).clone()
-
-            out_dict = {
-                "id":
-                self.iteration,
-                "input_ids":
-                input_ids,
-                "hidden_state_features":
-                resource_manager.hidden_states.cpu().clone(),
-                "hidden_state":
-                hidden_states,
-            }
-
-            self._saved_state.append(out_dict)
+    def _process_request(self, request: LlmRequest,
+                         resource_manager: ResourceManager) -> None:
+        if local_mpi_rank() == 0:
+            input_ids = torch.tensor(
+                list(request.get_tokens(0)), dtype=torch.long, device="cpu")
+            hidden_size = resource_manager.hidden_size
+            if getattr(self.spec_config, "save_last_layer_post_norm", False):
+                hidden_state = resource_manager.last_hidden_states.cpu().clone()
+            else:
+                # Slice the last layer from the features buffer.
+                hidden_state = resource_manager.hidden_states[:, -hidden_size:].cpu().clone()
+            out_dict = {
+                "id": self._iter,
+                "input_ids": input_ids,
+                "hidden_state_features": resource_manager.hidden_states.cpu().clone(),
+                "hidden_state": hidden_state,
+            }
+            self._saved_state.append(out_dict)

85-91: Write only when there’s data; clear buffer after save.

Prevents empty files and resets the state list in the same rank-0 block.

Apply this diff:

     def _write_to_file(self) -> None:
-        if local_mpi_rank() == 0 and self.iteration != self.start_iteration:
-            output_path = os.path.join(self._output_directory,
-                                       f"{self._file_prefix}_{self._iter}.pt")
-            torch.save(self._saved_state, output_path)
-        self._saved_state = []
+        if local_mpi_rank() == 0 and self._saved_state:
+            output_path = os.path.join(
+                self._output_directory, f"{self._file_prefix}_{self._iter}.pt")
+            torch.save(self._saved_state, output_path)
+            self._saved_state = []

92-111: Initialize draft_tokens and pass resource_manager; avoid negative padding.

Avoid NameError and ensure padding is non-negative.

Apply this diff:

         for request in sorted(
                 scheduled_requests.context_requests,
                 key=lambda r:
             (r.py_batch_idx is None, r.py_batch_idx or r.request_id),
         ):
             request.py_max_new_tokens = 1
             self._process_request(request, resource_manager)
             if self._iter % self._write_interval == 0:
                 self._write_to_file()
             self._iter += 1
             # Pad length to `self.max_draft_len`
-            if len(draft_tokens) > 0:
-                pad_length = self.max_draft_len - len(draft_tokens)
-                draft_tokens.extend([0] * pad_length)
-            request.py_draft_tokens = draft_tokens
+            draft_tokens: list[int] = []
+            pad_length = max(0, self.max_draft_len - len(draft_tokens))
+            if pad_length:
+                draft_tokens.extend([0] * pad_length)
+            request.py_draft_tokens = draft_tokens
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/__init__.py (1)

20-21: all updated appropriately; consider consistent ordering.

Entries are exposed correctly. For readability, consider keeping all alphabetized or grouped by feature (optional).

tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

1-111: Add minimal docstrings for new public classes/methods.

These types are now part of the public surface (also exported in init.py). Add short Google-style docstrings describing purpose and expected shapes/devices for hidden_state buffers.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between dc2eed8 and 8eb53f3.

📒 Files selected for processing (6)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/__init__.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (5 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (4 hunks)
  • tensorrt_llm/llmapi/llm_args.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/llmapi/llm_args.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/__init__.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/speculative/utils.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/__init__.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/_torch/speculative/utils.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

12-12: Undefined name dataclass

(F821)


20-20: Undefined name torch

(F821)


29-29: Undefined name SaveHiddenStatesDecodingConfig

(F821)


30-30: Undefined name torch

(F821)


36-36: Undefined name torch

(F821)


46-46: Undefined name SaveHiddenStatesDecodingConfig

(F821)


60-60: Undefined name torch

(F821)


61-61: Undefined name torch

(F821)


63-63: Undefined name resource_manager

(F821)


65-65: Undefined name resource_manager

(F821)


68-68: Undefined name resource_manager

(F821)


78-78: Undefined name resource_manager

(F821)


87-87: Undefined name os

(F821)


89-89: Undefined name torch

(F821)


108-108: Undefined name draft_tokens

(F821)


109-109: Undefined name draft_tokens

(F821)


110-110: Undefined name draft_tokens

(F821)


111-111: Undefined name draft_tokens

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check

@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from 95142ec to 8a19b2e Compare August 23, 2025 05:15
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/interface.py (1)

133-134: Bug: default value becomes a tuple due to trailing comma.

spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE, creates a tuple default (SpeculativeDecodingMode.NONE,). Any call like self.spec_dec_mode.is_ngram() will fail.

Apply:

-    spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
+    spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
♻️ Duplicate comments (4)
tensorrt_llm/_torch/speculative/interface.py (1)

89-90: has_spec_drafter now includes SAVE_HIDDEN_STATES — ensure utils.get_spec_drafter returns a drafter.

If has_spec_drafter() can return True for this mode, tensorrt_llm/_torch/speculative/utils.py:get_spec_drafter must return a SaveHiddenStatesDrafter instance. Otherwise callers will receive None.

If missing, update utils.py accordingly:

-from .save_hidden_state import (SaveHiddenStatesResourceManager,
-                                SaveHiddenStatesSpecMetadata)
+from .save_hidden_state import (SaveHiddenStatesResourceManager,
+                                SaveHiddenStatesSpecMetadata,
+                                SaveHiddenStatesDrafter)
@@
     if spec_config.spec_dec_mode.is_ngram():
         return NGramDrafter(spec_config, spec_resource_manager)
 
+    if spec_config.spec_dec_mode.is_save_hidden_states():
+        return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)
tensorrt_llm/_torch/speculative/save_hidden_state.py (3)

1-7: Missing NVIDIA copyright header (required for source files).

Add the standard NVIDIA header at the top of this file.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
 import os
 from dataclasses import dataclass
 from typing import Optional

66-97: Root/non-root rank logic is inverted; data never gets written on rank 0.

You append to _saved_state only on non-root ranks, but _write_to_file() writes only on root. Flip the condition and simplify.

-        out_dict = {}
-        if local_mpi_rank() != 0:
+        if local_mpi_rank() == 0:
             input_ids = torch.tensor(list(request.get_tokens(0)),
                                      dtype=torch.long,
                                      device='cpu')
             hidden_size = resource_manager.hidden_size
             num_tokens = input_ids.shape[0]
             if self.spec_config.save_last_layer_post_norm:
-                hidden_states = resource_manager.last_hidden_states[:
-                                                                    num_tokens, :].cpu(
-                                                                    ).clone()
+                hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone()
             else:
-                hidden_states = resource_manager.hidden_states[:num_tokens,
-                                                               -hidden_size:].cpu(
-                                                               ).clone()
+                hidden_states = resource_manager.hidden_states[:num_tokens, -hidden_size:].cpu().clone()
 
             out_dict = {
                 "id":
                 self._iter,
                 "input_ids":
                 input_ids,
                 "hidden_state_features":
                 resource_manager.hidden_states[:num_tokens, :].cpu().clone(),
                 "hidden_state":
                 hidden_states,
             }
 
             self._saved_state.append(out_dict)

98-104: Write only when there’s buffered data; reset after successful write.

As written, you can emit empty files. Also consider flushing leftovers at the end of prepare_draft_tokens.

-    def _write_to_file(self) -> None:
-        if local_mpi_rank() == 0 and self._iter != self._start_iter:
-            output_path = os.path.join(self._output_directory,
-                                       f"{self._file_prefix}_{self._iter}.pt")
-            torch.save(self._saved_state, output_path)
-        self._saved_state = []
+    def _write_to_file(self) -> None:
+        if local_mpi_rank() == 0 and self._saved_state:
+            output_path = os.path.join(
+                self._output_directory, f"{self._file_prefix}_{self._iter}.pt"
+            )
+            torch.save(self._saved_state, output_path)
+            self._saved_state = []
🧹 Nitpick comments (7)
tensorrt_llm/_torch/speculative/interface.py (1)

199-205: Name/docstring mismatch; clarify intent.

This method controls final-output capture, not per-layer capture. The docstring should reflect that to avoid confusion with is_layer_capture.

-    def is_final_output_capture(self):
-        """
-        Whether the layer should be captured (eg for Eagle3).
-        Captured after layer norm in modeling_speculative.
-        """
+    def is_final_output_capture(self):
+        """
+        Whether the final hidden state (post norm) should be captured.
+        Used by modes that need the last-layer output after layer norm.
+        """
tensorrt_llm/llmapi/llm_args.py (4)

461-495: Add minimal validation/docstrings for SaveHiddenStatesDecodingConfig.

Two small guardrails will prevent misconfigurations:

  • Validate write_interval >= 1.
  • Restrict backend via supports_backend override (mirrors other modes).

Also add a brief class docstring (consistent with guidelines).

 class SaveHiddenStatesDecodingConfig(DecodingBaseConfig):
+    """
+    Configuration for saving hidden states during decoding (PyTorch backend only).
+    Writes per-request inputs and hidden states to output_directory every write_interval iterations.
+    """
     output_directory: str
     write_interval: int = 20
     file_prefix: str = "data"
     eagle3_layers_to_capture: Optional[Set[int]] = None
     save_last_layer_post_norm: bool = True
 
     @classmethod
     def from_dict(cls, data: dict):
         return cls(**data)
 
     decoding_type: ClassVar[str] = "SaveState"
 
     def validate(self) -> None:
         if self.output_directory is None or not self.eagle3_layers_to_capture:
             raise ValueError(
                 "Save directory and layers to capture must be provided")
 
+    def supports_backend(self, backend: str) -> bool:
+        return backend == "pytorch"
+
+    @field_validator('write_interval')
+    @classmethod
+    def _validate_write_interval(cls, v: int) -> int:
+        if v < 1:
+            raise ValueError("write_interval must be >= 1")
+        return v

968-978: Type alias updated; fix long line to appease linters.

The union includes the new config. Consider wrapping to <=120 chars for Ruff E501.

-SpeculativeConfig: TypeAlias = Optional[Union[
-    DraftTargetDecodingConfig,
-    EagleDecodingConfig,
-    LookaheadDecodingConfig,
-    MedusaDecodingConfig,
-    MTPDecodingConfig,
-    NGramDecodingConfig,
-    UserProvidedDecodingConfig,
-    SaveHiddenStatesDecodingConfig,
-    AutoDecodingConfig,
-]]
+SpeculativeConfig: TypeAlias = Optional[
+    Union[
+        DraftTargetDecodingConfig,
+        EagleDecodingConfig,
+        LookaheadDecodingConfig,
+        MedusaDecodingConfig,
+        MTPDecodingConfig,
+        NGramDecodingConfig,
+        UserProvidedDecodingConfig,
+        SaveHiddenStatesDecodingConfig,
+        AutoDecodingConfig,
+    ]
+]

1751-1761: Runtime constraints for SaveState are sensible; add explicit doc note or logging.

Forcing PyTorch, BS=1, disabling overlap scheduler, and turning off CUDA graphs are appropriate. Consider logging a one-liner when this branch triggers so users understand why their knobs were overridden.

             elif isinstance(self.speculative_config,
                             SaveHiddenStatesDecodingConfig):
                 assert self.backend in ['pytorch']
+                logger.info("SaveHiddenStates mode: forcing max_batch_size=1, disabling overlap scheduler and CUDA graphs.")
                 self.build_config.max_batch_size = 1
                 self.max_batch_size = 1
                 self.disable_overlap_scheduler = True
                 self.cuda_graph_config = None
                 self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
                 self.build_config.max_draft_len = 1
                 self.speculative_config.max_draft_len = 1

If helpful, I can add a minimal unit/integration test that exercises this path and asserts these overrides.


368-371: Dispatch entry for “SaveState” confirmed – please update user-facing documentation

The mapping in DecodingBaseConfig.from_dict (llm_args.py:362) and the decoding_type declaration (llm_args.py:463) correctly recognize "SaveState". However, a search across the repository (excluding third_party) shows no mentions of “SaveState” in any user-facing docs or examples.

To prevent confusion or typos, please update the documentation and examples accordingly. Consider adding or updating references to “SaveState” in:

  • README.md
  • Any files under a docs/ directory
  • Example scripts (e.g., in an examples/ folder)
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)

20-30: Final hidden-state capture: add cheap guards for shape/dtype.

Slicing to the active rows is good. Add an assertion to catch unexpected dtype/device mismatches early.

     def maybe_capture_final_hidden_states(self,
                                           hidden_states: torch.Tensor) -> None:
         if self.save_last_layer_post_norm:
             # Assume no chunking, BS=1
             eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states
+            assert eagle3_hidden_states is not None, "last_hidden_states buffer is not allocated"
+            assert hidden_states.shape[-1] == eagle3_hidden_states.shape[-1], "hidden size mismatch"
             eagle3_hidden_states[:hidden_states.shape[0], :].copy_(
                 hidden_states)

34-45: Resource manager buffer allocation OK; consider propagating device via config.

Hard-coding device='cuda' is fine for now, but passing the device (or using current device) would be more robust in multi-GPU setups.

-            self.last_hidden_states = torch.empty(
-                (max_num_tokens, self.hidden_size),
-                dtype=self.dtype,
-                device='cuda')
+            device = torch.device('cuda')
+            self.last_hidden_states = torch.empty(
+                (max_num_tokens, self.hidden_size), dtype=self.dtype, device=device)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8eb53f3 and 8a19b2e.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/__init__.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (5 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (4 hunks)
  • tensorrt_llm/llmapi/__init__.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (4 hunks)
  • tensorrt_llm/models/modeling_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/init.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/models/modeling_utils.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/models/modeling_utils.py
  • tensorrt_llm/llmapi/__init__.py
  • tensorrt_llm/_torch/speculative/interface.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/llmapi/llm_args.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

34-34: Undefined name SaveHiddenStatesDecodingConfig

(F821)


51-51: Undefined name SaveHiddenStatesDecodingConfig

(F821)


121-121: Undefined name draft_tokens

(F821)


122-122: Undefined name draft_tokens

(F821)


123-123: Undefined name draft_tokens

(F821)


124-124: Undefined name draft_tokens

(F821)

tensorrt_llm/llmapi/llm_args.py

478-481: 1 blank line required between summary line and description

(D205)


976-976: Line too long (194 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/interface.py (2)

54-56: Predicate reads well.

is_save_hidden_states() is consistent with the other helpers.


214-220: Hook addition LGTM; consider being explicit about no-op default.

The hook is a no-op by default; the docstring already implies it. Looks good.

tensorrt_llm/llmapi/__init__.py (1)

14-16: Public export added correctly.

SaveHiddenStatesDecodingConfig is imported and exposed in __all__. Good.

Also applies to: 62-63

@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from 8a19b2e to 3f40394 Compare August 24, 2025 19:24
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)

105-125: draft_tokens is undefined; also flush leftover records after the loop.

  • draft_tokens triggers a NameError the first time. Provide a fixed-shape placeholder of length max_draft_len to satisfy downstream shape checks.
  • After processing the batch, flush any remaining records not aligned to write_interval.
     def prepare_draft_tokens(
         self,
         scheduled_requests: ScheduledRequests,
         resource_manager: Optional[ResourceManager] = None,
     ) -> None:
         for request in sorted(
                 scheduled_requests.context_requests,
                 key=lambda r:
             (r.py_batch_idx is None, r.py_batch_idx or r.request_id),
         ):
             request.py_max_new_tokens = 1
             self._process_request(request, self.spec_resource_manager)
-            if self._iter % self._write_interval == 0:
+            if self._iter % self._write_interval == 0:
                 self._write_to_file()
             self._iter += 1
-            # Pad length to `self.max_draft_len`
-            if len(draft_tokens) > 0:
-                pad_length = self.max_draft_len - len(draft_tokens)
-                draft_tokens.extend([0] * pad_length)
-            request.py_draft_tokens = draft_tokens
+            # Always provide a placeholder of length `self.max_draft_len`
+            draft_tokens: list[int] = [0] * self.max_draft_len
+            request.py_draft_tokens = draft_tokens
+        # Final flush for leftovers not aligned to write_interval
+        if (self._iter % self._write_interval) != 0:
+            self._write_to_file()
♻️ Duplicate comments (2)
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)

66-97: Wrong MPI rank for accumulation; captured data will never be written on rank 0.

_saved_state is appended only when local_mpi_rank() != 0, but _write_to_file writes only on rank 0. On single-GPU (rank 0), nothing is captured; on multi-GPU, rank 0 writes empty shards. Flip the condition and simplify the body.

-        out_dict = {}
-        if local_mpi_rank() != 0:
+        if local_mpi_rank() == 0:
             input_ids = torch.tensor(list(request.get_tokens(0)),
                                      dtype=torch.long,
                                      device='cpu')
             hidden_size = resource_manager.hidden_size
             num_tokens = input_ids.shape[0]
             if self.spec_config.save_last_layer_post_norm:
-                hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone()
+                hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone()
             else:
                 hidden_states = resource_manager.hidden_states[:num_tokens,
                                                                -hidden_size:].cpu(
                                                                ).clone()
 
-            out_dict = {
+            out_dict = {
                 "id":
                 self._iter,
                 "input_ids":
                 input_ids,
                 "hidden_state_features":
                 resource_manager.hidden_states[:num_tokens, :].cpu().clone(),
                 "hidden_state":
                 hidden_states,
             }
 
             self._saved_state.append(out_dict)

98-104: Write only when there is data; avoid emitting empty shards.

As written, rank 0 writes even if _saved_state is empty (and on iteration 0). Gate on presence of data and clear after writing.

-    def _write_to_file(self) -> None:
-        if local_mpi_rank() == 0 and self._iter != self._start_iter:
+    def _write_to_file(self) -> None:
+        if local_mpi_rank() == 0 and self._saved_state:
             output_path = os.path.join(self._output_directory,
                                        f"{self._file_prefix}_{self._iter}.pt")
             torch.save(self._saved_state, output_path)
-        self._saved_state = []
+            self._saved_state = []
🧹 Nitpick comments (4)
tensorrt_llm/llmapi/llm_args.py (2)

968-978: TypeAlias update includes SaveHiddenStatesDecodingConfig — LGTM.

Union now exposes the new config in the public surface. Keep line length within style limits if the linter flags it in CI.


1751-1761: PyTorch-only gating and runtime knobs — good; add a quick guardrail.

The branch correctly forces PyTorch, BS=1, disables overlap scheduler, turns off CUDA graphs, selects SAVE_HIDDEN_STATES, and pins max_draft_len=1. Suggest also guarding against conflicting user inputs to avoid silent surprises.

             elif isinstance(self.speculative_config,
                             SaveHiddenStatesDecodingConfig):
                 assert self.backend in ['pytorch']
+                if self.max_batch_size not in (None, 1):
+                    logger.warning("Overriding max_batch_size to 1 for SAVE_HIDDEN_STATES")
                 self.build_config.max_batch_size = 1
                 self.max_batch_size = 1
                 self.disable_overlap_scheduler = True
                 self.cuda_graph_config = None
                 self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES
                 self.build_config.max_draft_len = 1
                 self.speculative_config.max_draft_len = 1

If helpful, I can add a small unit/config test ensuring these overrides are applied when SaveHiddenStatesDecodingConfig is used.

tensorrt_llm/_torch/speculative/save_hidden_state.py (2)

20-29: Optional: Add concise docstrings to public methods.

Short docstrings for is_final_output_capture and maybe_capture_final_hidden_states help downstream users understand when the capture hook is active and what gets copied.


34-45: Allocate last_hidden_states only when needed — LGTM.

The conditional allocation matches the flag. Consider documenting the shape contract (max_num_tokens x hidden_size) in a comment.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 8a19b2e and 3f40394.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/models/modeling_speculative.py (1 hunks)
  • tensorrt_llm/_torch/speculative/__init__.py (2 hunks)
  • tensorrt_llm/_torch/speculative/interface.py (5 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
  • tensorrt_llm/_torch/speculative/utils.py (4 hunks)
  • tensorrt_llm/llmapi/__init__.py (2 hunks)
  • tensorrt_llm/llmapi/llm_args.py (4 hunks)
  • tensorrt_llm/models/modeling_utils.py (2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
  • tensorrt_llm/models/modeling_utils.py
  • tensorrt_llm/llmapi/init.py
  • tensorrt_llm/_torch/speculative/init.py
  • tensorrt_llm/_torch/models/modeling_speculative.py
  • tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/utils.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
  • tensorrt_llm/llmapi/llm_args.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/utils.py (4)
tensorrt_llm/_torch/speculative/save_hidden_state.py (3)
  • SaveHiddenStatesDrafter (47-124)
  • SaveHiddenStatesResourceManager (32-44)
  • SaveHiddenStatesSpecMetadata (17-29)
tensorrt_llm/llmapi/llm_args.py (4)
  • spec_dec_mode (396-403)
  • spec_dec_mode (442-447)
  • spec_dec_mode (480-483)
  • spec_dec_mode (586-591)
tensorrt_llm/_torch/speculative/interface.py (1)
  • is_save_hidden_states (54-55)
tensorrt_llm/_torch/model_config.py (1)
  • torch_dtype (151-156)
tensorrt_llm/_torch/speculative/save_hidden_state.py (6)
tensorrt_llm/_utils.py (1)
  • local_mpi_rank (502-503)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • LlmRequest (282-422)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • ScheduledRequests (18-39)
tensorrt_llm/_torch/speculative/drafter.py (1)
  • Drafter (9-38)
tensorrt_llm/_torch/speculative/eagle3.py (2)
  • Eagle3ResourceManager (18-77)
  • Eagle3SpecMetadata (81-186)
tensorrt_llm/_torch/speculative/interface.py (2)
  • is_final_output_capture (199-204)
  • maybe_capture_final_hidden_states (214-219)
tensorrt_llm/llmapi/llm_args.py (2)
tensorrt_llm/models/modeling_utils.py (1)
  • SpeculativeDecodingMode (91-127)
tensorrt_llm/_torch/speculative/interface.py (1)
  • SpeculativeDecodingMode (12-116)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

34-34: Undefined name SaveHiddenStatesDecodingConfig

(F821)


51-51: Undefined name SaveHiddenStatesDecodingConfig

(F821)


121-121: Undefined name draft_tokens

(F821)


122-122: Undefined name draft_tokens

(F821)


123-123: Undefined name draft_tokens

(F821)


124-124: Undefined name draft_tokens

(F821)

tensorrt_llm/llmapi/llm_args.py

478-481: 1 blank line required between summary line and description

(D205)


976-976: Line too long (194 > 120)

(E501)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/llmapi/llm_args.py (1)

368-371: Dispatch mapping for "SaveState" looks correct.

Adding "SaveState": SaveHiddenStatesDecodingConfig to the DecodingBaseConfig.from_dict map is consistent with the new class’s decoding_type and fixes prior duplicate-key issues. No changes needed.

tensorrt_llm/_torch/speculative/utils.py (4)

14-16: Completes the SAVE_HIDDEN_STATES wiring by importing the drafter/metadata/manager.

Imports are correct and keep module boundaries clean.


56-69: Spec metadata construction looks consistent with Eagle3 and the new mode.

Passing layers_to_capture and save_last_layer_post_norm into SaveHiddenStatesSpecMetadata aligns with the subclass’s API. No issues spotted.


117-125: Resource manager uses model dtype instead of draft dtype — good catch.

Using model_engine.model.config.torch_dtype avoids None dereferences when no draft engine exists for this mode.


173-175: Drafter construction matches current constructor signature.

Call site aligns with SaveHiddenStatesDrafter(spec_config, spec_resource_manager).

@IzzyPutterman IzzyPutterman requested a review from a team as a code owner August 25, 2025 00:20
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/save_hidden_state.py (3)

97-103: Write only when there’s buffered data.

Avoid emitting empty files and clear the buffer only after a successful write.

-    def _write_to_file(self) -> None:
-        if local_mpi_rank() == 0:
-            output_path = os.path.join(self._output_directory,
-                                       f"{self._file_prefix}_{self._iter}.pt")
-            torch.save(self._saved_state, output_path)
-        self._saved_state = []
+    def _write_to_file(self) -> None:
+        if local_mpi_rank() == 0 and self._saved_state:
+            output_path = os.path.join(
+                self._output_directory, f"{self._file_prefix}_{self._iter}.pt"
+            )
+            torch.save(self._saved_state, output_path)
+            self._saved_state = []

119-136: Flush leftovers not aligned to write_interval.

If the loop ends without hitting the write boundary, the last batch remains in memory. Add a final flush.

         for request in sorted(
                 scheduled_requests.context_requests,
                 key=lambda r:
             (r.py_batch_idx is None, r.py_batch_idx or r.request_id),
         ):
             if is_warmup:
                 continue
             self._process_request(request, self.spec_resource_manager)
             if self._iter % self._write_interval == 0:
                 self._write_to_file()
             self._iter += 1
+        # Final flush for leftovers
+        if (self._iter % self._write_interval) != 0:
+            self._write_to_file()

1-7: Add NVIDIA header and resolve F821 for forward-referenced config.

Source files need the NVIDIA copyright header. Also, Ruff flags SaveHiddenStatesDecodingConfig as undefined; import it under TYPE_CHECKING (or enable postponed annotations).

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+from __future__ import annotations
 import os
 from dataclasses import dataclass
-from typing import Optional
+from typing import Optional, TYPE_CHECKING
@@
 from tensorrt_llm._utils import local_mpi_rank
+
+if TYPE_CHECKING:
+    from tensorrt_llm.llmapi.llm_args import SaveHiddenStatesDecodingConfig
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/model_drafter.py (1)

432-448: Consider guarding against forgotten overrides when post-forward is required.

Right now, the default prepare_draft_tokens_post silently does nothing. If a subclass returns True from needs_draft_forward_post() but forgets to override this method, the flow will be silently broken. Optionally add a defensive check.

 @nvtx_range("prepare_draft_tokens_post")
 def prepare_draft_tokens_post(
     self,
     scheduled_requests: ScheduledRequests,
     resource_manager: Optional[ResourceManager] = None,
     is_warmup: bool = False,
 ) -> None:
-    """
+    """
     If draft forward needs to be run directly after the target model forward,
     this method can be overridden to do that.
     Used in SaveHiddenStatesDrafter (to ensure correct input_ids)
 
     Args:
         scheduled_requests: The scheduled requests for this iteration
         resource_manager: The resource manager for this iteration
     """
+    if self.needs_draft_forward_post():
+        raise NotImplementedError(
+            f"{type(self).__name__} needs post-forward drafting but did not override prepare_draft_tokens_post()"
+        )
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

920-921: Compute-and-store the post-forward flag per-iteration (good), but initialize it in init to avoid accidental use before set.

The assignment here is correct. Initializing the flag in __init__ helps future-proofing against code paths that might read it before _prepare_and_schedule_batch() runs.

         self.sampler = sampler
         self.drafter = drafter
+        self.draft_forward_post_needed = False
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)

49-64: Minor: start iteration counter at 0 to match typical file-indexing; non-blocking.

Current _iter = 1 is fine but unconventional. Starting at 0 makes file indices monotonically match first write boundaries.

-        self._iter = 1
+        self._iter = 0

65-96: Tighten rank-0 block and avoid unnecessary locals; also prefer as_tensor.

out_dict is only used on rank 0; move it inside the rank-0 block and use torch.as_tensor to avoid an intermediate Python list allocation.

-        out_dict = {}
-        if local_mpi_rank() == 0:
-            input_ids = torch.tensor(list(request.get_tokens(0)),
-                                     dtype=torch.long,
-                                     device='cpu')
+        if local_mpi_rank() == 0:
+            input_ids = torch.as_tensor(
+                request.get_tokens(0), dtype=torch.long, device="cpu"
+            )
@@
-            out_dict = {
+            out_dict = {
                 "id":
                 self._iter,
                 "input_ids":
                 input_ids,
                 "hidden_state_features":
                 resource_manager.hidden_states[:num_tokens, :].cpu().clone(),
                 "hidden_state":
                 hidden_states,
             }
 
             self._saved_state.append(out_dict)
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 3f40394 and 4a8a0af.

📒 Files selected for processing (4)
  • tensorrt_llm/_torch/pyexecutor/py_executor.py (2 hunks)
  • tensorrt_llm/_torch/speculative/drafter.py (1 hunks)
  • tensorrt_llm/_torch/speculative/model_drafter.py (1 hunks)
  • tensorrt_llm/_torch/speculative/save_hidden_state.py (1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else

Files:

  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA copyright header (current year) to all source files

Files:

  • tensorrt_llm/_torch/speculative/drafter.py
  • tensorrt_llm/_torch/speculative/model_drafter.py
  • tensorrt_llm/_torch/pyexecutor/py_executor.py
  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.

Applied to files:

  • tensorrt_llm/_torch/speculative/save_hidden_state.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
  • needs_draft_forward_post (137-143)
tensorrt_llm/_torch/speculative/model_drafter.py (6)
tensorrt_llm/_utils.py (1)
  • nvtx_range (843-862)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
  • prepare_draft_tokens_post (119-135)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • ScheduledRequests (18-39)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
  • ResourceManager (966-1001)
tensorrt_llm/_torch/pyexecutor/py_executor.py (2)
  • is_warmup (289-290)
  • is_warmup (293-298)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
  • is_warmup (494-495)
  • is_warmup (498-501)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
  • needs_draft_forward_post (40-46)
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
  • needs_draft_forward_post (137-143)
  • prepare_draft_tokens_post (119-135)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
  • prepare_draft_tokens_post (433-447)
tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
  • is_warmup (494-495)
  • is_warmup (498-501)
tensorrt_llm/_torch/speculative/save_hidden_state.py (7)
tensorrt_llm/_utils.py (1)
  • local_mpi_rank (502-503)
tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
  • LlmRequest (282-422)
tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
  • ResourceManager (966-1001)
tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
  • ScheduledRequests (18-39)
tensorrt_llm/_torch/speculative/drafter.py (3)
  • Drafter (9-46)
  • prepare_draft_tokens (16-27)
  • needs_draft_forward_post (40-46)
tensorrt_llm/_torch/speculative/eagle3.py (2)
  • Eagle3ResourceManager (18-77)
  • Eagle3SpecMetadata (81-186)
tensorrt_llm/_torch/speculative/interface.py (2)
  • is_final_output_capture (199-204)
  • maybe_capture_final_hidden_states (214-219)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py

34-34: Undefined name SaveHiddenStatesDecodingConfig

(F821)


51-51: Undefined name SaveHiddenStatesDecodingConfig

(F821)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/speculative/drafter.py (1)

40-46: Good additive hook; default is safe.

Adding a base no-op switch for post-forward drafting is reasonable and keeps existing drafters unaffected.

tensorrt_llm/_torch/pyexecutor/py_executor.py (1)

1007-1010: Verify ordering: post-forward hook runs after model forward and sampling but before request-state updates.

This looks correct for SaveHiddenStates (final hidden states are captured during forward, then processed here before _update_request_states). Please confirm this ordering matches the intended contract for all drafters that may use the hook.

tensorrt_llm/_torch/speculative/save_hidden_state.py (3)

16-30: Spec metadata: final-hidden-states capture path looks correct.

Gates on save_last_layer_post_norm and slices by active tokens to avoid shape mismatches. LGTM.


34-45: Resource manager: allocate buffer only when needed; aligns with metadata usage.

The shape (max_num_tokens, hidden_size) is consistent with per-token last-layer captures. LGTM.


104-118: Pre-forward drafting placeholders look fine.

Setting py_max_new_tokens = 1 and providing zeroed draft tokens of fixed length keeps downstream shape assumptions intact.

@juney-nvidia
Copy link
Collaborator

@IzzyPutterman Hi Izzy, I noticed that this PR has been pending for 3 weeks, is there any rough timeline as to when you think this can be ready to be prepared to merge to main? :)

June

@IzzyPutterman
Copy link
Collaborator Author

@IzzyPutterman Hi Izzy, I noticed that this PR has been pending for 3 weeks, is there any rough timeline as to when you think this can be ready to be prepared to merge to main? :)

June

I will have this PR ready before the free days this week, I need to make some small changes to the logic, address Michaels comments, and add a test I think

@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch 3 times, most recently from 55f5575 to fd8a03c Compare September 15, 2025 23:22
@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch 2 times, most recently from 77a4eb3 to 7d7ad63 Compare September 18, 2025 01:56
@IzzyPutterman IzzyPutterman requested a review from a team as a code owner September 18, 2025 01:56
@tensorrt-cicd
Copy link
Collaborator

PR_Github #20001 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20001 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15063 completed with status: 'FAILURE'

@IzzyPutterman
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20027 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20027 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15085 completed with status: 'FAILURE'

@IzzyPutterman
Copy link
Collaborator Author

/bot run

Signed-off-by: Izzy Putterman <[email protected]>

small changes

Signed-off-by: Izzy Putterman <[email protected]>

fixes inflight

Signed-off-by: Izzy Putterman <[email protected]>

maybe functional

Signed-off-by: Izzy Putterman <[email protected]>

cleanup + small design changes

Signed-off-by: Izzy Putterman <[email protected]>

drop unused lines

Signed-off-by: Izzy Putterman <[email protected]>
Signed-off-by: Izzy Putterman <[email protected]>
@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from 206978a to fc3112f Compare September 26, 2025 19:27
@IzzyPutterman
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20118 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20118 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15161 completed with status: 'FAILURE'

@IzzyPutterman
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20122 [ run ] triggered by Bot

Signed-off-by: Izzy Putterman <[email protected]>
@IzzyPutterman IzzyPutterman force-pushed the iputterman/savestate-config branch from fc3112f to 6b6b73b Compare September 26, 2025 22:38
@IzzyPutterman
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20125 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20122 [ run ] completed with state ABORTED
LLM/main/L0_MergeRequest_PR #15165 (Blue Ocean) completed with status: ABORTED

@tensorrt-cicd
Copy link
Collaborator

PR_Github #20125 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #15168 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@IzzyPutterman
Copy link
Collaborator Author

Bump on review

Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

api changes look good 👍

@mikeiovine mikeiovine merged commit 1ad7bc4 into NVIDIA:main Oct 1, 2025
5 checks passed
@juney-nvidia juney-nvidia changed the title [None][feat] Draft: Save state first pass [None][feat] Save state first pass Oct 2, 2025
@juney-nvidia
Copy link
Collaborator

@IzzyPutterman Next time when we merge the PR to the main branch, we can remove the "draft" annotation:)
I have already removed it, but in the history the "draft" word still exist.

June

@IzzyPutterman
Copy link
Collaborator Author

@IzzyPutterman Next time when we merge the PR to the main branch, we can remove the "draft" annotation:) I have already removed it, but in the history the "draft" word still exist.

June

Yep, sorry about that. Thanks for the catch!

faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 2, 2025
Signed-off-by: Izzy Putterman <[email protected]>
Signed-off-by: Faradawn Yang <[email protected]>
evezhier pushed a commit to evezhier/TensorRT-LLM that referenced this pull request Oct 3, 2025
faradawn pushed a commit to faradawn/TensorRT-LLM that referenced this pull request Oct 3, 2025
Signed-off-by: Izzy Putterman <[email protected]>
Signed-off-by: Faradawn Yang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants