Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
03d1966
add torch ref impl for FP8, add unit test
Fridah-nv Aug 12, 2025
dcd68df
add torch ref impl for FP4, add op map unit test
Fridah-nv Aug 14, 2025
f38101e
split linear and bmm quantization
Fridah-nv Aug 14, 2025
3932877
update quantize_linear_from_config to point to the custom op
Fridah-nv Aug 15, 2025
a19eeeb
separate custom op into two torch ops
Fridah-nv Aug 19, 2025
7c2c0d1
quantized fusion transforms, WIP for FP4
Fridah-nv Aug 19, 2025
d5cfe13
add QuantizationFusionMixin class
Fridah-nv Aug 21, 2025
437b942
quantized sharding class for FP8 and FP4
Fridah-nv Aug 22, 2025
5c3d7b4
remove QuantizationImpl in sharding, remove unused methods in Quantiz…
Fridah-nv Aug 23, 2025
aaa28d9
remove custom_quant_linear op
Fridah-nv Aug 23, 2025
3733366
rename custom quant ops
Fridah-nv Aug 23, 2025
e477434
WIP to map custom quant op to real implementation using pattern matcher
Fridah-nv Aug 23, 2025
7f8a8f8
fix unit tests
Fridah-nv Aug 26, 2025
10a4994
remove unused ENUM
Fridah-nv Aug 26, 2025
24885ea
minor updates: rabbit feedback, docstrings, code cleaning
Fridah-nv Aug 27, 2025
c8b21f9
clear unit tests on blackwell; address a few comments; rename FP ops …
Fridah-nv Sep 3, 2025
961c33f
remove include_quantization from is_linear_node
Fridah-nv Sep 4, 2025
19500d7
address few comments: remove pattern matcher fake mode patch; remove …
Fridah-nv Sep 5, 2025
d01af1e
update quantization transforms:Linear, BMM, MoE and MoE matching into…
Fridah-nv Sep 6, 2025
7d011b6
remove QuantizationImpl class; remove more reference of is_quantized_op
Fridah-nv Sep 6, 2025
f57fa57
fix test_quantization_utils.py
Fridah-nv Sep 6, 2025
75063f1
minor: address comments, uncomment skipped unit tests
Fridah-nv Sep 8, 2025
d9afed8
skip quant gemm fusion for perf
Fridah-nv Sep 9, 2025
30cef34
Merge branch 'main' into user/fridah/inherit-quant2
Fridah-nv Sep 9, 2025
5178d0f
Merge branch 'main' into user/fridah/inherit-quant2
Fridah-nv Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions tensorrt_llm/_torch/auto_deploy/config/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,19 @@ transforms:
# see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528
optimize_rope:
stage: pattern_matcher
quantize_from_config:
quantize_fp8_linear_from_config:
stage: pattern_matcher
quantize_from_graph:
quantize_nvfp4_linear_from_config:
stage: pattern_matcher
quantize_moe:
quantize_fp8_bmm_from_config:
stage: pattern_matcher
quantize_fp8_from_graph:
stage: pattern_matcher
quantize_nvfp4_from_graph:
stage: pattern_matcher
quantize_fp8_moe:
stage: pattern_matcher
quantize_nvfp4_moe:
stage: pattern_matcher
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
detect_sharding:
Expand All @@ -70,10 +78,21 @@ transforms:
# RUN POST-LOAD FUSION AND OPTIMIZATIONS
############################################################################################
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
# fuse_moe:
# stage: post_load_fusion
# fuse_gemms:
# stage: post_load_fusion
# fuse_fp4_gemms:
# stage: post_load_fusion
# fuse_fp8_gemms:
# stage: post_load_fusion
fuse_fp8_linear:
stage: post_load_fusion
backend: torch
fuse_nvfp4_linear:
stage: post_load_fusion
backend: trtllm
# TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs
fuse_moe:
stage: post_load_fusion
fuse_allreduce_residual_rmsnorm:
stage: post_load_fusion
fuse_collectives:
Expand Down
2 changes: 1 addition & 1 deletion tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
| `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine |
Expand Down
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .torch_attention import *
from .torch_backend_attention import *
from .torch_moe import *
from .torch_quant import *
from .torch_rope import *
from .triton_attention import *
from .triton_rope import *
Expand Down
18 changes: 3 additions & 15 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,9 @@ def forward(self, x):
)


@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=())
@torch.library.custom_op("auto_deploy::torch_quant_nvfp4_linear", mutates_args=())
@torch.compile(dynamic=True)
def fp4_linear(
def nvfp4_linear(
input: torch.Tensor,
weight_fp4: torch.Tensor,
bias: Optional[torch.Tensor] = None,
Expand Down Expand Up @@ -212,7 +212,7 @@ def fp4_linear(
return output.reshape(*input_shape[:-1], n)


@fp4_linear.register_fake
@nvfp4_linear.register_fake
def fp4_linear_fake(
input: torch.Tensor,
weight_fp4: torch.Tensor,
Expand Down Expand Up @@ -299,15 +299,3 @@ def fp8_bmm_fake(
"""Fake implementation of fp8_bmm for testing and tracing."""
# Use standard bmm
return torch.bmm(input.to(torch.float), mat2.to(torch.float)).to(input.dtype)


QUANT_LINEAR_OPS = [
torch.ops.auto_deploy.torch_quant_fp8_linear,
torch.ops.auto_deploy.torch_quant_fp4_linear,
]

QUANT_BMM_OPS = [
torch.ops.auto_deploy.torch_quant_fp8_bmm,
]

QUANT_OPS = QUANT_LINEAR_OPS + QUANT_BMM_OPS
14 changes: 7 additions & 7 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,8 @@ def torch_quant_fp8_moe_fake(
return torch.empty_like(x)


@torch.library.custom_op("auto_deploy::torch_quant_fp4_moe", mutates_args=())
def torch_quant_fp4_moe(
@torch.library.custom_op("auto_deploy::torch_quant_nvfp4_moe", mutates_args=())
def torch_quant_nvfp4_moe(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
Expand Down Expand Up @@ -273,15 +273,15 @@ def make_fp4_mlp(i):
def mlp(inp):
if inp.shape[0] == 0:
return torch.zeros_like(inp)
gate_out = torch.ops.auto_deploy.torch_quant_fp4_linear(
gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w1_weight[i],
bias=None,
input_scale=w1_input_scale[i],
weight_scale=w1_weight_scale[i],
alpha=w1_alpha[i],
)
up_out = torch.ops.auto_deploy.torch_quant_fp4_linear(
up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear(
inp,
w3_weight[i],
bias=None,
Expand All @@ -290,7 +290,7 @@ def mlp(inp):
alpha=w3_alpha[i],
)
prod = F.silu(gate_out) * up_out
return torch.ops.auto_deploy.torch_quant_fp4_linear(
return torch.ops.auto_deploy.torch_quant_nvfp4_linear(
prod,
w2_weight[i],
bias=None,
Expand All @@ -305,8 +305,8 @@ def mlp(inp):
return _template_moe(x, selected_experts, routing_weights, mlps)


@torch_quant_fp4_moe.register_fake
def torch_quant_fp4_moe_fake(
@torch_quant_nvfp4_moe.register_fake
def torch_quant_nvfp4_moe_fake(
x: torch.Tensor,
selected_experts: torch.Tensor,
routing_weights: torch.Tensor,
Expand Down
Loading