Skip to content

Conversation

@NVShreyas
Copy link
Collaborator

@NVShreyas NVShreyas commented Jul 29, 2025

Summary by CodeRabbit

  • New Features

    • Enhanced model with improved handling of residual connections for more precise control during processing.
    • Added a new forward method supporting flexible input options and iterative layer processing.
  • Refactor

    • Updated forward pass to return both hidden states and residuals, improving output structure and normalization flow.

Description

The nsys profile showed that these ops were not fused so this PR addresses that and improves perf as shown below:

  1. Llama nemotron super v1:
Config TPS improvement
BF16 H100 TP8 1.03 - 1.12x
BF16 A100 TP8 1.01 - 1.07x

Test Coverage

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

@NVShreyas NVShreyas requested a review from a team as a code owner July 29, 2025 16:05
@NVShreyas NVShreyas requested review from hyukn and nv-yilinf July 29, 2025 16:05
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jul 29, 2025

📝 Walkthrough

"""

Walkthrough

The NemotronNASDecoderLayer's forward method was refactored to handle an optional residual tensor, changing its internal normalization and output signature. The NemotronNASModel class now includes a new forward method that manages input validation, embedding lookup, residual propagation, and normalization across decoder layers, updating the model's control flow.

Changes

Cohort / File(s) Change Summary
NemotronNAS Decoder Layer Residual Refactor
tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Refactored NemotronNASDecoderLayer.forward to accept and propagate an optional residual tensor, altered normalization logic, changed output signature to return a tuple.
NemotronNAS Model Forward Implementation
tensorrt_llm/_torch/models/modeling_nemotron_nas.py
Added NemotronNASModel.forward to handle input validation, embedding resolution, residual initialization and propagation, iterative decoder layer calls, and final normalization.

Sequence Diagram(s)

sequenceDiagram
    participant Caller
    participant NemotronNASModel
    participant DecoderLayer
    participant RMSNorm

    Caller->>NemotronNASModel: forward(attn_metadata, input_ids/inputs_embeds, ...)
    alt input_ids provided
        NemotronNASModel->>NemotronNASModel: lookup embeddings
    end
    NemotronNASModel->>NemotronNASModel: residual = None
    loop for each DecoderLayer
        NemotronNASModel->>DecoderLayer: forward(position_ids, hidden_states, attn_metadata, residual, ...)
        DecoderLayer-->>NemotronNASModel: (hidden_states, residual)
    end
    NemotronNASModel->>RMSNorm: forward(hidden_states, residual)
    RMSNorm-->>NemotronNASModel: normalized_hidden_states
    NemotronNASModel-->>Caller: normalized_hidden_states
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes
"""

Note

⚡️ Unit Test Generation is now available in beta!

Learn more here, or try it out under "Finishing Touches" below.


📜 Recent review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4132eaa and a65a00b.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (3)
tensorrt_llm/_torch/models/modeling_nemotron_nas.py (3)

157-181: LGTM! Residual handling enables operator fusion optimization.

The conditional residual handling is well-implemented:

  • Properly initializes residual when None
  • Correctly passes both tensors to normalization layers when residual is provided
  • Maintains consistency between attention and FFN blocks
  • Returns both tensors to support chaining through decoder layers

This change aligns with the PR objective of enabling operator fusion for performance improvements.


250-266: LGTM! Model forward method correctly implements residual propagation.

The implementation properly:

  • Handles embedding lookup when needed
  • Initializes residual state for decoder layer chain
  • Propagates residual through all decoder layers
  • Applies final normalization with residual connection

This enables the performance optimization mentioned in the PR objectives by allowing operator fusion across the model.


245-248: Fix the input validation logic.

The XOR condition (input_ids is None) ^ (inputs_embeds is not None) is incorrect. It should ensure that exactly one of them is provided, not that they have different truthiness values.

Apply this diff to fix the validation logic:

-        if (input_ids is None) ^ (inputs_embeds is not None):
+        if (input_ids is None) == (inputs_embeds is None):
             raise ValueError(
                 "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
             )

Likely an incorrect or invalid review comment.

✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share
🪧 Tips

Chat

There are 3 ways to chat with CodeRabbit:

  • Review comments: Directly reply to a review comment made by CodeRabbit. Example:
    • I pushed a fix in commit <commit_id>, please review it.
    • Explain this complex logic.
    • Open a follow-up GitHub issue for this discussion.
  • Files and specific lines of code (under the "Files changed" tab): Tag @coderabbitai in a new review comment at the desired location with your query. Examples:
    • @coderabbitai explain this code block.
    • @coderabbitai modularize this function.
  • PR comments: Tag @coderabbitai in a new PR comment to ask questions about the PR branch. For the best results, please provide a very specific query, as very limited context is provided in this mode. Examples:
    • @coderabbitai gather interesting stats about this repository and render them as a table. Additionally, render a pie chart showing the language distribution in the codebase.
    • @coderabbitai read src/utils.ts and explain its main purpose.
    • @coderabbitai read the files in the src/scheduler package and generate a class diagram using mermaid and a README in the markdown format.
    • @coderabbitai help me debug CodeRabbit configuration file.

