Skip to content

Conversation

Fridah-nv
Copy link
Collaborator

@Fridah-nv Fridah-nv commented Aug 25, 2025

This PR includes these major changes:

  1. quantization transformation is separated into two stages: pattern matcher and post_load_fusion.
  • In pattern matcher stage, quantize_linear_from_config now quantize the linear nodes into torch fake quant ops like torch.ops.auto_deploy.torch_fake_quant_fp8_linear
  • In post_load_fusion stage, reference ops are mapped to real quant implementation, e.g. torch_quant_fp8_linear (This is the op we previously map to, can consider rename these ops)
  • We introduce the fake quantize op stage to further standardize the graph for other transforms in the pipeline, and also represent same quantization format (e.g.'FP8') in different sources (e.g. 'modelopt', 'compressed_tensor'
  1. Fusion passes now handled quantized ops with quantized fusion transforms (fuse_fp4_gemms, fuse_fp8_gemms), each quantization format would inherit from QuantizationFusionMixin and implement fuse_rule for its weight and scalers.

  2. Similarly, Sharding passes handle quantized op with specialized sharding info (FP8TPShardingInfo , FP4TPShardingInfo) that contains specific implementation for each quantization format. Different types of TPShardingInfo is dispatched when we add new object with TPShardingInfo.from_node.

  3. Refactor quantization transforms to inherit from base Quantization transform. e.g. Linear (from config, from graph), BMM, MoE quantization

  4. Separate quant special case of other transforms to inherit from the transform. e.g. quantize version of MoE pattern matching

Description

Test Coverage

  • Unit tests H100/B100
  • E2E tests: Llama3, Llama3-FP8, QWen2.5-FP4
  • E2E quant MoE: Mistral-FP8
  • use_sharding_from_factory tested with Llama3
  • Perf: trtllm-bench with llama3 and llama3-FP8

Perf results:
https://docs.google.com/spreadsheets/d/1JY2YvvNitYdFhVwGK-60Wr0KOHK0-3hbElASIsWaazY/edit?usp=sharing

**E2E tests use world_size=2

TODO

Quantized GEMM fusion is disabled because it's incompatible with the pattern matcher afterwards. (Unquantized GEMM fusion has similar issue)
Related issue: #7270

GitHub Bot Help

/bot [-h] ['run', 'kill', 'skip', 'reuse-pipeline'] ...

Provide a user friendly way for developers to interact with a Jenkins server.

Run /bot [-h|--help] to print this help message.

See details below for each supported subcommand.

run [--reuse-test (optional)pipeline-id --disable-fail-fast --skip-test --stage-list "A10-PyTorch-1, xxx" --gpu-type "A30, H100_PCIe" --test-backend "pytorch, cpp" --add-multi-gpu-test --only-multi-gpu-test --disable-multi-gpu-test --post-merge --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" --detailed-log --debug(experimental)]

Launch build/test pipelines. All previously running jobs will be killed.

--reuse-test (optional)pipeline-id (OPTIONAL) : Allow the new pipeline to reuse build artifacts and skip successful test stages from a specified pipeline or the last pipeline if no pipeline-id is indicated. If the Git commit ID has changed, this option will be always ignored. The DEFAULT behavior of the bot is to reuse build artifacts and successful test results from the last pipeline.

--disable-reuse-test (OPTIONAL) : Explicitly prevent the pipeline from reusing build artifacts and skipping successful test stages from a previous pipeline. Ensure that all builds and tests are run regardless of previous successes.

--disable-fail-fast (OPTIONAL) : Disable fail fast on build/tests/infra failures.

--skip-test (OPTIONAL) : Skip all test stages, but still run build stages, package stages and sanity check stages. Note: Does NOT update GitHub check status.

--stage-list "A10-PyTorch-1, xxx" (OPTIONAL) : Only run the specified test stages. Examples: "A10-PyTorch-1, xxx". Note: Does NOT update GitHub check status.

--gpu-type "A30, H100_PCIe" (OPTIONAL) : Only run the test stages on the specified GPU types. Examples: "A30, H100_PCIe". Note: Does NOT update GitHub check status.

--test-backend "pytorch, cpp" (OPTIONAL) : Skip test stages which don't match the specified backends. Only support [pytorch, cpp, tensorrt, triton]. Examples: "pytorch, cpp" (does not run test stages with tensorrt or triton backend). Note: Does NOT update GitHub pipeline status.

--only-multi-gpu-test (OPTIONAL) : Only run the multi-GPU tests. Note: Does NOT update GitHub check status.

--disable-multi-gpu-test (OPTIONAL) : Disable the multi-GPU tests. Note: Does NOT update GitHub check status.

--add-multi-gpu-test (OPTIONAL) : Force run the multi-GPU tests in addition to running L0 pre-merge pipeline.

--post-merge (OPTIONAL) : Run the L0 post-merge pipeline instead of the ordinary L0 pre-merge pipeline.

--extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx" (OPTIONAL) : Run the ordinary L0 pre-merge pipeline and specified test stages. Examples: --extra-stage "H100_PCIe-TensorRT-Post-Merge-1, xxx".

--detailed-log (OPTIONAL) : Enable flushing out all logs to the Jenkins console. This will significantly increase the log volume and may slow down the job.

--debug (OPTIONAL) : Experimental feature. Enable access to the CI container for debugging purpose. Note: Specify exactly one stage in the stage-list parameter to access the appropriate container environment. Note: Does NOT update GitHub check status.

For guidance on mapping tests to stage names, see docs/source/reference/ci-overview.md
and the scripts/test_to_stage_mapping.py helper.

kill

kill

Kill all running builds associated with pull request.

skip

skip --comment COMMENT

Skip testing for latest commit on pull request. --comment "Reason for skipping build/test" is required. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

reuse-pipeline

reuse-pipeline

Reuse a previous pipeline to validate current commit. This action will also kill all currently running builds associated with the pull request. IMPORTANT NOTE: This is dangerous since lack of user care and validation can cause top of tree to break.

Copy link
Contributor

coderabbitai bot commented Aug 25, 2025

📝 Walkthrough

Walkthrough

Renames FP4 ops to NVFP4 across code/tests, adds FP8/NVFP4 fake quant ops, and introduces class-based quantization, fusion, and MoE matching transforms. Updates sharding to node-centric factories with quantization-aware TP/EP variants. Config reworks transform keys and adds new post-load fusion. Tests updated/added for FP8/NVFP4 quantization, fusion, MoE, and sharding.

Changes

Cohort / File(s) Summary
Config transforms
tensorrt_llm/_torch/auto_deploy/config/default.yaml
Renames quantize_* keys to fp8/nvfp4-specific entries; adds fp8/nvfp4 graph/config quantizers and MoE variants; enables post-load fusion: fuse_fp4_gemms, fuse_fp8_gemms, fuse_fp8_linear (torch), fuse_nvfp4_linear (trtllm), fuse_moe.
Custom ops – API surface
.../custom_ops/README.md, .../custom_ops/__init__.py
Public op renamed: torch_quant_fp4_linear → torch_quant_nvfp4_linear; re-exports custom ops via wildcard in init.
Custom ops – linear/MoE
.../custom_ops/quant.py, .../custom_ops/torch_moe.py
Renames FP4 linear/MoE to NVFP4 (symbols and fakes); removes QUANT_* aggregation constants; updates internal call sites to nvfp4 names.
Custom ops – fake quant ops
.../custom_ops/torch_quant.py
Adds fake-quant linear ops: torch_fake_quant_fp8_linear and torch_fake_quant_nvfp4_linear with supporting helpers (FP8 convert/dequant, NVFP4 pack/dequant, FP4 tables).
Transforms – quant fusion rules
.../transform/library/fuse_quant.py
New pattern-matcher transforms to fuse FP8/NVFP4 fake-quant linear into fused linear ops; registers patterns (with/without bias); adds FuseFP8Linear/FuseNVFP4Linear with backend validation.
Transforms – GEMM fusion framework
.../transform/library/fusion.py
Introduces QuantizationFusionMixin; adds FuseFP8Gemms and FuseFP4Gemms; enforces homogeneous-children rule; utility check_same_children; refactors general fuse path.
Transforms – MoE fusion matcher
.../transform/library/fused_moe.py
Generalizes MoE pattern matching with op/scale-driven design; adds MatchSimple/FP8/NVFP4 pattern classes; builds fused MoE args including per-branch scales.
Transforms – quantization framework
.../transform/library/quantization.py
Adds Quantization base and concrete FP8/NVFP4 classes for linear/BMM; provides config/graph variants; registers quantize_fp8/nvfp4_* transforms; embeds hooks for scales and load/amax.
Transforms – MoE quantization
.../transform/library/quantize_moe.py
Replaces single QuantizeMOE with QuantizeFP8MOE and QuantizeNVFP4MOE using new Quantization classes and nvfp4 op.
Transforms – sharding pass
.../transform/library/sharding.py
Switches to TP/EPShardingInfo.from_node; recognizes fake-quant linear ops; updates MoE op to nvfp4 variant.
Utils – node utils
.../utils/node_utils.py
Simplifies is_linear_op/is_bmm_op (no quant arg); adds is_fake_quantized_linear_op; removes QUANT_* usage and an assertion.
Utils – quantization utils
.../utils/quantization_utils.py
Removes QuantizationImpl and related helpers; drops get_scales_and_type_from_node; streamlines checks and removals.
Utils – sharding utils
.../utils/sharding_utils.py
Adds QuantizationShardingMixin; FP8/FP4 TP/EP sharding classes; TP/EP resolver rules; _shard_fp4_weight_scale; extends _insert_sharded_matmul with quantization callback; adjusts _insert_sharded_moe.
Tests – helpers
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
Adds debug prints for detected vs expected sets.
Tests – model utils
.../_model_test_utils.py
Adds FP8_MAX, FakeFP8Linear, and generate_dynamic_shapes helper.
Tests – EP sharding
.../unit/multigpu/transformations/library/test_ep_sharding.py
Imports sharding info from utils; selects EP sharding class per MoE type (simple/FP8/NVFP4).
Tests – TP sharding
.../unit/multigpu/transformations/library/test_tp_sharding.py
Adds FP8MLP using FakeFP8Linear; uses FP8TPShardingInfo; updates detection logic to new node utils.
Tests – MoE custom op
.../unit/singlegpu/custom_ops/test_ad_moe_op.py
Switches nvfp4 MoE op symbol in FP4 test.
Tests – quant custom ops
.../unit/singlegpu/custom_ops/test_quant.py
Renames scaling_vector_size → SCALING_VECTOR_SIZE; switches nvfp4 linear op; adds FP8/NVFP4 parity tests.
Tests – GEMM fusion
.../unit/singlegpu/transformations/library/test_gemm_fusion.py
Uses FakeFP8Linear; recognizes fake-quant FP8 ops; enables fuse_fp8_gemms in config.
Tests – MoE fusion
.../unit/singlegpu/transformations/library/test_moe_fusion.py
Updates NVFP4 paths to nvfp4 ops; adds fp8/nvfp4 MoE pattern-matchers in config.
Tests – quant fusion
.../unit/singlegpu/transformations/library/test_quant_fusion.py
New suite validating FP8/NVFP4 linear fusion rewrites with reference vs fused outputs.
Tests – quant MoE
.../unit/singlegpu/transformations/library/test_quant_moe.py
Splits quantize_moe into quantize_fp8_moe and quantize_nvfp4_moe; adjusts expected params for NVFP4.
Tests – quantization
.../unit/singlegpu/transformations/library/test_quantization.py
Selects QUANT_OP per algo; updates transform keys to fp8/nvfp4-specific; removes QUANT_OPS helper.
Tests – quant utils
.../unit/singlegpu/utils/test_quantization_utils.py
Uses FP8LinearQuantizationFromConfig with TransformConfig; moves _shard_fp4_weight_scale import to sharding_utils.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant InferenceOptimizer
  participant PatternMatcher
  participant GraphModule
  participant FusedOps

  User->>InferenceOptimizer: configure post_load_fusion (fuse_fp8_linear / fuse_nvfp4_linear)
  InferenceOptimizer->>PatternMatcher: register patterns (ref fake-quant linear ± bias)
  PatternMatcher->>GraphModule: scan graph for matches
  alt matches found
    PatternMatcher->>GraphModule: replace with fused op (torch_quant_fp8_linear / torch_quant_nvfp4_linear)
    GraphModule-->>InferenceOptimizer: report num_matches
    InferenceOptimizer->>FusedOps: execute fused kernels at runtime
  else no matches
    GraphModule-->>InferenceOptimizer: skipped=true
  end
Loading
sequenceDiagram
  autonumber
  participant Config
  participant QuantTransform as Quantization (FP8/NVFP4)
  participant GraphModule
  participant Weights
  participant Scales

  Config->>QuantTransform: algo=FP8 or NVFP4
  QuantTransform->>GraphModule: iterate linear/BMM nodes
  QuantTransform->>Weights: quantize per algo
  QuantTransform->>Scales: create/register scale buffers
  QuantTransform->>GraphModule: replace op with quantized op + args (scales[, alpha])
Loading
sequenceDiagram
  autonumber
  participant Matcher as MoE Matcher (Simple/FP8/NVFP4)
  participant GraphModule
  participant FusedMoE

  Matcher->>GraphModule: detect expert compute branches (target_op)
  Matcher->>GraphModule: collect per-branch weights + scales
  Matcher->>GraphModule: insert fused MoE op (torch_moe / torch_quant_fp8_moe / torch_quant_nvfp4_moe)
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Possibly related PRs

Suggested labels

AutoDeploy

Suggested reviewers

  • suyoggupta
  • QiJune
  • yuxianq

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
		  - name: "Undocumented Breaking Changes"
			  mode: "warning"
			  instructions: |
				  Flag potential breaking changes that are not documented:
				  1. Identify changes to public APIs/exports, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints (including removed/renamed items and changes to types, required params, return values, defaults, or behavior).
				  2. Ignore purely internal/private changes (e.g., code not exported from package entry points or marked internal).
				  3. Verify documentation exists: a "Breaking Change" section in the PR description and updates to CHANGELOG.md.

Pre-merge checks (1 passed, 2 warnings)

❌ Failed checks (2 warnings)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 29.50% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
Description Check ⚠️ Warning The PR description does not adhere to the repository’s required template: it lacks a formal title in the “[JIRA…][type] Summary” format, the “## Description” section is empty, and the “## PR Checklist” section is missing even though the template specifies these sections explicitly. The initial numbered list of changes sits outside the prescribed “## Description” heading, and the placeholder comments for title and description remain unfilled. Consequently, critical context about the issue, solution, and checklist verification steps are not clearly defined. Please update the PR description to include a properly formatted title (e.g. “[TRTLLM-1234][feat] …”), move the summary of changes into the “## Description” section, complete that section with the problem statement and solution overview, and add the “## PR Checklist” section with the required items. This will ensure full compliance with the repository’s PR submission guidelines.
✅ Passed checks (1 passed)
Check name Status Explanation
Title Check ✅ Passed The title succinctly summarizes the primary change—refactoring the quantization transforms using inheritance—and clearly relates to the main scope of the PR without extraneous detail.
✨ Finishing Touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 10

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (6)
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)

1-1: Add NVIDIA Apache-2.0 copyright header (required by repo guidelines).

All Python sources must carry the NVIDIA Apache-2.0 header. Please prepend it here.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

1-1: Add NVIDIA Apache-2.0 header (tests are Python files too).

Please add the standard header at the top.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)

265-283: Recompile GraphModule after pattern rewrites.

ADPatternMatcherPass.apply mutates gm.graph, but the Python code of the GraphModule isn’t regenerated automatically. Recompiling avoids stale code and ensures subsequent passes (and any eager execution) see the updated graph.

-        num_matches = patterns.apply(gm.graph)
+        num_matches = patterns.apply(gm.graph)
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)

1-1: Add NVIDIA Apache-2.0 header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

53-55: Fix potential IndexError when reading node.users.

user = list(node.users.keys())[0] will throw if there are zero users. Guard the access.

-        user = list(node.users.keys())[0]
-        if len(node.users) == 1 and is_quantized_op(user):
-            user.replace_all_uses_with(node)
+        if len(node.users) == 1:
+            sole_user = next(iter(node.users))
+            if is_quantized_op(sole_user):
+                sole_user.replace_all_uses_with(node)

174-211: Recompile and set skipped flag accurately in LinearQuantizationFromConfig.

  • Set skipped based on num_matches == 0.
  • Recompile after mutations to materialize changes in GraphModule code.
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )
🧹 Nitpick comments (21)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

252-254: Make linear-op detection robust when custom ops are unavailable

Directly accessing torch.ops.auto_deploy.torch_fake_quant_fp{8,4}_linear can raise AttributeError on builds where these ops aren’t registered. Guard the additions so import-time doesn’t fail on older runtimes.

Apply this diff:

     if include_quantization:
         lin_ops.update(QUANT_LINEAR_OPS)
-        lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp8_linear])
-        lin_ops.update([torch.ops.auto_deploy.torch_fake_quant_fp4_linear])
+        for _name in ("torch_fake_quant_fp8_linear", "torch_fake_quant_fp4_linear"):
+            _op = getattr(torch.ops.auto_deploy, _name, None)
+            if _op is not None:
+                lin_ops.add(_op)

1-1: Missing NVIDIA Apache-2.0 header

Per coding guidelines, prepend the NVIDIA Apache-2.0 copyright header (year 2025) to this file.

Do you want me to apply a repo-standard header template across all changed Python files?

tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)

37-73: Fake-mode detection: good coverage; minor resilience nits

The priority (TracingContext → active dispatch stack → inputs) looks correct and fixes fragility in upstream detection. Two small nits:

  • If inputs is None, pytree.tree_leaves(None) returns [None]; harmless but you can avoid a tiny branch cost.
  • If multiple FakeTensorModes are stacked, you’re already iterating the reversed stack (most recent first) — nice.

Optional micro-tidy:

-    flat_inputs = pytree.tree_leaves(inputs)
+    flat_inputs = () if inputs is None else pytree.tree_leaves(inputs)

75-77: Monkeypatch scope: limit to the pattern-matcher pass

Replacing torch._dynamo.utils.detect_fake_mode globally can have side effects in unrelated compilation paths. Patch only during ADPatternMatcherPass.apply() and restore afterward.

Apply this diff:

-# Replace the function used as a context manager
-torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode
+_ORIG_DETECT_FAKE_MODE = getattr(torch._dynamo.utils, "detect_fake_mode", None)

And update the pass:

 class ADPatternMatcherPass(PatternMatcherPass):
@@
     def apply(self, graph: Union[torch.fx.Graph, GraphModule]) -> int:
         """Apply pattern matcher with unsupported_input_tensor patch to bypass meta tensor check."""
-        with _patch_unsupported_input_tensor():
-            return super().apply(graph)
+        with _patch_unsupported_input_tensor():
+            # temporarily patch fake-mode detection
+            orig = getattr(torch._dynamo.utils, "detect_fake_mode", None)
+            torch._dynamo.utils.detect_fake_mode = ad_detect_fake_mode
+            try:
+                return super().apply(graph)
+            finally:
+                if orig is not None:
+                    torch._dynamo.utils.detect_fake_mode = orig

1-1: Missing NVIDIA Apache-2.0 header

Please prepend the NVIDIA Apache-2.0 header (2025) at the top of this file.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

1-1: Missing NVIDIA Apache-2.0 header

Add the standard NVIDIA Apache-2.0 header (2025).

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (4)

112-114: Set sharded params to requires_grad=False: ensure consistency with BMM path

Here we set requires_grad=False for linear shards (good for inference). In BMMShardingInfo.apply, sharded parameters are created with requires_grad=True. For consistency across inference transforms, prefer disabling grads there too.

If you agree, I can provide a follow-up diff in the BMM path.


391-444: FP8/FP4 TP sharding classes: interface looks sound

  • FP8: pass-through scales and no-op load hook make sense.
  • FP4: sharding both weight_scale and handling the *_scale state_dict suffix in the pre-hook is correct.

Two small nits:

  • Consider documenting required buffers explicitly in class docstrings (e.g., FP4 expects input_scale, weight_scale, alpha on the owning submodule).
  • Minor naming: sharded_uint8_weight_shape could just be weight_shape for consistency (docstring already clarifies it’s the sharded packed shape).

446-460: Resolver fallthrough handling

The except Exception: pass in _resolve_tp_cls_from_node hides real issues (e.g., NameError on missing ops). Log at debug at least, so we can diagnose resolver paths.

Minimal change:

-        except Exception:
-            pass
+        except Exception as e:  # pragma: no cover
+            ad_logger.debug(f"TP resolver predicate failed: {e}")

1-1: Missing NVIDIA Apache-2.0 header

Please add the required header (2025).

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)

1-1: Missing NVIDIA Apache-2.0 header in test file

Tests are also subject to the header requirement per guidelines.

I can sweep-add the header to all modified tests if desired.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

94-95: Order of post-load fusions: consider moving fuse_quant earlier.

fuse_quant rewrites fake-quant reference ops into fused kernels. Running it earlier in post-load-fusion can:

  • Increase the chance for downstream kernel-level fusions to act on the fused ops,
  • Reduce pattern mismatch after other structural fusions.

Recommendation: place fuse_quant ahead of GEMM/RMSNorm fusions unless there is a known dependency the other way around.

tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

12-16: Remove duplicate/unused constants; use a single SCALING_VECTOR_SIZE.

  • FORMAT_FP8 and FORMAT_NVFP4 are unused here.
  • Both scaling_vector_size and SCALING_VECTOR_SIZE represent the same value (16). Keep one for consistency.
- scaling_vector_size = 16
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
-
-SCALING_VECTOR_SIZE = 16  # NVFP4 block size along K
+SCALING_VECTOR_SIZE = 16  # NVFP4 block size along K

And replace uses of scaling_vector_size below:

-    weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
-        weight, weight_scale_2, scaling_vector_size, False
-    )
+    weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
+        weight, weight_scale_2, SCALING_VECTOR_SIZE, False
+    )

156-201: LGTM with a minor suggestion for negative-case coverage.

This validates CUTLASS-scale + alpha wiring across dtypes with appropriate tolerances. Consider adding a negative test asserting failure when K is not a multiple of SCALING_VECTOR_SIZE to lock the contract.

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1)

30-46: Be explicit with op overloads for consistency.

Elsewhere you used .default for FP8 pattern calls; mirror that for replacements to avoid ambiguity.

-return torch.ops.auto_deploy.torch_quant_fp8_linear(
+return torch.ops.auto_deploy.torch_quant_fp8_linear.default(
     x, w_fp8, None, input_scale=input_scale, weight_scale=weight_scale)
...
-return torch.ops.auto_deploy.torch_quant_fp8_linear(
+return torch.ops.auto_deploy.torch_quant_fp8_linear.default(
     x, w_fp8, bias, input_scale=input_scale, weight_scale=weight_scale)

And similarly for FP4:

-return torch.ops.auto_deploy.torch_quant_fp4_linear(
+return torch.ops.auto_deploy.torch_quant_fp4_linear.default(
     x, w_fp4, bias=None, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha)
...
-return torch.ops.auto_deploy.torch_quant_fp4_linear(
+return torch.ops.auto_deploy.torch_quant_fp4_linear.default(
     x, w_fp4, bias=bias, input_scale=input_scale, weight_scale=weight_scale, alpha=alpha)

Also applies to: 67-84

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

86-171: Consider deferring graph recompilation to the transform level (avoid per-node overhead).

_insert_quantized_bmm and _insert_quantized_linear appropriately avoid recompiling per node. The transform-level recompile suggested above keeps performance in check while ensuring correctness. No changes needed here beyond the transform finalize steps.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

30-33: Consider consolidating FORMAT_ constants with existing imports*

The TODO comment indicates these format constants should be imported from a common location. Since torch_quant.py also defines the same constants (FORMAT_FP8 = 0, FORMAT_NVFP4 = 1), consider importing them from there to maintain a single source of truth.

-# TODO: put the ENUMs in the same place and import it
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
+from ..custom_ops.torch_quant import FORMAT_FP8, FORMAT_NVFP4

143-154: Consider adding type hints for better clarity

The new methods would benefit from more specific type hints for the object return type to improve type safety and IDE support.

-def build_custom_kwargs_for_linear(
-    scale_getattrs: Dict[str, Node],
-) -> Dict[str, object]:
+def build_custom_kwargs_for_linear(
+    scale_getattrs: Dict[str, Node],
+) -> Dict[str, Union[List[Node], List]]:

194-210: Inconsistent comment: torch_fake_quant_fp8_linear vs FP8

The docstring mentions "torch_fake_quant_fp8_linear" but this is the FP8QuantizationImpl class. Also, the example pattern in the comment doesn't match the actual implementation which uses Node objects, not raw arguments.

    def build_custom_args_for_linear(  # renamed to reflect args
        scale_getattrs: Dict[str, Node],
    ) -> Tuple[object, ...]:
        """
-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for FP8 quantized linear:
            (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list))
-
-        We pass bias=None to match the exported pattern:
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
        """
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)

255-255: Consider using modulo check with error message

The assertion could provide more context about the actual value of K when it fails.

-assert K % 16 == 0, "NVFP4 requires K to be a multiple of 16"
+assert K % 16 == 0, f"NVFP4 requires K to be a multiple of 16, got K={K}"
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)

214-216: Ensure consistent node filtering

The condition checks for partial(is_op, ops=self.target_op) but the initial collection uses is_op(node, self.target_op). While functionally equivalent, consider using the same pattern for clarity.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 788fc62 and d14d5f4.

📒 Files selected for processing (13)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (3 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2 hunks)
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent with 4 spaces; do not use tabs
Preserve module namespace when importing: from package.subpackage import foo; then use foo.SomeClass()
Python filenames use snake_case (e.g., some_file.py)
Class names use PascalCase
Function and method names use snake_case
Local variables use snake_case; prefix k for names starting with a number (e.g., k_99th_percentile)
Global variables are UPPER_SNAKE_CASE prefixed with G (e.g., G_MY_GLOBAL)
Constants are UPPER_SNAKE_CASE
Avoid shadowing variables from an outer scope
Initialize all externally visible members of a class in init
For interfaces used outside a file, prefer docstrings over comments; comments for internal code or local interfaces
Use Google-style docstrings for classes and functions (Sphinx-parsable)
Attributes and variables can be documented inline with trailing docstrings under the class or module
Avoid using reflection when easily avoidable; prefer explicit parameters/constructs over dict(**locals())
In try/except, catch the narrowest exception types possible
For duck-typing try/except, keep try body minimal and place logic in else after attribute existence checks

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
**/*.{h,hpp,hxx,hh,c,cc,cpp,cxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend NVIDIA Apache-2.0 copyright header with current year to all source files

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/node_utils.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
🧬 Code graph analysis (9)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
  • ADPatternMatcherPass (104-110)
  • register_ad_pattern (142-225)
  • apply (107-110)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (3)
  • BaseTransform (139-376)
  • SharedConfig (51-57)
  • TransformRegistry (379-407)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp8_compatible (29-30)
  • fp4_compatible (33-34)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (62-64)
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • TPShardingInfo (221-264)
  • from_node (229-234)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • custom_op (162-164)
  • custom_op (236-238)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (5)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • is_op (183-206)
  • extract_param_names_from_lin_node (149-170)
  • get_op_overload_packet (173-180)
  • is_linear_op (240-254)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • build_custom_args_for_linear (150-153)
  • build_custom_args_for_linear (195-210)
  • build_custom_args_for_linear (291-306)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (2)
  • register (385-392)
  • BaseTransform (139-376)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (8)
  • build_custom_args_for_linear (150-153)
  • build_custom_args_for_linear (195-210)
  • build_custom_args_for_linear (291-306)
  • custom_op (162-164)
  • custom_op (236-238)
  • QuantizationImpl (72-153)
  • create (76-105)
  • should_skip_quantization (471-484)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • TransformRegistry (379-407)
  • register (385-392)
  • BaseTransform (139-376)
  • get (395-397)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_linear_op (240-254)
  • is_bmm_op (257-264)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (166-198)
  • torch_fake_quant_fp8_linear (202-212)
  • torch_fake_quant_fp4_linear (216-274)
  • torch_fake_quant_fp4_linear (278-287)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
  • QuantizationImpl (72-153)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • modelopt_fp4_scale_to_cutlass_fp4_scale (35-44)
  • scale_names (118-120)
  • scale_names (173-174)
  • scale_names (246-247)
  • scale_names (425-426)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Pre-commit Check
🔇 Additional comments (21)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (1)

18-18: PyTorch version guard for TracingContext import

from torch._guards import TracingContext is internal and version-sensitive. If CI still runs on older 2.x, this import will fail at import-time.

Consider guarding the import:

-from torch._guards import TracingContext
+try:
+    from torch._guards import TracingContext
+except Exception:  # pragma: no cover
+    TracingContext = None  # Fallback; ad_detect_fake_mode should branch if None

And in ad_detect_fake_mode, check if TracingContext and (context := TracingContext.try_get()): ...

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (2)

106-114: Good migration to node-aware TPShardingInfo factory

Switching to TPShardingInfo.from_node(n, …) enables quantization-aware subclassing. This looks correct and makes future extensions easier.


308-316: Quantization-aware two-way sharding: ensure fake-quant linears resolve to the right subclass

This path relies on TPShardingInfo.from_node(n, …) correctly detecting FP8/FP4 nodes. Currently, the resolver in sharding_utils.py only checks fused ops (torch_quant_linear_fp{8,4}). If the graph contains fake-quant ops (torch_fake_quant_fp{8,4}_linear), it will fall back to the base TP class and skip scale sharding.

I’ve proposed a fix in sharding_utils.py to include fake-quant ops in TP_SHARDING_RULES. See that comment for a concrete diff.

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (3)

65-68: Extension point for quantization-aware sharding is well designed

Adding quantization_cb to _insert_sharded_matmul is a clean way to decouple quantization specifics from the core sharding flow.


148-158: Quantization callback invocation: good integration point

The call to quantization_cb after weight/bias sharding ensures scales and load hooks are set while shapes are known. This ordering is correct.


378-389: _shard_fp4_weight_scale: shape reconstruction math—request a quick validation

The reconstruction of the original weight shape assumes:

  • weight_shape_original[dim] *= world_size
  • last dim doubled (* 2) for unpacking FP4

If the shard is along columns (dim=1), doubling the last dimension aligns with your packed uint8 scheme. Please verify this with a shape example in tests (e.g., N×K_packed with K_packed=K/2) to avoid off-by-one on block edges.

I can add a focused unit test exercising dim=0 vs dim=1 with uneven multiples of 128/16 to ensure correct slicing of scales.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (2)

76-79: Config key rename LGTM

Renaming to quantize_linear_from_config aligns with the split linear/BMM transforms. Looks good.


158-161: BMM quantization config key rename LGTM

quantize_bmm_from_config key matches the new transform; no issues spotted.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (2)

80-85: Re-enabling GEMM fusions can increase memory pressure; consider guardrails.

fuse_gemms, fuse_fp4_gemms, and fuse_fp8_gemms were previously disabled due to OOM risk. If we re-enable them by default, consider:

  • A config toggle to disable at runtime,
  • Conditioning on batch/hidden sizes,
  • Or staging behind a perf profile flag.

If recent CI runs haven’t covered large models, please schedule one to confirm memory headroom with these fusions enabled.


48-51: Transforms split verified – registry and config updated

  • Registered in quantization.py:
    @TransformRegistry.register("quantize_linear_from_config") at line 174
    @TransformRegistry.register("quantize_bmm_from_config") at line 213
  • Config example in config/default.yaml:
    quantize_linear_from_config at line 48
    quantize_bmm_from_config at line 50
    • No remaining references to the old quantize_from_config key
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

114-143: LGTM: FP8 fused vs. unified parity test is tight and representative.

Good coverage with/without bias, explicit scale wiring ([in_s], [w_s]), and strict tolerances.

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

80-84: LGTM: switch to unified custom op tail args is correct.

Appending positional tail produced by build_custom_args_for_linear aligns with the new fake-quant op contracts.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

103-104: Good improvement: Using custom_op() for unified quantization detection

The change from target_op() to custom_op() standardizes the node matching logic for quantized operations, providing a cleaner separation between target operations and custom kernel entry points.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

165-198: Well-structured FP8 quantization implementation

The reference implementation is clear and well-documented, with proper error handling for dtype mismatches and missing scales.


215-274: Complex but well-documented FP4 quantization logic

The NVFP4 implementation correctly handles the complex per-block quantization scheme with proper scale vector handling and shape transformations. The comments effectively explain the multi-step process.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (6)

86-97: Clean utility function for verifying homogeneous children

The check_same_children function provides a robust way to ensure all child nodes match the expected type before fusion, preventing mixed-precision issues.


99-131: Excellent abstraction with the QuantizationFusionMixin

The mixin pattern effectively captures the common fusion logic while allowing subclasses to customize their specific quantization formats. The clear documentation of required attributes and methods makes this easy to extend.


153-157: Good error handling with informative logging

The try-except block properly catches NotImplementedError and logs a warning with context about which operations couldn't be fused.


282-283: Add dtype validation for weight fusion

Good assertion to ensure only FP8 quantized weights are being fused. This prevents unexpected behavior with mixed precision.


266-311: Well-implemented FP8 fusion with proper scaling

The FP8 fusion correctly handles weight scale recalculation by finding the maximum scale and re-quantizing accordingly. The implementation properly preserves numerical accuracy while enabling fusion.


313-356: FP4 fusion correctly handles per-block scales

The FP4 implementation properly concatenates the per-block scale vectors along with the weights, maintaining the quantization structure needed for the fused operation.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (4)

1-1: Add NVIDIA SPDX+Apache-2.0 header (license compliance).

This source file lacks the required header.

Apply this diff at the very top:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

475-489: Replace built-in generics list[str] with typing.List[str] for Python 3.8 support.

The codebase targets Python 3.8+, which doesn’t support PEP 585 generics on built-ins.

Apply this diff:

-def should_skip_quantization(
-    node_or_name: Union[Node, str],
-    excluded_patterns: list[str],
-) -> bool:
+def should_skip_quantization(
+    node_or_name: Union[Node, str],
+    excluded_patterns: List[str],
+) -> bool:

491-518: Fix type annotations for Python 3.8; also widen return type to Optional.

Use List[str] instead of list[str]; get_scales_and_type_from_node can return None for scales.

Apply this diff:

-def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Optional[Node]]:
+def extract_scales_from_node(node: Node, scale_names: List[str]) -> Dict[str, Optional[Node]]:
@@
-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:

312-359: Ensure consistent input_scale and alpha semantics across both load paths

ModelOpt’s branch currently computes

alpha = 1 / (s_w2 * s_in)

and leaves input_scale as s_in, whereas the HF branch uses

alpha = s_w2 * s_in  
input_scale = 1 / s_in

The custom op expects input_scale to always represent 1/x and alpha to be s_in * s_w2. Please update the ModelOpt path accordingly:

• File: tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
Static method load_hook, around lines 320–335

Proposed diff:

     # ModelOpt quantized graph path
     if weight.dtype != torch.uint8:
         …
-        state_dict[alpha_name] = 1 / (weight_scale_2 * state_dict[input_scale_name])
+        state_dict[alpha_name] = weight_scale_2 * state_dict[input_scale_name]
+        state_dict[input_scale_name] = 1 / state_dict[input_scale_name]

If there are consumers relying on the old convention, consider gating this change under a version/compatibility flag. Let me know if you’d like assistance wiring that up.

♻️ Duplicate comments (3)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

1-1: Add NVIDIA SPDX+Apache-2.0 header (license compliance).

All Python source files must start with the NVIDIA copyright header. Please add it before any imports.

Apply this diff at the top of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

52-54: Use typing.Optional instead of PEP 604 unions for Python 3.8 compatibility.

The “A | B” syntax requires Python 3.10+. Project guideline targets Python 3.8+.

Apply this diff:

 def _nvfp4_get_weights_scaling_factor(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
     keep_high_precision: bool = False,
 ):
 ...
 def _quantize_nvfp4(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
 ):

Also applies to: 107-108

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

294-304: Docstring mentions fp8 operator name in FP4 section; correct to fp4.

Copy/paste error; use torch_fake_quant_fp4_linear in the docstring and example.

Apply this diff:

-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
@@
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (7)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)

39-46: Remove unused parameter from _dequant_weight_fp8 and its call site.

out_features is not used; simplify signature and call.

Apply this diff:

 def _dequant_weight_fp8(
     weight_fp8: torch.Tensor,
     weight_scale: torch.Tensor,
-    out_features: int,
     dtype: torch.dtype,
 ) -> torch.Tensor:
     return weight_fp8.to(dtype) * weight_scale
-    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
+    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)

Also applies to: 195-195


109-117: Docstring return mismatch (function returns 2 values, docstring says 3).

Update the docstring to reflect the actual 2-tuple return: (packed_weight, q_per_block_scale).

Apply this diff:

-    Returns:
-    tuple: Contains quantized data, quantized per block scaling factor, and per block scaling factor.
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: (packed_weight, q_per_block_scale)

88-101: Clarify tie-breaking mask logic for FP4 rounding.

Multiplying a boolean equality by a uint8 mask works by accident. Prefer explicit boolean logic for readability and safety.

Apply this diff for clarity:

-    # Define mask to perform rounding
-    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
+    # Define boolean tie mask: True on odd indices to round up ties
+    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device).bool()
 ...
-    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1)
+    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask, dim=-1)

217-276: NVFP4 scale semantics: input_scale/alpha naming vs. usage is inconsistent with helper contract.

The code treats input_scale as the inverse scale (inv_x) and expects alpha = s_in2 * s_w2. Ensure upstream load_hooks and defaults produce these exact semantics, or normalize inside this op to avoid subtle bugs across weight-load paths.

Would you like me to provide a normalization shim at the start of this op that accepts either representation and converts to (inv_x, alpha = s_in2*s_w2) consistently?

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

30-33: Avoid duplicating FORMAT_ enums; centralize and import.*

These enums also exist in custom_ops.torch_quant. Duplicates risk divergence.

Consider moving FORMAT_FP8/NVFP4 to a single shared module (e.g., tensorrt_llm/_torch/auto_deploy/common/quant_formats.py) and import from there.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)

39-56: Avoid tracking gradients when reassigning quantized weights/bias in tests.

Wrap parameter/buffer reassignments in torch.no_grad() to be explicit and avoid autograd hooks in some envs.

Apply this diff:

     def __init__(self, *args, **kwargs):
         super().__init__(*args, **kwargs)
-        device = self.weight.device
-        weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX
-        self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
-        self.register_buffer(
-            "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float)
-        )
-        self.register_buffer("weight_scale", weight_scale)
-        if self.bias is not None:
-            self.bias = nn.Parameter(self.bias.to(torch.half))
+        device = self.weight.device
+        with torch.no_grad():
+            weight_scale = torch.max(torch.abs(self.weight)).to(torch.float).to(device) / FP8_MAX
+            self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
+            if self.bias is not None:
+                self.bias = nn.Parameter(self.bias.to(torch.half))
+        self.register_buffer(
+            "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float)
+        )
+        self.register_buffer("weight_scale", weight_scale)
tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

80-86: New post-load fusion switches: fuse_gemms, fuse_fp4_gemms, fuse_fp8_gemms, and fuse_quant.

Good separation of generic and quantized GEMM fusions. Ensure docs mention these toggles for users.

Would you like me to draft a short README snippet describing these flags and the expected effects?

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between d14d5f4 and 7d1f402.

📒 Files selected for processing (5)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-07-28T17:06:08.621Z
Learnt from: moraxu
PR: NVIDIA/TensorRT-LLM#6303
File: tests/integration/test_lists/qa/examples_test_list.txt:494-494
Timestamp: 2025-07-28T17:06:08.621Z
Learning: In TensorRT-LLM testing, it's common to have both CLI flow tests (test_cli_flow.py) and PyTorch API tests (test_llm_api_pytorch.py) for the same model. These serve different purposes: CLI flow tests validate the traditional command-line workflow, while PyTorch API tests validate the newer LLM API backend. Both are legitimate and should coexist.

Applied to files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧬 Code graph analysis (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (47-59)
  • custom_op (166-168)
  • custom_op (240-242)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
🔇 Additional comments (4)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (3)

36-36: LGTM: stable FP8 constant for tests.

Using torch.finfo for FP8 max is appropriate and future-proof.


96-99: LGTM: switching target linear layers to FakeFP8Linear for FP8 scenarios.

This isolates the fusion test from HW-dependent kernels and exercises the FP8 fuse path.

Also applies to: 111-114, 135-137, 159-162, 191-193


286-288: LGTM: enabling fuse_fp8_gemms in test optimizer config.

Matches the intent of exercising the post_load_fusion path for FP8.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

48-51: Rename quantize_from_config → quantize_linear_from_config and add quantize_bmm_from_config: config aligns with code.

Configuration looks consistent with the split transforms; no action needed.

@Fridah-nv Fridah-nv force-pushed the user/fridah/inherit-quant2 branch from 7d1f402 to 40ef068 Compare August 26, 2025 23:48
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

1-1: Add NVIDIA 2025 copyright header (per repo policy).

All source files must start with the NVIDIA header for the current year. Please prepend the SPDX header to comply.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0

475-479: Python 3.8 compatibility: replace PEP 585 builtins with typing generics.

The repo targets Python 3.8+. Using list[str] requires 3.9+. Switch to List[str] (already imported).

-def should_skip_quantization(
-    node_or_name: Union[Node, str],
-    excluded_patterns: list[str],
-) -> bool:
+def should_skip_quantization(
+    node_or_name: Union[Node, str],
+    excluded_patterns: List[str],
+) -> bool:
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

110-121: FP8MLP path casts FP8 weights to float16 during pattern detection.

In _run_pattern_detection_job the generic else branch uses .to(device="cuda", dtype=torch.float16) for all non-GQA models. When model_cls == FP8MLP, this will upcast the FP8-quantized weights to fp16 and break the FP8 reference op (it expects float8_e4m3fn). This can cause export-time failures or incorrect pattern detection.

Fix by adding a dedicated FP8MLP branch that does not cast dtype:

@@ def _run_pattern_detection_job(...):
-    else:
-        model = model_cls(num_features, num_features, bias=bias).to(
-            device="cuda", dtype=torch.float16
-        )
+    elif model_cls == FP8MLP:
+        # Keep FP8 quantized params; don't cast dtype.
+        model = model_cls(num_features, num_features, bias=bias).to("cuda")
+    else:
+        model = model_cls(num_features, num_features, bias=bias).to(
+            device="cuda", dtype=torch.float16
+        )
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)

1-5: Add mandatory NVIDIA copyright header (2025).

Per coding guidelines, every Python source must begin with the NVIDIA copyright header for the current year.

Apply this diff at the top of the file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

510-517: Update detector to also match custom_op() (align with create()).

get_scales_and_type_from_node only matches target_op(); with the new custom_op()-based path, scales won’t be found from nodes calling the custom kernels.

-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
     """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
-    for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
-        if is_op(node, qtype.target_op()):
+    for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
+        if is_op(node, [qtype.custom_op(), qtype.target_op()]):
             return extract_scales_from_node(
                 node, qtype.scale_names()
             ), qtype.__name__.lower().replace("quantizationimpl", "")
     return None, "simple"

294-311: Fix copy-paste error in FP4 docstring (mentions FP8 instead of FP4).

The docstring for FP4’s build_custom_args_for_linear references torch_fake_quant_fp8_linear. It should reference the FP4 variant.

-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
@@
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
🧹 Nitpick comments (17)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (6)

30-33: Stabilize and document format constants; mark as Final.

These int flags are fine, but make the intent explicit to type checkers and future readers. Also, the TODO suggests centralizing—good follow-up, but marking them Final is a quick win now.

Apply this diff here:

-# TODO: put the ENUMs in the same place and import it
-FORMAT_FP8 = 0
-FORMAT_NVFP4 = 1
+# TODO: centralize format constants in a shared module and import from there
+from typing import Final  # type: ignore[unused-ignore]  (ok in local scope if not hoisted)
+FORMAT_FP8: Final[int] = 0
+FORMAT_NVFP4: Final[int] = 1

Additionally (outside this hunk), add Final to the typing import on Line 2:

from typing import Dict, List, Optional, Tuple, Union, Final

88-89: Return a clear error for unsupported NVFP4 BMM rather than None.

Returning None here can lead to confusing downstream AttributeErrors. Raise a targeted NotImplementedError to fail fast with context.

-                    "NVFP4": None,  # BMM NVFP4 is not supported yet
+                    "NVFP4": None,  # BMM NVFP4 is not supported yet
                 }
-            return quantization_impl_map[quant_type_or_node]
+            impl = quantization_impl_map[quant_type_or_node]
+            if quant_type_or_node == "NVFP4" and impl is None:
+                raise NotImplementedError("BMM NVFP4 is not supported yet")
+            return impl

101-104: Broaden node detection to support both new custom_op() and legacy target_op().

Only checking custom_op() risks missing older graphs that still use target_op(). Use an OR to be backward compatible without sacrificing the new path.

-        ]:
-            if is_op(quant_type_or_node, q.custom_op()):
-                return q
+        ]:
+            if is_op(quant_type_or_node, [q.custom_op(), q.target_op()]):
+                return q

184-197: Remove commented-out format_type to avoid confusion.

Since FP8/FP4 have distinct custom kernels now, the commented format_type is misleading.

         return dict(
             input_scale=[scale_getattrs["input_scale"]],
             weight_scale=[scale_getattrs["weight_scale"]],
             input_zp=[],
             weight_zp=[],
-            # format_type=FORMAT_FP8,
         )

198-215: Docstring example: clarify bias handling or drop the example.

Doc says “We pass bias=None,” but the example shows positional args without explicitly indicating None. Either annotate which arg is None or omit the example to prevent misreads.


271-293: Remove commented-out format_type here as well (mirror FP8 change).

Keeps the contract focused on the actual inputs the kernels consume now.

         return dict(
             input_scale=[scale_getattrs["input_scale"]],
             weight_scale=[scale_getattrs["weight_scale"], scale_getattrs["alpha"]],
             input_zp=[],
             weight_zp=[],
-            # format_type=FORMAT_NVFP4,
         )
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)

12-13: Ensure custom quant ops are registered for FP8 tests.

FP8MLP relies on FakeFP8Linear which calls torch.ops.auto_deploy.torch_fake_quant_fp8_linear. Importing the custom ops in this test file avoids order-dependent failures when running tests selectively.

Add near the top (close to other imports):

+import tensorrt_llm._torch.auto_deploy.custom_ops  # noqa: F401

81-92: Minor: add a short docstring and keep dtype consistent.

FP8MLP is a test-only module. Adding a one-liner docstring clarifies intent. Also consider keeping the module in half precision where appropriate for inputs, but avoid mutating FP8 params (you already do in _run_job). No code change required beyond docstring.

Example:

 class FP8MLP(nn.Module):
+    """Two-layer MLP using FakeFP8Linear modules to exercise FP8 sharding paths."""
     def __init__(self, in_features, out_features, bias=False):
         super().__init__()

304-315: Param space coverage note (optional).

Pattern-detection is only checked for world_size=[8]. If CI time permits, adding a smaller value (e.g., 2 or 4) can catch shape corner cases without significant overhead.


1-1: Missing NVIDIA copyright header.

Per the coding guidelines, prepend the NVIDIA copyright header (2025) to this file.

Add at the very top:

# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (4)

1-1: Add NVIDIA copyright header.

All source files (.py included) must carry the NVIDIA copyright header.

Insert above line 1:

+ # Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.

16-24: Initialize ModelFactory base class and return quant config from the stub.

DummyFactory inherits ModelFactory but doesn’t call super().__init__, leaving base fields uninitialized. Some optimizer paths read attributes like skip_loading_weights or call get_quant_config(). Make the stub explicit and safe.

Apply this diff:

 class DummyFactory(ModelFactory):
     def __init__(self, quant_config=None):
-        self._quant_config = quant_config or {}
+        super().__init__(model="", skip_loading_weights=True)
+        self._quant_config = dict(quant_config or {})
 
+    def get_quant_config(self) -> dict:
+        return dict(self._quant_config)
+
     def _build_model(self, device: str):
-        return
+        return None
 
     def _load_checkpoint(self, model, device):
-        return
+        return None

88-130: Clarify FP4 scale semantics (‘s_in2’ vs ‘inv’).

The custom op docstring indicates input_scale[0] = s_in2 and weight_scale[1] = alpha = s_in2 * s_w2. Here you pass input_scale_2 = s_in2 and alpha = 1/(s_in2 * s_w2). That’s the reciprocal of the docstring’s alpha, but matches the internal variable names (inv_x, etc.) in the reference implementation.

Consider adding a short comment to disambiguate “inv” vs “s2” conventions, or rename buffers to reflect “inv_” semantics if that’s the intended contract. This reduces future confusion should the op’s docs or code be refactored.

Would you like me to propose a consistent naming/comment patch once you confirm which convention we standardize on?


27-35: Assertion helpers correctly detect rewrite; minor consistency nit.

_has_fused_linear_fp8 checks .default overload for the reference op, while _has_fused_linear_fp4 checks the packet (no .default). Both work due to is_op handling OpOverloadPacket, but using .default in both places improves uniformity.

Optional tweak:

-    found_ref = any(
-        is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear) for n in gm.graph.nodes
-    )
+    found_ref = any(
+        is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear.default) for n in gm.graph.nodes
+    )
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (3)

259-261: Add a short docstring for FakeFP8Linear (public test helper).

This class is imported by multiple tests; a concise docstring clarifies its contract and limits (e.g., test-only, expects FP8 path).

Apply this diff:

 class FakeFP8Linear(nn.Linear):
     def __init__(self, *args, **kwargs):
+        """
+        Test-only Linear that stores FP8-quantized weights and invokes the FP8 fake-quant op.
+        Not intended for training/autograd; biases are cast to the input dtype at runtime.
+        """
         super().__init__(*args, **kwargs)

263-265: Consider keeping float32 weights and storing FP8 weights in a buffer to reduce surprise for code expecting nn.Linear.weight to be float.

Overwriting self.weight with FP8 may confuse utilities or checkpoints that assume a float weight tensor. A low-impact alternative: keep self.weight as-is and store weight_fp8 as a buffer.

Optional refactor:

-        self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn))
+        # Keep original float weight; store FP8 copy for the fake-quant op.
+        self.register_buffer("weight_fp8", (self.weight.detach() / weight_scale).to(torch.float8_e4m3fn))

and in forward:

-        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
-            x, self.weight, bias, [self.input_scale], [self.weight_scale], [], []
-        )
+        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
+            x, self.weight_fp8, bias, [self.input_scale], [self.weight_scale], [], []
+        )

272-275: Use .default for invoking torch_fake_quant_fp8_linear consistently

The call sites for torch.ops.auto_deploy.torch_fake_quant_fp8_linear are currently inconsistent: some invoke the op directly, while others use the .default attribute. Our convention is to always use .default when calling custom ops. Please update the following locations:

  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py:273
    Change

    return torch.ops.auto_deploy.torch_fake_quant_fp8_linear(
        x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], []
    )

    to

    return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
        x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], []
    )
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py:131
    Similarly prepend .default to the op invocation there.

This will ensure all custom-op calls follow the same pattern.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 7d1f402 and 40ef068.

📒 Files selected for processing (8)
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py
  • tensorrt_llm/_torch/auto_deploy/config/default.yaml
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (4)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (11)
  • FakeFP8Linear (259-275)
  • forward (74-94)
  • forward (106-108)
  • forward (130-135)
  • forward (154-159)
  • forward (184-203)
  • forward (220-222)
  • forward (232-234)
  • forward (247-253)
  • forward (272-275)
  • MLP (97-111)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (2)
  • forward (75-85)
  • forward (119-129)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (7)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)
  • run_test_transformed_gm (68-138)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp4_compatible (33-34)
  • fp8_compatible (29-30)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/export/export.py (1)
  • torch_export_to_gm (198-284)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)
  • fp4_global_scale (62-64)
  • fp8_scale (67-69)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (168-200)
  • torch_fake_quant_fp8_linear (204-214)
  • torch_fake_quant_fp4_linear (218-276)
  • torch_fake_quant_fp4_linear (280-289)
🔇 Additional comments (4)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

147-157: Interface extension looks good; default no-ops are appropriate.

The default kwargs/args hooks are clear and safe. Good separation of concerns for per-format overrides.


165-169: Unified entry point via custom_op() is a solid choice.

Binding to the .default overload makes matching in FX robust and consistent across sites.


239-243: Custom op entry point for FP4 mirrors FP8—good consistency.

Keeps transform plumbing uniform across formats.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1)

47-86: TinyFP8Ref looks good; ensure device/dtype handling remains consistent.

Precomputing weight_fp8/scales and using the reference op aligns with the fusion expectations. Buffers will migrate on .to("cuda"). No functional issues spotted.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

501-509: Ensure consistent operator detection across all methods.

This is the location where the detection inconsistency manifests - the function only checks target_op() but should also check custom_op() to align with the create() method's logic.

See the fix suggested in the earlier comment for lines 94-105.

♻️ Duplicate comments (2)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

94-105: Operator detection should be consistent between methods.

The create() method now uses custom_op() for FP4/FP8 detection (lines 98-99) but get_scales_and_type_from_node() (lines 503-504) still only checks target_op(). This inconsistency could cause issues.

Apply this diff to get_scales_and_type_from_node (around lines 503-504):

 def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
     """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc)."""
     for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]:
-        if is_op(node, qtype.target_op()):
+        if is_op(node, qtype.target_op()) or is_op(node, qtype.custom_op()):
             return extract_scales_from_node(
                 node, qtype.scale_names()
             ), qtype.__name__.lower().replace("quantizationimpl", "")
     return None, "simple"

285-302: Fix copy-paste error in docstring.

The docstring incorrectly mentions "torch_fake_quant_fp8_linear" when it should be "torch_fake_quant_fp4_linear".

Apply this diff:

     @staticmethod
     def build_custom_args_for_linear(  # renamed to reflect args
         scale_getattrs: Dict[str, Node],
     ) -> Tuple[object, ...]:
         """
-        Build the *positional* tail for torch_fake_quant_fp8_linear:
+        Build the *positional* tail for torch_fake_quant_fp4_linear:
             (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list))

         We pass bias=None to match the exported pattern:
-        torch_fake_quant_fp8_linear(args_0, args_1, args_2, [args_2_0], [args_3_0], [], [])
+        torch_fake_quant_fp4_linear(args_0, args_1, args_2, [args_2_0], [args_3_0, args_3_1], [], [])
         """
🧹 Nitpick comments (3)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

108-139: Consider extracting common test setup pattern.

The test has similar setup logic to test_fp8_linear, including input generation, weight quantization, and comparison assertions. Consider extracting common test utilities to reduce duplication and improve maintainability.

Additionally, add error handling for when CUDA device is not available:

def _get_test_device():
    if not torch.cuda.is_available():
        pytest.skip("CUDA not available")
    return "cuda"

140-197: Add explicit K alignment validation and improve error message.

The assertion at line 156 correctly enforces the NVFP4 block size requirement, but could provide more helpful debugging information.

Apply this diff to improve the assertion:

-    assert K % SCALING_VECTOR_SIZE == 0
+    assert K % SCALING_VECTOR_SIZE == 0, (
+        f"K dimension ({K}) must be a multiple of SCALING_VECTOR_SIZE ({SCALING_VECTOR_SIZE}) "
+        f"for NVFP4 quantization"
+    )

Also consider adding a comment explaining why alpha is computed as 1/(s_in2*s_w2) for future maintainability:

+    # Alpha combines the input and weight per-tensor scales for the fused kernel
     alpha_fused = (1.0 / (s_in2 * s_w2)).to(torch.float32)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)

264-284: Good documentation, consider extracting common pattern.

The contract documentation clearly explains the expected argument structure. Both FP8 and FP4 implementations share a similar pattern for building kwargs - consider extracting common logic to reduce duplication.

📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between 40ef068 and f4463a5.

📒 Files selected for processing (3)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (2)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (3)
tests/unittest/_torch/auto_deploy/_utils_test/_torch_test_utils.py (3)
  • fp8_compatible (29-30)
  • fp4_compatible (33-34)
  • trtllm_ops_available (37-38)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (160-191)
  • torch_fake_quant_fp8_linear (195-205)
  • torch_fake_quant_fp4_linear (209-267)
  • torch_fake_quant_fp4_linear (271-280)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
  • fp4_global_scale (58-60)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_op (183-206)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (128-130)
  • build_custom_args_for_linear (294-301)
  • build_custom_args_for_linear (340-347)
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (4)
  • torch_fake_quant_fp8_linear (160-191)
  • torch_fake_quant_fp8_linear (195-205)
  • torch_fake_quant_fp4_linear (209-267)
  • torch_fake_quant_fp4_linear (271-280)
🔇 Additional comments (7)
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (2)

11-11: LGTM! Clear constant naming improvement.

The constant name change from scaling_vector_size to SCALING_VECTOR_SIZE follows Python's uppercase convention for global constants.


50-50: LGTM! Consistent use of the renamed constant.

The update correctly reflects the renamed constant SCALING_VECTOR_SIZE.

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5)

84-84: LGTM! Clear documentation for unsupported feature.

The comment clearly indicates that BMM NVFP4 is not yet supported.


143-154: LGTM! Well-structured base class methods.

The new methods provide a clean interface for custom quantization implementations to supply kernel-specific arguments.


161-165: LGTM! Clear custom operation interface.

The custom_op() method properly returns the FP8 fake quantization operator for unified kernel entry.


180-208: Methods provide consistent FP8 argument handling.

The implementation correctly builds both kwargs and args for the FP8 linear operation, maintaining consistency with the expected signature of torch_fake_quant_fp8_linear.


232-236: LGTM! Consistent FP4 custom operation interface.

The custom_op() method properly returns the FP4 fake quantization operator.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

68-84: Scale buffers are never populated from load_hook output

_default_scales registers submodule buffers input_scale/weight_scale, but FP8/FP4 load_hook currently writes to weight_name + "_scale" in the state_dict, leaving module buffers at defaults (1.0/zeros). The fused/fake ops then read wrong scales.

I recommend updating the load hooks in quantization_utils.py to write into module buffer keys (e.g., f"{modprefix}.weight_scale" and, when derived, f"{modprefix}.input_scale"/f"{modprefix}.alpha"), or add a post-load hook here to transfer state_dict entries into those buffers before execution. Happy to provide a patch in quantization_utils.py (see below comment).

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

180-187: FP8 load_hook should populate module buffers, not only aux keys

Currently writes weight_name + "_scale" and leaves submodule ".weight_scale" buffer at default. Populate buffer keys so the op reads correct scales.

     def load_hook(state_dict, prefix, *args, weight_name):
         if weight_name in state_dict:
             weight = state_dict[weight_name]
             if weight.dtype != torch.float8_e4m3fn:
                 scale = fp8_scale(state_dict[weight_name])
-                state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn)
-                state_dict[weight_name + "_scale"] = scale
+                state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn)
+                mod_prefix = weight_name.rsplit(".", 1)[0]
+                # Ensure module buffers get correct values
+                state_dict[f"{mod_prefix}.weight_scale"] = scale
+                # input_scale often defaults to 1.0 here; set if missing
+                state_dict.setdefault(f"{mod_prefix}.input_scale", torch.tensor(1.0))

241-287: FP4 load_hook should also set module buffers

Similar issue: ensure submodule buffers (input_scale/weight_scale/alpha) are written so execution uses the right scales.

     def load_hook(state_dict, prefix, *args, weight_name):
         if weight_name in state_dict:
             input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale"
             alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha"
+            weight_scale_buf_name = weight_name.rsplit(".", 1)[0] + ".weight_scale"
             weight = state_dict[weight_name]
@@
-                state_dict[weight_name + "_scale"] = weight_scale
+                # Also populate module buffer key for execution
+                state_dict[weight_scale_buf_name] = weight_scale
@@
-                if (
+                if (
                     weight_name + "_scale_2" in state_dict
                     and weight_name + "_scale" in state_dict
                     and input_scale_name in state_dict
                     and float4_sf_dtype
                 ):
@@
-                    state_dict[weight_name + "_scale"] = (
+                    converted = (
                         torch.ops.trtllm.block_scale_interleave(
                             weight_scale.view(torch.uint8).cpu().contiguous()
                         )
                         .reshape(ori_shape)
                         .view(float4_sf_dtype)
                         .reshape(-1)
                     )
+                    state_dict[weight_name + "_scale"] = converted
+                    state_dict[weight_scale_buf_name] = converted

403-417: Exclude-pattern helper uses linear-specific extractor for BMM

For BMM nodes, extract_param_names_from_lin_node is invalid. Guard by op type and derive module name from the get_attr weight when available.

 def should_skip_quantization(
     node_or_name: Union[Node, str],
     excluded_patterns: list[str],
 ) -> bool:
@@
-    else:
-        if not (is_linear_op(node_or_name, include_quantization=False) or is_bmm_op(node_or_name)):
-            return True
-        param_name, _ = extract_param_names_from_lin_node(node_or_name)
-        modname, _, _ = param_name.rpartition(".")
+    else:
+        if is_linear_op(node_or_name, include_quantization=False):
+            param_name, _ = extract_param_names_from_lin_node(node_or_name)
+            modname, _, _ = param_name.rpartition(".")
+        elif is_bmm_op(node_or_name):
+            wt = node_or_name.args[1]
+            if getattr(wt, "op", None) == "get_attr":
+                modname, _, _ = wt.target.rpartition(".")
+            else:
+                # dynamic weight; no owning module — don't exclude
+                return False
+        else:
+            return True
♻️ Duplicate comments (5)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

300-318: quantization_cb re-registers existing buffers => runtime error

Calling register_buffer on existing names raises; update in place when present and validate buffer existence. This mirrors prior feedback.

-        for scale_name in self.scale_names():
-            scales[scale_name] = submod.get_buffer(scale_name)
+        for scale_name in self.scale_names():
+            try:
+                buf = submod.get_buffer(scale_name)
+            except Exception:
+                buf = None
+            if buf is None:
+                raise RuntimeError(
+                    f"Expected buffer '{scale_name}' on module '{type(submod).__name__}'"
+                )
+            scales[scale_name] = buf
@@
-        for k, v in sharded_scales.items():
-            submod.register_buffer(k, v)
+        for k, v in sharded_scales.items():
+            if k in submod._buffers:
+                setattr(submod, k, v)
+            else:
+                submod.register_buffer(k, v)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

212-247: Mirror finalize/flag logic (prior feedback)

This is the same suggestion previously raised for the BMM path; applies here as well.

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (3)

44-46: Replace Python 3.10 union syntax with Optional for Py3.8 compatibility.

Keep repo-wide Python 3.8 support.

-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,

(Apply in both function signatures.)

Also applies to: 99-100


81-93: Avoid boolean*uint8 mix, don’t shadow built-ins, keep inputs unmodified.

Use logical AND with a boolean mask; rename ord/round.

-    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device)
+    mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8, device=device)
@@
-    ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8)
+    ordinal = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8)
@@
-    round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1)
-    fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8)
+    round_up = torch.any(
+        (weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) & mask.bool(),
+        dim=-1,
+    )
+    fp4_val = (sign_bit * 0b1000 + ordinal + round_up).to(torch.uint8)

1-1: Add NVIDIA SPDX/Apache-2.0 header (2025) at file top.

Required by project guidelines for all source files.

Apply this diff at the very beginning of the file:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
🧹 Nitpick comments (18)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header

Source files require the header.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

446-449: Broaden resolver to also match fused ops and guard missing symbols

Make TP_SHARDING_RULES robust across fake and fused paths and import-time safe.

-TP_SHARDING_RULES = [
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo),
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp4_linear), FP4TPShardingInfo),
-]
+TP_SHARDING_RULES = [
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_fp8_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_quant_fp8_linear", None),
+            ),
+        ),
+        FP8TPShardingInfo,
+    ),
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_fp4_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_quant_fp4_linear", None),
+            ),
+        ),
+        FP4TPShardingInfo,
+    ),
+]
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)

1-1: Add NVIDIA Apache-2.0 header

Required for non-test source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

136-142: Use Python 3.8-compatible typing

Replace X | None with Optional[X] and import Optional.

-from typing import Tuple
+from typing import Optional, Tuple
@@
-def _fp4_ref_repl_2(
+def _fp4_ref_repl_2(
     x: torch.Tensor,
     w_fp4: torch.Tensor,
-    bias: torch.Tensor | None,
+    bias: Optional[torch.Tensor],

259-269: Finalize graph after rewrites

Eliminate dead code and recompile for stability and downstream passes.

-        num_matches = patterns.apply(gm.graph)
+        num_matches = patterns.apply(gm.graph)
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (2)

1-1: Add NVIDIA Apache-2.0 header

Required for source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

174-209: Return skipped=True when no matches; recompile on changes

Align semantics across transforms.

@@ class LinearQuantizationFromConfig(BaseTransform):
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )
@@ class BMMQuantizationFromConfig(BaseTransform):
-        info = TransformInfo(
-            skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
-        )
-        return gm, info
+        if num_matches:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(num_matches == 0),
+            num_matches=num_matches,
+            is_clean=False,
+            has_valid_shapes=True,
+        )

Also applies to: 244-247

tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

1-1: Add NVIDIA Apache-2.0 header

Required for source files.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

438-446: Return type annotation doesn’t reflect None possibility

Function can return (None, "simple"); update typing or avoid None returns.

-def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]:
+from typing import Optional
+def get_scales_and_type_from_node(node: Node) -> Tuple[Optional[Dict[str, Node]], str]:
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (6)

12-13: Set explicit float dtype for FP4 tables.

Avoid implicit int64 for e2m1_values; ensure math stays in FP.

-e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
-e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6])
+e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5], dtype=torch.float32)
+e2m1_values = torch.tensor([0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0,
+                            0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0], dtype=torch.float32)

31-38: Remove unused parameter from _dequant_weight_fp8 and update call site.

out_features is not used. Simplify signature and callers.

-def _dequant_weight_fp8(
-    weight_fp8: torch.Tensor,
-    weight_scale: torch.Tensor,
-    out_features: int,
-    dtype: torch.dtype,
-) -> torch.Tensor:
-    return weight_fp8.to(dtype) * weight_scale
+def _dequant_weight_fp8(
+    weight_fp8: torch.Tensor,
+    weight_scale: torch.Tensor,
+    dtype: torch.dtype,
+) -> torch.Tensor:
+    return weight_fp8.to(dtype) * weight_scale

And adjust usage below:

-    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
+    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)

133-157: Add basic shape assertions to dequant path.

Fail-fast on malformed inputs to ease debugging.

 def _dequantize_nvfp4(
@@
 ) -> torch.Tensor:
     device = quantized_t.device
     N, K = orig_shape
+    assert K % 16 == 0, "NVFP4 dequant requires K to be a multiple of 16."
+    assert quantized_t.shape[-2] == N and quantized_t.shape[-1] == K // 2, \
+        f"quantized_t has shape {tuple(quantized_t.shape[-2:])}, expected {(N, K//2)}"

159-206: FP8 fake op: minor robustness, ensure scales are on same device.

Move scales to input.device to avoid device mismatch if provided as CPU tensors.

-    s_in = _expect_single_scale(input_scale, "input_scale")
-    s_w = _expect_single_scale(weight_scale, "weight_scale")
+    s_in = _expect_single_scale(input_scale, "input_scale").to(input.device)
+    s_w = _expect_single_scale(weight_scale, "weight_scale").to(input.device)

208-281: FP4 eager path: validate scale list lengths and devices.

Guard early and pin scales to device.

-    if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None:
+    if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None:
         raise ValueError(
             "NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)."
         )
-    cutlass_qscale = weight_scale[0]
-    alpha = weight_scale[1]
+    cutlass_qscale = weight_scale[0].to(input.device)
+    alpha = weight_scale[1].to(input.device)

11-14: Optional: treat FP tables/constants as UPPER_SNAKE_CASE.

Matches repo style for constants.

-# FP4 tables (E2M1)
-e2m1_bounds = ...
-e2m1_values = ...
+# FP4 tables (E2M1)
+E2M1_BOUNDS = ...
+E2M1_VALUES = ...

(Propagate symbol rename locally.)

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)

142-151: Defensive buffer loading to avoid hard failures.

If any expected buffer is missing, skip fusion for this group instead of raising.

-        for weight_key in keys_unfused:
-            key = weight_key.rsplit(".", 1)[0]
-            for scale_name in flat_scale_names:
-                buffer_name = key + "." + scale_name
-                scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name))
+        try:
+            for weight_key in keys_unfused:
+                key = weight_key.rsplit(".", 1)[0]
+                for scale_name in flat_scale_names:
+                    buffer_name = key + "." + scale_name
+                    scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name))
+        except (AttributeError, KeyError) as e:
+            ad_logger.warning(f"Missing quant buffers for {keys_unfused}, skipping fusion: {e}")
+            return

166-173: Prefer graph.get_attr over create_node('get_attr') for buffers.

Keeps FX graph construction consistent with parameter fetch above.

-            scale_getattrs: Dict[str, Node] = {
-                name: gm.graph.create_node("get_attr", f"{key_fused}_{name}")
-                for name in flat_scale_names
-            }
+            scale_getattrs: Dict[str, Node] = {
+                name: gm.graph.get_attr(f"{key_fused}_{name}")
+                for name in flat_scale_names
+            }

270-291: FP8 fuse_rule: scale aggregation via max is OK; ensure device alignment.

Small tweak to avoid device mismatch when stacking scales.

-        new_weight_scale = torch.max(torch.stack(weight_scale))
+        ws = [s.to(weights[0].device) for s in weight_scale]
+        new_weight_scale = torch.max(torch.stack(ws))
📜 Review details

Configuration used: Path: .coderabbit.yaml

Review profile: CHILL

Plan: Pro

💡 Knowledge Base configuration:

  • MCP integration is disabled by default for public repositories
  • Jira integration is disabled by default for public repositories
  • Linear integration is disabled by default for public repositories

You can enable these sources in your CodeRabbit configuration.

📥 Commits

Reviewing files that changed from the base of the PR and between f4463a5 and 0974df5.

📒 Files selected for processing (11)
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (1 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (7 hunks)
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (2 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (5 hunks)
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (6 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1 hunks)
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (7 hunks)
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py
🚧 Files skipped from review as they are similar to previous changes (3)
  • tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py
  • tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py
  • tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py
🧰 Additional context used
📓 Path-based instructions (2)
**/*.py

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

**/*.py: Code must target Python 3.8+
Indent Python code with 4 spaces; do not use tabs
Preserve module namespaces when importing; import modules/packages and access members via the module (e.g., from package.subpackage import foo; foo.SomeClass())
Python file names should be snake_case
Python class names should be PascalCase
Python functions/methods and local variables should be snake_case; variables beginning with a number should be prefixed with k_ (e.g., k_99th_percentile)
Global variables should be UPPER_SNAKE_CASE prefixed with G_ (e.g., G_MY_GLOBAL); constants should be UPPER_SNAKE_CASE
Avoid shadowing variables from outer scopes; initialize all externally visible members in init
Prefer docstrings for interfaces used outside a file; comments should be reserved for in-function or file-local interfaces
Use Google-style docstrings for classes and functions; attributes and variables may be documented inline with trailing string literals
Avoid reflection when simpler, explicit code suffices (e.g., avoid dict(**locals()) patterns)
In try/except, catch the narrowest exceptions possible
For duck-typing patterns, keep the try body minimal and move logic to else to avoid masking unrelated failures

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
**/*.{c,cc,cpp,cxx,h,hh,hpp,hxx,cu,cuh,py}

📄 CodeRabbit inference engine (CODING_GUIDELINES.md)

Prepend the NVIDIA copyright header (current year) to all source files (.cpp, .h, .cu, .py, etc.)

Files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py
  • tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py
  • tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
  • tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py
  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧠 Learnings (2)
📚 Learning: 2025-08-06T13:58:07.506Z
Learnt from: galagam
PR: NVIDIA/TensorRT-LLM#6487
File: tests/unittest/_torch/auto_deploy/unit/singlegpu/test_ad_trtllm_bench.py:1-12
Timestamp: 2025-08-06T13:58:07.506Z
Learning: In TensorRT-LLM, test files (files under tests/ directories) do not require NVIDIA copyright headers, unlike production source code files. Test files typically start directly with imports, docstrings, or code.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py
  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py
📚 Learning: 2025-08-27T16:22:10.642Z
Learnt from: Fridah-nv
PR: NVIDIA/TensorRT-LLM#7227
File: tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py:94-100
Timestamp: 2025-08-27T16:22:10.642Z
Learning: When there are inconsistent operator detection methods (like custom_op() vs target_op()), removing one method and standardizing on the other is often cleaner than supporting both methods simultaneously.

Applied to files:

  • tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py
🧬 Code graph analysis (6)
tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (4)
tensorrt_llm/_torch/auto_deploy/models/factory.py (1)
  • ModelFactory (15-207)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/pattern_matcher.py (3)
  • ADPatternMatcherPass (96-102)
  • register_ad_pattern (134-217)
  • apply (99-102)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • BaseTransform (139-376)
  • SharedConfig (51-57)
  • TransformInfo (108-133)
  • TransformRegistry (379-407)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (4)
  • is_op (183-206)
  • extract_param_names_from_lin_node (149-170)
  • get_op_overload_packet (173-180)
  • is_linear_op (240-252)
tensorrt_llm/_torch/auto_deploy/shim/interface.py (1)
  • CachedSequenceInterface (12-70)
tensorrt_llm/_torch/auto_deploy/utils/cuda_mem_tracker.py (1)
  • cuda_memory_tracker (10-26)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (4)
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (10)
  • FakeFP8Linear (259-275)
  • forward (74-94)
  • forward (106-108)
  • forward (130-135)
  • forward (154-159)
  • forward (184-203)
  • forward (220-222)
  • forward (232-234)
  • forward (247-253)
  • forward (272-275)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)
  • FP8TPShardingInfo (331-375)
  • SplitDimension (178-182)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)
  • is_linear_op (240-252)
tensorrt_llm/_torch/modules/linear.py (1)
  • split_dim (48-49)
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (7)
  • QuantizationImpl (68-142)
  • cutlass_fp4_scale_to_modelopt_fp4_scale (43-55)
  • modelopt_fp4_scale_to_cutlass_fp4_scale (31-40)
  • scale_names (114-116)
  • scale_names (157-158)
  • scale_names (208-209)
  • scale_names (357-358)
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (4)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (126-128)
  • build_custom_args_for_linear (292-299)
  • build_custom_args_for_linear (338-345)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (10)
  • build_custom_args_for_linear (139-142)
  • build_custom_args_for_linear (165-177)
  • build_custom_args_for_linear (230-238)
  • target_op (104-106)
  • target_op (147-148)
  • target_op (199-200)
  • target_op (347-348)
  • QuantizationImpl (68-142)
  • create (72-101)
  • should_skip_quantization (403-416)
tensorrt_llm/_torch/auto_deploy/transform/interface.py (4)
  • TransformRegistry (379-407)
  • register (385-392)
  • BaseTransform (139-376)
  • get (395-397)
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)
  • is_linear_op (240-252)
  • is_bmm_op (255-262)
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (1)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)
  • build_custom_args_for_linear (126-128)
  • build_custom_args_for_linear (292-299)
  • build_custom_args_for_linear (338-345)
🔇 Additional comments (6)
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (2)

82-93: Add FP8 MLP test module: looks good

Class mirrors MLP correctly using FakeFP8Linear; forward shape/activation path matches expectations.


331-333: Param additions LGTM

Including FP8MLP in sharding tests and focusing pattern-detection on FP8 is fine for this PR scope.

Also applies to: 349-353

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (1)

31-38: dequant_weight_fp8 call sites verified; no updates needed

The only invocation of _dequant_weight_fp8 is in

  • tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py:186
    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)

which exactly matches the updated parameter order and count. No other references were found across the codebase, so no further changes are required.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (3)

86-97: check_same_children: implementation is sound.

Converts users to a stable list before iteration and enforces unanimous type; good helper for safe fusion.


201-205: Bias-less constraint is explicit; good.

The bias==None guard simplifies fusion guarantees; matches non-quant path semantics.


316-337: FP4 fuse_rule: concatenating weight scales assumes identical layout.

Document/verify that per-block vectors are already in fused order; otherwise, derive via recorded split sizes.

Would you confirm that weight_scale vectors are laid out as [out_i blocks] per weight, matching cat(weights, dim=0)? If not, we should reorder before concat.

@Fridah-nv Fridah-nv changed the title [None][autodeploy] Quantization Transforms with Inheritance [#5861][autodeploy] Quantization Transforms with Inheritance Aug 27, 2025
@suyoggupta
Copy link
Collaborator

Thanks @Fridah-nv . Did you get a chance to test this on a sharded quantized model? maybe llama 70B Fp8 and/or llama4-fp8?

@Fridah-nv Fridah-nv requested a review from meenchen August 29, 2025 19:16
@Fridah-nv
Copy link
Collaborator Author

Thanks @Fridah-nv . Did you get a chance to test this on a sharded quantized model? maybe llama 70B Fp8 and/or llama4-fp8?

Yes, I tested TP sharding for FP8 and NVFP4 e2e. I'm planning to test quantized MoE and BMM models for sanity check

@Fridah-nv Fridah-nv self-assigned this Sep 3, 2025
@Fridah-nv Fridah-nv changed the title [#5861][autodeploy] Quantization Transforms with Inheritance [#5861][autodeploy] Refactor: Quantization Transforms with Inheritance Sep 3, 2025
@Fridah-nv Fridah-nv changed the title [#5861][autodeploy] Refactor: Quantization Transforms with Inheritance [https://github.com/NVIDIA/TensorRT-LLM/issues/5255][autodeploy] Refactor: Quantization Transforms with Inheritance Sep 4, 2025
@Fridah-nv Fridah-nv changed the title [https://github.com/NVIDIA/TensorRT-LLM/issues/5255][autodeploy] Refactor: Quantization Transforms with Inheritance [#5861][autodeploy] Refactor: Quantization Transforms with Inheritance Sep 4, 2025
@Fridah-nv Fridah-nv force-pushed the user/fridah/inherit-quant2 branch from 70f882b to f57fa57 Compare September 8, 2025 15:50
Copy link
Member

@lucaslie lucaslie left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just a few minor comments. Feel free to skip them and merge right away or apply them as you see fit

@Fridah-nv
Copy link
Collaborator Author

is any of it used outside the quantization transforms? If not I would put it into transform/library/quantization.py

They are also used by sharding/fusion/MoE source files and their tests. Some of the functions are shared across the files. So I'll just keep as it.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 12

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (20)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)

1-1: Add required NVIDIA Apache-2.0 header.

Repository guidelines require the standard 2025 header at the top of all .py files.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and limitations under the License.
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (1)

1-1: Add required NVIDIA Apache-2.0 header.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
+# an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and limitations under the License.
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (3)

110-114: Fix 3.8-incompatible list[str] annotation.

-    excluded_patterns: list[str],
+    excluded_patterns: List[str],

126-143: Fix 3.8-incompatible list[str] annotation here as well.

-def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Optional[Node]]:
+def extract_scales_from_node(node: Node, scale_names: List[str]) -> Dict[str, Optional[Node]]:

58-61: Guard FP8 scale against amax==0 to prevent NaNs/Inf downstream.

Mirror the test fix: clamp with tiny eps.

-def fp8_scale(input: torch.Tensor) -> torch.Tensor:
-    """Computes the FP8 per-tensor scale of the input."""
-    return torch.max(torch.abs(input).to(torch.float)) / FP8_MAX
+def fp8_scale(input: torch.Tensor) -> torch.Tensor:
+    """Computes the FP8 per-tensor scale of the input (guarded)."""
+    amax = torch.max(torch.abs(input).to(torch.float))
+    eps = torch.finfo(torch.float32).tiny
+    return torch.clamp(amax / FP8_MAX, min=eps)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (3)

63-76: Fix typo in helper name and use Python 3.8-compatible typing.

  • Rename “ancessor” → “ancestor” for clarity.
  • Replace PEP 585 generics (list[]) with typing.List[...] to meet Python 3.8 target.
-def _find_lowest_common_ancessor(nodes: list[Node]) -> Optional[Node]:
+def _find_lowest_common_ancestor(nodes: List[Node]) -> Optional[Node]:
@@
-    common = nodes[0]
+    common = nodes[0]
@@
-            return None
+            return None
@@
-            return None
+            return None
@@
-    return common
+    return common
-            normalized_routing_weights = _find_lowest_common_ancessor(arg1_list)
+            normalized_routing_weights = _find_lowest_common_ancestor(arg1_list)
@@
-            common_ancessor2 = _find_lowest_common_ancessor(arg2_list)
-            if not common_ancessor2:
+            common_ancestor2 = _find_lowest_common_ancestor(arg2_list)
+            if not common_ancestor2:
                 continue
             selected_experts = bfs(
-                common_ancessor2,
+                common_ancestor2,
                 lambda node: is_op(node, torch.ops.aten.one_hot),
                 attr_next="all_input_nodes",
                 boundary=start_boundary,
             ).args[0]
-            hidden_states = _find_lowest_common_ancessor(pattern_input_nodes)
+            hidden_states = _find_lowest_common_ancestor(pattern_input_nodes)

Also applies to: 118-126, 420-433, 437-439


283-291: Select the correct mul user deterministically.

Iteration over users isn’t ordered; explicitly filter for aten.mul.

-        mul_node = next(iter(node.users))
-        if not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
+        mul_node = next((u for u in node.users if is_op(u, torch.ops.aten.mul)), None)
+        if mul_node is None or not (hasattr(mul_node, "args") and len(mul_node.args) >= 2):
             return None

39-42: Mark fused MoE parameters as non-trainable.

Fused weights are constants; avoid tracking grads.

-        param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts)
-        param_w2 = torch.nn.Parameter(fused_w2_experts)
+        param_w3_w1 = torch.nn.Parameter(fused_w3_w1_experts, requires_grad=False)
+        param_w2 = torch.nn.Parameter(fused_w2_experts, requires_grad=False)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)

65-66: Remove extra dtype argument from Graph.get_attr calls

Graph.get_attr only accepts the attribute name; update both occurrences in fusion.py (lines 65 and 167):

-        get_param_node = gm.graph.get_attr(key_fused, torch.Tensor)
+        get_param_node = gm.graph.get_attr(key_fused)
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025) at file top.
Required by repo guidelines for all .py sources.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py (1)

1-3: Add NVIDIA Apache-2.0 header (2025) at file top.
Required header is missing.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025) at file top.
Header missing.

Apply:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# You may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025).

All source files must start with the NVIDIA Apache-2.0 copyright header for the current year.

Apply at top of file:

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025).

Required on all Python sources.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.

272-301: Add explicit shape validation for weight_scale and alpha in torch_quant_nvfp4_linear. Schema already includes an alpha kwarg; add checks (e.g. ensure weight_scale and alpha are 1-D tensors of the expected length) before calling torch.ops.auto_deploy.torch_quant_nvfp4_linear to prevent silent mis-scaling.

tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025).

Tests also require the header.

+# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#     http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

14-20: Parametrized bias is overwritten inside the test.

The local reassignment nullifies @pytest.mark.parametrize. Keep the parametrized value.

Apply:

 @pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None])
 @pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support")
 def test_fp8_linear(bias):
     input = torch.rand(3, 16).to("cuda")
     weight = torch.rand(32, 16).to("cuda")
-    bias = torch.rand(32).to("cuda") * 10
tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

1-1: Add NVIDIA Apache-2.0 copyright header (2025).

Please prepend the standard header above the module docstring.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+#
+ # (rest of standard header as used in the repo)
tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (1)

1-1: Add NVIDIA Apache-2.0 header.

Required for all .py sources.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from functools import partial
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

1-1: Add NVIDIA Apache-2.0 header (2025) at file top.

Required by project guidelines for all source files.

Apply this diff:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
♻️ Duplicate comments (12)
tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (1)

1-1: Add NVIDIA SPDX/Apache-2.0 header (2025) at file top.

Consistent with project policy. This was flagged previously.

+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ # SPDX-License-Identifier: Apache-2.0
tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)

14-14: Replace wildcard export with explicit re-export to avoid symbol shadowing (Ruff F403)
The star import broadens the public API and risks silent overrides. Prefer a curated re-export via all.

Apply:

-from .torch_quant import *
+from . import torch_quant as _torch_quant
+# Re-export only what torch_quant marks public
+__all__ = [*globals().get("__all__", []), *_torch_quant.__all__]

If torch_quant lacks all, please add one listing only the public ops.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

243-252: Avoid dtype casting FP8 weights during pattern detection (still pending here).
FP8MLP falls into the generic branch, upcasting weights to fp16; export/pattern detection can fail.

Apply:

     if model_cls == GQA_Block:
         model = model_cls(
             num_attention_heads=num_heads,
             hidden_size=num_features,
             num_key_value_heads=num_key_value_heads,
         ).to(device="cuda", dtype=torch.float16)
+    elif model_cls == FP8MLP:
+        model = model_cls(num_features, num_features, bias=bias).to("cuda")
     else:
         model = model_cls(num_features, num_features, bias=bias).to(
             device="cuda", dtype=torch.float16
         )
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (2)

134-146: run_test_transformed_gm call is properly ordered.

Past misalignment is resolved; no extra positional args after skip_output_assert.


168-180: run_test_transformed_gm call is properly ordered.

Same as FP8 path; args are correctly positioned.

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (2)

144-149: Use Python 3.8-compatible typing (Optional instead of |).

Replace union operator with Optional and import it.

-from typing import Tuple, Type
+from typing import Optional, Tuple, Type
...
-def _fp4_ref_repl_2(
+def _fp4_ref_repl_2(
     x: torch.Tensor,
     w_fp4: torch.Tensor,
-    bias: torch.Tensor | None,
+    bias: Optional[torch.Tensor],
     input_scale: torch.Tensor,
     weight_scale: torch.Tensor,
     alpha: torch.Tensor,
 ):

1-1: Add NVIDIA Apache-2.0 header.

Required for all .py sources.

+# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
 from typing import Tuple, Type
tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (2)

315-323: Quantization buffers: avoid re-registering, and fail clearly if missing.

Re-registering an existing buffer raises; also emit a clear error when a required buffer is absent. This mirrors prior feedback.

Apply this diff:

-        scales = {}
-        for scale_name in self.scale_names():
-            scales[scale_name] = submod.get_buffer(scale_name)
+        scales = {}
+        for scale_name in self.scale_names():
+            try:
+                buf = submod.get_buffer(scale_name)
+            except Exception:
+                buf = None
+            if buf is None:
+                raise RuntimeError(
+                    f"Expected buffer '{scale_name}' on module '{type(submod).__name__}'"
+                )
+            scales[scale_name] = buf
@@
-        for k, v in sharded_scales.items():
-            submod.register_buffer(k, v)
+        for k, v in sharded_scales.items():
+            if k in submod._buffers:  # update existing
+                setattr(submod, k, v)
+            else:
+                submod.register_buffer(k, v)

450-463: from_node resolver should also recognize fused quant linears (not only fake-quant).

Ensures TP sharding works post-fusion too; use getattr to remain import-safe. This updates an earlier ask to include quant + fake-quant variants.

Apply this diff:

-TP_SHARDING_RULES = [
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo),
-    (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear), FP4TPShardingInfo),
-]
+TP_SHARDING_RULES = [
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_quant_fp8_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_fp8_linear", None),
+            ),
+        ),
+        FP8TPShardingInfo,
+    ),
+    (
+        lambda n: is_op(
+            n,
+            (
+                getattr(torch.ops.auto_deploy, "torch_quant_nvfp4_linear", None),
+                getattr(torch.ops.auto_deploy, "torch_fake_quant_nvfp4_linear", None),
+            ),
+        ),
+        FP4TPShardingInfo,
+    ),
+]

Run to confirm op symbol names present in this repo:

#!/bin/bash
rg -nP -C1 '(torch_quant_(nv)?fp4_linear|torch_quant_fp8_linear|torch_fake_quant_nvfp4_linear|torch_fake_quant_fp8_linear)'
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

1-1: Add NVIDIA Apache-2.0 header (2025) at file top.

Still missing per prior comment.

Apply this diff:

+# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+

39-47: Use Optional[...] for Python 3.8 compatibility (avoid T | None).

Same issue flagged earlier; update both signatures.

Apply this diff:

-def _nvfp4_get_weights_scaling_factor(
+def _nvfp4_get_weights_scaling_factor(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
     keep_high_precision: bool = False,
 ):
@@
-def _quantize_nvfp4(
+def _quantize_nvfp4(
     input: torch.Tensor,
     block_size: int,
-    weights_scaling_factor_2: torch.Tensor | None = None,
+    weights_scaling_factor_2: Optional[torch.Tensor] = None,
 ):

Also applies to: 94-101

tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (1)

449-473: BMM transform: also recompile after edits.

Mirror linear path finalize behavior.

Apply this diff:

-        return gm, TransformInfo(
-            skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True
-        )
+        if cnt:
+            gm.graph.eliminate_dead_code()
+            gm.recompile()
+        return gm, TransformInfo(
+            skipped=(cnt == 0), num_matches=cnt, is_clean=False, has_valid_shapes=True
+        )
🧹 Nitpick comments (40)
tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py (1)

267-270: Replace debug prints with actionable assertion diff.

Avoid noisy prints in tests; include the diffs in the assertion message instead.

-    print("detected_set", detected_set)
-    print("expected_set", expected_set)
-
-    assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern"
+    missing = expected_set - detected_set
+    unexpected = detected_set - expected_set
+    assert detected_set == expected_set, (
+        f"Sharding pattern mismatch. Missing: {missing}; Unexpected: {unexpected}; "
+        f"Expected: {expected_set}; Detected: {detected_set}"
+    )
tensorrt_llm/_torch/auto_deploy/custom_ops/README.md (1)

28-28: Name reads “FP4” but entry is NVFP4.

Tighten wording to avoid confusion.

-| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
+| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | NVFP4 quantized linear layer |
tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py (2)

5-8: Prefer import torch.nn as nn for namespace consistency.

-from torch import nn
+import torch.nn as nn

272-275: Select the .default overload explicitly; keep bias high-precision.

Align with other callsites and avoid overload ambiguity.

-        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear(
+        return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default(
             x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], []
         )
tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py (2)

84-86: Typo: “quatnizer” → “quantizer”.

-    """Remove output quatnizer if any from the graph."""
+    """Remove output quantizer if any from the graph."""

87-92: Handle multiple users when removing output quantizers.

Current logic skips when a linear node has >1 user. Consider removing only quantizer users.

-    for n in gm.graph.nodes:
-        if is_linear_op(n) and len(n.users) == 1:
-            user = list(n.users.keys())[0]
-            if is_quantized_op(user):
-                # skip the output quantizer
-                user.replace_all_uses_with(n)
+    for n in gm.graph.nodes:
+        if is_linear_op(n):
+            for user in list(n.users.keys()):
+                if is_quantized_op(user):
+                    user.replace_all_uses_with(n)
tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py (3)

171-189: Update return-value docstring to match implementation.

Docstring still mentions weight_type; the function returns 4 items.

-    Returns:
-        A tuple:
-          (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type)
-
-          - pattern_input_nodes: List of input nodes (x) used for the expert compute.
-          - pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
-          - expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight tensors.
-          - expert_scales: Dict with keys "w1_input_scale", "w1_weight_scale", etc., containing scale tensors
-                           (empty if weight_type is "simple").
-          - weight_type: "fp8" if FP8 ops were used, "simple" otherwise.
+    Returns:
+        A tuple:
+          (pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales)
+
+          - pattern_input_nodes: List of input nodes (x) used for the expert compute.
+          - pattern_output_nodes: List of final expert output nodes (the linear op with weight w2).
+          - expert_weights: Dict with keys "w1", "w2", "w3" mapping to lists of weight nodes.
+          - expert_scales: Dict with keys like "w1_input_scale", "w2_input_scale", ... (empty for unquantized).

260-264: Use Python 3.8-compatible typing in annotations.

Replace list[...] and tuple[...] with List[...] and Tuple[...].

-def _find_final_hidden_state_node(
-    pattern_output_nodes: list[Node], end_boundary: Node
+def _find_final_hidden_state_node(
+    pattern_output_nodes: List[Node], end_boundary: Node
 ) -> Optional[Node]:
-def _extract_index_branches_from_expert_outputs(
-    pattern_output_nodes: list[Node],
-) -> tuple[list[Node], list[Node]]:
+def _extract_index_branches_from_expert_outputs(
+    pattern_output_nodes: List[Node],
+) -> Tuple[List[Node], List[Node]]:

Also applies to: 311-314


364-381: Consider making MatchMoePattern abstract.

Subclassing ABC clarifies required overrides (target_op, moe_op, scale_arg_indices, scale_keys).

If desired, inherit from ABC and annotate abstract methods with @AbstractMethod.

tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py (2)

171-176: Prefer graph.get_attr over create_node('get_attr', ...).

Keeps style consistent and avoids low-level node construction.

-            scale_getattrs: Dict[str, Node] = {
-                name: gm.graph.create_node("get_attr", f"{key_fused}_{name}")
-                for name in flat_scale_names
-            }
+            scale_getattrs: Dict[str, Node] = {
+                name: gm.graph.get_attr(f"{key_fused}_{name}") for name in flat_scale_names
+            }

6-7: Annotate class attributes with ClassVar and import it.

Silences linters and clarifies intent for mixin consumers.

-from typing import Callable, Dict, List, Tuple
+from typing import Callable, ClassVar, Dict, List, Tuple
 class FuseFP8Gemms(QuantizationFusionMixin, BaseTransform):
-    target_op = torch.ops.auto_deploy.torch_fake_quant_fp8_linear
-    scale_groups = [["input_scale"], ["weight_scale"]]
+    target_op: ClassVar[Callable] = torch.ops.auto_deploy.torch_fake_quant_fp8_linear
+    scale_groups: ClassVar[List[List[str]]] = [["input_scale"], ["weight_scale"]]
 class FuseFP4Gemms(QuantizationFusionMixin, BaseTransform):
-    target_op = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear
-    scale_groups = [["input_scale"], ["weight_scale", "alpha"]]
+    target_op: ClassVar[Callable] = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear
+    scale_groups: ClassVar[List[List[str]]] = [["input_scale"], ["weight_scale", "alpha"]]

Also applies to: 265-269, 313-316

tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py (1)

1-1: Add NVIDIA Apache-2.0 header
Per repo guidelines, prepend the 2025 NVIDIA Apache-2.0 header at the top of this file.

tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py (1)

90-113: Prefer factory: simplify with EPShardingInfo.from_node
Leverage the dispatcher to future-proof the test and reduce branching.

-    if world_size > 1:
-        for node in gm.graph.nodes:
-            if is_op(node, torch.ops.auto_deploy.torch_moe):
-                expected_transformations.append(
-                    EPShardingInfo(
-                        target_node=node.name,
-                        rank=rank,
-                        world_size=world_size,
-                    )
-                )
-            elif is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe):
-                expected_transformations.append(
-                    FP8EPShardingInfo(
-                        target_node=node.name,
-                        rank=rank,
-                        world_size=world_size,
-                    )
-                )
-            elif is_op(node, torch.ops.auto_deploy.torch_quant_nvfp4_moe):
-                expected_transformations.append(
-                    NVFP4EPShardingInfo(
-                        target_node=node.name,
-                        rank=rank,
-                        world_size=world_size,
-                    )
-                )
+    if world_size > 1:
+        for node in gm.graph.nodes:
+            info = EPShardingInfo.from_node(node, rank=rank, world_size=world_size)
+            if info is not None:
+                expected_transformations.append(info)
tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py (3)

160-169: Op rename LGTM; fix docstring params (alpha vs weight_scale_2)
Docstring still mentions weight_scale_2; update to alpha and clarify shapes.

-    Args:
-        input: unquantized input tensor
-        weight_fp4: pre-quantized weight tensor, with dtype torch.uint8 (1 uint8 == 2 elements)
-        input_scale: a scalar tensor defined as per_tensor_amax / (FP8 max value (448.0) * FP4 max value (6.0)).
-        weight_scale: a 1D tensor with shape (out_dim * in_dim / 16) padded to be multiple of (128 * 4).
-            with value: per_block_amax / per_tensor_amax * FP8 max value (448.0)
-        weight_scale_2: a scalar tensor defined as per_tensor_amax / (FP8 max value (448.0) * FP4 max value (6.0)).
+    Args:
+        input: Unquantized input tensor.
+        weight_fp4: Pre-quantized weight tensor (torch.uint8; 1 byte packs 2 FP4 values).
+        input_scale: Scalar tensor, per-tensor scale (amax / (FP8_MAX * FP4_MAX)).
+        weight_scale: 1D tensor of per-block scales with length (out_dim * in_dim / 16), padded to a multiple of (128 * 4).
+                      Values typically encode per_block_amax / per_tensor_amax * FP8_MAX.
+        alpha: Scalar tensor equal to 1 / (input_scale * weight_scale_tensor_scale).

215-224: Rename fake impl for consistency
Function name still says fp4_linear_fake; rename to nvfp4_linear_fake for clarity.

-@nvfp4_linear.register_fake
-def fp4_linear_fake(
+@nvfp4_linear.register_fake
+def nvfp4_linear_fake(
     input: torch.Tensor,
     weight_fp4: torch.Tensor,
     bias: Optional[torch.Tensor] = None,
     input_scale: Optional[torch.Tensor] = None,
     weight_scale: Optional[torch.Tensor] = None,
     alpha: Optional[torch.Tensor] = None,
 ) -> torch.Tensor:

1-1: Add NVIDIA Apache-2.0 header
Please prepend the 2025 NVIDIA Apache-2.0 header at the top of this source file.

tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py (2)

11-11: Import style nit: prefer module namespace.
Use module import per guideline (keeps symbols scoped).

Example:

-from _model_test_utils import FakeFP8Linear
+import _model_test_utils as mtu

and replace usages with mtu.FakeFP8Linear.


275-278: Minor: centralize FP8/linear op predicate.
Consider a helper (e.g., is_fake_quantized_linear_op) or tuple ops to avoid inline duplication.

Example:

-        lambda gm: sum(
-            (is_linear_op(n) or is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear))
-            for n in gm.graph.nodes
-        )
+        lambda gm: sum(
+            is_linear_op(n) or is_op(n, (torch.ops.auto_deploy.torch_fake_quant_fp8_linear,))
+            for n in gm.graph.nodes
+        )
tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py (1)

12-12: Import style nit: prefer module namespace.
Scope FakeFP8Linear via module import.

Example:

-from _model_test_utils import FakeFP8Linear
+import _model_test_utils as mtu
tensorrt_llm/_torch/auto_deploy/utils/node_utils.py (2)

263-270: Docstring + naming nit for new predicate.

Add a short docstring and pluralize the local set name for clarity.

-def is_fake_quantized_linear_op(node: Node) -> bool:
-    quantized_linear_op = {
+def is_fake_quantized_linear_op(node: Node) -> bool:
+    """Check if node is a fake-quantized linear op (FP8 or NVFP4)."""
+    quantized_linear_ops = {
         torch.ops.auto_deploy.torch_fake_quant_fp8_linear,
         torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear,
     }
-
-    return is_op(node, quantized_linear_op)
+    return is_op(node, quantized_linear_ops)

272-275: Optional: support matmul variants or document scope.

If downstream code expects to catch batched matmul dispatched via aten.matmul, consider adding it here or explicitly document that only aten.bmm is recognized.

-def is_bmm_op(node: Node) -> bool:
-    bmm_ops = {torch.ops.aten.bmm}
-    return is_op(node, bmm_ops)
+def is_bmm_op(node: Node) -> bool:
+    """Check if the node is a batched matmul."""
+    return is_op(node, {torch.ops.aten.bmm})
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py (2)

238-256: NVFP4 naming consistency in docstring.

Docstring still says “FP4 MoE”; prefer “NVFP4 MoE” to match public API.

-    """
-    FP4 MoE op using quantized linear operations.
+    """
+    NVFP4 MoE op using quantized linear operations.

272-301: Zero-batch guard: redundant with _template_moe.

_template_moe() already avoids calling mlps for empty token sets. You can drop the early-return branch to reduce branching; keep if you’ve seen kernels invoked elsewhere.

-        def mlp(inp):
-            if inp.shape[0] == 0:
-                return torch.zeros_like(inp)
+        def mlp(inp):
tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py (1)

116-116: Prefer helper over hard-coded FP8 constant.

Use fp8_scale() for clarity/consistency.

-    weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda")
+    from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale
+    weight_scale = fp8_scale(weight).to("cuda")
tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py (3)

8-8: noqa is unnecessary; keep import for side effects.

Ruff warns RUF100; drop the noqa and explain the side effect.

-import tensorrt_llm._torch.auto_deploy.custom_ops  # noqa: F401
+import tensorrt_llm._torch.auto_deploy.custom_ops  # required to register custom ops

41-61: TinyFP8Ref setup looks correct.

Optional: consider storing bias as fp32 to better reflect FP8 practice; cast at add-site if needed.


80-104: TinyFP4Ref setup looks correct.

Optional: compute s_in2 from the real input in forward for tighter parity; current random-based scale is acceptable for this test.

tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py (1)

544-552: split_dim passed as int; prefer explicit enum for type safety.

TPShardingInfo.split_dim is an enum; ints rely on coercion. Be explicit.

-                    TPShardingInfo.from_node(
-                        n,
-                        split_dim=i,
+                    TPShardingInfo.from_node(
+                        n,
+                        split_dim=SplitDimension(i),
                         rank=rank,
                         world_size=world_size,
                         dist_op=dist_op,
                         min_local_shape=min_local_shape,
                     )

Please confirm pydantic doesn’t choke on raw ints if coercion is disabled in future.

tensorrt_llm/_torch/auto_deploy/config/default.yaml (1)

83-86: Naming consistency: FP4 vs NVFP4.

You renamed ops to NVFP4 elsewhere; consider aligning this key from fuse_fp4_gemms → fuse_nvfp4_gemms for consistency (or document why GEMM keeps FP4 while linear uses NVFP4).

tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py (3)

271-291: Silence unused-args warnings in _apply.

cm, factory, shared_config aren’t used; explicitly discard to appease linters.

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        # Unused in this transform
+        del cm, factory, shared_config
         if self.config.backend.lower() != "torch":
             raise ValueError(f"Unsupported FP8 backend: {self.config.backend}")

320-333: Silence unused-args warnings in _apply (NVFP4).

Same as above.

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        # Unused in this transform
+        del cm, factory, shared_config
         if self.config.backend.lower() != "trtllm":
             raise ValueError(f"Unsupported NVFP4 backend: {self.config.backend}")

252-259: Optional: constrain backend via typing Literal to catch config typos.

This gives earlier validation errors if someone sets backend: torhc.

-from pydantic import Field
+from pydantic import Field
+from typing import Literal
...
-    backend: str = Field(
+    backend: Literal["torch"] = Field(
         default="torch",
         description="Backend to use for FP8 linear computation (default: 'torch').",
     )
...
-    backend: str = Field(
+    backend: Literal["trtllm"] = Field(
         default="trtllm",
         description="Backend to use for NVFP4 linear computation (default: 'trtllm').",
     )

Also applies to: 294-301

tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py (2)

19-22: Remove unused quantized_moe_op_map.

Not referenced; dead code.

-quantized_moe_op_map = {
-    "FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
-    "NVFP4": torch.ops.auto_deploy.torch_quant_nvfp4_moe,
-}

149-161: Silence unused-args warnings in _apply.

cm and shared_config aren’t used; explicitly discard to appease linters.

     def _apply(
         self,
         gm: GraphModule,
         cm: CachedSequenceInterface,
         factory: ModelFactory,
         shared_config: SharedConfig,
     ) -> Tuple[GraphModule, TransformInfo]:
+        # Unused in this transform
+        del cm, shared_config
         # Gate by algo in quant_config

Also applies to: 200-212

tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py (1)

456-463: Narrow blind except and log at debug.

Make failures observable while staying tolerant to missing ops.

Apply this diff:

-def _resolve_tp_cls_from_node(node: Node):
-    for pred, cls in TP_SHARDING_RULES:
-        try:
-            if pred(node):
-                return cls
-        except Exception:
-            pass
+def _resolve_tp_cls_from_node(node: Node):
+    for pred, cls in TP_SHARDING_RULES:
+        try:
+            if pred(node):
+                return cls
+        except (AttributeError, RuntimeError) as e:
+            ad_logger.debug(f"TP resolver predicate failed for {node}: {e}")
     return TPShardingInfo
tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py (2)

29-36: Remove unused parameter from _dequant_weight_fp8.

Tighten API and silence unused-arg warning.

Apply this diff:

-def _dequant_weight_fp8(
+def _dequant_weight_fp8(
     weight_fp8: torch.Tensor,
     weight_scale: torch.Tensor,
-    out_features: int,
     dtype: torch.dtype,
 ) -> torch.Tensor:
-    return weight_fp8.to(dtype) * weight_scale
+    return weight_fp8.to(dtype) * weight_scale
@@
-    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype)
+    weight_deq = _dequant_weight_fp8(weight_quantized, s_w, in_dtype)

Also applies to: 184-185


52-56: Replace asserts with explicit errors in _nvfp4_get_weights_scaling_factor.

More robust in production.

Apply this diff:

-    assert block_size != 0, "Block size is zero. Cannot return per_block amax for given input."
+    if block_size == 0:
+        raise ValueError("Block size is zero. Cannot return per_block amax for given input.")
@@
-    assert k % block_size == 0, (
-        "Weight shape is not divisible for block size for block quantization."
-    )
+    if k % block_size != 0:
+        raise ValueError("Weight shape is not divisible by block size for block quantization.")
tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py (3)

156-159: Prefer next(iter(...)) over list(...)[0].

Avoids intermediate list and handles large user sets better.

Apply this diff:

-            user = list(node.users.keys())[0]
+            user = next(iter(node.users.keys()))

338-387: NVFP4 load hook: avoid hard-coding CUDA device; guard for availability.

Make device selection robust; improves portability of state_dict loading.

Apply this diff:

-                weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
-                    weight.to("cuda"),
-                    weight_scale_2.to("cuda"),
+                # Select device: use existing device if CUDA, else current CUDA if available.
+                if not torch.cuda.is_available():
+                    raise RuntimeError("NVFP4 weight pre-quantization requires CUDA.")
+                dev = weight.device if weight.is_cuda else torch.device("cuda")
+                weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize(
+                    weight.to(dev),
+                    weight_scale_2.to(dev),
                     TRTLLM_NVFP4_SCALING_VECTOR_SIZE,
                     False,
                 )

Please confirm all CI jobs that invoke load_state_dict have CUDA available; otherwise we should add a CPU fallback path (can prototype one using the Python helpers in custom_ops/torch_quant.py).


39-55: Unify hook signatures in Quantization base.

post_load_hook signature in base differs from usage; consider aligning for clarity (module, incompatible_keys, weight_name) or make it abstract only in subclasses that use it.

Apply this diff:

-    @staticmethod
-    def post_load_hook(state_dict, prefix, *args, weight_name: str):
-        """Load hook for state_dict quantization post-processing."""
-        pass
+    # Subclasses may implement: def post_load_hook(self, module, incompatible_keys, weight_name: str): ...

@Fridah-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18112 [ run ] triggered by Bot

@Fridah-nv Fridah-nv enabled auto-merge (squash) September 9, 2025 01:12
@tensorrt-cicd
Copy link
Collaborator

PR_Github #18112 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13573 completed with status: 'FAILURE'

@Fridah-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18163 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18163 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13610 completed with status: 'FAILURE'

@Fridah-nv
Copy link
Collaborator Author

/bot run

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18190 [ run ] triggered by Bot

@tensorrt-cicd
Copy link
Collaborator

PR_Github #18190 [ run ] completed with state SUCCESS
/LLM/main/L0_MergeRequest_PR pipeline #13632 completed with status: 'SUCCESS'
Pipeline passed with automatic retried tests. Check the rerun report for details.

@Fridah-nv Fridah-nv merged commit bbb5ae3 into NVIDIA:main Sep 10, 2025
5 checks passed
@github-project-automation github-project-automation bot moved this from Backlog to Done in AutoDeploy Board Sep 10, 2025
Wong4j pushed a commit to Wong4j/TensorRT-LLM that referenced this pull request Sep 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
Archived in project
Development

Successfully merging this pull request may close these issues.

Add Simulated Quantized Linear Operator for Debugging Quantization-aware Transformation Pipeline via Early Fusion + Node Annotation
5 participants