-
Notifications
You must be signed in to change notification settings - Fork 1.8k
[None][feat] Save state first pass #7012
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[None][feat] Save state first pass #7012
Conversation
📝 WalkthroughWalkthroughAdds a SAVE_HIDDEN_STATES speculative decoding mode with config, SpecMetadata/resource-manager/drafter implementations to persist final hidden states, wires a conditional final-hidden-state capture hook into the model forward, and updates utilities and public exports to construct and validate the new mode. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Client
participant Config as LLMArgs/Config
participant Utils
participant ResourceMgr
participant Model
participant SpecMeta
participant Drafter
participant FS as Filesystem
Client->>Config: submit request with SaveState config
Config->>Utils: get_spec_metadata(), get_spec_resource_manager(), get_spec_drafter()
Utils->>ResourceMgr: construct SaveHiddenStatesResourceManager
Client->>Model: forward(input_ids, spec_metadata)
Model->>SpecMeta: is_final_output_capture()?
alt final capture requested
SpecMeta->>SpecMeta: maybe_capture_final_hidden_states(hidden_states)
SpecMeta-->>ResourceMgr: store last_hidden_states (if allocated)
end
Model-->>Client: logits/draft output
Client->>Drafter: prepare_draft_tokens(requests, resource_manager)
Drafter->>Drafter: collect per-request saved states
alt root rank & interval reached
Drafter->>FS: write .pt file(s)
end
Drafter-->>Client: draft tokens
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
TODO: actually test, handle initial warmup configs (dummy data for sizing shouldnt be saved), handle chunking, make tests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 13
🔭 Outside diff range comments (2)
tensorrt_llm/_torch/speculative/utils.py (1)
167-171
: Add drafter branch for SAVE_HIDDEN_STATES to match has_spec_drafter().Apply this diff:
if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return SaveHiddenStatesDrafter(spec_config) + return Nonetensorrt_llm/_torch/speculative/save_hidden_state.py (1)
92-112
: Fix call site, draft tokens, and write cadence.
- Pass resource_manager into _process_request.
- Define draft_tokens deterministically to max_draft_len.
- Increment _iter before write modulo check to avoid writing at iteration 0.
Apply this diff:
def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: for request in sorted( scheduled_requests.context_requests, key=lambda r: (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): request.py_max_new_tokens = 1 - self._process_request(request, resource_manager) - if self._iter % self._write_interval == 0: - self._write_to_file() - self._iter += 1 - # Pad length to `self.max_draft_len` - if len(draft_tokens) > 0: - pad_length = self.max_draft_len - len(draft_tokens) - draft_tokens.extend([0] * pad_length) - request.py_draft_tokens = draft_tokens + self._process_request(request, resource_manager) + self._iter += 1 + if self._iter % self._write_interval == 0: + self._write_to_file() + # Always provide a fixed-length draft token list + draft_tokens = [0] * self.max_draft_len + request.py_draft_tokens = draft_tokens
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/interface.py (1)
192-209
: No-op hooks added appropriately; minor doc/typing polish suggested.The hooks are fine as defaults. Consider adding a return type for is_final_output_capture -> bool for consistency.
- def is_final_output_capture(self): + def is_final_output_capture(self) -> bool: return Falsetensorrt_llm/_torch/speculative/save_hidden_state.py (2)
12-18
: Add return typing; keep API explicit.- def is_final_output_capture(self): + def is_final_output_capture(self) -> bool: return self.save_last_layer_post_norm
29-40
: Resource manager init OK; ensure torch is imported (fixed above).Also consider documenting last_hidden_states semantics in a short docstring.
tensorrt_llm/llmapi/llm_args.py (1)
449-472
: New config LGTM; consider ensuring the output directory exists in validate().Apply this diff:
def validate(self) -> None: - if self.output_directory is None or not self.eagle3_layers_to_capture: + if self.output_directory is None or not self.eagle3_layers_to_capture: raise ValueError( "Save directory and layers to capture must be provided") + os.makedirs(self.output_directory, exist_ok=True)
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/interface.py
(5 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(5 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/models/modeling_speculative.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/models/modeling_speculative.py
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/llmapi/llm_args.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
12-12: Undefined name dataclass
(F821)
20-20: Undefined name torch
(F821)
29-29: Undefined name SaveHiddenStatesDecodingConfig
(F821)
30-30: Undefined name torch
(F821)
36-36: Undefined name torch
(F821)
46-46: Undefined name SaveHiddenStatesDecodingConfig
(F821)
60-60: Undefined name torch
(F821)
61-61: Undefined name torch
(F821)
63-63: Undefined name resource_manager
(F821)
65-65: Undefined name resource_manager
(F821)
68-68: Undefined name resource_manager
(F821)
78-78: Undefined name resource_manager
(F821)
87-87: Undefined name os
(F821)
89-89: Undefined name torch
(F821)
108-108: Undefined name draft_tokens
(F821)
109-109: Undefined name draft_tokens
(F821)
110-110: Undefined name draft_tokens
(F821)
111-111: Undefined name draft_tokens
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/interface.py (1)
54-56
: Predicate addition LGTM.tensorrt_llm/_torch/models/modeling_speculative.py (1)
421-424
: Hook to capture final hidden states LGTM.Gated behind metadata and a predicate, with no overhead when unused. Correct placement before draft-branching.
tensorrt_llm/llmapi/llm_args.py (1)
12-13
: Import of Set is fine.No issues.
16b9e78
to
dc2eed8
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/utils.py (1)
144-170
: Add get_spec_drafter branch for SAVE_HIDDEN_STATES.Without this branch, the mode won't construct a drafter and will silently return None.
Apply this diff:
if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return SaveHiddenStatesDrafter(spec_config) + return None
♻️ Duplicate comments (10)
tensorrt_llm/_torch/speculative/utils.py (3)
14-15
: Import the drafter to complete the SAVE_HIDDEN_STATES wiring.The new mode defines a Drafter, but it isn’t imported here, which prevents get_spec_drafter from constructing it.
Apply this diff:
-from .save_hidden_state import (SaveHiddenStatesResourceManager, - SaveHiddenStatesSpecMetadata) +from .save_hidden_state import (SaveHiddenStatesResourceManager, + SaveHiddenStatesSpecMetadata, + SaveHiddenStatesDrafter)
53-66
: Fix incorrect config attribute: use eagle3_layers_to_capture to derive capture count.spec_config has eagle3_layers_to_capture, not num_capture_layers. Referencing spec_config.num_capture_layers will raise at runtime.
Apply this diff:
return SaveHiddenStatesSpecMetadata( max_draft_len=spec_config.max_draft_len, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, - num_capture_layers=spec_config.num_capture_layers, + num_capture_layers=( + len(spec_config.eagle3_layers_to_capture) + if spec_config.eagle3_layers_to_capture is not None + else Eagle3SpecMetadata.num_capture_layers + ), save_last_layer_post_norm=spec_config.save_last_layer_post_norm, )
114-122
: Avoid possible None dereference: use model dtype instead of draft model dtype.draft_model_engine can be None for this mode; using draft_model_engine.model.config.torch_dtype risks an AttributeError. Use model_config.torch_dtype.
Apply this diff:
if spec_dec_mode.is_save_hidden_states(): return SaveHiddenStatesResourceManager( spec_config, - draft_model_engine.model.config.torch_dtype, + model_config.torch_dtype, model_config.hidden_size, max_num_requests, max_seq_len, max_num_tokens, )tensorrt_llm/_torch/speculative/save_hidden_state.py (5)
1-11
: Add required header and missing imports; fix forward refs to avoid import-time NameErrors.dataclass/torch/os are missing; type SaveHiddenStatesDecodingConfig is referenced but not imported or quoted. Use TYPE_CHECKING and string annotations to avoid cycles.
Apply this diff:
-from typing import Optional + # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from dataclasses import dataclass +import os +from typing import Optional, TYPE_CHECKING +import torch @@ -from tensorrt_llm._utils import local_mpi_rank +from tensorrt_llm._utils import local_mpi_rank + +if TYPE_CHECKING: + from tensorrt_llm.llmapi.llm_args import SaveHiddenStatesDecodingConfig
19-25
: Slice copy to avoid shape mismatch when capturing final hidden states.The preallocated buffer is sized to max_num_tokens; hidden_states holds only the active tokens. Copy only the active rows.
Apply this diff:
def maybe_capture_final_hidden_states(self, hidden_states: torch.Tensor) -> None: if self.save_last_layer_post_norm: - # Assume no chunking, BS=1 - eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states - eagle3_hidden_states.copy_(hidden_states) + # Assume no chunking, BS=1; copy only active tokens + buf = self.eagle3_resource_manager.last_hidden_states + n = hidden_states.shape[0] + buf[:n].copy_(hidden_states)
44-56
: Ensure output directory exists; fix type annotation to avoid runtime import.Create the output directory up-front. Quote the type annotation to avoid importing at module import time.
Apply this diff:
def __init__( self, - spec_config: SaveHiddenStatesDecodingConfig, + spec_config: "SaveHiddenStatesDecodingConfig", ): super().__init__(spec_config.max_concurrency) self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len self._iter = 0 self._output_directory = spec_config.output_directory self._file_prefix = spec_config.file_prefix self._write_interval = spec_config.write_interval self._saved_state = [] + os.makedirs(self._output_directory, exist_ok=True)
57-84
: Fix signature, rank logic, undefined names, and iteration field.
- _process_request must accept resource_manager.
- Accumulate on root (rank 0) to match write-on-root.
- Use self._iter.
- Don’t reference undefined variables.
Apply this diff:
- def _process_request(self, request: LlmRequest) -> None: - out_dict = {} - if local_mpi_rank() != 0: - input_ids = torch.tensor(list(request.get_tokens(0)), - dtype=torch.long, - device='cpu') - hidden_size = resource_manager.hidden_size - if self.spec_config.save_last_layer_post_norm: - hidden_states = resource_manager.last_hidden_states.cpu().clone( - ) - else: - hidden_states = resource_manager.hidden_states[:, - -hidden_size:].cpu( - ).clone() - - out_dict = { - "id": - self.iteration, - "input_ids": - input_ids, - "hidden_state_features": - resource_manager.hidden_states.cpu().clone(), - "hidden_state": - hidden_states, - } - - self._saved_state.append(out_dict) + def _process_request(self, request: LlmRequest, + resource_manager: ResourceManager) -> None: + if local_mpi_rank() == 0: + input_ids = torch.tensor( + list(request.get_tokens(0)), dtype=torch.long, device='cpu') + hidden_size = resource_manager.hidden_size + if getattr(self.spec_config, "save_last_layer_post_norm", False): + hidden_state = resource_manager.last_hidden_states.cpu().clone() + else: + # Fallback: slice last layer/state from captured features. + hidden_state = resource_manager.hidden_states[:, -hidden_size: + ].cpu().clone() + out_dict = { + "id": self._iter, + "input_ids": input_ids, + "hidden_state_features": resource_manager.hidden_states.cpu(). + clone(), + "hidden_state": hidden_state, + } + self._saved_state.append(out_dict)
85-91
: Write only when there’s data; fix undefined fields.Use self._iter and check self._saved_state; clear after write.
Apply this diff:
def _write_to_file(self) -> None: - if local_mpi_rank() == 0 and self.iteration != self.start_iteration: - output_path = os.path.join(self._output_directory, - f"{self._file_prefix}_{self._iter}.pt") - torch.save(self._saved_state, output_path) - self._saved_state = [] + if local_mpi_rank() == 0 and self._saved_state: + output_path = os.path.join(self._output_directory, + f"{self._file_prefix}_{self._iter}.pt") + torch.save(self._saved_state, output_path) + self._saved_state = []tensorrt_llm/llmapi/llm_args.py (2)
943-953
: Fix broken type alias: use SaveHiddenStatesDecodingConfig.The singular name is undefined and will break imports.
Apply this diff:
SpeculativeConfig: TypeAlias = Optional[Union[ DraftTargetDecodingConfig, EagleDecodingConfig, LookaheadDecodingConfig, MedusaDecodingConfig, MTPDecodingConfig, NGramDecodingConfig, UserProvidedDecodingConfig, - SaveHiddenStateDecodingConfig, + SaveHiddenStatesDecodingConfig, AutoDecodingConfig, ]]
364-371
: Remove duplicate "SaveState" dispatch entry.Duplicate keys lead to confusing behavior and may shadow earlier entries.
Apply this diff:
config_classes = { "MTP": MTPDecodingConfig, "Medusa": MedusaDecodingConfig, "Eagle": EagleDecodingConfig, - "SaveState": SaveHiddenStatesDecodingConfig, "Lookahead": LookaheadDecodingConfig, "NGram": NGramDecodingConfig, "DraftTarget": DraftTargetDecodingConfig, - "SaveState": SaveHiddenStatesDecodingConfig, + "SaveState": SaveHiddenStatesDecodingConfig, "UserProvided": UserProvidedDecodingConfig, "AUTO": AutoDecodingConfig, }
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/utils.py (1)
1-1
: Prepend NVIDIA copyright header.Per coding guidelines, add the NVIDIA copyright header to all source files.
Apply this diff at the top of the file:
+ # Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
tensorrt_llm/llmapi/llm_args.py (1)
449-472
: New SaveHiddenStatesDecodingConfig: basic wiring looks correct.Definition, spec_dec_mode override, and validation are coherent. Consider adding a class docstring and field docstrings to match repo standards, but not blocking.
📜 Review details
Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (5)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/interface.py
(5 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(3 hunks)tensorrt_llm/llmapi/llm_args.py
(5 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/llmapi/llm_args.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
12-12: Undefined name dataclass
(F821)
20-20: Undefined name torch
(F821)
29-29: Undefined name SaveHiddenStatesDecodingConfig
(F821)
30-30: Undefined name torch
(F821)
36-36: Undefined name torch
(F821)
46-46: Undefined name SaveHiddenStatesDecodingConfig
(F821)
60-60: Undefined name torch
(F821)
61-61: Undefined name torch
(F821)
63-63: Undefined name resource_manager
(F821)
65-65: Undefined name resource_manager
(F821)
68-68: Undefined name resource_manager
(F821)
78-78: Undefined name resource_manager
(F821)
87-87: Undefined name os
(F821)
89-89: Undefined name torch
(F821)
108-108: Undefined name draft_tokens
(F821)
109-109: Undefined name draft_tokens
(F821)
110-110: Undefined name draft_tokens
(F821)
111-111: Undefined name draft_tokens
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (2)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
29-40
: Ensure last_hidden_states buffer allocated only when configured; torch dtype import fixed above.Constructor is fine; annotation for config already addressed via TYPE_CHECKING. No further changes required here.
tensorrt_llm/llmapi/llm_args.py (1)
12-12
: LGTM: typing.Set import is correct.The added Set import aligns with new config usage.
dc2eed8
to
8eb53f3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (9)
tensorrt_llm/_torch/speculative/utils.py (3)
14-16
: Resolved: Drafter/resource manager/metadata are now imported.This addresses the previous import omission for the drafter.
117-125
: Avoid possible None dereference; use model dtype, not draft model dtype.In SAVE_HIDDEN_STATES there may be no draft model. Use model_config.torch_dtype.
Apply this diff:
if spec_dec_mode.is_save_hidden_states(): return SaveHiddenStatesResourceManager( spec_config, - draft_model_engine.model.config.torch_dtype, + model_config.torch_dtype, model_config.hidden_size, max_num_requests, max_seq_len, max_num_tokens, )
56-69
: Fix SaveHiddenStatesSpecMetadata constructor kwargsThe
SaveHiddenStatesSpecMetadata
class inherits its initializer fromEagle3SpecMetadata
, which defines a field namedlayers_to_capture
—there is nonum_capture_layers
parameter. Passingnum_capture_layers
will cause a runtime error.• Location to update:
- File: tensorrt_llm/_torch/speculative/utils.py
- In the
if spec_config.spec_dec_mode.is_save_hidden_states():
branch (around lines 56–69)Apply this diff:
if spec_config.spec_dec_mode.is_save_hidden_states(): return SaveHiddenStatesSpecMetadata( max_draft_len=spec_config.max_draft_len, spec_dec_mode=spec_config.spec_dec_mode, max_num_requests=max_num_requests, num_layers=model_config.num_hidden_layers, hidden_size=model_config.hidden_size, max_num_tokens=max_num_tokens, dtype=model_config.torch_dtype, is_draft_model=is_draft_model, eagle3_resource_manager=spec_resource_manager, - num_capture_layers=spec_config.num_capture_layers, + layers_to_capture=spec_config.eagle3_layers_to_capture, save_last_layer_post_norm=spec_config.save_last_layer_post_norm, )tensorrt_llm/_torch/speculative/save_hidden_state.py (6)
1-4
: Missing required header and imports cause import-time failures.The module uses dataclass/torch/os but does not import them, and the file lacks the required NVIDIA copyright header.
Apply this diff at the top of the file:
+# Copyright (c) 2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +from dataclasses import dataclass +import os +import torch from typing import Optional from tensorrt_llm._utils import local_mpi_rank
16-25
: Fix shape mismatch risk when copying final hidden states.last_hidden_states is allocated to (max_num_tokens, hidden_size), while incoming hidden_states is (#active_tokens, hidden_size). copy_ requires matching shapes.
Apply this diff:
def maybe_capture_final_hidden_states(self, hidden_states: torch.Tensor) -> None: if self.save_last_layer_post_norm: - # Assume no chunking, BS=1 - eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states - eagle3_hidden_states.copy_(hidden_states) + # Assume no chunking, BS=1; copy only active token rows. + buf = self.eagle3_resource_manager.last_hidden_states + n = hidden_states.shape[0] + buf[:n].copy_(hidden_states)
44-56
: Initialize output directory and fix forward reference typing.Ensure the output directory exists and avoid importing the config type at runtime by using a string annotation.
Apply this diff:
- def __init__( - self, - spec_config: SaveHiddenStatesDecodingConfig, - ): + def __init__( + self, + spec_config: "SaveHiddenStatesDecodingConfig", + ): super().__init__(spec_config.max_concurrency) self.spec_config = spec_config self.max_draft_len = spec_config.max_draft_len self._iter = 0 self._output_directory = spec_config.output_directory self._file_prefix = spec_config.file_prefix self._write_interval = spec_config.write_interval self._saved_state = [] + os.makedirs(self._output_directory, exist_ok=True)
57-84
: Fix undefined names, wrong rank condition, and signature; use resource_manager param.
- _process_request must accept resource_manager.
- Accumulate on root (rank 0) if only root writes.
- Replace undefined iteration with self._iter.
- Use getattr for optional flag.
Apply this diff:
- def _process_request(self, request: LlmRequest) -> None: - out_dict = {} - if local_mpi_rank() != 0: - input_ids = torch.tensor(list(request.get_tokens(0)), - dtype=torch.long, - device='cpu') - hidden_size = resource_manager.hidden_size - if self.spec_config.save_last_layer_post_norm: - hidden_states = resource_manager.last_hidden_states.cpu().clone( - ) - else: - hidden_states = resource_manager.hidden_states[:, - -hidden_size:].cpu( - ).clone() - - out_dict = { - "id": - self.iteration, - "input_ids": - input_ids, - "hidden_state_features": - resource_manager.hidden_states.cpu().clone(), - "hidden_state": - hidden_states, - } - - self._saved_state.append(out_dict) + def _process_request(self, request: LlmRequest, + resource_manager: ResourceManager) -> None: + if local_mpi_rank() == 0: + input_ids = torch.tensor( + list(request.get_tokens(0)), dtype=torch.long, device="cpu") + hidden_size = resource_manager.hidden_size + if getattr(self.spec_config, "save_last_layer_post_norm", False): + hidden_state = resource_manager.last_hidden_states.cpu().clone() + else: + # Slice the last layer from the features buffer. + hidden_state = resource_manager.hidden_states[:, -hidden_size:].cpu().clone() + out_dict = { + "id": self._iter, + "input_ids": input_ids, + "hidden_state_features": resource_manager.hidden_states.cpu().clone(), + "hidden_state": hidden_state, + } + self._saved_state.append(out_dict)
85-91
: Write only when there’s data; clear buffer after save.Prevents empty files and resets the state list in the same rank-0 block.
Apply this diff:
def _write_to_file(self) -> None: - if local_mpi_rank() == 0 and self.iteration != self.start_iteration: - output_path = os.path.join(self._output_directory, - f"{self._file_prefix}_{self._iter}.pt") - torch.save(self._saved_state, output_path) - self._saved_state = [] + if local_mpi_rank() == 0 and self._saved_state: + output_path = os.path.join( + self._output_directory, f"{self._file_prefix}_{self._iter}.pt") + torch.save(self._saved_state, output_path) + self._saved_state = []
92-111
: Initialize draft_tokens and pass resource_manager; avoid negative padding.Avoid NameError and ensure padding is non-negative.
Apply this diff:
for request in sorted( scheduled_requests.context_requests, key=lambda r: (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): request.py_max_new_tokens = 1 self._process_request(request, resource_manager) if self._iter % self._write_interval == 0: self._write_to_file() self._iter += 1 # Pad length to `self.max_draft_len` - if len(draft_tokens) > 0: - pad_length = self.max_draft_len - len(draft_tokens) - draft_tokens.extend([0] * pad_length) - request.py_draft_tokens = draft_tokens + draft_tokens: list[int] = [] + pad_length = max(0, self.max_draft_len - len(draft_tokens)) + if pad_length: + draft_tokens.extend([0] * pad_length) + request.py_draft_tokens = draft_tokens
🧹 Nitpick comments (2)
tensorrt_llm/_torch/speculative/__init__.py (1)
20-21
: all updated appropriately; consider consistent ordering.Entries are exposed correctly. For readability, consider keeping all alphabetized or grouped by feature (optional).
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
1-111
: Add minimal docstrings for new public classes/methods.These types are now part of the public surface (also exported in init.py). Add short Google-style docstrings describing purpose and expected shapes/devices for hidden_state buffers.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (6)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/__init__.py
(2 hunks)tensorrt_llm/_torch/speculative/interface.py
(5 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(4 hunks)tensorrt_llm/llmapi/llm_args.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/llmapi/llm_args.py
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/__init__.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/speculative/utils.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/__init__.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/_torch/speculative/utils.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
12-12: Undefined name dataclass
(F821)
20-20: Undefined name torch
(F821)
29-29: Undefined name SaveHiddenStatesDecodingConfig
(F821)
30-30: Undefined name torch
(F821)
36-36: Undefined name torch
(F821)
46-46: Undefined name SaveHiddenStatesDecodingConfig
(F821)
60-60: Undefined name torch
(F821)
61-61: Undefined name torch
(F821)
63-63: Undefined name resource_manager
(F821)
65-65: Undefined name resource_manager
(F821)
68-68: Undefined name resource_manager
(F821)
78-78: Undefined name resource_manager
(F821)
87-87: Undefined name os
(F821)
89-89: Undefined name torch
(F821)
108-108: Undefined name draft_tokens
(F821)
109-109: Undefined name draft_tokens
(F821)
110-110: Undefined name draft_tokens
(F821)
111-111: Undefined name draft_tokens
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
95142ec
to
8a19b2e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/interface.py (1)
133-134
: Bug: default value becomes a tuple due to trailing comma.
spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE,
creates a tuple default(SpeculativeDecodingMode.NONE,)
. Any call likeself.spec_dec_mode.is_ngram()
will fail.Apply:
- spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE, + spec_dec_mode: SpeculativeDecodingMode = SpeculativeDecodingMode.NONE
♻️ Duplicate comments (4)
tensorrt_llm/_torch/speculative/interface.py (1)
89-90
: has_spec_drafter now includes SAVE_HIDDEN_STATES — ensure utils.get_spec_drafter returns a drafter.If
has_spec_drafter()
can return True for this mode,tensorrt_llm/_torch/speculative/utils.py:get_spec_drafter
must return aSaveHiddenStatesDrafter
instance. Otherwise callers will receive None.If missing, update utils.py accordingly:
-from .save_hidden_state import (SaveHiddenStatesResourceManager, - SaveHiddenStatesSpecMetadata) +from .save_hidden_state import (SaveHiddenStatesResourceManager, + SaveHiddenStatesSpecMetadata, + SaveHiddenStatesDrafter) @@ if spec_config.spec_dec_mode.is_ngram(): return NGramDrafter(spec_config, spec_resource_manager) + if spec_config.spec_dec_mode.is_save_hidden_states(): + return SaveHiddenStatesDrafter(spec_config, spec_resource_manager)tensorrt_llm/_torch/speculative/save_hidden_state.py (3)
1-7
: Missing NVIDIA copyright header (required for source files).Add the standard NVIDIA header at the top of this file.
+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. import os from dataclasses import dataclass from typing import Optional
66-97
: Root/non-root rank logic is inverted; data never gets written on rank 0.You append to
_saved_state
only on non-root ranks, but_write_to_file()
writes only on root. Flip the condition and simplify.- out_dict = {} - if local_mpi_rank() != 0: + if local_mpi_rank() == 0: input_ids = torch.tensor(list(request.get_tokens(0)), dtype=torch.long, device='cpu') hidden_size = resource_manager.hidden_size num_tokens = input_ids.shape[0] if self.spec_config.save_last_layer_post_norm: - hidden_states = resource_manager.last_hidden_states[: - num_tokens, :].cpu( - ).clone() + hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone() else: - hidden_states = resource_manager.hidden_states[:num_tokens, - -hidden_size:].cpu( - ).clone() + hidden_states = resource_manager.hidden_states[:num_tokens, -hidden_size:].cpu().clone() out_dict = { "id": self._iter, "input_ids": input_ids, "hidden_state_features": resource_manager.hidden_states[:num_tokens, :].cpu().clone(), "hidden_state": hidden_states, } self._saved_state.append(out_dict)
98-104
: Write only when there’s buffered data; reset after successful write.As written, you can emit empty files. Also consider flushing leftovers at the end of prepare_draft_tokens.
- def _write_to_file(self) -> None: - if local_mpi_rank() == 0 and self._iter != self._start_iter: - output_path = os.path.join(self._output_directory, - f"{self._file_prefix}_{self._iter}.pt") - torch.save(self._saved_state, output_path) - self._saved_state = [] + def _write_to_file(self) -> None: + if local_mpi_rank() == 0 and self._saved_state: + output_path = os.path.join( + self._output_directory, f"{self._file_prefix}_{self._iter}.pt" + ) + torch.save(self._saved_state, output_path) + self._saved_state = []
🧹 Nitpick comments (7)
tensorrt_llm/_torch/speculative/interface.py (1)
199-205
: Name/docstring mismatch; clarify intent.This method controls final-output capture, not per-layer capture. The docstring should reflect that to avoid confusion with
is_layer_capture
.- def is_final_output_capture(self): - """ - Whether the layer should be captured (eg for Eagle3). - Captured after layer norm in modeling_speculative. - """ + def is_final_output_capture(self): + """ + Whether the final hidden state (post norm) should be captured. + Used by modes that need the last-layer output after layer norm. + """tensorrt_llm/llmapi/llm_args.py (4)
461-495
: Add minimal validation/docstrings for SaveHiddenStatesDecodingConfig.Two small guardrails will prevent misconfigurations:
- Validate
write_interval >= 1
.- Restrict backend via
supports_backend
override (mirrors other modes).Also add a brief class docstring (consistent with guidelines).
class SaveHiddenStatesDecodingConfig(DecodingBaseConfig): + """ + Configuration for saving hidden states during decoding (PyTorch backend only). + Writes per-request inputs and hidden states to output_directory every write_interval iterations. + """ output_directory: str write_interval: int = 20 file_prefix: str = "data" eagle3_layers_to_capture: Optional[Set[int]] = None save_last_layer_post_norm: bool = True @classmethod def from_dict(cls, data: dict): return cls(**data) decoding_type: ClassVar[str] = "SaveState" def validate(self) -> None: if self.output_directory is None or not self.eagle3_layers_to_capture: raise ValueError( "Save directory and layers to capture must be provided") + def supports_backend(self, backend: str) -> bool: + return backend == "pytorch" + + @field_validator('write_interval') + @classmethod + def _validate_write_interval(cls, v: int) -> int: + if v < 1: + raise ValueError("write_interval must be >= 1") + return v
968-978
: Type alias updated; fix long line to appease linters.The union includes the new config. Consider wrapping to <=120 chars for Ruff E501.
-SpeculativeConfig: TypeAlias = Optional[Union[ - DraftTargetDecodingConfig, - EagleDecodingConfig, - LookaheadDecodingConfig, - MedusaDecodingConfig, - MTPDecodingConfig, - NGramDecodingConfig, - UserProvidedDecodingConfig, - SaveHiddenStatesDecodingConfig, - AutoDecodingConfig, -]] +SpeculativeConfig: TypeAlias = Optional[ + Union[ + DraftTargetDecodingConfig, + EagleDecodingConfig, + LookaheadDecodingConfig, + MedusaDecodingConfig, + MTPDecodingConfig, + NGramDecodingConfig, + UserProvidedDecodingConfig, + SaveHiddenStatesDecodingConfig, + AutoDecodingConfig, + ] +]
1751-1761
: Runtime constraints for SaveState are sensible; add explicit doc note or logging.Forcing PyTorch, BS=1, disabling overlap scheduler, and turning off CUDA graphs are appropriate. Consider logging a one-liner when this branch triggers so users understand why their knobs were overridden.
elif isinstance(self.speculative_config, SaveHiddenStatesDecodingConfig): assert self.backend in ['pytorch'] + logger.info("SaveHiddenStates mode: forcing max_batch_size=1, disabling overlap scheduler and CUDA graphs.") self.build_config.max_batch_size = 1 self.max_batch_size = 1 self.disable_overlap_scheduler = True self.cuda_graph_config = None self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES self.build_config.max_draft_len = 1 self.speculative_config.max_draft_len = 1
If helpful, I can add a minimal unit/integration test that exercises this path and asserts these overrides.
368-371
: Dispatch entry for “SaveState” confirmed – please update user-facing documentationThe mapping in
DecodingBaseConfig.from_dict
(llm_args.py:362
) and thedecoding_type
declaration (llm_args.py:463
) correctly recognize"SaveState"
. However, a search across the repository (excludingthird_party
) shows no mentions of “SaveState” in any user-facing docs or examples.To prevent confusion or typos, please update the documentation and examples accordingly. Consider adding or updating references to “SaveState” in:
README.md
- Any files under a
docs/
directory- Example scripts (e.g., in an
examples/
folder)tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
20-30
: Final hidden-state capture: add cheap guards for shape/dtype.Slicing to the active rows is good. Add an assertion to catch unexpected dtype/device mismatches early.
def maybe_capture_final_hidden_states(self, hidden_states: torch.Tensor) -> None: if self.save_last_layer_post_norm: # Assume no chunking, BS=1 eagle3_hidden_states = self.eagle3_resource_manager.last_hidden_states + assert eagle3_hidden_states is not None, "last_hidden_states buffer is not allocated" + assert hidden_states.shape[-1] == eagle3_hidden_states.shape[-1], "hidden size mismatch" eagle3_hidden_states[:hidden_states.shape[0], :].copy_( hidden_states)
34-45
: Resource manager buffer allocation OK; consider propagating device via config.Hard-coding device='cuda' is fine for now, but passing the device (or using current device) would be more robust in multi-GPU setups.
- self.last_hidden_states = torch.empty( - (max_num_tokens, self.hidden_size), - dtype=self.dtype, - device='cuda') + device = torch.device('cuda') + self.last_hidden_states = torch.empty( + (max_num_tokens, self.hidden_size), dtype=self.dtype, device=device)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (8)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/__init__.py
(2 hunks)tensorrt_llm/_torch/speculative/interface.py
(5 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(4 hunks)tensorrt_llm/llmapi/__init__.py
(2 hunks)tensorrt_llm/llmapi/llm_args.py
(4 hunks)tensorrt_llm/models/modeling_utils.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/utils.py
- tensorrt_llm/_torch/speculative/init.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/models/modeling_utils.py
tensorrt_llm/llmapi/__init__.py
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/models/modeling_utils.py
tensorrt_llm/llmapi/__init__.py
tensorrt_llm/_torch/speculative/interface.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/llmapi/llm_args.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
34-34: Undefined name SaveHiddenStatesDecodingConfig
(F821)
51-51: Undefined name SaveHiddenStatesDecodingConfig
(F821)
121-121: Undefined name draft_tokens
(F821)
122-122: Undefined name draft_tokens
(F821)
123-123: Undefined name draft_tokens
(F821)
124-124: Undefined name draft_tokens
(F821)
tensorrt_llm/llmapi/llm_args.py
478-481: 1 blank line required between summary line and description
(D205)
976-976: Line too long (194 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/speculative/interface.py (2)
54-56
: Predicate reads well.
is_save_hidden_states()
is consistent with the other helpers.
214-220
: Hook addition LGTM; consider being explicit about no-op default.The hook is a no-op by default; the docstring already implies it. Looks good.
tensorrt_llm/llmapi/__init__.py (1)
14-16
: Public export added correctly.
SaveHiddenStatesDecodingConfig
is imported and exposed in__all__
. Good.Also applies to: 62-63
8a19b2e
to
3f40394
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
105-125
: draft_tokens is undefined; also flush leftover records after the loop.
- draft_tokens triggers a NameError the first time. Provide a fixed-shape placeholder of length max_draft_len to satisfy downstream shape checks.
- After processing the batch, flush any remaining records not aligned to write_interval.
def prepare_draft_tokens( self, scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, ) -> None: for request in sorted( scheduled_requests.context_requests, key=lambda r: (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): request.py_max_new_tokens = 1 self._process_request(request, self.spec_resource_manager) - if self._iter % self._write_interval == 0: + if self._iter % self._write_interval == 0: self._write_to_file() self._iter += 1 - # Pad length to `self.max_draft_len` - if len(draft_tokens) > 0: - pad_length = self.max_draft_len - len(draft_tokens) - draft_tokens.extend([0] * pad_length) - request.py_draft_tokens = draft_tokens + # Always provide a placeholder of length `self.max_draft_len` + draft_tokens: list[int] = [0] * self.max_draft_len + request.py_draft_tokens = draft_tokens + # Final flush for leftovers not aligned to write_interval + if (self._iter % self._write_interval) != 0: + self._write_to_file()
♻️ Duplicate comments (2)
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
66-97
: Wrong MPI rank for accumulation; captured data will never be written on rank 0._saved_state is appended only when local_mpi_rank() != 0, but _write_to_file writes only on rank 0. On single-GPU (rank 0), nothing is captured; on multi-GPU, rank 0 writes empty shards. Flip the condition and simplify the body.
- out_dict = {} - if local_mpi_rank() != 0: + if local_mpi_rank() == 0: input_ids = torch.tensor(list(request.get_tokens(0)), dtype=torch.long, device='cpu') hidden_size = resource_manager.hidden_size num_tokens = input_ids.shape[0] if self.spec_config.save_last_layer_post_norm: - hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone() + hidden_states = resource_manager.last_hidden_states[:num_tokens, :].cpu().clone() else: hidden_states = resource_manager.hidden_states[:num_tokens, -hidden_size:].cpu( ).clone() - out_dict = { + out_dict = { "id": self._iter, "input_ids": input_ids, "hidden_state_features": resource_manager.hidden_states[:num_tokens, :].cpu().clone(), "hidden_state": hidden_states, } self._saved_state.append(out_dict)
98-104
: Write only when there is data; avoid emitting empty shards.As written, rank 0 writes even if _saved_state is empty (and on iteration 0). Gate on presence of data and clear after writing.
- def _write_to_file(self) -> None: - if local_mpi_rank() == 0 and self._iter != self._start_iter: + def _write_to_file(self) -> None: + if local_mpi_rank() == 0 and self._saved_state: output_path = os.path.join(self._output_directory, f"{self._file_prefix}_{self._iter}.pt") torch.save(self._saved_state, output_path) - self._saved_state = [] + self._saved_state = []
🧹 Nitpick comments (4)
tensorrt_llm/llmapi/llm_args.py (2)
968-978
: TypeAlias update includes SaveHiddenStatesDecodingConfig — LGTM.Union now exposes the new config in the public surface. Keep line length within style limits if the linter flags it in CI.
1751-1761
: PyTorch-only gating and runtime knobs — good; add a quick guardrail.The branch correctly forces PyTorch, BS=1, disables overlap scheduler, turns off CUDA graphs, selects SAVE_HIDDEN_STATES, and pins max_draft_len=1. Suggest also guarding against conflicting user inputs to avoid silent surprises.
elif isinstance(self.speculative_config, SaveHiddenStatesDecodingConfig): assert self.backend in ['pytorch'] + if self.max_batch_size not in (None, 1): + logger.warning("Overriding max_batch_size to 1 for SAVE_HIDDEN_STATES") self.build_config.max_batch_size = 1 self.max_batch_size = 1 self.disable_overlap_scheduler = True self.cuda_graph_config = None self.build_config.speculative_decoding_mode = SpeculativeDecodingMode.SAVE_HIDDEN_STATES self.build_config.max_draft_len = 1 self.speculative_config.max_draft_len = 1If helpful, I can add a small unit/config test ensuring these overrides are applied when SaveHiddenStatesDecodingConfig is used.
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
20-29
: Optional: Add concise docstrings to public methods.Short docstrings for is_final_output_capture and maybe_capture_final_hidden_states help downstream users understand when the capture hook is active and what gets copied.
34-45
: Allocate last_hidden_states only when needed — LGTM.The conditional allocation matches the flag. Consider documenting the shape contract (max_num_tokens x hidden_size) in a comment.
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (8)
tensorrt_llm/_torch/models/modeling_speculative.py
(1 hunks)tensorrt_llm/_torch/speculative/__init__.py
(2 hunks)tensorrt_llm/_torch/speculative/interface.py
(5 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)tensorrt_llm/_torch/speculative/utils.py
(4 hunks)tensorrt_llm/llmapi/__init__.py
(2 hunks)tensorrt_llm/llmapi/llm_args.py
(4 hunks)tensorrt_llm/models/modeling_utils.py
(2 hunks)
🚧 Files skipped from review as they are similar to previous changes (5)
- tensorrt_llm/models/modeling_utils.py
- tensorrt_llm/llmapi/init.py
- tensorrt_llm/_torch/speculative/init.py
- tensorrt_llm/_torch/models/modeling_speculative.py
- tensorrt_llm/_torch/speculative/interface.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/llmapi/llm_args.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/utils.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
tensorrt_llm/llmapi/llm_args.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
🧬 Code graph analysis (3)
tensorrt_llm/_torch/speculative/utils.py (4)
tensorrt_llm/_torch/speculative/save_hidden_state.py (3)
SaveHiddenStatesDrafter
(47-124)SaveHiddenStatesResourceManager
(32-44)SaveHiddenStatesSpecMetadata
(17-29)tensorrt_llm/llmapi/llm_args.py (4)
spec_dec_mode
(396-403)spec_dec_mode
(442-447)spec_dec_mode
(480-483)spec_dec_mode
(586-591)tensorrt_llm/_torch/speculative/interface.py (1)
is_save_hidden_states
(54-55)tensorrt_llm/_torch/model_config.py (1)
torch_dtype
(151-156)
tensorrt_llm/_torch/speculative/save_hidden_state.py (6)
tensorrt_llm/_utils.py (1)
local_mpi_rank
(502-503)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)tensorrt_llm/_torch/speculative/drafter.py (1)
Drafter
(9-38)tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3ResourceManager
(18-77)Eagle3SpecMetadata
(81-186)tensorrt_llm/_torch/speculative/interface.py (2)
is_final_output_capture
(199-204)maybe_capture_final_hidden_states
(214-219)
tensorrt_llm/llmapi/llm_args.py (2)
tensorrt_llm/models/modeling_utils.py (1)
SpeculativeDecodingMode
(91-127)tensorrt_llm/_torch/speculative/interface.py (1)
SpeculativeDecodingMode
(12-116)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
34-34: Undefined name SaveHiddenStatesDecodingConfig
(F821)
51-51: Undefined name SaveHiddenStatesDecodingConfig
(F821)
121-121: Undefined name draft_tokens
(F821)
122-122: Undefined name draft_tokens
(F821)
123-123: Undefined name draft_tokens
(F821)
124-124: Undefined name draft_tokens
(F821)
tensorrt_llm/llmapi/llm_args.py
478-481: 1 blank line required between summary line and description
(D205)
976-976: Line too long (194 > 120)
(E501)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/llmapi/llm_args.py (1)
368-371
: Dispatch mapping for "SaveState" looks correct.Adding "SaveState": SaveHiddenStatesDecodingConfig to the DecodingBaseConfig.from_dict map is consistent with the new class’s decoding_type and fixes prior duplicate-key issues. No changes needed.
tensorrt_llm/_torch/speculative/utils.py (4)
14-16
: Completes the SAVE_HIDDEN_STATES wiring by importing the drafter/metadata/manager.Imports are correct and keep module boundaries clean.
56-69
: Spec metadata construction looks consistent with Eagle3 and the new mode.Passing layers_to_capture and save_last_layer_post_norm into SaveHiddenStatesSpecMetadata aligns with the subclass’s API. No issues spotted.
117-125
: Resource manager uses model dtype instead of draft dtype — good catch.Using model_engine.model.config.torch_dtype avoids None dereferences when no draft engine exists for this mode.
173-175
: Drafter construction matches current constructor signature.Call site aligns with SaveHiddenStatesDrafter(spec_config, spec_resource_manager).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
♻️ Duplicate comments (3)
tensorrt_llm/_torch/speculative/save_hidden_state.py (3)
97-103
: Write only when there’s buffered data.Avoid emitting empty files and clear the buffer only after a successful write.
- def _write_to_file(self) -> None: - if local_mpi_rank() == 0: - output_path = os.path.join(self._output_directory, - f"{self._file_prefix}_{self._iter}.pt") - torch.save(self._saved_state, output_path) - self._saved_state = [] + def _write_to_file(self) -> None: + if local_mpi_rank() == 0 and self._saved_state: + output_path = os.path.join( + self._output_directory, f"{self._file_prefix}_{self._iter}.pt" + ) + torch.save(self._saved_state, output_path) + self._saved_state = []
119-136
: Flush leftovers not aligned to write_interval.If the loop ends without hitting the write boundary, the last batch remains in memory. Add a final flush.
for request in sorted( scheduled_requests.context_requests, key=lambda r: (r.py_batch_idx is None, r.py_batch_idx or r.request_id), ): if is_warmup: continue self._process_request(request, self.spec_resource_manager) if self._iter % self._write_interval == 0: self._write_to_file() self._iter += 1 + # Final flush for leftovers + if (self._iter % self._write_interval) != 0: + self._write_to_file()
1-7
: Add NVIDIA header and resolve F821 for forward-referenced config.Source files need the NVIDIA copyright header. Also, Ruff flags
SaveHiddenStatesDecodingConfig
as undefined; import it under TYPE_CHECKING (or enable postponed annotations).+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +from __future__ import annotations import os from dataclasses import dataclass -from typing import Optional +from typing import Optional, TYPE_CHECKING @@ from tensorrt_llm._utils import local_mpi_rank + +if TYPE_CHECKING: + from tensorrt_llm.llmapi.llm_args import SaveHiddenStatesDecodingConfig
🧹 Nitpick comments (4)
tensorrt_llm/_torch/speculative/model_drafter.py (1)
432-448
: Consider guarding against forgotten overrides when post-forward is required.Right now, the default
prepare_draft_tokens_post
silently does nothing. If a subclass returnsTrue
fromneeds_draft_forward_post()
but forgets to override this method, the flow will be silently broken. Optionally add a defensive check.@nvtx_range("prepare_draft_tokens_post") def prepare_draft_tokens_post( self, scheduled_requests: ScheduledRequests, resource_manager: Optional[ResourceManager] = None, is_warmup: bool = False, ) -> None: - """ + """ If draft forward needs to be run directly after the target model forward, this method can be overridden to do that. Used in SaveHiddenStatesDrafter (to ensure correct input_ids) Args: scheduled_requests: The scheduled requests for this iteration resource_manager: The resource manager for this iteration """ + if self.needs_draft_forward_post(): + raise NotImplementedError( + f"{type(self).__name__} needs post-forward drafting but did not override prepare_draft_tokens_post()" + )tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
920-921
: Compute-and-store the post-forward flag per-iteration (good), but initialize it in init to avoid accidental use before set.The assignment here is correct. Initializing the flag in
__init__
helps future-proofing against code paths that might read it before_prepare_and_schedule_batch()
runs.self.sampler = sampler self.drafter = drafter + self.draft_forward_post_needed = False
tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
49-64
: Minor: start iteration counter at 0 to match typical file-indexing; non-blocking.Current
_iter = 1
is fine but unconventional. Starting at 0 makes file indices monotonically match first write boundaries.- self._iter = 1 + self._iter = 0
65-96
: Tighten rank-0 block and avoid unnecessary locals; also prefer as_tensor.
out_dict
is only used on rank 0; move it inside the rank-0 block and usetorch.as_tensor
to avoid an intermediate Python list allocation.- out_dict = {} - if local_mpi_rank() == 0: - input_ids = torch.tensor(list(request.get_tokens(0)), - dtype=torch.long, - device='cpu') + if local_mpi_rank() == 0: + input_ids = torch.as_tensor( + request.get_tokens(0), dtype=torch.long, device="cpu" + ) @@ - out_dict = { + out_dict = { "id": self._iter, "input_ids": input_ids, "hidden_state_features": resource_manager.hidden_states[:num_tokens, :].cpu().clone(), "hidden_state": hidden_states, } self._saved_state.append(out_dict)
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (4)
tensorrt_llm/_torch/pyexecutor/py_executor.py
(2 hunks)tensorrt_llm/_torch/speculative/drafter.py
(1 hunks)tensorrt_llm/_torch/speculative/model_drafter.py
(1 hunks)tensorrt_llm/_torch/speculative/save_hidden_state.py
(1 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
**/*.py
: Python code must target Python 3.8+
Python indentation: 4 spaces, no tabs
Maintain module namespace in imports (from package.subpackage import foo; then use foo.SomeClass())
Python file names use snake_case
Python class names use PascalCase
Python functions/methods and local variables use snake_case; variables starting with a number get k_ prefix (e.g., k_99th_percentile)
Global variables use G_ prefixed UPPER_SNAKE_CASE (e.g., G_MY_GLOBAL)
Constants use UPPER_SNAKE_CASE in Python
Avoid shadowing variables from outer scopes in Python
Initialize all externally visible members of a Python class in init
Prefer docstrings for interfaces used outside a file; comments for local code
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Document attributes/variables inline with short docstrings
Avoid reflection when simple alternatives exist (e.g., prefer explicit parameters over dict(**locals()))
In try/except, catch the narrowest exceptions possible
For duck-typing with try/except, keep try body minimal and put logic in else
Files:
tensorrt_llm/_torch/speculative/drafter.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
**/*.{cpp,cxx,cc,cu,h,hpp,hxx,hh,cuh,py}
📄 CodeRabbit inference engine (CODING_GUIDELINES.md)
Prepend NVIDIA copyright header (current year) to all source files
Files:
tensorrt_llm/_torch/speculative/drafter.py
tensorrt_llm/_torch/speculative/model_drafter.py
tensorrt_llm/_torch/pyexecutor/py_executor.py
tensorrt_llm/_torch/speculative/save_hidden_state.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
📚 Learning: 2025-08-09T20:57:04.084Z
Learnt from: sklevtsov-nvidia
PR: NVIDIA/TensorRT-LLM#3294
File: cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu:118-127
Timestamp: 2025-08-09T20:57:04.084Z
Learning: In the CUTLASS MoE finalize fusion implementation (cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_tma_warp_specialized_input.cu), when setting `fused_finalize_epilogue.stride_final_output` with shape `(hidden_size, num_output_tokens, 1)`, the `num_rows_in_final_output` should be set to `num_output_tokens` (not `hidden_size`) because of a swap+transpose operation that maps rows of the output tensor to `hidden_size` and columns to `num_output_tokens`.
Applied to files:
tensorrt_llm/_torch/speculative/save_hidden_state.py
🧬 Code graph analysis (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
needs_draft_forward_post
(137-143)
tensorrt_llm/_torch/speculative/model_drafter.py (6)
tensorrt_llm/_utils.py (1)
nvtx_range
(843-862)tensorrt_llm/_torch/speculative/save_hidden_state.py (1)
prepare_draft_tokens_post
(119-135)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
ResourceManager
(966-1001)tensorrt_llm/_torch/pyexecutor/py_executor.py (2)
is_warmup
(289-290)is_warmup
(293-298)tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
is_warmup
(494-495)is_warmup
(498-501)
tensorrt_llm/_torch/pyexecutor/py_executor.py (4)
tensorrt_llm/_torch/speculative/drafter.py (1)
needs_draft_forward_post
(40-46)tensorrt_llm/_torch/speculative/save_hidden_state.py (2)
needs_draft_forward_post
(137-143)prepare_draft_tokens_post
(119-135)tensorrt_llm/_torch/speculative/model_drafter.py (1)
prepare_draft_tokens_post
(433-447)tensorrt_llm/_torch/pyexecutor/model_engine.py (2)
is_warmup
(494-495)is_warmup
(498-501)
tensorrt_llm/_torch/speculative/save_hidden_state.py (7)
tensorrt_llm/_utils.py (1)
local_mpi_rank
(502-503)tensorrt_llm/_torch/pyexecutor/llm_request.py (1)
LlmRequest
(282-422)tensorrt_llm/_torch/pyexecutor/resource_manager.py (1)
ResourceManager
(966-1001)tensorrt_llm/_torch/pyexecutor/scheduler.py (1)
ScheduledRequests
(18-39)tensorrt_llm/_torch/speculative/drafter.py (3)
Drafter
(9-46)prepare_draft_tokens
(16-27)needs_draft_forward_post
(40-46)tensorrt_llm/_torch/speculative/eagle3.py (2)
Eagle3ResourceManager
(18-77)Eagle3SpecMetadata
(81-186)tensorrt_llm/_torch/speculative/interface.py (2)
is_final_output_capture
(199-204)maybe_capture_final_hidden_states
(214-219)
🪛 Ruff (0.12.2)
tensorrt_llm/_torch/speculative/save_hidden_state.py
34-34: Undefined name SaveHiddenStatesDecodingConfig
(F821)
51-51: Undefined name SaveHiddenStatesDecodingConfig
(F821)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Pre-commit Check
🔇 Additional comments (5)
tensorrt_llm/_torch/speculative/drafter.py (1)
40-46
: Good additive hook; default is safe.Adding a base no-op switch for post-forward drafting is reasonable and keeps existing drafters unaffected.
tensorrt_llm/_torch/pyexecutor/py_executor.py (1)
1007-1010
: Verify ordering: post-forward hook runs after model forward and sampling but before request-state updates.This looks correct for SaveHiddenStates (final hidden states are captured during forward, then processed here before
_update_request_states
). Please confirm this ordering matches the intended contract for all drafters that may use the hook.tensorrt_llm/_torch/speculative/save_hidden_state.py (3)
16-30
: Spec metadata: final-hidden-states capture path looks correct.Gates on
save_last_layer_post_norm
and slices by active tokens to avoid shape mismatches. LGTM.
34-45
: Resource manager: allocate buffer only when needed; aligns with metadata usage.The shape
(max_num_tokens, hidden_size)
is consistent with per-token last-layer captures. LGTM.
104-118
: Pre-forward drafting placeholders look fine.Setting
py_max_new_tokens = 1
and providing zeroed draft tokens of fixed length keeps downstream shape assumptions intact.
@IzzyPutterman Hi Izzy, I noticed that this PR has been pending for 3 weeks, is there any rough timeline as to when you think this can be ready to be prepared to merge to main? :) June |
I will have this PR ready before the free days this week, I need to make some small changes to the logic, address Michaels comments, and add a test I think |
55f5575
to
fd8a03c
Compare
77a4eb3
to
7d7ad63
Compare
PR_Github #20001 [ run ] triggered by Bot |
PR_Github #20001 [ run ] completed with state |
/bot run |
PR_Github #20027 [ run ] triggered by Bot |
PR_Github #20027 [ run ] completed with state |
/bot run |
Signed-off-by: Izzy Putterman <[email protected]> small changes Signed-off-by: Izzy Putterman <[email protected]> fixes inflight Signed-off-by: Izzy Putterman <[email protected]> maybe functional Signed-off-by: Izzy Putterman <[email protected]> cleanup + small design changes Signed-off-by: Izzy Putterman <[email protected]> drop unused lines Signed-off-by: Izzy Putterman <[email protected]>
Signed-off-by: Izzy Putterman <[email protected]>
206978a
to
fc3112f
Compare
/bot run |
PR_Github #20118 [ run ] triggered by Bot |
PR_Github #20118 [ run ] completed with state |
/bot run |
PR_Github #20122 [ run ] triggered by Bot |
Signed-off-by: Izzy Putterman <[email protected]>
fc3112f
to
6b6b73b
Compare
/bot run |
PR_Github #20125 [ run ] triggered by Bot |
PR_Github #20122 [ run ] completed with state |
PR_Github #20125 [ run ] completed with state |
Bump on review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
api changes look good 👍
@IzzyPutterman Next time when we merge the PR to the main branch, we can remove the "draft" annotation:) June |
Yep, sorry about that. Thanks for the catch! |
Signed-off-by: Izzy Putterman <[email protected]> Signed-off-by: Faradawn Yang <[email protected]>
Signed-off-by: Izzy Putterman <[email protected]>
Signed-off-by: Izzy Putterman <[email protected]> Signed-off-by: Faradawn Yang <[email protected]>
Summary by CodeRabbit
New Features
Behavior Changes
Description
Test Coverage
GitHub Bot Help
/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...
Provide a user friendly way for developers to interact with a Jenkins server.
Run
/bot [-h|--help]
to print this help message.See details below for each supported subcommand.
run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]
Launch build/test pipelines. All previously running jobs will be killed.
--reuse-test (optional)pipeline-id
(OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.--disable-reuse-test
(OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.--disable-fail-fast
(OPTIONAL) : Disable fail fast on build/tests/infra failures.--skip-test
(OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.--stage-list "A10-PyTorch-1, xxx"
(OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.--gpu-type "A30, H100_PCIe"
(OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.--test-backend "pytorch, cpp"
(OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.--only-multi-gpu-test
(OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.--disable-multi-gpu-test
(OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.--add-multi-gpu-test
(OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.--post-merge
(OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx"
(OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".--detailed-log
(OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.--debug
(OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in thestage-list
parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.For guidance on mapping tests to stage names, see
docs/source/reference/ci-overview.md
and the
scripts/test_to_stage_mapping.py
helper.kill
kill
Kill all running builds associated with pull request.
skip
skip --comment COMMENT
Skip testing for latest commit on pull request.
--comment "Reason for skipping build/test"
is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.reuse-pipeline
reuse-pipeline
Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.