Support

Need help? Create a ticket on our support page for assistance with any issues or questions.

Note: Be mindful of the bot's finite context window. It's strongly recommended to break down tasks such as reading entire modules into smaller chunks. For a focused discussion, use review comments to chat about specific files and their changes, instead of using the PR comments.

CodeRabbit Commands (Invoked using PR comments)

  • @coderabbitai pause to pause the reviews on a PR.
  • @coderabbitai resume to resume the paused reviews.
  • @coderabbitai review to trigger an incremental review. This is useful when automatic reviews are disabled for the repository.
  • @coderabbitai full review to do a full review from scratch and review all the files again.
  • @coderabbitai summary to regenerate the summary of the PR.
  • @coderabbitai generate docstrings to generate docstrings for this PR.
  • @coderabbitai generate sequence diagram to generate a sequence diagram of the changes in this PR.
  • @coderabbitai generate unit tests to generate unit tests for this PR.
  • @coderabbitai resolve resolve all the CodeRabbit review comments.
  • @coderabbitai configuration to show the current CodeRabbit configuration for the repository.
  • @coderabbitai help to get help.

Other keywords and placeholders

  • Add @coderabbitai ignore anywhere in the PR description to prevent this PR from being reviewed.
  • Add @coderabbitai summary to generate the high-level summary at a specific location in the PR description.
  • Add @coderabbitai or @coderabbitai title anywhere in the PR title to generate the title automatically.

Documentation and Community

  • Visit our Documentation for detailed information on how to use CodeRabbit.
  • Join our Discord Community to get help, request features, and share feedback.
  • Follow us on X/Twitter for updates and announcements.

@NVShreyas NVShreyas marked this pull request as draft July 29, 2025 16:06
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🔭 Outside diff range comments (1)
tensorrt_llm/_torch/models/modeling_nemotron_nas.py (1)

147-154: Fix return type annotation to match the actual return value.

The method now returns a tuple (hidden_states, residual) but the return type annotation still shows torch.Tensor.

Apply this diff to fix the return type annotation:

+from typing import Any, Dict, Optional, Tuple
 
 def forward(
     self,
     position_ids: torch.IntTensor,
     hidden_states: torch.Tensor,
     attn_metadata: AttentionMetadata,
     residual: Optional[torch.Tensor] = None,
     **kwargs,
-) -> torch.Tensor:
+) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:

Note: Add Tuple to the imports at the top of the file if not already present.

Also applies to: 181-181

🧹 Nitpick comments (1)
tensorrt_llm/_torch/models/modeling_nemotron_nas.py (1)

245-248: Improve error message clarity.

The error message could be more precise about the requirement.

Apply this diff to clarify the error message:

 if (input_ids is None) ^ (inputs_embeds is not None):
     raise ValueError(
-        "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+        "You must specify exactly one of input_ids or inputs_embeds"
     )
📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7231134 and a2f47de.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile = ...).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL = ...).
Python constants should use upper snake_case (e.g., MY_CONSTANT = ...).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a class in the constructor in Python.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for classes and functions in Python, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without it.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
**/*.{cpp,h,hpp,cc,cxx,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
🔇 Additional comments (3)
tensorrt_llm/_torch/models/modeling_nemotron_nas.py (3)

157-162: Residual handling implementation looks good!

The conditional residual initialization and propagation through normalization layers enables operator fusion while maintaining backward compatibility. This aligns well with the PR's performance optimization goals.

Also applies to: 173-179


235-267: Well-implemented forward method with proper residual propagation!

The method correctly manages residual connections through all decoder layers and applies the final normalization. The implementation enables the operator fusion optimizations mentioned in the PR objectives.


161-162: RMSNorm two-argument support confirmed

The RMSNorm implementation in tensorrt_llm/_torch/modules/rms_norm.py defines:

def forward(
    self,
    hidden_states: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    …
    if isinstance(residual, torch.Tensor):
        flashinfer_fused_add_rmsnorm(hidden_states, residual, …)
    else:
        hidden_states = flashinfer_rmsnorm(hidden_states, …)
    …

This shows that forward accepts (hidden_states, residual) and correctly dispatches the fused operator.

No changes required.

@NVShreyas NVShreyas marked this pull request as ready for review July 29, 2025 16:35
@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13395 [ run ] triggered by Bot

@NVShreyas
Copy link
Collaborator Author

duplicate of #6157 without extra merge commits

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13395 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10018 completed with status: 'FAILURE'

@NVShreyas NVShreyas force-pushed the user/shreyasm/nemotron-norm-fuse branch from a2f47de to 1061c4e Compare July 29, 2025 18:51
@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13409 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13409 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10030 completed with status: 'FAILURE'

@NVShreyas
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13546 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13546 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #10154 completed with status: 'SUCCESS'

@NVShreyas NVShreyas force-pushed the user/shreyasm/nemotron-norm-fuse branch from 1061c4e to f756b0d Compare July 30, 2025 14:57
@NVShreyas NVShreyas requested a review from yizhang-nv July 30, 2025 14:59
@NVShreyas
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13553 [ reuse-pipeline ] triggered by Bot

@NVShreyas NVShreyas force-pushed the user/shreyasm/nemotron-norm-fuse branch from f756b0d to 4132eaa Compare July 30, 2025 15:12
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: .coderabbit.yaml
Review profile: CHILL
Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f756b0d and 4132eaa.

📒 Files selected for processing (1)
  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

**/*.py: The code developed for TensorRT-LLM should conform to Python 3.8+.
Indent Python code with 4 spaces. Do not use tabs.
Always maintain the namespace when importing in Python, even if only one class or function from a module is used.
Python filenames should use snake_case (e.g., some_file.py).
Python classes should use PascalCase (e.g., class SomeClass).
Python functions and methods should use snake_case (e.g., def my_awesome_function():).
Python local variables should use snake_case. Prefix k for variable names that start with a number (e.g., k_99th_percentile).
Python global variables should use upper snake_case and prefix G (e.g., G_MY_GLOBAL).
Python constants should use upper snake_case (e.g., MY_CONSTANT).
Avoid shadowing variables declared in an outer scope in Python.
Initialize all externally visible members of a Python class in the constructor.
For interfaces that may be used outside a file, prefer docstrings over comments in Python.
Comments in Python should be reserved for code within a function, or interfaces that are local to a file.
Use Google style docstrings for Python classes and functions, which can be parsed by Sphinx.
Attributes and variables in Python can be documented inline; attribute docstrings will be rendered under the docstring for the class.
Avoid using reflection in Python when functionality can be easily achieved without reflection.
When using try-except blocks in Python, limit the except to the smallest set of errors possible.
When using try-except blocks to handle multiple possible variable types in Python, keep the body of the try as small as possible, using the else block to implement the logic.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
**/*.{cpp,h,cu,py}

📄 CodeRabbit Inference Engine (CODING_GUIDELINES.md)

All TensorRT-LLM Open Source Software code should contain an NVIDIA copyright header that includes the current year. This includes .cpp, .h, .cu, .py, and any other source files which are compiled or interpreted.

Files:

  • tensorrt_llm/_torch/models/modeling_nemotron_nas.py
🔇 Additional comments (5)
tensorrt_llm/_torch/models/modeling_nemotron_nas.py (5)

173-178: LGTM! Consistent residual handling implementation.

The FFN block residual handling follows the same pattern as the attention block, ensuring consistent behavior throughout the layer. This implementation should enable the intended operation fusion for performance gains.


181-181: LGTM! Correct tuple return for residual propagation.

The return statement correctly returns both hidden_states and residual, enabling proper residual propagation between decoder layers for the fusion optimization.


250-266: LGTM! Well-implemented forward pass with residual propagation.

The forward method correctly implements:

  • Embedding lookup from input_ids when inputs_embeds is not provided
  • Proper residual initialization and propagation through decoder layers
  • Final normalization using both hidden_states and residual for fusion benefits

This implementation should achieve the intended performance improvements through operation fusion, assuming the input validation is fixed.


245-248: Fix the input validation logic.

The XOR validation logic is incorrect. The condition (input_ids is None) ^ (inputs_embeds is not None) will be True when both are None or both are provided, but the error message suggests you want exactly one to be provided.

Apply this diff to fix the validation logic:

-        if (input_ids is None) ^ (inputs_embeds is not None):
+        if (input_ids is None) == (inputs_embeds is None):
             raise ValueError(
-                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
+                "You must specify exactly one of input_ids or inputs_embeds"
             )

Likely an incorrect or invalid review comment.


157-162: RMSNorm dual-tensor forward verified

The RMSNorm.forward(self, hidden_states, residual: Optional[Tensor] = None) signature returns Union[Tensor, Tuple[Tensor, Tensor]], confirming it supports both single- and dual-tensor inputs. The attention block’s residual handling logic is correct and ready to merge.

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13553 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #13546 for commit f756b0d

Signed-off-by: Shreyas Misra <[email protected]>
@NVShreyas NVShreyas force-pushed the user/shreyasm/nemotron-norm-fuse branch from 4132eaa to a65a00b Compare July 30, 2025 15:34
@NVShreyas
Copy link
Collaborator Author

/bot reuse-pipeline

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13557 [ reuse-pipeline ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #13557 [ reuse-pipeline ] completed with state SUCCESS
Reusing PR_Github #13546 for commit a65a00b

@Tabrizian Tabrizian merged commit e67f4da into NVIDIA:main Jul 30, 2025
3 checks passed
lancelly pushed a commit to lancelly/TensorRT-LLM that referenced this pull request Aug 6, 2025
jain-ria pushed a commit to jain-ria/TensorRT-LLM that referenced this pull request Aug 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants