diff --git a/tensorrt_llm/_torch/auto_deploy/config/default.yaml b/tensorrt_llm/_torch/auto_deploy/config/default.yaml index 7a85ce8a232..5a0645c2edb 100644 --- a/tensorrt_llm/_torch/auto_deploy/config/default.yaml +++ b/tensorrt_llm/_torch/auto_deploy/config/default.yaml @@ -45,11 +45,19 @@ transforms: # see https://github.com/NVIDIA/TensorRT-LLM/pull/3668#discussion_r2052714528 optimize_rope: stage: pattern_matcher - quantize_from_config: + quantize_fp8_linear_from_config: stage: pattern_matcher - quantize_from_graph: + quantize_nvfp4_linear_from_config: stage: pattern_matcher - quantize_moe: + quantize_fp8_bmm_from_config: + stage: pattern_matcher + quantize_fp8_from_graph: + stage: pattern_matcher + quantize_nvfp4_from_graph: + stage: pattern_matcher + quantize_fp8_moe: + stage: pattern_matcher + quantize_nvfp4_moe: stage: pattern_matcher # TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config. detect_sharding: @@ -70,10 +78,21 @@ transforms: # RUN POST-LOAD FUSION AND OPTIMIZATIONS ############################################################################################ # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs - # fuse_moe: - # stage: post_load_fusion # fuse_gemms: # stage: post_load_fusion + # fuse_fp4_gemms: + # stage: post_load_fusion + # fuse_fp8_gemms: + # stage: post_load_fusion + fuse_fp8_linear: + stage: post_load_fusion + backend: torch + fuse_nvfp4_linear: + stage: post_load_fusion + backend: trtllm + # TODO: https://github.com/NVIDIA/TensorRT-LLM/issues/4674 this is causing OOMs + fuse_moe: + stage: post_load_fusion fuse_allreduce_residual_rmsnorm: stage: post_load_fusion fuse_collectives: diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 6bef175199b..0ed95a83ab1 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -25,7 +25,7 @@ The table below lists the operators ordered by their backend. | `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation | | `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values | | `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation | -| `torch.ops.auto_deploy.torch_quant_fp4_linear` | FP4 quantized linear layer | +| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer | | `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer | | `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies | | `torch.ops.auto_deploy.torch_rope_with_explicit_cos_sin` | RoPE with explicit cosine/sine | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py index 23a80b94d74..4c2b8c0fae8 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py @@ -11,6 +11,7 @@ from .torch_attention import * from .torch_backend_attention import * from .torch_moe import * +from .torch_quant import * from .torch_rope import * from .triton_attention import * from .triton_rope import * diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index d17b816e825..cc4c2b6bd1f 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -157,9 +157,9 @@ def forward(self, x): ) -@torch.library.custom_op("auto_deploy::torch_quant_fp4_linear", mutates_args=()) +@torch.library.custom_op("auto_deploy::torch_quant_nvfp4_linear", mutates_args=()) @torch.compile(dynamic=True) -def fp4_linear( +def nvfp4_linear( input: torch.Tensor, weight_fp4: torch.Tensor, bias: Optional[torch.Tensor] = None, @@ -212,7 +212,7 @@ def fp4_linear( return output.reshape(*input_shape[:-1], n) -@fp4_linear.register_fake +@nvfp4_linear.register_fake def fp4_linear_fake( input: torch.Tensor, weight_fp4: torch.Tensor, @@ -299,15 +299,3 @@ def fp8_bmm_fake( """Fake implementation of fp8_bmm for testing and tracing.""" # Use standard bmm return torch.bmm(input.to(torch.float), mat2.to(torch.float)).to(input.dtype) - - -QUANT_LINEAR_OPS = [ - torch.ops.auto_deploy.torch_quant_fp8_linear, - torch.ops.auto_deploy.torch_quant_fp4_linear, -] - -QUANT_BMM_OPS = [ - torch.ops.auto_deploy.torch_quant_fp8_bmm, -] - -QUANT_OPS = QUANT_LINEAR_OPS + QUANT_BMM_OPS diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py index 5b7131f1296..cbfa59bd2c2 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py @@ -235,8 +235,8 @@ def torch_quant_fp8_moe_fake( return torch.empty_like(x) -@torch.library.custom_op("auto_deploy::torch_quant_fp4_moe", mutates_args=()) -def torch_quant_fp4_moe( +@torch.library.custom_op("auto_deploy::torch_quant_nvfp4_moe", mutates_args=()) +def torch_quant_nvfp4_moe( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, @@ -273,7 +273,7 @@ def make_fp4_mlp(i): def mlp(inp): if inp.shape[0] == 0: return torch.zeros_like(inp) - gate_out = torch.ops.auto_deploy.torch_quant_fp4_linear( + gate_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( inp, w1_weight[i], bias=None, @@ -281,7 +281,7 @@ def mlp(inp): weight_scale=w1_weight_scale[i], alpha=w1_alpha[i], ) - up_out = torch.ops.auto_deploy.torch_quant_fp4_linear( + up_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( inp, w3_weight[i], bias=None, @@ -290,7 +290,7 @@ def mlp(inp): alpha=w3_alpha[i], ) prod = F.silu(gate_out) * up_out - return torch.ops.auto_deploy.torch_quant_fp4_linear( + return torch.ops.auto_deploy.torch_quant_nvfp4_linear( prod, w2_weight[i], bias=None, @@ -305,8 +305,8 @@ def mlp(inp): return _template_moe(x, selected_experts, routing_weights, mlps) -@torch_quant_fp4_moe.register_fake -def torch_quant_fp4_moe_fake( +@torch_quant_nvfp4_moe.register_fake +def torch_quant_nvfp4_moe_fake( x: torch.Tensor, selected_experts: torch.Tensor, routing_weights: torch.Tensor, diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py new file mode 100644 index 00000000000..b3b4a1c372f --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/torch_quant.py @@ -0,0 +1,278 @@ +from typing import List, Optional + +import torch + +from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import ( + cutlass_fp4_scale_to_modelopt_fp4_scale, +) + +# FP4 tables (E2M1) +e2m1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5]) +e2m1_values = torch.tensor([0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6]) + + +# ===== Helpers ===== +def _expect_single_scale(scales: List[Optional[torch.Tensor]], name: str) -> torch.Tensor: + if len(scales) == 0 or scales[0] is None: + raise ValueError(f"{name} must provide at least one scale tensor (scales[0]).") + return scales[0] + + +def _to_fp8_fake(x: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + return (x / scale).to(torch.float8_e4m3fn) + + +def _from_fp8(x_fp8: torch.Tensor, scale: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: + return x_fp8.to(dtype) * scale + + +def _dequant_weight_fp8( + weight_fp8: torch.Tensor, + weight_scale: torch.Tensor, + out_features: int, + dtype: torch.dtype, +) -> torch.Tensor: + return weight_fp8.to(dtype) * weight_scale + + +# The NVFP4 helpers below are adapted from modelopt.torch.quantization.qtensor.nvfp4_tensor.NVFP4QTensor +def _nvfp4_get_weights_scaling_factor( + input: torch.Tensor, + block_size: int, + weights_scaling_factor_2: torch.Tensor | None = None, + keep_high_precision: bool = False, +): + """Returns quantized per block weight scaling factor.""" + if weights_scaling_factor_2 is None: + # per-tensor scale-2 = amax / (6 * 448) + weights_scaling_factor_2 = input.abs().amax().float() / (6.0 * 448.0) + + # Get per_block amax + [n, k] = input.shape[-2:] + assert block_size != 0, "Block size is zero. Cannot return per_block amax for given input." + + assert k % block_size == 0, ( + "Weight shape is not divisible for block size for block quantization." + ) + + input = input.reshape((*tuple(input.shape[:-2]), n, k // block_size, block_size)) + # Get per block amax + per_block_amax = input.abs().amax(dim=-1).float() + # Get per-block-scale + per_block_scale = per_block_amax / 6.0 + # Quantize per_block_scale to FP8 + q_per_block_scale = per_block_scale / weights_scaling_factor_2 + # Set all zero values in scale to 1.0 + q_per_block_scale[per_block_scale == 0] = 1.0 + # Convert to torch.float8_e4m3fn + if not keep_high_precision: + q_per_block_scale = q_per_block_scale.to(torch.float8_e4m3fn) + return q_per_block_scale, weights_scaling_factor_2 + + +def _cast_fp4(weight: torch.Tensor): + """Converts tensor to uint4.""" + # Get device + device = weight.device + + # Define mask to perform rounding + mask = torch.tensor([0, 1, 0, 1, 0, 1, 0], dtype=torch.uint8).to(device) + mask_shape = list(weight.shape) + mask = mask.expand([*mask_shape, 7]) + + sign_bit = (weight < 0).to(torch.uint8) + + weight_abs = weight.abs() # avoid in-place modification to input + # Calculate the ordinal value based on the bounds + ord = torch.searchsorted(e2m1_bounds.to(device), weight_abs, out_int32=True).to(torch.uint8) + # All values equal to e2m1_bounds at odd indices are rounded up and even indices are rounded down + round = torch.any((weight_abs.unsqueeze(-1) == e2m1_bounds.to(device)) * mask, dim=-1) + fp4_val = (sign_bit * 0b1000 + ord + round).to(torch.uint8) + return fp4_val + + +def _quantize_nvfp4( + input: torch.Tensor, + block_size: int, + weights_scaling_factor_2: torch.Tensor | None = None, +): + """Converting a tensor to a quantized format based on NVFP4 quantization. + + Args: + input (torch.Tensor): The input tensor to be quantized. + block_size (int): The size of each block for quantization. + weights_scaling_factor_2 (torch.Tensor): The per-tensor scaling factor for the weights. + Returns: + tuple: Contains quantized data and quantized per block scaling factor + """ + + weights_scaling_factor, weights_scaling_factor_2 = _nvfp4_get_weights_scaling_factor( + input, block_size, weights_scaling_factor_2 + ) + + # Reshape the weight and scale factors + input = input.view((*tuple(input.shape[:-1]), -1, block_size)) + + # Scale weights + scaled_weight = input / ( + (weights_scaling_factor.to(torch.float32) * weights_scaling_factor_2).unsqueeze(-1) + ) + + # Reshape weights to original + scaled_weight = scaled_weight.view((*tuple(scaled_weight.shape[:-2]), -1)) + + # Cast weights to fp4 + q_weight = _cast_fp4(scaled_weight) + # Pack weights + packed_weight = (q_weight[..., 1::2] << 4) | q_weight[..., 0::2] + return packed_weight, weights_scaling_factor + + +def _dequantize_nvfp4( + quantized_t: torch.Tensor, # [N, K/2] uint8 + scale_1: torch.Tensor, # q_per_block_scale (FP8/FP32), flat or shaped + scale_2: torch.Tensor, # per-tensor scale-2 (FP32 scalar) + orig_shape: tuple, # (N, K) + orig_dtype: torch.dtype, +) -> torch.Tensor: + device = quantized_t.device + N, K = orig_shape + # slice/pad handling for the scale vector: take exactly N*K/16 entries + num_blocks = N * (K // 16) + s1 = scale_1.reshape(-1)[:num_blocks] + + high = (quantized_t >> 4) & 0x0F + low = quantized_t & 0x0F + idx = torch.empty(N, (K // 2) * 2, dtype=torch.long, device=device) + idx[..., 0::2] = low.long() + idx[..., 1::2] = high.long() + + vals = e2m1_values.to(device)[idx] # [N, K], float32 + + scale_real = (s1.to(torch.float32) * scale_2.to(torch.float32)).view(N, K // 16, 1) + vals = vals.view(N, K // 16, 16) * scale_real + return vals.view(N, K).to(orig_dtype) + + +@torch.library.custom_op("auto_deploy::torch_fake_quant_fp8_linear", mutates_args=()) +def torch_fake_quant_fp8_linear( + input: torch.Tensor, + weight_quantized: torch.Tensor, + bias: torch.Tensor, + input_scale: List[torch.Tensor], + weight_scale: List[torch.Tensor], + input_zp: List[torch.Tensor], + weight_zp: List[torch.Tensor], +) -> torch.Tensor: + """ + Reference (eager) implementation for multiple quant formats via `format_type`. + For FP8: + - input_scale[0] and weight_scale[0] are required (amax/448 style) + - input_zp / weight_zp ignored + """ + if weight_quantized.dtype != torch.float8_e4m3fn: + raise TypeError("FP8 path requires weight_quantized.dtype == float8_e4m3fn") + s_in = _expect_single_scale(input_scale, "input_scale") + s_w = _expect_single_scale(weight_scale, "weight_scale") + + in_dtype = input.dtype + out_features, in_features = weight_quantized.shape + + input_fp8 = _to_fp8_fake(input, s_in) + input_deq = _from_fp8(input_fp8, s_in, in_dtype) + + weight_deq = _dequant_weight_fp8(weight_quantized, s_w, out_features, in_dtype) + + out = torch.matmul(input_deq.reshape(-1, in_features), weight_deq.t()) + if bias is not None: + out = out + bias + return out.reshape(*input.shape[:-1], out_features) + + +@torch_fake_quant_fp8_linear.register_fake +def torch_fake_quant_fp8_linear( + input: torch.Tensor, + weight_quantized: torch.Tensor, + bias: torch.Tensor, + input_scale: List[torch.Tensor], + weight_scale: List[torch.Tensor], + input_zp: List[torch.Tensor], + weight_zp: List[torch.Tensor], +) -> torch.Tensor: + w = weight_quantized.to(input.dtype) + return torch.ops.aten.linear(input, w, bias) + + +@torch.library.custom_op("auto_deploy::torch_fake_quant_nvfp4_linear", mutates_args=()) +def torch_fake_quant_nvfp4_linear( + input: torch.Tensor, + weight_quantized: torch.Tensor, + bias: torch.Tensor, + input_scale: List[torch.Tensor], + weight_scale: List[torch.Tensor], + input_zp: List[torch.Tensor], + weight_zp: List[torch.Tensor], +) -> torch.Tensor: + """ + Reference (eager) implementation for multiple quant formats via `format_type`. + For FP4: + - input_scale[0] = s_in2 (scalar, amax/(448*6)) + - weight_scale[0] = q_per_block_scale_w (len >= N*K/16; may be padded) + - weight_scale[1] = alpha = s_in2 * s_w2 (combined per-tensor scales) + """ + if weight_quantized.dtype != torch.uint8: + raise TypeError("NVFP4 path requires packed uint8 weights (2x FP4 per byte).") + + inv_x = _expect_single_scale(input_scale, "input_scale") + if len(weight_scale) < 2 or weight_scale[0] is None or weight_scale[1] is None: + raise ValueError( + "NVFP4 needs weight_scale[0] (per-block vector) and weight_scale[1] (alpha)." + ) + cutlass_qscale = weight_scale[0] + alpha = weight_scale[1] + + if cutlass_qscale.dtype != torch.uint8: + raise TypeError("NVFP4 expects CUTLASS per-block scale vector in uint8 (same as fused op).") + + inv_w = 1 / (inv_x * alpha) + s2_x = 1.0 / inv_x + s2_w = 1.0 / inv_w + + # Shapes + in_dtype = input.dtype + input_shape = input.shape + N, K_packed = weight_quantized.shape[-2], weight_quantized.shape[-1] + K = K_packed * 2 + assert K % 16 == 0, "NVFP4 requires K to be a multiple of 16" + num_blocks_w = N * (K // 16) + + q_scale_w_slice = cutlass_fp4_scale_to_modelopt_fp4_scale(cutlass_qscale, (N, K)) + # (1) Dequantize weights with scale_1 = q_scale_w (sliced), scale_2 = s_w2 + q_scale_w_slice = q_scale_w_slice.reshape(-1)[:num_blocks_w] + W_deq = _dequantize_nvfp4(weight_quantized, q_scale_w_slice, s2_w, (N, K), in_dtype) # [N, K] + + # (2) Quantize+dequantize inputs with _quantize_nvfp4/_dequantize_nvfp4 + # Flatten batch for NVFP4 block processing + X_2d = input.reshape(-1, K) + + X_packed, X_q_scale = _quantize_nvfp4(X_2d, block_size=16, weights_scaling_factor_2=s2_x) + X_deq = _dequantize_nvfp4(X_packed, X_q_scale, s2_x, (X_2d.shape[0], K), in_dtype) # [B, K] + + # (3) GEMM + bias (float GEMM with codec error baked in) + out_2d = torch.matmul(X_deq, W_deq.t()) # [B, N] + if bias is not None: + out_2d = out_2d + bias + return out_2d.reshape(*input_shape[:-1], N) + + +@torch_fake_quant_nvfp4_linear.register_fake +def torch_fake_quant_nvfp4_linear( + input: torch.Tensor, + weight_quantized: torch.Tensor, + bias: torch.Tensor, + input_scale: List[torch.Tensor], + weight_scale: List[torch.Tensor], + input_zp: List[torch.Tensor], + weight_zp: List[torch.Tensor], +) -> torch.Tensor: + return torch.ops.aten.linear(input, weight_quantized.repeat(1, 2).to(input.dtype), bias) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py new file mode 100644 index 00000000000..3380442061c --- /dev/null +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fuse_quant.py @@ -0,0 +1,333 @@ +from typing import Tuple, Type + +import torch +from pydantic import Field +from torch.fx import GraphModule + +from ...models.factory import ModelFactory +from ...shim.interface import CachedSequenceInterface +from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern +from ..interface import ( + BaseTransform, + SharedConfig, + TransformConfig, + TransformInfo, + TransformRegistry, +) + + +# with bias=None +def _fp8_ref_pattern_1( + x: torch.Tensor, + w_fp8: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, +): + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( + x, + w_fp8, + None, + input_scale=[input_scale], + weight_scale=[weight_scale], + input_zp=[], + weight_zp=[], + ) + + +def _fp8_ref_repl_1( + x: torch.Tensor, + w_fp8: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, +): + return torch.ops.auto_deploy.torch_quant_fp8_linear( + x, + w_fp8, + None, + input_scale=input_scale, + weight_scale=weight_scale, + ) + + +# with bias!=None +def _fp8_ref_pattern_2( + x: torch.Tensor, + w_fp8: torch.Tensor, + bias: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, +): + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( + x, + w_fp8, + bias, + input_scale=[input_scale], + weight_scale=[weight_scale], + input_zp=[], + weight_zp=[], + ) + + +def _fp8_ref_repl_2( + x: torch.Tensor, + w_fp8: torch.Tensor, + bias: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, +): + return torch.ops.auto_deploy.torch_quant_fp8_linear( + x, + w_fp8, + bias, + input_scale=input_scale, + weight_scale=weight_scale, + ) + + +# NVFP4: with bias=None +def _fp4_ref_pattern_1( + x: torch.Tensor, + w_fp4: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, +): + return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear( + x, + w_fp4, + None, + input_scale=[input_scale], + weight_scale=[weight_scale, alpha], + input_zp=[], + weight_zp=[], + ) + + +def _fp4_ref_repl_1( + x: torch.Tensor, + w_fp4: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, +): + return torch.ops.auto_deploy.torch_quant_nvfp4_linear( + x, + w_fp4, + bias=None, + input_scale=input_scale, + weight_scale=weight_scale, + alpha=alpha, + ) + + +# with bias!=None +def _fp4_ref_pattern_2( + x: torch.Tensor, + w_fp4: torch.Tensor, + bias: torch.Tensor, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, +): + return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear( + x, + w_fp4, + bias, + input_scale=[input_scale], + weight_scale=[weight_scale, alpha], + input_zp=[], + weight_zp=[], + ) + + +def _fp4_ref_repl_2( + x: torch.Tensor, + w_fp4: torch.Tensor, + bias: torch.Tensor | None, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + alpha: torch.Tensor, +): + return torch.ops.auto_deploy.torch_quant_nvfp4_linear( + x, + w_fp4, + bias=bias, + input_scale=input_scale, + weight_scale=weight_scale, + alpha=alpha, + ) + + +def _register_quant_fp8_linear_patterns(patterns: ADPatternMatcherPass) -> None: + """ + Register FP8 linear patterns with robust dummy args and minimal ignores. + """ + # FP8 dummy tensors + x_fp8 = torch.randn(3, 16, device="meta", dtype=torch.float16) + w_fp8 = torch.randn(32, 16, device="meta", dtype=torch.float16) + bias32 = torch.randn(32, device="meta", dtype=torch.float32) + one = torch.tensor(1.0, device="meta", dtype=torch.float32) + + # no-bias variant + dummy_args_fp8 = [ + x_fp8, + w_fp8, + one, + torch.tensor(0.5, device="meta", dtype=torch.float32), + ] + register_ad_pattern( + search_fn=_fp8_ref_pattern_1, + replace_fn=_fp8_ref_repl_1, + patterns=patterns, + dummy_args=dummy_args_fp8, + ) + + # bias variant + dummy_args_fp8_2 = [ + x_fp8, + w_fp8, + bias32, + one, + torch.tensor(0.5, device="meta", dtype=torch.float32), + ] + register_ad_pattern( + search_fn=_fp8_ref_pattern_2, + replace_fn=_fp8_ref_repl_2, + patterns=patterns, + dummy_args=dummy_args_fp8_2, + ) + + +def _register_quant_fp4_linear_patterns(patterns: ADPatternMatcherPass) -> None: + """ + Register FP4 linear patterns with robust dummy args and minimal ignores. + """ + # FP4 shape params + N = 32 + K_packed = 32 # weight is packed by 2 FP4 per byte + K_eff = 2 * K_packed + + # FP4 dummy tensors + x_fp4 = torch.randn(3, K_eff, device="meta", dtype=torch.float16) + w_fp4 = torch.randint(0, 255, (N, K_packed), device="meta", dtype=torch.uint8) + + s_in2 = torch.tensor(0.01, device="meta", dtype=torch.float32) + alpha = torch.tensor(1.2345, device="meta", dtype=torch.float32) + + cutlass_len = N * (K_eff // 16) # 32 * (64/16) = 128 + cutlass_vec = torch.randint(0, 255, (cutlass_len,), device="meta", dtype=torch.uint8) + + # no-bias variant + dummy_args_fp4_1 = [ + x_fp4, + w_fp4, + s_in2, + cutlass_vec, + alpha, + ] + register_ad_pattern( + search_fn=_fp4_ref_pattern_1, + replace_fn=_fp4_ref_repl_1, + patterns=patterns, + dummy_args=dummy_args_fp4_1, + ) + + # bias variant + dummy_args_fp4_2 = [ + x_fp4, + w_fp4, + torch.randn(N, device="meta", dtype=torch.float16), # bias + s_in2, + cutlass_vec, + alpha, + ] + register_ad_pattern( + search_fn=_fp4_ref_pattern_2, + replace_fn=_fp4_ref_repl_2, + patterns=patterns, + dummy_args=dummy_args_fp4_2, + ) + + +class FuseFP8LinearConfig(TransformConfig): + """Configuration for FP8 linear fusion transform.""" + + backend: str = Field( + default="torch", + description="Backend to use for FP8 linear computation (default: 'torch').", + ) + + +@TransformRegistry.register("fuse_fp8_linear") +class FuseFP8Linear(BaseTransform): + """Matches and replaces FP8 fake quantized linear ops with fused torch backend ops.""" + + config: FuseFP8LinearConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseFP8LinearConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if self.config.backend.lower() != "torch": + raise ValueError(f"Unsupported FP8 backend: {self.config.backend}") + + patterns = ADPatternMatcherPass() + _register_quant_fp8_linear_patterns(patterns) + cnt = patterns.apply(gm.graph) + + info = TransformInfo( + skipped=(cnt == 0), + num_matches=cnt, + is_clean=False, + has_valid_shapes=False, + ) + return gm, info + + +class FuseNVFP4LinearConfig(TransformConfig): + """Configuration for NVFP4 linear fusion transform.""" + + backend: str = Field( + default="trtllm", + description="Backend to use for NVFP4 linear computation (default: 'trtllm').", + ) + + +@TransformRegistry.register("fuse_nvfp4_linear") +class FuseNVFP4Linear(BaseTransform): + """Matches and replaces NVFP4 fake quantized linear ops with fused TensorRT-LLM ops.""" + + config: FuseNVFP4LinearConfig + + @classmethod + def get_config_class(cls) -> Type[TransformConfig]: + return FuseNVFP4LinearConfig + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + if self.config.backend.lower() != "trtllm": + raise ValueError(f"Unsupported NVFP4 backend: {self.config.backend}") + + patterns = ADPatternMatcherPass() + _register_quant_fp4_linear_patterns(patterns) + cnt = patterns.apply(gm.graph) + + info = TransformInfo( + skipped=(cnt == 0), + num_matches=cnt, + is_clean=False, + has_valid_shapes=False, + ) + return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py index 8a395ea9124..d645fa1e8bc 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fused_moe.py @@ -1,5 +1,5 @@ from collections import defaultdict -from typing import Optional, Tuple +from typing import Dict, List, Optional, Tuple import torch from torch.fx import GraphModule, Node @@ -7,8 +7,7 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.cuda_mem_tracker import cuda_memory_tracker -from ...utils.node_utils import bfs, identify_regions_between_residuals, is_linear_op, is_op -from ...utils.quantization_utils import get_scales_and_type_from_node +from ...utils.node_utils import bfs, identify_regions_between_residuals, is_op from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -126,33 +125,42 @@ def lca_two(a: Node, b: Node) -> Optional[Node]: return common -def _extract_linear_parameters(linear_node: Node) -> tuple[Node, torch.Tensor, Optional[dict], str]: +def _extract_linear_parameters( + linear_node: Node, + target_op, + scale_arg_indices: Dict[str, int], +) -> Tuple[Node, Node, Dict[str, Node]]: """ - Given a linear op node, extract the input tensor node, weight tensor, - any quantization scales (if the op is quantized), and return a weight type. + Extract (input_node, weight_node, scales) from a *specific* linear op variant. - For a torch.ops.auto_deploy.torch_linear_simple.default op: - - Returns (input_node, weight, None, "simple") - - For a torch.ops.auto_deploy.torch_quant_fp8_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale}, "fp8") - For a torch.ops.auto_deploy.torch_quant_fp4_linear op: - - Returns (input_node, weight, {"input_scale": input_scale, "weight_scale": weight_scale, "alpha": alpha}, "fp4") + Returns (None, None, {}) if `linear_node` is not the expected target_op. """ + if not is_op(linear_node, target_op): + return None, None, {} + + # Expected argument layout: + # input, weight, (optional bias), then scale args at provided indices. + if not linear_node.args or not isinstance(linear_node.args[0], Node): + return None, None, {} input_node = linear_node.args[0] - if is_op(linear_node, torch.ops.auto_deploy.torch_linear_simple): - weight = linear_node.args[1] - return input_node, weight, None, "" - elif { - is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp4_linear) - or is_op(linear_node, torch.ops.auto_deploy.torch_quant_fp8_linear), - }: - weight = linear_node.args[1] - scales, quant_type = get_scales_and_type_from_node(linear_node) - return input_node, weight, scales or {}, quant_type - - -def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): + weight = linear_node.args[1] + + scales: Dict[str, Node] = {} + for k, idx in scale_arg_indices.items(): + try: + scales[k] = linear_node.args[idx] + except Exception: + return None, None, {} + + return input_node, weight, scales + + +def _match_expert_compute_pattern( + start_boundary: Node, + end_boundary: Node, + target_op, + scale_arg_indices: Dict[str, int], +): """ Match the expert compute pattern between the given boundaries. @@ -166,7 +174,7 @@ def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): This function supports both: - torch.ops.auto_deploy.torch_linear_simple.default ops, and - torch.ops.auto_deploy.torch_quant_fp8_linear ops (also extracts quantization scales). - - torch.ops.auto_deploy.torch_quant_fp4_linear ops (also extracts quantization scales). + - torch.ops.auto_deploy.torch_quant_nvfp4_linear ops (also extracts quantization scales). Returns: A tuple: @@ -182,14 +190,12 @@ def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): pattern_input_nodes, pattern_output_nodes = [], [] expert_weights = defaultdict(list) expert_scales = defaultdict(list) - weight_type = "simple" # default nodes = list(start_boundary.graph.nodes) region_nodes = nodes[nodes.index(start_boundary) + 1 : nodes.index(end_boundary)] for node in region_nodes: - # Accept both simple and quantized linear ops. - if not is_linear_op(node, include_quantization=True): + if not is_op(node, target_op): continue final_linear = node @@ -211,58 +217,44 @@ def _match_expert_compute_pattern(start_boundary: Node, end_boundary: Node): if silu_node is None: continue - if not (silu_node.args and is_linear_op(silu_node.args[0], include_quantization=True)): + if not (silu_node.args and is_op(silu_node.args[0], target_op)): continue linear_w1_node = silu_node.args[0] # The other branch should be a linear op (w3 branch). linear_w3_node = arg_b if arg_a is silu_node else arg_a - if not is_linear_op(linear_w3_node, include_quantization=True): + if not is_op(linear_w3_node, target_op): continue if not (linear_w1_node.args and linear_w3_node.args): continue # Extract parameters from each linear op. - input_node_w1, weight_w1, quant_params_w1, wt_type_w1 = _extract_linear_parameters( - linear_w1_node + input_node_w1, weight_w1, s_w1 = _extract_linear_parameters( + linear_w1_node, target_op, scale_arg_indices ) - _, weight_w3, quant_params_w3, wt_type_w3 = _extract_linear_parameters(linear_w3_node) - _, weight_w2, quant_params_w2, wt_type_w2 = _extract_linear_parameters(final_linear) + _, weight_w3, s_w3 = _extract_linear_parameters( + linear_w3_node, target_op, scale_arg_indices + ) + _, weight_w2, s_w2 = _extract_linear_parameters(final_linear, target_op, scale_arg_indices) if None in (weight_w1, weight_w3, weight_w2): continue - # Ensure the weight type is consistent across branches. - if wt_type_w1 != wt_type_w3 or wt_type_w1 != wt_type_w2: - continue - weight_type = wt_type_w1 - pattern_input_nodes.append(input_node_w1) pattern_output_nodes.append(final_linear) expert_weights["w1"].append(weight_w1) expert_weights["w3"].append(weight_w3) expert_weights["w2"].append(weight_w2) - # TODO: sanity check that all experts have same weight type - if weight_type == "fp8": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - elif weight_type == "fp4": - expert_scales["w1_input_scale"].append(quant_params_w1["input_scale"]) - expert_scales["w1_weight_scale"].append(quant_params_w1["weight_scale"]) - expert_scales["w1_alpha"].append(quant_params_w1["alpha"]) - expert_scales["w3_input_scale"].append(quant_params_w3["input_scale"]) - expert_scales["w3_weight_scale"].append(quant_params_w3["weight_scale"]) - expert_scales["w3_alpha"].append(quant_params_w3["alpha"]) - expert_scales["w2_input_scale"].append(quant_params_w2["input_scale"]) - expert_scales["w2_weight_scale"].append(quant_params_w2["weight_scale"]) - expert_scales["w2_alpha"].append(quant_params_w2["alpha"]) - - return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales, weight_type + # Collect scales per-branch with keys "w{1|2|3}_" + for key, node_scale in s_w1.items(): + expert_scales[f"w1_{key}"].append(node_scale) + for key, node_scale in s_w3.items(): + expert_scales[f"w3_{key}"].append(node_scale) + for key, node_scale in s_w2.items(): + expert_scales[f"w2_{key}"].append(node_scale) + + return pattern_input_nodes, pattern_output_nodes, expert_weights, expert_scales def _find_final_hidden_state_node( @@ -369,8 +361,24 @@ def target(n: torch.fx.Node) -> bool: return False -@TransformRegistry.register("match_moe_pattern") class MatchMoePattern(BaseTransform): + """Base MoE pattern matcher; subclasses specify linear and fused MoE ops and scale layouts.""" + + # Subclasses must implement: + def target_op(self): # linear op to match + raise NotImplementedError + + def moe_op(self): # fused MoE op to insert + raise NotImplementedError + + def scale_arg_indices(self) -> Dict[str, int]: + """Map scale names -> arg index in the matched linear op.""" + raise NotImplementedError + + def scale_keys(self) -> List[str]: + """Order of scale keys to emit into fused MoE op (e.g., ['input_scale','weight_scale',...]).""" + raise NotImplementedError + def _apply( self, gm: GraphModule, @@ -385,6 +393,11 @@ def _apply( num_moe_patterns = 0 + lin_op = self.target_op() + scale_idx = self.scale_arg_indices() + scale_keys = self.scale_keys() + fused_moe = self.moe_op() + for start_boundary, end_boundary in zip(boundary_nodes[:-1], boundary_nodes[1:]): # Step 1: Identify Expert Compute pattern ( @@ -392,8 +405,12 @@ def _apply( pattern_output_nodes, expert_weights, expert_scales, - weight_type, - ) = _match_expert_compute_pattern(start_boundary, end_boundary) + ) = _match_expert_compute_pattern( + start_boundary, + end_boundary, + target_op=lin_op, + scale_arg_indices=scale_idx, + ) if not expert_weights: continue # TODO: naming convention to verify the order of the weight nodes @@ -434,58 +451,27 @@ def _apply( w2_list = expert_weights["w2"] w3_list = expert_weights["w3"] - if weight_type == "fp8": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp8_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - ), - ) - elif weight_type == "fp4": - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_quant_fp4_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - expert_scales["w1_input_scale"], - expert_scales["w2_input_scale"], - expert_scales["w3_input_scale"], - expert_scales["w1_weight_scale"], - expert_scales["w2_weight_scale"], - expert_scales["w3_weight_scale"], - expert_scales["w1_alpha"], - expert_scales["w2_alpha"], - expert_scales["w3_alpha"], - ), - ) - else: - fused_moe_node = graph.call_function( - torch.ops.auto_deploy.torch_moe, - args=( - hidden_states, - selected_experts, - normalized_routing_weights, - w1_list, - w2_list, - w3_list, - ), + fused_args = [ + hidden_states, + selected_experts, + normalized_routing_weights, + w1_list, + w2_list, + w3_list, + ] + + # Append scales as: for each key -> (w1_key_list, w2_key_list, w3_key_list) + for key in scale_keys: + fused_args.extend( + [ + expert_scales[f"w1_{key}"], + expert_scales[f"w2_{key}"], + expert_scales[f"w3_{key}"], + ] ) + fused_moe_node = graph.call_function(fused_moe, args=tuple(fused_args)) + final_hidden_state_node.replace_all_uses_with(fused_moe_node) graph.erase_node(final_hidden_state_node) @@ -500,6 +486,57 @@ def _apply( return gm, info +@TransformRegistry.register("match_moe_pattern") +class MatchSimpleMoePattern(MatchMoePattern): + """Match and fuse simple (unquantized) MoE subgraph.""" + + def target_op(self): + return torch.ops.auto_deploy.torch_linear_simple + + def moe_op(self): + return torch.ops.auto_deploy.torch_moe + + def scale_arg_indices(self) -> Dict[str, int]: + return {} + + def scale_keys(self) -> List[str]: + return [] + + +@TransformRegistry.register("match_fp8_moe_pattern") +class MatchFP8MoePattern(MatchMoePattern): + """Match and fuse FP8-quantized MoE subgraph.""" + + def target_op(self): + return torch.ops.auto_deploy.torch_quant_fp8_linear + + def moe_op(self): + return torch.ops.auto_deploy.torch_quant_fp8_moe + + def scale_arg_indices(self) -> Dict[str, int]: + return {"input_scale": 3, "weight_scale": 4} + + def scale_keys(self) -> List[str]: + return ["input_scale", "weight_scale"] + + +@TransformRegistry.register("match_nvfp4_moe_pattern") +class MatchNVFP4MoePattern(MatchMoePattern): + """Match and fuse NVFP4-quantized MoE subgraph.""" + + def target_op(self): + return torch.ops.auto_deploy.torch_quant_nvfp4_linear + + def moe_op(self): + return torch.ops.auto_deploy.torch_quant_nvfp4_moe + + def scale_arg_indices(self) -> Dict[str, int]: + return {"input_scale": 3, "weight_scale": 4, "alpha": 5} + + def scale_keys(self) -> List[str]: + return ["input_scale", "weight_scale", "alpha"] + + @TransformRegistry.register("fuse_moe") class FuseMoe(BaseTransform): """ diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py index 2d422c42d6b..8a4dd058437 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/fusion.py @@ -1,6 +1,9 @@ import operator +from abc import ABC, abstractmethod from collections import defaultdict -from typing import List, Tuple +from functools import partial +from itertools import chain +from typing import Callable, Dict, List, Tuple import torch import torch.nn as nn @@ -14,8 +17,8 @@ extract_param_names_from_lin_node, get_op_overload_packet, is_linear_op, + is_op, ) -from ...utils.quantization_utils import QuantizationImpl from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry @@ -43,8 +46,6 @@ def _insert_fused_gemm(gm: GraphModule, idx: int, parent_node: Node, linear_node sizes_unfused = [p.size(0) for p in params_unfused] key_fused = f"fused_weight_{idx}" - quantization_impls = [QuantizationImpl.create(n) for n in linear_nodes] - def fuse_weights(tensors: List[torch.Tensor]) -> torch.Tensor: """Fuse weights of linear nodes.""" return torch.cat(tensors, dim=0) @@ -53,36 +54,7 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]: """Split the output tensor of the fused linear node to obtain the original outputs.""" return tuple(t.contiguous() for t in torch.split(tensor, sizes_unfused, dim=-1)) - if all( - q is not None and quantization_impls[0].target_op() == q.target_op() - for q in quantization_impls - ): - scales = {} - for weight_key in keys_unfused: - key = weight_key.rsplit(".", 1)[0] - - for scale_name in quantization_impls[0].scale_names(): - buffer_name = key + "." + scale_name - scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name)) - - try: - weights_fused, buffer_fused = quantization_impls[0].fuse_linear_weights( - params_unfused, **scales - ) - except NotImplementedError as e: - ad_logger.warning(f"Cannot fuse ops {keys_unfused}, skipping: {e}") - return - param_fused = nn.Parameter(weights_fused, requires_grad=False) - - for scale_name, buffer in buffer_fused.items(): - fused_buffer_name = key_fused + "_" + scale_name - gm.register_buffer(fused_buffer_name, buffer) - - elif all(q is None for q in quantization_impls): - param_fused = nn.Parameter(fuse_weights([gm.get_parameter(k) for k in keys_unfused])) - else: - ad_logger.warning(f"Cannot fuse ops {keys_unfused} for mixed-precision linear nodes.") - return + param_fused = nn.Parameter(fuse_weights([gm.get_parameter(k) for k in keys_unfused])) setattr(gm, key_fused, param_fused) @@ -91,12 +63,6 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]: with gm.graph.inserting_before(linear_nodes[0]): get_param_node = gm.graph.get_attr(key_fused, torch.Tensor) - if quantization_impls[0]: - for scale_name in quantization_impls[0].scale_names(): - # Creates new nodes for the fused scales so the unfused linear ops can be fully erased. - fused_kwargs[scale_name] = gm.graph.create_node( - "get_attr", key_fused + "_" + scale_name - ) # add new linear node + split node with gm.graph.inserting_before(linear_nodes[0]): @@ -118,6 +84,146 @@ def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]: gm.delete_all_unused_submodules() +def check_same_children(parent_node: Node, is_desired_child: Callable[[Node], bool]) -> bool: + """ + Return True iff *all* direct users of `parent_node` satisfy `is_desired_child`. + """ + users_dict = getattr(parent_node, "users", None) + if not users_dict: + return False + users = list(users_dict.keys()) if isinstance(users_dict, dict) else list(users_dict) + if not users: + return False + return all(is_desired_child(u) for u in users) + + +class QuantizationFusionMixin(ABC): + """ + Mixin that factors out the shared logic for fusing quantized GEMMs + that share the same input activation (parent node). + + Subclasses must define: + - target_op: the torch op identifying the quantized linear + - scale_groups: List[List[str]] describing how kwargs should be grouped, e.g. + FP8 -> [["input_scale"], ["weight_scale"]] + FP4 -> [["input_scale"], ["weight_scale", "alpha"]] + - fuse_rule(weights, **scales) -> (fused_weight, fused_buffers: Dict[str, Tensor]) + which takes: + weights: List[Tensor] # unfused per-out features (stacked along dim=0) + **scales: Dict[str, List[Tensor]] with keys = flattened(scale_groups) + and returns: + fused_weight: Tensor + fused_buffers: Dict[str, Tensor] to register as buffers on the fused module + """ + + target_op: Callable + scale_groups: List[List[str]] + + @abstractmethod + def fuse_rule( + self, weights: List[torch.Tensor], **scales + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + raise NotImplementedError + + @abstractmethod + def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]: + """Return the *positional* tail after bias for the fused call.""" + raise NotImplementedError + + def _insert_fused_quant_gemm( + self, gm: GraphModule, idx: int, parent_node: Node, linear_nodes: List[Node] + ): + keys_unfused = [extract_param_names_from_lin_node(n)[0] for n in linear_nodes] + params_unfused = [gm.get_parameter(k) for k in keys_unfused] + sizes_unfused = [p.size(0) for p in params_unfused] + key_fused = f"fused_weight_{idx}" + + def split_output(tensor: torch.Tensor) -> Tuple[torch.Tensor, ...]: + """Split the output tensor of the fused linear node to obtain the original outputs.""" + return tuple(t.contiguous() for t in torch.split(tensor, sizes_unfused, dim=-1)) + + # Load scale buffers grouped by flattened scale names + flat_scale_names = list(chain.from_iterable(self.scale_groups)) + scales: Dict[str, List[torch.Tensor]] = {} + for weight_key in keys_unfused: + key = weight_key.rsplit(".", 1)[0] + for scale_name in flat_scale_names: + buffer_name = key + "." + scale_name + scales.setdefault(scale_name, []).append(gm.get_buffer(buffer_name)) + + try: + weights_fused, buffers_fused = self.fuse_rule(params_unfused, **scales) + except NotImplementedError as e: + ad_logger.warning(f"Cannot fuse ops {keys_unfused}, skipping: {e}") + return + param_fused = nn.Parameter(weights_fused, requires_grad=False) + setattr(gm, key_fused, param_fused) + for name, buf in buffers_fused.items(): + gm.register_buffer(f"{key_fused}_{name}", buf) + + # Handle fused_kwargs for quantized fused gemm. + fused_kwargs = dict(linear_nodes[0].kwargs) + with gm.graph.inserting_before(linear_nodes[0]): + get_param_node = gm.graph.get_attr(key_fused, torch.Tensor) + + # For each kwarg group (e.g., input_scale, weight_scale[, alpha]), + # create a list of get_attr nodes in the same structure the op expects. + scale_getattrs: Dict[str, Node] = { + name: gm.graph.create_node("get_attr", f"{key_fused}_{name}") + for name in flat_scale_names + } + custom_tail_args = self.build_custom_args_for_linear(scale_getattrs) + + # add new linear node + split node + with gm.graph.inserting_before(linear_nodes[0]): + fused_linear_node = gm.graph.call_function( + get_op_overload_packet(linear_nodes[0].target), + args=(parent_node, get_param_node, None, *custom_tail_args), + kwargs=fused_kwargs, + ) + split_node = gm.graph.call_function(split_output, args=(fused_linear_node,)) + + # now we need to replace all the linear nodes with the correct index of the split node + for i, n in enumerate(linear_nodes): + with gm.graph.inserting_before(n): + get_split_node = gm.graph.call_function(operator.getitem, args=(split_node, i)) + n.replace_all_uses_with(get_split_node) + + # Clean up deleted modules to save GPU memory + gm.graph.eliminate_dead_code() + gm.delete_all_unused_submodules() + + def _apply_fusion_pass( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # Group quantized linear nodes by their parent (same activation) + quant_linear_nodes = defaultdict(list) + for node in gm.graph.nodes: + if is_op(node, self.target_op) and node.args[2] is None: + quant_linear_nodes[node.args[0]].append(node) + + idx = -1 + num_matches = 0 + with cuda_memory_tracker(): + for parent_node, lin_children in quant_linear_nodes.items(): + if len(lin_children) < 2: + continue + if not check_same_children(parent_node, partial(is_op, ops=self.target_op)): + # Mixed children (e.g., quantized or non-linear) — skip fusion + continue + self._insert_fused_quant_gemm(gm, idx := idx + 1, parent_node, lin_children) + num_matches += 1 + + torch.cuda.empty_cache() + return gm, TransformInfo( + skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False + ) + + @TransformRegistry.register("fuse_gemms") class FuseGemms(BaseTransform): def _apply( @@ -131,7 +237,7 @@ def _apply( linear_nodes = defaultdict(list) for node in gm.graph.nodes: # TODO: we don't handle bias for now... - if is_linear_op(node, include_quantization=True) and node.args[2] is None: + if is_linear_op(node) and node.args[2] is None: linear_nodes[node.args[0]].append(node) # fuse linear nodes @@ -141,6 +247,9 @@ def _apply( for parent_node, lin_children in linear_nodes.items(): if len(lin_children) < 2: continue + if not check_same_children(parent_node, is_linear_op): + # Mixed children (e.g., quantized or non-linear) — skip fusion + continue # linear nodes to fuse _insert_fused_gemm(gm, idx := idx + 1, parent_node, lin_children) num_matches += 1 @@ -151,3 +260,96 @@ def _apply( skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=False ) return gm, info + + +@TransformRegistry.register("fuse_fp8_gemms") +class FuseFP8Gemms(QuantizationFusionMixin, BaseTransform): + target_op = torch.ops.auto_deploy.torch_fake_quant_fp8_linear + scale_groups = [["input_scale"], ["weight_scale"]] + + def fuse_rule( + self, weights: List[torch.Tensor], **scales + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + weight_scale: List[torch.Tensor] = scales["weight_scale"] + input_scale: List[torch.Tensor] = scales["input_scale"] + + if not all(s == input_scale[0] for s in input_scale): + raise NotImplementedError(f"Cannot fuse due to mismatched input_scale {input_scale}") + + # Handle quantized weights with weight_scale. + # First we upcast to FP32 precision and then downcast back to the original precision (FP8) + assert weights[0].dtype == torch.float8_e4m3fn, "Only support FP8 quantized weights fusion." + fused_fp32_weights = torch.cat( + [t.to(torch.float) * s for t, s in zip(weights, weight_scale)], dim=0 + ) + new_weight_scale = torch.max(torch.stack(weight_scale)) + fused_fp8_weights = (fused_fp32_weights / new_weight_scale).to(weights[0].dtype) + + return fused_fp8_weights, { + "weight_scale": new_weight_scale, + "input_scale": input_scale[0].clone(), + } + + def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]: + # (..., bias, input_scale(list), weight_scale(list), input_zp(list), weight_zp(list)) + return ( + [scale_getattrs["input_scale"]], + [scale_getattrs["weight_scale"]], + [], + [], + ) + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + return self._apply_fusion_pass(gm, cm, factory, shared_config) + + +@TransformRegistry.register("fuse_fp4_gemms") +class FuseFP4Gemms(QuantizationFusionMixin, BaseTransform): + target_op = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear + scale_groups = [["input_scale"], ["weight_scale", "alpha"]] + + def fuse_rule( + self, weights: List[torch.Tensor], **scales + ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: + weight_scale: List[torch.Tensor] = scales["weight_scale"] + input_scale: List[torch.Tensor] = scales["input_scale"] + alpha: List[torch.Tensor] = scales["alpha"] + + if not all(s == input_scale[0] for s in input_scale): + raise NotImplementedError(f"Cannot fuse due to mismatched input_scale {input_scale}") + + if not all(s == alpha[0] for s in alpha): + raise NotImplementedError(f"Cannot fuse due to mismatched alpha {alpha}") + + fused_weights = torch.cat(weights, dim=0) + fused_weight_scale = torch.cat(weight_scale, dim=0) + + return fused_weights, { + "weight_scale": fused_weight_scale, + "alpha": alpha[0], + "input_scale": input_scale[0].clone(), + } + + def build_custom_args_for_linear(self, scale_getattrs: Dict[str, Node]) -> Tuple[object, ...]: + # (..., bias, input_scale(list), weight_scale(list-with-alpha), input_zp(list), weight_zp(list)) + return ( + [scale_getattrs["input_scale"]], + [scale_getattrs["weight_scale"], scale_getattrs["alpha"]], + [], + [], + ) + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + return self._apply_fusion_pass(gm, cm, factory, shared_config) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py index 7f0a55b9ee0..fb45660776f 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py @@ -1,9 +1,16 @@ from functools import partial -from typing import Tuple +from typing import Dict, List, Tuple +import torch import torch.nn as nn from torch.fx import GraphModule, Node +from ...custom_ops.quant import ( + FP4_GLOBAL_SCALE_MAX, + FP8_MAX, + TRTLLM_NVFP4_SCALING_VECTOR_SIZE, + is_column_major, +) from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import ( @@ -13,7 +20,8 @@ is_linear_op, ) from ...utils.quantization_utils import ( - QuantizationImpl, + fp4_global_scale, + fp8_scale, get_quantization_from_linear_node, is_quantized_graph, is_quantized_op, @@ -22,157 +30,452 @@ ) from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry +try: + from .....quantization.utils.fp4_utils import float4_sf_dtype +except ImportError: + float4_sf_dtype = None -def _insert_quantized_linear( - gm: GraphModule, - node: Node, - quantization_impl: QuantizationImpl, - is_quantized_graph: bool = False, -): - """Replaces the matmul node with a new quantized matmul node. - - The state_dict is also updated to contain the sharded weights. - """ - param_name, _ = extract_param_names_from_lin_node(node) - original_weight = gm.get_parameter(param_name) - new_param = nn.Parameter( - quantization_impl.quantize_weight(original_weight), requires_grad=False - ) - modname, _, attrname = param_name.rpartition(".") - - submod = gm.get_submodule(modname) - setattr(submod, attrname, new_param) - - # check modelopt quantizers from graph - if is_quantized_graph: - input_params, weight_params, output_params = get_quantization_params_from_linear_node(node) - # redirect to input and weight - node.args = (input_params.input_node, weight_params.input_node, *node.args[2:]) - - # redirect output to skip output quantizer if any - user = list(node.users.keys())[0] - if len(node.users) == 1 and is_quantized_op(user): - user.replace_all_uses_with(node) - - # when loading the state_dict, we need to convert input amax to input scale - input_scale_name = quantization_impl.scale_names()[0] - gm._register_load_state_dict_pre_hook( - partial( - quantization_impl.convert_amax_hook, - scale_name=modname + "." + input_scale_name, - amax_name=input_params.amax.target, - ) - ) - # Note: canonicalize_graph() will remove input/weight/output quantizer - for scale_name, scale in quantization_impl.default_scales(original_weight.shape).items(): - submod.register_buffer(scale_name, scale) +class Quantization(BaseTransform): + """Abstract base for config-driven quantization of a single algorithm/op-kind. - gm._register_load_state_dict_pre_hook( - partial(quantization_impl.load_hook, weight_name=param_name) - ) - - node.target = quantization_impl.target_op() - - with gm.graph.inserting_before(node): - scales = {} - for scale_name in quantization_impl.scale_names(): - scales[scale_name] = gm.graph.create_node("get_attr", modname + "." + scale_name) - - node.kwargs = {**node.kwargs, **scales} + Subclasses MUST implement: + - algo_name: str # e.g., "FP8" or "NVFP4" + - target_op(self) -> Callable + - quantize_weight(self, w: Tensor) -> Tensor + - scale_names(self) -> List[str] + - default_scales(self, shape: Tuple) -> Dict[str, Tensor] + - build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple + - _apply(self, gm, cm, factory, shared_config) -> (gm, TransformInfo) + Optional (define only if needed): + - load_hook(self, state_dict, prefix, *args, weight_name: str) + - post_load_hook(self, module, incompatible_keys, weight_name: str) + - convert_amax_hook(self, state_dict, prefix, *args, scale_name: str, amax_name: str) + """ -def _insert_quantized_bmm( - gm: GraphModule, - node: Node, - quantization_impl: QuantizationImpl, - is_quantized_graph: bool = False, -): - """Replaces the bmm node with a new quantized bmm node.""" - weight_node = node.args[1] + algo_name: str = None # override in subclasses + + # Algorithm API + @staticmethod + def target_op(): + """Returns the target quantization ops.""" + raise NotImplementedError("Abstract Interface") + + @staticmethod + def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: + """Returns the quantized weight from the original unquantized weight.""" + raise NotImplementedError("Abstract Interface") + + @staticmethod + def scale_names() -> List[str]: + """Returns the list of names of the scales for this quantization.""" + return [] + + @staticmethod + def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: + """Returns a dict of the default scale values for this quantization.""" + return {} + + @staticmethod + def load_hook(state_dict, prefix, *args, weight_name: str): + """Load hook for state_dict quantization pre-processing.""" + pass + + @staticmethod + def post_load_hook(state_dict, prefix, *args, weight_name: str): + """Load hook for state_dict quantization post-processing.""" + pass + + @staticmethod + def convert_amax_hook(state_dict, prefix, *args, scale_name: str, amax_name: str): + """Convert amax from modelopt quantized graph to scales.""" + pass + + @staticmethod + def build_custom_args_for_linear( # renamed to reflect args + scale_getattrs: Dict[str, Node], + ) -> Tuple[object, ...]: + return () + + # Transform logic for ModelOPT linear layers + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + qcfg = factory.get_quant_config() + if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True + ) - # Weight is a parameter - if weight_node.op == "get_attr": - # Handle parameter tensor - param_name = weight_node.target - original_weight = gm.get_parameter(param_name) - weight_shape = original_weight.shape + excluded = qcfg.get("exclude_modules", []) + cnt = 0 + for n in gm.graph.nodes: + if not is_linear_op(n): + continue + if should_skip_quantization(n, excluded): + continue + self._insert_quantized_linear(gm, n, is_quantized_graph=False) + cnt += 1 - # Quantize the weight - new_param = nn.Parameter( - quantization_impl.quantize_weight(original_weight), requires_grad=False + return gm, TransformInfo( + skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True ) - # Update the parameter in the model + def _insert_quantized_linear( + self, + gm: GraphModule, + node: Node, + is_quantized_graph: bool = False, + ): + """Replaces the matmul node with a new custom quantized linear node. + + The state_dict is also updated to contain the sharded weights. + """ + param_name, _ = extract_param_names_from_lin_node(node) + original_weight = gm.get_parameter(param_name) + new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) modname, _, attrname = param_name.rpartition(".") + submod = gm.get_submodule(modname) setattr(submod, attrname, new_param) - # Register load state dict hook - gm._register_load_state_dict_pre_hook( - partial(quantization_impl.load_hook, weight_name=param_name) - ) - if quantization_impl.post_load_hook: - gm.register_load_state_dict_post_hook( - partial(quantization_impl.post_load_hook, weight_name=param_name) + # check modelopt quantizers from graph + if is_quantized_graph: + input_params, weight_params, output_params = get_quantization_params_from_linear_node( + node ) + # redirect to input and weight + node.args = (input_params.input_node, weight_params.input_node, *node.args[2:]) + + # redirect output to skip output quantizer if any + user = list(node.users.keys())[0] + if len(node.users) == 1 and is_quantized_op(user): + user.replace_all_uses_with(node) + + # when loading the state_dict, we need to convert input amax to input scale + input_scale_name = self.scale_names()[0] + gm._register_load_state_dict_pre_hook( + partial( + self.convert_amax_hook, + scale_name=modname + "." + input_scale_name, + amax_name=input_params.amax.target, + ) + ) + # Note: canonicalize_graph() will remove input/weight/output quantizer - # Setup scale names and target module for parameter case - def get_scale_name(scale_name): - return attrname + "_" + scale_name - - scale_target_module = submod - scale_name_prefix = f"{modname}." - - # Weight is a dynamic tensor - elif hasattr(weight_node, "meta") and "val" in weight_node.meta: - weight_shape = weight_node.meta["val"].shape - - # Create a unique identifier for this dynamic weight node - node_id = f"bmm_dynamic_{id(node)}" - - # Setup scale names and target module for dynamic case - def get_scale_name(scale_name): - return f"{node_id}_{scale_name}" + for scale_name, scale in self.default_scales(original_weight.shape).items(): + submod.register_buffer(scale_name, scale) - scale_target_module = gm # Register in root module - scale_name_prefix = "" + gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name)) - else: - # If we can't determine the shape, skip quantization - return + with gm.graph.inserting_before(node): + scales = {} + for scale_name in self.scale_names(): + scales[scale_name] = gm.graph.create_node("get_attr", modname + "." + scale_name) - # Common logic for both parameter and dynamic tensor cases - # Register scales in the target module - for scale_name, scale in quantization_impl.default_scales(weight_shape).items(): - scale_buffer_name = get_scale_name(scale_name) - scale_target_module.register_buffer(scale_buffer_name, scale) + custom_args = self.build_custom_args_for_linear(scales) - # Change node target to quantized bmm op - node.target = quantization_impl.target_op() + node.target = self.target_op() + node.args = (*node.args, *custom_args) - # Insert scale nodes - with gm.graph.inserting_before(node): - scales = {} - for scale_name in quantization_impl.scale_names(): + def _insert_quantized_bmm( + self, + gm: GraphModule, + node: Node, + is_quantized_graph: bool = False, + ) -> bool: + """Replace a bmm op with its quantized equivalent and wire scales/state_dict hooks. + + Returns: + True if quantization was applied; False if skipped (e.g., unknown shape). + """ + weight_node = node.args[1] + + # Weight is a parameter + if weight_node.op == "get_attr": + # Handle parameter tensor + param_name = weight_node.target + original_weight = gm.get_parameter(param_name) + weight_shape = original_weight.shape + + # Quantize the weight + new_param = nn.Parameter(self.quantize_weight(original_weight), requires_grad=False) + + # Update the parameter in the model + modname, _, attrname = param_name.rpartition(".") + submod = gm.get_submodule(modname) + setattr(submod, attrname, new_param) + + # Register load state dict hook + gm._register_load_state_dict_pre_hook(partial(self.load_hook, weight_name=param_name)) + if self.post_load_hook: + gm.register_load_state_dict_post_hook( + partial(self.post_load_hook, weight_name=param_name) + ) + + # Setup scale names and target module for parameter case + def get_scale_name(scale_name): + return attrname + "_" + scale_name + + scale_target_module = submod + scale_name_prefix = f"{modname}." + + # Weight is a dynamic tensor + elif hasattr(weight_node, "meta") and "val" in weight_node.meta: + weight_shape = weight_node.meta["val"].shape + + # Create a unique identifier for this dynamic weight node + node_id = f"bmm_dynamic_{id(node)}" + + # Setup scale names and target module for dynamic case + def get_scale_name(scale_name): + return f"{node_id}_{scale_name}" + + scale_target_module = gm # Register in root module + scale_name_prefix = "" + + else: + # If we can't determine the shape, skip quantization + return False + + # Common logic for both parameter and dynamic tensor cases + # Register scales in the target module + for scale_name, scale in self.default_scales(weight_shape).items(): scale_buffer_name = get_scale_name(scale_name) - scales[scale_name] = gm.graph.create_node( - "get_attr", f"{scale_name_prefix}{scale_buffer_name}" + scale_target_module.register_buffer(scale_buffer_name, scale) + + # Change node target to quantized bmm op + node.target = self.target_op() + + # Insert scale nodes + with gm.graph.inserting_before(node): + scales = {} + for scale_name in self.scale_names(): + scale_buffer_name = get_scale_name(scale_name) + scales[scale_name] = gm.graph.create_node( + "get_attr", f"{scale_name_prefix}{scale_buffer_name}" + ) + + # Update node arguments and kwargs + scale_values = [scales[scale_name] for scale_name in self.scale_names()] + node.args = (*node.args, *scale_values) + return True + + +@TransformRegistry.register("quantize_fp8_linear_from_config") +class FP8LinearQuantizationFromConfig(Quantization): + algo_name = "FP8" + + def target_op(self): + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default + + def quantize_weight(self, w: torch.Tensor) -> torch.Tensor: + return torch.empty_like(w, dtype=torch.float8_e4m3fn, device=w.device) + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale"] + + def default_scales(self, _shape: Tuple) -> Dict[str, torch.Tensor]: + return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)} + + def build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple: + # (input_scale(list), weight_scale(list), input_zp(list), weight_zp(list)) + return ([scales["input_scale"]], [scales["weight_scale"]], [], []) + + def load_hook(self, state_dict, prefix, *args, weight_name): + if weight_name in state_dict: + weight = state_dict[weight_name] + if weight.dtype != torch.float8_e4m3fn: + scale = fp8_scale(state_dict[weight_name]) + state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn) + state_dict[weight_name + "_scale"] = scale + + def convert_amax_hook(self, state_dict, prefix, *args, scale_name: str, amax_name: str): + """Convert amax from modelopt quantized graph to scales.""" + if amax_name in state_dict: + amax = state_dict[amax_name] + scale = amax / FP8_MAX + state_dict[scale_name] = scale + + +@TransformRegistry.register("quantize_nvfp4_linear_from_config") +class NVFP4LinearQuantizationFromConfig(Quantization): + algo_name = "NVFP4" + + def target_op(self): + return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear.default + + def quantize_weight(self, w: torch.Tensor) -> torch.Tensor: + m, n = w.shape + return torch.empty((m, n // 2), dtype=torch.uint8, device=w.device) + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale", "alpha"] + + def default_scales(self, original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: + m, n = original_weight_shape + # scaling factors m is padded along 128 and n is padded along 4. + # check cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp for more details. + n = n // TRTLLM_NVFP4_SCALING_VECTOR_SIZE + padded_m = (m + 127) // 128 * 128 + padded_n = (n + 3) // 4 * 4 + # definition of scales + # input_scale: FP4_GLOBAL_SCALE_MAX / input_amax + # weight_scale_2: FP4_GLOBAL_SCALE_MAX / weight_amax + # alpha: 1 / (input_scale * weight_scale_2) + return { + "input_scale": torch.tensor(1.0 / 6.0), + "weight_scale": torch.empty((padded_m * padded_n), dtype=torch.uint8), + "alpha": torch.tensor(1.0 / 6.0), + } + + def build_custom_args_for_linear(self, scales: Dict[str, Node]) -> Tuple: + # weight_scale list is (cutlass_vec, alpha) + return ([scales["input_scale"]], [scales["weight_scale"], scales["alpha"]], [], []) + + def load_hook(self, state_dict, prefix, *args, weight_name): + if weight_name in state_dict: + input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale" + alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha" + weight = state_dict[weight_name] + # ModelOpt quantized graph path + if weight.dtype != torch.uint8: + assert input_scale_name in state_dict + # Unquantized weight + amax_name = weight_name + "_quantizer._amax" + if amax_name in state_dict: + weight_scale_2 = FP4_GLOBAL_SCALE_MAX / state_dict[amax_name].to(torch.float) + else: + weight_scale_2 = fp4_global_scale(weight) + weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( + weight.to("cuda"), + weight_scale_2.to("cuda"), + TRTLLM_NVFP4_SCALING_VECTOR_SIZE, + False, + ) + state_dict[weight_name] = weight_fp4 + state_dict[weight_name + "_scale"] = weight_scale + state_dict[weight_name + "_scale_2"] = weight_scale_2 + state_dict[alpha_name] = 1 / (weight_scale_2 * state_dict[input_scale_name]) + # Unified HF ckpt path + else: + if ( + weight_name + "_scale_2" in state_dict + and weight_name + "_scale" in state_dict + and input_scale_name in state_dict + and float4_sf_dtype + ): + state_dict[alpha_name] = ( + state_dict[weight_name + "_scale_2"] * state_dict[input_scale_name] + ) + state_dict[input_scale_name] = 1 / state_dict[input_scale_name] + weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype) + ori_shape = weight_scale.shape + state_dict[weight_name + "_scale"] = ( + torch.ops.trtllm.block_scale_interleave( + weight_scale.view(torch.uint8).cpu().contiguous() + ) + .reshape(ori_shape) + .view(float4_sf_dtype) + .reshape(-1) + ) + + def convert_amax_hook(self, state_dict, prefix, *args, scale_name: str, amax_name: str): + """Convert amax from modelopt quantized graph to scales.""" + if amax_name in state_dict: + amax = state_dict[amax_name] + scale = ((448 * 6) / amax).float() + state_dict[scale_name] = scale + + +@TransformRegistry.register("quantize_fp8_bmm_from_config") +class FP8BMMQuantizationFromConfig(Quantization): + algo_name = "FP8" + + def target_op(self): + return torch.ops.auto_deploy.torch_quant_fp8_bmm + + def quantize_weight(self, w: torch.Tensor) -> torch.Tensor: + return torch.empty_like(w, dtype=torch.float8_e4m3fn, device=w.device) + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale"] + + def default_scales(self, _shape: Tuple) -> Dict[str, torch.Tensor]: + return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)} + + def load_hook(self, state_dict, prefix, *args, weight_name): + """Pre-hook: Only handle quantization.""" + if weight_name in state_dict: + weight = state_dict[weight_name] + + # If weight is not already quantized (not float8) + if weight.dtype != torch.float8_e4m3fn: + # Compute weight scale + weight_scale = fp8_scale(weight) + weight = (weight / weight_scale).to(torch.float8_e4m3fn) + state_dict[weight_name + "_scale"] = weight_scale + state_dict[weight_name] = weight + + def post_load_hook(self, module, incompatible_keys, weight_name): + """Post-hook: Handle column-major conversion after parameter is loaded.""" + # Navigate to the actual parameter + *path, attr_name = weight_name.split(".") + target_module = module + for p in path: + target_module = getattr(target_module, p) + + if hasattr(target_module, attr_name): + param = getattr(target_module, attr_name) + if isinstance(param, torch.nn.Parameter): + # Convert to column-major format + if not is_column_major(param): + with torch.no_grad(): + # Create column-major version + param_cm = param.transpose(-2, -1).contiguous().transpose(-2, -1) + # Replace the parameter + setattr( + target_module, + attr_name, + torch.nn.Parameter(param_cm, requires_grad=param.requires_grad), + ) + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + qcfg = factory.get_quant_config() + if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name: + return gm, TransformInfo( + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - # Update node arguments and kwargs - scale_values = [scales[scale_name] for scale_name in quantization_impl.scale_names()] - node.args = (*node.args, *scale_values) + excluded = qcfg.get("exclude_modules", []) + cnt = 0 + for n in gm.graph.nodes: + if not is_bmm_op(n): + continue + if should_skip_quantization(n, excluded): + continue + if self._insert_quantized_bmm(gm, n, is_quantized_graph=False): + cnt += 1 + return gm, TransformInfo( + skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True + ) -@TransformRegistry.register("quantize_from_config") -class QuantizationFromConfig(BaseTransform): - """ - Quantize linear and BMM ops using a quantization config. - Replaces eligible ops with quantized equivalents based on the quantization algorithm - and exclude patterns defined in the config. - """ +@TransformRegistry.register("quantize_fp8_from_graph") +class FP8QuantizationFromGraph(FP8LinearQuantizationFromConfig): + """Fuse ModelOpt-quantized FP8 linears into fused ops.""" def _apply( self, @@ -181,50 +484,29 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - quant_config = factory.get_quant_config() - if not quant_config: - return gm, TransformInfo( - skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True - ) - quant_algo = quant_config.get("quant_algo", None) - excluded_patterns = quant_config.get("exclude_modules", []) - if not quant_algo: + if not is_quantized_graph(gm): return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - num_matches = 0 - + cnt = 0 for n in gm.graph.nodes: - if should_skip_quantization(n, excluded_patterns): - continue - - if is_linear_op(n, include_quantization=False): - impl = QuantizationImpl.create(quant_algo, is_bmm=False) - _insert_quantized_linear(gm, n, impl, False) - num_matches += 1 - - # TODO: Make _insert_quantized_bmm return a bool and increment only on success - elif is_bmm_op(n): - impl = QuantizationImpl.create(quant_algo, is_bmm=True) - _insert_quantized_bmm(gm, n, impl, False) - num_matches += 1 + if is_linear_op(n): + algo_n = get_quantization_from_linear_node(n) + if (algo_n or "").upper() != "FP8": + continue + self._insert_quantized_linear(gm, n, is_quantized_graph=True) + cnt += 1 - info = TransformInfo( - skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True + remove_output_quantizers(gm) + return gm, TransformInfo( + skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True ) - return gm, info - -@TransformRegistry.register("quantize_from_graph") -class QuantizationFromGraph(BaseTransform): - """ - Fuse ModelOpt-quantized linear ops into fused quantized implementations. - - Detects quantized nodes from ModelOpt checkpoints's graph and replaces them with - fused linear ops based on the quantization type. - """ +@TransformRegistry.register("quantize_nvfp4_from_graph") +class NVFP4QuantizationFromGraph(NVFP4LinearQuantizationFromConfig): + """Fuse ModelOpt-quantized NVFP4 linears into fused ops.""" def _apply( self, @@ -233,34 +515,21 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - is_quant_graph = is_quantized_graph(gm) - - # no quantization to do - if not is_quant_graph: + if not is_quantized_graph(gm): return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - # tracking quantized operations in the graph - num_matches = 0 + cnt = 0 for n in gm.graph.nodes: - # Process linear operations - if is_linear_op(n, include_quantization=False): - # get per-layer quantization format from the node - quant_algo_n: str = get_quantization_from_linear_node(n) - if not quant_algo_n: + if is_linear_op(n): + algo_n = get_quantization_from_linear_node(n) + if (algo_n or "").upper() != "NVFP4": continue - - # insert quantized linear node - _insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), True) - num_matches += 1 - - # To check: quant BMM does not have graph based pass? + self._insert_quantized_linear(gm, n, is_quantized_graph=True) + cnt += 1 remove_output_quantizers(gm) - - info = TransformInfo( - skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True + return gm, TransformInfo( + skipped=False, num_matches=cnt, is_clean=False, has_valid_shapes=True ) - - return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py index e930543aeff..d25ad7c270a 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py @@ -8,19 +8,24 @@ from ...models.factory import ModelFactory from ...shim.interface import CachedSequenceInterface from ...utils.node_utils import is_op -from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization -from ..interface import BaseTransform, SharedConfig, TransformInfo, TransformRegistry +from ...utils.quantization_utils import should_skip_quantization +from ..interface import SharedConfig, TransformInfo, TransformRegistry +from .quantization import ( + FP8LinearQuantizationFromConfig, + NVFP4LinearQuantizationFromConfig, + Quantization, +) quantized_moe_op_map = { "FP8": torch.ops.auto_deploy.torch_quant_fp8_moe, - "NVFP4": torch.ops.auto_deploy.torch_quant_fp4_moe, + "NVFP4": torch.ops.auto_deploy.torch_quant_nvfp4_moe, } def _quantize_moe_node( gm: GraphModule, node: Node, - quant_impl: QuantizationImpl, + quant_impl: Quantization, quantized_op: Callable[..., Node], ): """ @@ -131,13 +136,16 @@ def _unwrap_list(arg) -> List[str]: return w1_names, w2_names, w3_names -@TransformRegistry.register("quantize_moe") -class QuantizeMOE(BaseTransform): +@TransformRegistry.register("quantize_fp8_moe") +class QuantizeFP8MOE(FP8LinearQuantizationFromConfig): """ Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the quantized version using the quant_algo from quant_config. """ + def target_op(self): + return torch.ops.auto_deploy.torch_quant_fp8_moe + def _apply( self, gm: GraphModule, @@ -145,39 +153,86 @@ def _apply( factory: ModelFactory, shared_config: SharedConfig, ) -> Tuple[GraphModule, TransformInfo]: - quant_config = factory.get_quant_config() - quant_algo = quant_config.get("quant_algo") if quant_config else None - - if not quant_config or not quant_algo: + # Gate by algo in quant_config + qcfg = factory.get_quant_config() + if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name: return gm, TransformInfo( skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) - excluded_patterns = quant_config.get("exclude_modules", []) - - quant_impl = QuantizationImpl.create(quant_algo) - quantized_op = quantized_moe_op_map[quant_algo] + excluded_patterns = qcfg.get("exclude_modules", []) count = 0 for node in list(gm.graph.nodes): - if is_op(node, torch.ops.auto_deploy.torch_moe): - # Check that all expert weights should be quantized - w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node) - if any( - should_skip_quantization(n, excluded_patterns) - for n in w1_names + w2_names + w3_names - ): - continue - _quantize_moe_node(gm, node, quant_impl, quantized_op) - count += 1 - - if count == 0: + if not is_op(node, torch.ops.auto_deploy.torch_moe): + continue + + # Check experts are allowed (no excludes) + w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node) + if any( + should_skip_quantization(n, excluded_patterns) + for n in (w1_names + w2_names + w3_names) + ): + continue + + _quantize_moe_node(gm, node, self, self.target_op()) + count += 1 + + info = TransformInfo( + skipped=(count == 0), + num_matches=count, + is_clean=(count == 0), + has_valid_shapes=True, + ) + return gm, info + + +@TransformRegistry.register("quantize_nvfp4_moe") +class QuantizeNVFP4MOE(NVFP4LinearQuantizationFromConfig): + """ + Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the + quantized version using the quant_algo from quant_config. + """ + + def target_op(self): + return torch.ops.auto_deploy.torch_quant_nvfp4_moe + + def _apply( + self, + gm: GraphModule, + cm: CachedSequenceInterface, + factory: ModelFactory, + shared_config: SharedConfig, + ) -> Tuple[GraphModule, TransformInfo]: + # Gate by algo in quant_config + qcfg = factory.get_quant_config() + if not qcfg or qcfg.get("quant_algo", "").upper() != self.algo_name: return gm, TransformInfo( - skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True + skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True ) + excluded_patterns = qcfg.get("exclude_modules", []) + count = 0 + + for node in list(gm.graph.nodes): + if not is_op(node, torch.ops.auto_deploy.torch_moe): + continue + + # Check experts are allowed (no excludes) + w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node) + if any( + should_skip_quantization(n, excluded_patterns) + for n in (w1_names + w2_names + w3_names) + ): + continue + + _quantize_moe_node(gm, node, self, self.target_op()) + count += 1 + info = TransformInfo( - skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False + skipped=(count == 0), + num_matches=count, + is_clean=(count == 0), + has_valid_shapes=True, ) - return gm, info diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index c37b627240f..a85389c94a7 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -31,6 +31,7 @@ from ...utils.node_utils import ( filtered_nodes, identify_regions_between_residuals, + is_fake_quantized_linear_op, is_linear_op, is_op, ) @@ -109,8 +110,8 @@ def _append_simple_shard( for node_group in nodes_linear.values(): for n in node_group: tp_shards.append( - TPShardingInfo( - target_node=n.name, + TPShardingInfo.from_node( + n, split_dim=SplitDimension.COLUMN, rank=rank, world_size=world_size, @@ -312,8 +313,8 @@ def detect_sharding_from_factory_config( config = tp_plan[key] if config == "colwise": sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, + TPShardingInfo.from_node( + lin_node, split_dim=SplitDimension.COLUMN, rank=rank, world_size=world_size, @@ -323,8 +324,8 @@ def detect_sharding_from_factory_config( ) elif config == "rowwise": sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, + TPShardingInfo.from_node( + lin_node, split_dim=SplitDimension.ROW, rank=rank, world_size=world_size, @@ -342,8 +343,8 @@ def detect_sharding_from_factory_config( elif "gather" in config: # Simple shard (row + all_gather) sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=lin_node.name, + TPShardingInfo.from_node( + lin_node, split_dim=SplitDimension.COLUMN, rank=rank, world_size=world_size, @@ -455,7 +456,7 @@ def detect_column_row_shard( unaccounted_nodes: Set[Node] = set() current_node = n_start while current_node != n_end: - if is_linear_op(current_node, include_quantization=True): + if is_linear_op(current_node) or is_fake_quantized_linear_op(current_node): nodes_linear[current_node.args[0]].append(current_node) elif is_op(current_node, shardable_attention_nodes): attention_nodes.add(current_node) @@ -540,8 +541,8 @@ def detect_column_row_shard( else: dist_op = None sharding_config.tp_transforms.append( - TPShardingInfo( - target_node=n.name, + TPShardingInfo.from_node( + n, split_dim=i, rank=rank, world_size=world_size, @@ -654,13 +655,13 @@ def detect_ep_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Transfo ( torch.ops.auto_deploy.torch_moe, torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, + torch.ops.auto_deploy.torch_quant_nvfp4_moe, ), ): continue sharding_config.ep_transforms.append( - EPShardingInfo( - target_node=node.name, + EPShardingInfo.from_node( + node, rank=rank, world_size=world_size, ) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 6cc98616d47..c0d76c0548b 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -8,7 +8,6 @@ from torch._ops import OpOverload, OpOverloadPacket from torch.fx import Graph, GraphModule, Node -from ..custom_ops.quant import QUANT_BMM_OPS, QUANT_LINEAR_OPS from .logger import ad_logger try: @@ -152,9 +151,6 @@ def extract_param_names_from_lin_node(mm_node: Node) -> Tuple[str, Optional[str] Args: mm_node: Matmul node in the graph. """ - assert is_linear_op(mm_node, include_quantization=True) or is_bmm_op(mm_node), ( - f"Expecting linear or bmm node, Found: {mm_node}" - ) weight_node = extract_weight_node(mm_node) assert weight_node, "Cannot identify weight parameter of linear node." @@ -251,7 +247,7 @@ def filtered_nodes( yield node -def is_linear_op(node: Node, include_quantization: bool = False) -> bool: +def is_linear_op(node: Node) -> bool: """Check if the node is a linear op. Using this function is preferred over `is_op` for linear ops to ensure all variants are covered. @@ -261,19 +257,22 @@ def is_linear_op(node: Node, include_quantization: bool = False) -> bool: torch.ops.auto_deploy.torch_linear_simple, } - if include_quantization: - lin_ops.update(QUANT_LINEAR_OPS) return is_op(node, lin_ops) -def is_bmm_op(node: Node, include_quantization: bool = False) -> bool: - """Check if the node is a distributed op.""" - dist_ops = {torch.ops.aten.bmm} +def is_fake_quantized_linear_op(node: Node) -> bool: + quantized_linear_op = { + torch.ops.auto_deploy.torch_fake_quant_fp8_linear, + torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear, + } - if include_quantization: - dist_ops.update(QUANT_BMM_OPS) + return is_op(node, quantized_linear_op) - return is_op(node, dist_ops) + +def is_bmm_op(node: Node) -> bool: + bmm_ops = {torch.ops.aten.bmm} + + return is_op(node, bmm_ops) def is_dist_op(node: Node) -> bool: diff --git a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py index 05878745980..873f12227f5 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/quantization_utils.py @@ -1,16 +1,11 @@ from fnmatch import fnmatch -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, Optional, Tuple, Union import torch import torch.nn.functional as F from torch.fx import GraphModule, Node -from ..custom_ops.quant import ( - FP4_GLOBAL_SCALE_MAX, - FP8_MAX, - TRTLLM_NVFP4_SCALING_VECTOR_SIZE, - is_column_major, -) +from ..custom_ops.quant import FP4_GLOBAL_SCALE_MAX, FP8_MAX from .logger import ad_logger from .node_utils import ( extract_param_names_from_lin_node, @@ -65,305 +60,10 @@ def fp8_scale(input: torch.Tensor) -> torch.Tensor: return torch.max(torch.abs(input).to(torch.float)) / FP8_MAX -class QuantizationImpl: - """An abstracted static class for node quantization.""" - - @staticmethod - def create(quant_type_or_node: Union[str, Node], is_bmm: bool = False): - """Returns the QuantizationImpl based on quantization type or quantized node. - - Args: - quant_type_or_node: Quantization type string or quantized node - is_bmm: Whether the operation is BMM (batch matrix multiplication) - """ - if isinstance(quant_type_or_node, str): - if is_bmm: - quantization_impl_map = { - "": None, - "FP8": FP8BMMQuantizationImpl, - } - else: - quantization_impl_map = { - "": None, - "FP8": FP8QuantizationImpl, - "NVFP4": FP4QuantizationImpl, - } - return quantization_impl_map[quant_type_or_node] - - for q in [ - FP4QuantizationImpl, - FP8QuantizationImpl, - FP8BMMQuantizationImpl, - ]: - if is_op(quant_type_or_node, q.target_op()): - return q - return None - - @staticmethod - def target_op(): - """Returns the target quantization ops.""" - raise NotImplementedError("Abstract Interface") - - @staticmethod - def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: - """Returns the quantized weight from the original unquantized weight.""" - raise NotImplementedError("Abstract Interface") - - @staticmethod - def scale_names() -> List[str]: - """Returns the list of names of the scales for this quantization.""" - return [] - - @staticmethod - def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: - """Returns a dict of the default scale values for this quantization.""" - return {} - - @staticmethod - def load_hook(state_dict, prefix, *args, weight_name: str): - """Load hook for state_dict quantization pre-processing.""" - pass - - @staticmethod - def post_load_hook(state_dict, prefix, *args, weight_name: str): - """Load hook for state_dict quantization post-processing.""" - pass - - @staticmethod - def convert_amax_hook(state_dict, prefix, *args, scale_name: str, amax_name: str): - """Convert amax from modelopt quantized graph to scales.""" - pass - - @staticmethod - def shard_scales(dim, rank, world_size, **kwargs) -> Dict[str, torch.Tensor]: - """Returns a dict of sharded quantization scales.""" - return {} - - @staticmethod - def shard_load_hook( - state_dict, - prefix, - *args, - weight_name: str, - weight_shape: Tuple, - dim: int, - rank: int, - world_size: int, - ): - """Load hook for state_dict quantized sharding pre-processing. - - This load_hook handles the sharding of the quantization scales. - """ - pass - - @staticmethod - def fuse_linear_weights(weights, **kwargs) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - pass - - -class FP8QuantizationImpl(QuantizationImpl): - @staticmethod - def target_op(): - return torch.ops.auto_deploy.torch_quant_fp8_linear - - @staticmethod - def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: - return torch.empty_like( - original_weight, dtype=torch.float8_e4m3fn, device=original_weight.device - ) - - @staticmethod - def scale_names() -> List[str]: - return ["input_scale", "weight_scale"] - - @staticmethod - def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: - return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)} - - @staticmethod - def load_hook(state_dict, prefix, *args, weight_name): - if weight_name in state_dict: - weight = state_dict[weight_name] - if weight.dtype != torch.float8_e4m3fn: - scale = fp8_scale(state_dict[weight_name]) - state_dict[weight_name] = (state_dict[weight_name] / scale).to(torch.float8_e4m3fn) - state_dict[weight_name + "_scale"] = scale - - @staticmethod - def convert_amax_hook(state_dict, prefix, *args, scale_name: str, amax_name: str): - """Convert amax from modelopt quantized graph to scales.""" - if amax_name in state_dict: - amax = state_dict[amax_name] - scale = amax / FP8_MAX - state_dict[scale_name] = scale - - @staticmethod - def fuse_linear_weights( - weights, weight_scale, input_scale - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - if not all(s == input_scale[0] for s in input_scale): - raise NotImplementedError(f"Cannot fuse due to mismatched input_scale {input_scale}") - - # Handle quantized weights with weight_scale. - # First we upcast to FP32 precision and then downcast back to the original precision (FP8) - assert weights[0].dtype == torch.float8_e4m3fn, "Only support FP8 quantized weights fusion." - fused_fp32_weights = torch.cat( - [t.to(torch.float) * s for t, s in zip(weights, weight_scale)], dim=0 - ) - new_weight_scale = torch.max(torch.stack(weight_scale)) - fused_fp8_weights = (fused_fp32_weights / new_weight_scale).to(weights[0].dtype) - - return fused_fp8_weights, { - "weight_scale": new_weight_scale, - "input_scale": input_scale[0].clone(), - } - - -def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, world_size): - assert weight_scale.dim() == 1 - weight_shape_original = list(sharded_uint8_weight_shape) - weight_shape_original[dim] = weight_shape_original[dim] * world_size - weight_shape_original[-1] *= 2 - modelopt_weight_scale = cutlass_fp4_scale_to_modelopt_fp4_scale( - weight_scale, tuple(weight_shape_original) - ) - return modelopt_fp4_scale_to_cutlass_fp4_scale( - modelopt_weight_scale.tensor_split(world_size, dim=dim)[rank] - ) - - -class FP4QuantizationImpl(QuantizationImpl): - @staticmethod - def target_op(): - return torch.ops.auto_deploy.torch_quant_fp4_linear - - @staticmethod - def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: - m, n = original_weight.shape - return torch.empty((m, n // 2), dtype=torch.uint8, device=original_weight.device) - - @staticmethod - def scale_names() -> List[str]: - return ["input_scale", "weight_scale", "alpha"] - - @staticmethod - def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: - m, n = original_weight_shape - # scaling factors m is padded along 128 and n is padded along 4. - # check cpp/tensorrt_llm/plugins/fp4GemmPlugin/fp4GemmPlugin.cpp for more details. - n = n // TRTLLM_NVFP4_SCALING_VECTOR_SIZE - padded_m = (m + 127) // 128 * 128 - padded_n = (n + 3) // 4 * 4 - # definition of scales - # input_scale: FP4_GLOBAL_SCALE_MAX / input_amax - # weight_scale_2: FP4_GLOBAL_SCALE_MAX / weight_amax - # alpha: 1 / (input_scale * weight_scale_2) - return { - "input_scale": torch.tensor(1.0 / 6.0), - "weight_scale": torch.empty((padded_m * padded_n), dtype=torch.uint8), - "alpha": torch.tensor(1.0 / 6.0), - } - - @staticmethod - def load_hook(state_dict, prefix, *args, weight_name): - if weight_name in state_dict: - input_scale_name = weight_name.rsplit(".", 1)[0] + ".input_scale" - alpha_name = weight_name.rsplit(".", 1)[0] + ".alpha" - weight = state_dict[weight_name] - # ModelOpt quantized graph path - if weight.dtype != torch.uint8: - assert input_scale_name in state_dict - # Unquantized weight - amax_name = weight_name + "_quantizer._amax" - if amax_name in state_dict: - weight_scale_2 = FP4_GLOBAL_SCALE_MAX / state_dict[amax_name].to(torch.float) - else: - weight_scale_2 = fp4_global_scale(weight) - weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( - weight.to("cuda"), - weight_scale_2.to("cuda"), - TRTLLM_NVFP4_SCALING_VECTOR_SIZE, - False, - ) - state_dict[weight_name] = weight_fp4 - state_dict[weight_name + "_scale"] = weight_scale - state_dict[weight_name + "_scale_2"] = weight_scale_2 - state_dict[alpha_name] = 1 / (weight_scale_2 * state_dict[input_scale_name]) - # Unified HF ckpt path - else: - if ( - weight_name + "_scale_2" in state_dict - and weight_name + "_scale" in state_dict - and input_scale_name in state_dict - and float4_sf_dtype - ): - state_dict[alpha_name] = ( - state_dict[weight_name + "_scale_2"] * state_dict[input_scale_name] - ) - state_dict[input_scale_name] = 1 / state_dict[input_scale_name] - weight_scale = state_dict[weight_name + "_scale"].view(float4_sf_dtype) - ori_shape = weight_scale.shape - state_dict[weight_name + "_scale"] = ( - torch.ops.trtllm.block_scale_interleave( - weight_scale.view(torch.uint8).cpu().contiguous() - ) - .reshape(ori_shape) - .view(float4_sf_dtype) - .reshape(-1) - ) - - def convert_amax_hook(state_dict, prefix, *args, scale_name: str, amax_name: str): - """Convert amax from modelopt quantized graph to scales.""" - if amax_name in state_dict: - amax = state_dict[amax_name] - scale = ((448 * 6) / amax).float() - state_dict[scale_name] = scale - - @staticmethod - def shard_scales(dim, rank, world_size, weight_scale, alpha, input_scale, weight_shape): - result = {} - result["alpha"] = alpha - result["input_scale"] = input_scale - result["weight_scale"] = _shard_fp4_weight_scale( - weight_scale, weight_shape, dim, rank, world_size - ) - - return result - - @staticmethod - def shard_load_hook( - state_dict, prefix, *args, weight_name, weight_shape, dim, rank, world_size - ): - if weight_name + "_scale" in state_dict: - weight_scale = state_dict[weight_name + "_scale"] - state_dict[weight_name + "_scale"] = _shard_fp4_weight_scale( - weight_scale, weight_shape, dim, rank, world_size - ) - - @staticmethod - def fuse_linear_weights( - weights, weight_scale, alpha, input_scale - ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: - if not all(s == input_scale[0] for s in input_scale): - raise NotImplementedError(f"Cannot fuse due to mismatched input_scale {input_scale}") - - if not all(s == alpha[0] for s in alpha): - raise NotImplementedError(f"Cannot fuse due to mismatched alpha {alpha}") - - fused_weights = torch.cat(weights, dim=0) - fused_weight_scale = torch.cat(weight_scale, dim=0) - - return fused_weights, { - "weight_scale": fused_weight_scale, - "alpha": alpha[0], - "input_scale": input_scale[0].clone(), - } - - def is_quantized_graph(gm: GraphModule): """Check if the graph is quantized by modelopt.""" for n in gm.graph.nodes: - if is_linear_op(n, include_quantization=False): + if is_linear_op(n): input_params, weight_params, output_params = get_quantization_params_from_linear_node(n) if input_params or weight_params or output_params: return True @@ -384,7 +84,7 @@ def is_quantized_op(node: Node): def remove_output_quantizers(gm: GraphModule): """Remove output quatnizer if any from the graph.""" for n in gm.graph.nodes: - if is_linear_op(n, include_quantization=False) and len(n.users) == 1: + if is_linear_op(n) and len(n.users) == 1: user = list(n.users.keys())[0] if is_quantized_op(user): # skip the output quantizer @@ -407,66 +107,6 @@ def get_quantization_from_linear_node(node: torch.fx.node.Node): return "" -class FP8BMMQuantizationImpl(QuantizationImpl): - """Implementation of FP8 quantization for BMM operations.""" - - @staticmethod - def target_op(): - return torch.ops.auto_deploy.torch_quant_fp8_bmm - - @staticmethod - def quantize_weight(original_weight: torch.Tensor) -> torch.Tensor: - return torch.empty_like( - original_weight, dtype=torch.float8_e4m3fn, device=original_weight.device - ) - - @staticmethod - def scale_names() -> List[str]: - return ["input_scale", "weight_scale"] - - @staticmethod - def default_scales(original_weight_shape: Tuple) -> Dict[str, torch.Tensor]: - return {"input_scale": torch.tensor(1.0), "weight_scale": torch.tensor(1.0)} - - @staticmethod - def load_hook(state_dict, prefix, *args, weight_name): - """Pre-hook: Only handle quantization.""" - if weight_name in state_dict: - weight = state_dict[weight_name] - - # If weight is not already quantized (not float8) - if weight.dtype != torch.float8_e4m3fn: - # Compute weight scale - weight_scale = fp8_scale(weight) - weight = (weight / weight_scale).to(torch.float8_e4m3fn) - state_dict[weight_name + "_scale"] = weight_scale - state_dict[weight_name] = weight - - @staticmethod - def post_load_hook(module, incompatible_keys, weight_name): - """Post-hook: Handle column-major conversion after parameter is loaded.""" - # Navigate to the actual parameter - *path, attr_name = weight_name.split(".") - target_module = module - for p in path: - target_module = getattr(target_module, p) - - if hasattr(target_module, attr_name): - param = getattr(target_module, attr_name) - if isinstance(param, torch.nn.Parameter): - # Convert to column-major format - if not is_column_major(param): - with torch.no_grad(): - # Create column-major version - param_cm = param.transpose(-2, -1).contiguous().transpose(-2, -1) - # Replace the parameter - setattr( - target_module, - attr_name, - torch.nn.Parameter(param_cm, requires_grad=param.requires_grad), - ) - - def should_skip_quantization( node_or_name: Union[Node, str], excluded_patterns: list[str], @@ -475,7 +115,7 @@ def should_skip_quantization( if isinstance(node_or_name, str): modname, _, _ = node_or_name.rpartition(".") else: - if not (is_linear_op(node_or_name, include_quantization=False) or is_bmm_op(node_or_name)): + if not (is_linear_op(node_or_name) or is_bmm_op(node_or_name)): return True param_name, _ = extract_param_names_from_lin_node(node_or_name) modname, _, _ = param_name.rpartition(".") @@ -500,13 +140,3 @@ def extract_scales_from_node(node: Node, scale_names: list[str]) -> Dict[str, Op scales[name] = args[3 + i] return scales - - -def get_scales_and_type_from_node(node: Node) -> Tuple[Dict[str, Node], str]: - """Returns a dict of scale args and quantization type string ('fp4', 'fp8', etc).""" - for qtype in [FP4QuantizationImpl, FP8QuantizationImpl]: - if is_op(node, qtype.target_op()): - return extract_scales_from_node( - node, qtype.scale_names() - ), qtype.__name__.lower().replace("quantizationimpl", "") - return None, "simple" diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 7f9833a2d18..40680ada291 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -5,7 +5,7 @@ from abc import ABC, abstractmethod from enum import IntEnum from functools import partial -from typing import Any, Callable, Dict, List, Literal, Optional +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence import torch import torch.nn as nn @@ -15,7 +15,10 @@ from ..models.factory import ShardingConfigSource from ..utils.logger import ad_logger from .node_utils import extract_param_names_from_lin_node, is_op, num_users_of_weight_node -from .quantization_utils import QuantizationImpl +from .quantization_utils import ( + cutlass_fp4_scale_to_modelopt_fp4_scale, + modelopt_fp4_scale_to_cutlass_fp4_scale, +) def _load_hook( @@ -59,6 +62,9 @@ def _insert_sharded_matmul( world_size: int, add_dist: bool = False, min_local_shape: int = 1, + quantization_cb: Optional[ + Callable[[GraphModule, nn.Module, Node, str, torch.Size, int, int, int], None] + ] = None, ) -> None: """Replace the matmul node with a new matmul node that accepts sharded weights. @@ -67,8 +73,6 @@ def _insert_sharded_matmul( assert dim in [0, 1], "Only dim 0 and 1 are supported for sharding" assert add_dist or dim == 0, "For dim=1 sharding, dist_op is required." - quantization_impl = QuantizationImpl.create(node) - def split_tensor( t: torch.Tensor, d: int = dim, @@ -105,8 +109,7 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to None if remove else nn.Parameter( - split_tensor(gm.get_parameter(param_key)).detach().clone(), - requires_grad=quantization_impl is None, + split_tensor(gm.get_parameter(param_key)).detach().clone(), requires_grad=False ) ) @@ -142,24 +145,16 @@ def set_new_param(submod: nn.Module, param_key: str, remove: bool = False) -> to set_new_param(submod, bias_key, remove=True) gm._register_load_state_dict_pre_hook(partial(_load_hook_remove, param_key=bias_key)) - if quantization_impl: - scales = {} - for scale_name in quantization_impl.scale_names(): - scales[scale_name] = submod.get_buffer(scale_name) - scales["weight_shape"] = weight_new_shape - sharded_scales = quantization_impl.shard_scales(dim, rank, world_size, **scales) - for k, v in sharded_scales.items(): - submod.register_buffer(k, v) - - gm._register_load_state_dict_pre_hook( - partial( - quantization_impl.shard_load_hook, - weight_name=weight_key, - weight_shape=weight_new_shape, - dim=dim, - rank=rank, - world_size=world_size, - ) + if quantization_cb is not None: + quantization_cb( + gm=gm, + submod=submod, + node=node, + weight_key=weight_key, + weight_new_shape=weight_new_shape, + dim=dim, + rank=rank, + world_size=world_size, ) # no comm node needed for single device @@ -234,6 +229,14 @@ class TPShardingInfo(ShardingTransformInfo): dist_op: Optional[Literal["all_reduce", "all_gather"]] = None min_local_shape: int = 1 + @classmethod + def from_node(cls, node: Node, **kwargs) -> "TPShardingInfo": + """ + Create the correct TPShardingInfo subclass (FP8/FP4/base) based on `node`. + """ + subcls = _resolve_tp_cls_from_node(node) + return subcls(target_node=node.name, **kwargs) + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: """Validate the transformation configuration.""" if self.dist_op is not None: @@ -265,6 +268,201 @@ def apply(self, gm: GraphModule, node: Node) -> None: ) +class QuantizationShardingMixin(ABC): + """ + Mixin that provides a callback to handle quantization-aware sharding: + - shards/rewrites scale buffers + - registers the quantized shard load hook + """ + + @abstractmethod + def scale_names(self) -> List[str]: ... + + def shard_scales( + self, + dim: int, + rank: int, + world_size: int, + weight_shape: torch.Size, + **scales: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + return {k: v for k, v in scales.items() if isinstance(v, torch.Tensor)} + + def shard_load_hook( + self, + state_dict, + prefix, + *args, + weight_name: str, + weight_shape: torch.Size, + dim: int, + rank: int, + world_size: int, + ) -> None: + return + + def quantization_cb( + self, + gm: GraphModule, + submod: nn.Module, + node: Node, + weight_key: str, + weight_new_shape: torch.Size, + dim: int, + rank: int, + world_size: int, + ) -> None: + scales = {} + for scale_name in self.scale_names(): + scales[scale_name] = submod.get_buffer(scale_name) + scales["weight_shape"] = weight_new_shape + sharded_scales = self.shard_scales(dim, rank, world_size, **scales) + for k, v in sharded_scales.items(): + submod.register_buffer(k, v) + + gm._register_load_state_dict_pre_hook( + partial( + self.shard_load_hook, + weight_name=weight_key, + weight_shape=weight_new_shape, + dim=dim, + rank=rank, + world_size=world_size, + ) + ) + + +class FP8TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): + """Tensor-parallel sharding for FP8-quantized linears.""" + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale"] + + def shard_scales( + self, + dim: int, + rank: int, + world_size: int, + weight_shape: torch.Size, + *, + input_scale: torch.Tensor, + weight_scale: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + return { + "input_scale": input_scale, + "weight_scale": weight_scale, + } + + def shard_load_hook( + self, + state_dict, + prefix, + *args, + weight_name: str, + weight_shape: torch.Size, + dim: int, + rank: int, + world_size: int, + ) -> None: + return + + def apply(self, gm: GraphModule, node: Node) -> None: + _insert_sharded_matmul( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + quantization_cb=self.quantization_cb, # quant callback + ) + + +def _shard_fp4_weight_scale(weight_scale, sharded_uint8_weight_shape, dim, rank, world_size): + assert weight_scale.dim() == 1 + weight_shape_original = list(sharded_uint8_weight_shape) + weight_shape_original[dim] = weight_shape_original[dim] * world_size + weight_shape_original[-1] *= 2 + modelopt_weight_scale = cutlass_fp4_scale_to_modelopt_fp4_scale( + weight_scale, tuple(weight_shape_original) + ) + return modelopt_fp4_scale_to_cutlass_fp4_scale( + modelopt_weight_scale.tensor_split(world_size, dim=dim)[rank] + ) + + +class FP4TPShardingInfo(QuantizationShardingMixin, TPShardingInfo): + """Tensor-parallel sharding for FP4-quantized linears.""" + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale", "alpha"] + + def shard_scales( + self, + dim: int, + rank: int, + world_size: int, + weight_shape: torch.Size, + *, + weight_scale: torch.Tensor, + alpha: torch.Tensor, + input_scale: torch.Tensor, + ) -> Dict[str, torch.Tensor]: + return { + "alpha": alpha, + "input_scale": input_scale, + "weight_scale": _shard_fp4_weight_scale( + weight_scale, weight_shape, dim, rank, world_size + ), + } + + def shard_load_hook( + self, + state_dict, + prefix, + *args, + weight_name: str, + weight_shape: torch.Size, + dim: int, + rank: int, + world_size: int, + ) -> None: + key = weight_name + "_scale" + if key in state_dict: + state_dict[key] = _shard_fp4_weight_scale( + state_dict[key], weight_shape, dim, rank, world_size + ) + + def apply(self, gm: GraphModule, node: Node) -> None: + _insert_sharded_matmul( + gm=gm, + node=node, + dim=self.split_dim.value, + rank=self.rank, + world_size=self.world_size, + add_dist=self.dist_op is not None, + min_local_shape=self.min_local_shape, + quantization_cb=self.quantization_cb, # quant callback + ) + + +TP_SHARDING_RULES = [ + (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear), FP8TPShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear), FP4TPShardingInfo), +] + + +def _resolve_tp_cls_from_node(node: Node): + for pred, cls in TP_SHARDING_RULES: + try: + if pred(node): + return cls + except Exception: + pass + return TPShardingInfo + + class BMMShardingInfo(ShardingTransformInfo): """Configuration for BMM sharding transformations.""" @@ -372,13 +570,13 @@ def _insert_sharded_moe( node: Node, rank: int, world_size: int, + scale_names: Sequence[str] = (), ): """Update the torch_moe node with sharded weight lists, sharded `selected_experts` and `final_scales(router_logics)`. Add an all_reduce node after the moe node. """ - quant_impl = QuantizationImpl.create(node) - scale_names = quant_impl.scale_names() if quant_impl else [] + scale_names = list(scale_names) num_experts = len(node.args[3]) args = list(node.args) @@ -460,23 +658,74 @@ class EPShardingInfo(ShardingTransformInfo): rank: int world_size: int + @classmethod + def from_node(cls, node: Node, **kwargs) -> "EPShardingInfo": + """ + Create the correct EPShardingInfo subclass (FP8/NVFP4/base) based on `node`. + """ + subcls = _resolve_ep_cls_from_node(node) + return subcls(target_node=node.name, **kwargs) + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: """Validate the transformation configuration.""" - if not is_op( - node, - ( - torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, - ), - ): + if not is_op(node, torch.ops.auto_deploy.torch_moe): ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") return False return True def apply(self, gm: GraphModule, node: Node) -> None: """Apply EP sharding transformation to the graph module.""" - _insert_sharded_moe(gm, node, self.rank, self.world_size) + _insert_sharded_moe(gm, node, self.rank, self.world_size, []) + + +class FP8EPShardingInfo(EPShardingInfo, QuantizationShardingMixin): + """FP8-specific EP sharding behavior.""" + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + if not is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe): + ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") + return False + return True + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale"] + + def apply(self, gm: GraphModule, node: Node) -> None: + _insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names()) + + +class NVFP4EPShardingInfo(EPShardingInfo, QuantizationShardingMixin): + """NVFP4-specific EP sharding behavior.""" + + def validate(self, gm: GraphModule = None, node: Node = None) -> bool: + if not is_op(node, torch.ops.auto_deploy.torch_quant_nvfp4_moe): + ad_logger.warning(f"EP sharding is only supported for MOE nodes. Skipping {self}.") + return False + return True + + def scale_names(self) -> List[str]: + return ["input_scale", "weight_scale", "alpha"] + + def apply(self, gm: GraphModule, node: Node) -> None: + _insert_sharded_moe(gm, node, self.rank, self.world_size, self.scale_names()) + + +EP_SHARDING_RULES = [ + (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_fp8_moe), FP8EPShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_moe), NVFP4EPShardingInfo), + (lambda n: is_op(n, torch.ops.auto_deploy.torch_moe), EPShardingInfo), +] + + +def _resolve_ep_cls_from_node(node: Node) -> type[EPShardingInfo]: + for pred, cls in EP_SHARDING_RULES: + try: + if pred(node): + return cls + except Exception: + # Missing op variant in this build or other harmless issues — keep trying. + pass + return EPShardingInfo class ShardingConfig(BaseModel): diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py index 9548bed96e5..b00ee2bb97a 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py @@ -264,5 +264,7 @@ def run_sharding_pattern_detection_test( # Convert to sets for unordered comparison detected_set = set(detected_transformations) expected_set = set(expected_transformations) + print("detected_set", detected_set) + print("expected_set", expected_set) assert detected_set == expected_set, "Expected sharding pattern does not match detected pattern" diff --git a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py index 1bc03eebd0e..968b52013d1 100644 --- a/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py +++ b/tests/unittest/_torch/auto_deploy/_utils_test/_model_test_utils.py @@ -253,6 +253,28 @@ def forward(self, x): return torch.bmm(x, dynamic_weights) +FP8_MAX = torch.finfo(torch.float8_e4m3fn).max + + +class FakeFP8Linear(nn.Linear): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + device = self.weight.device + amax = self.weight.detach().abs().max().to(torch.float) + eps = torch.finfo(torch.float32).tiny + weight_scale = torch.clamp(amax / FP8_MAX, min=eps).to(device) + self.weight = nn.Parameter((self.weight / weight_scale).to(torch.float8_e4m3fn)) + self.register_buffer( + "input_scale", torch.tensor(1.0, device=self.weight.device, dtype=torch.float) + ) + self.register_buffer("weight_scale", weight_scale) + + def forward(self, x): + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear( + x, self.weight, self.bias, [self.input_scale], [self.weight_scale], [], [] + ) + + def generate_dynamic_shapes(max_batch_size, max_seq_len): dynamic_shapes = ( { diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py index 0d8c7a33936..94e236cd4e4 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_ep_sharding.py @@ -10,9 +10,13 @@ import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm -from tensorrt_llm._torch.auto_deploy.transform.library.sharding import EPShardingInfo from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import ( + EPShardingInfo, + FP8EPShardingInfo, + NVFP4EPShardingInfo, +) def _run_ep_shard_job(num_experts: int, rank: int, world_size: int) -> None: @@ -83,14 +87,7 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> # if world_size == 1, no sharding transformations should be detected if world_size > 1: for node in gm.graph.nodes: - if is_op( - node, - ( - torch.ops.auto_deploy.torch_moe, - torch.ops.auto_deploy.torch_quant_fp8_moe, - torch.ops.auto_deploy.torch_quant_fp4_moe, - ), - ): + if is_op(node, torch.ops.auto_deploy.torch_moe): expected_transformations.append( EPShardingInfo( target_node=node.name, @@ -98,6 +95,22 @@ def _run_pattern_detection_job(num_experts: int, rank: int, world_size: int) -> world_size=world_size, ) ) + elif is_op(node, torch.ops.auto_deploy.torch_quant_fp8_moe): + expected_transformations.append( + FP8EPShardingInfo( + target_node=node.name, + rank=rank, + world_size=world_size, + ) + ) + elif is_op(node, torch.ops.auto_deploy.torch_quant_nvfp4_moe): + expected_transformations.append( + NVFP4EPShardingInfo( + target_node=node.name, + rank=rank, + world_size=world_size, + ) + ) # get detected transformations optimizer = InferenceOptimizer( diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py index 802ec15b5bf..c4554bf89b0 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_tp_sharding.py @@ -9,6 +9,7 @@ import torch.nn.functional as F from _dist_test_utils import get_device_counts from _graph_test_helpers import run_sharding_pattern_detection_test, run_test_transformed_gm +from _model_test_utils import FakeFP8Linear import tensorrt_llm._torch.auto_deploy.distributed.common as dist_common from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm @@ -18,6 +19,7 @@ ) from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op +from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import FP8TPShardingInfo base_model_tp_plan = { "q_proj": "colwise", @@ -107,6 +109,19 @@ def forward(self, x): return self.linear2(y) +class FP8MLP(nn.Module): + def __init__(self, in_features, out_features, bias=False): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.linear1 = FakeFP8Linear(in_features, 4 * in_features, bias=bias) + self.linear2 = FakeFP8Linear(4 * in_features, out_features, bias=bias) + + def forward(self, x): + y = F.relu(self.linear1(x)) + return self.linear2(y) + + def _run_job( model_cls: nn.Module, dist_op_expected: str, @@ -130,6 +145,8 @@ def _run_job( hidden_size=num_features, num_key_value_heads=num_key_value_heads, ).to(device="cuda", dtype=torch.float16) + elif model_cls == FP8MLP: + model = model_cls(num_features, num_features, bias=bias).to("cuda") else: model = model_cls(num_features, num_features, bias=bias).to( device="cuda", dtype=torch.float16 @@ -243,7 +260,7 @@ def _run_pattern_detection_job( if model_cls == GQA_Block: min_local_shape = num_features // num_heads for node in gm.graph.nodes: - if is_linear_op(node, include_quantization=True): + if is_linear_op(node): # for Q, K, V layers, we expect: # dim = 0, add_dist = False # for O layer, we expect: @@ -266,7 +283,7 @@ def _run_pattern_detection_job( ) elif model_cls == MLP: for node in gm.graph.nodes: - if is_linear_op(node, include_quantization=True): + if is_linear_op(node): # linear1 should be sharded on dim=0, add_dist=False, min_local_shape=1 # linear2 should be sharded on dim=1, add_dist=True, min_local_shape=1 if "linear1" in node.args[1].name: @@ -288,7 +305,7 @@ def _run_pattern_detection_job( elif model_cls == nn.Linear: # expect simple shard only (dim=0, add_dist=True, min_local_shape=1) for node in gm.graph.nodes: - if is_linear_op(node, include_quantization=True): + if is_linear_op(node): expected_transformations.append( TPShardingInfo( target_node=node.name, @@ -299,6 +316,27 @@ def _run_pattern_detection_job( min_local_shape=1, ) ) + elif model_cls == FP8MLP: + for node in gm.graph.nodes: + if is_op(node, torch.ops.auto_deploy.torch_fake_quant_fp8_linear): + # linear1 should be sharded on dim=0, add_dist=False, min_local_shape=1 + # linear2 should be sharded on dim=1, add_dist=True, min_local_shape=1 + if "linear1" in node.args[1].name: + dim = SplitDimension.COLUMN + dist_op = None + else: + dim = SplitDimension.ROW + dist_op = "all_reduce" + expected_transformations.append( + FP8TPShardingInfo( + target_node=node.name, + split_dim=dim, + rank=rank, + world_size=world_size, + dist_op=dist_op, + min_local_shape=1, + ) + ) # get detected transformations optimizer = InferenceOptimizer( @@ -328,6 +366,7 @@ def _run_pattern_detection_job( "model_cls, dist_op_expected", ( (MLP, "torch_dist_all_reduce"), + (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), ), @@ -352,6 +391,7 @@ def test_sharding( "model_cls, dist_op_expected", ( (MLP, "torch_dist_all_reduce"), + (FP8MLP, "torch_dist_all_reduce"), (nn.Linear, "torch_dist_all_gather"), (GQA_Block, "torch_dist_all_reduce"), ), diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py index 2b8b16dcd73..da99e08bbe2 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_ad_moe_op.py @@ -248,7 +248,7 @@ def test_fp4_moe_op_run(dtype): # run FP4 MoE op with torch.inference_mode(): - output_torch_fp4_moe = torch.ops.auto_deploy.torch_quant_fp4_moe( + output_torch_fp4_moe = torch.ops.auto_deploy.torch_quant_nvfp4_moe( x, selected_experts, final_scales, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py index a7b48c83c81..440254d7caf 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/custom_ops/test_quant.py @@ -8,7 +8,7 @@ torch.manual_seed(0) -scaling_vector_size = 16 +SCALING_VECTOR_SIZE = 16 # NVFP4 block size along K @pytest.mark.parametrize("bias", [torch.rand(32).to("cuda") * 10, None]) @@ -47,10 +47,10 @@ def test_fp4_linear(): weight_scale_2 = fp4_global_scale(weight) weight_fp4, weight_scale = torch.ops.trtllm.fp4_quantize( - weight, weight_scale_2, scaling_vector_size, False + weight, weight_scale_2, SCALING_VECTOR_SIZE, False ) - output_fp4_gemm = torch.ops.auto_deploy.torch_quant_fp4_linear( + output_fp4_gemm = torch.ops.auto_deploy.torch_quant_nvfp4_linear( input, weight_fp4, bias=None, @@ -105,3 +105,92 @@ def test_fp8_bmm(input_dtype, mat2_dtype): ) assert cos_sim > 0.99 assert cos_sim_unquantized > 0.99 + + +@pytest.mark.parametrize("bias", [torch.rand(32, device="cuda") * 10, None]) +@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support") +def test_quant_linear_fp8_matches_fused_op(bias): + input = torch.rand(3, 16, device="cuda") + weight = torch.rand(32, 16, device="cuda") + + weight_scale = (torch.max(torch.abs(weight)) / 448).to("cuda") + weight_fp8 = (weight / weight_scale).to(torch.float8_e4m3fn) + + out_fused = torch.ops.auto_deploy.torch_quant_fp8_linear( + input, + weight_fp8, + bias=bias, + input_scale=torch.tensor(1.0, device="cuda"), + weight_scale=weight_scale, + ) + + out_unified = torch.ops.auto_deploy.torch_fake_quant_fp8_linear( + input, + weight_fp8, + bias, + [torch.tensor(1.0, device="cuda")], + [weight_scale], + [], + [], + ) + + assert out_unified.shape == out_fused.shape + torch.testing.assert_close(out_unified, out_fused, rtol=5e-4, atol=5e-4) + + +@pytest.mark.parametrize( + "bias", + [ + (torch.rand(32, device="cuda") * 10).to(torch.float16), + None, + ], +) +@pytest.mark.skipif( + not (fp4_compatible() and trtllm_ops_available()), + reason="Requires NVFP4 and TRT-LLM ops", +) +def test_quant_linear_nvfp4_matches_fused_op(bias): + x = torch.rand(3, 32, device="cuda", dtype=torch.half) # [..., K] + W = torch.rand(32, 32, device="cuda", dtype=torch.half) # [N, K] + N, K = W.shape + assert K % SCALING_VECTOR_SIZE == 0 + + # Per-tensor scale-2 (amax / (448 * 6)) + s_in2 = fp4_global_scale(x).to(torch.float32) # input per-tensor scale + s_w2 = fp4_global_scale(W).to(torch.float32) # weight per-tensor scale + + weight_fp4, weight_scale_cutlass = torch.ops.trtllm.fp4_quantize( + W, s_w2, SCALING_VECTOR_SIZE, False + ) + assert weight_fp4.dtype == torch.uint8 + assert weight_scale_cutlass.dtype == torch.uint8 + + # Fused op (expects CUTLASS uint8 scale + kernel alpha = 1/(s_in2*s_w2)) + alpha_fused = (1.0 / (s_in2 * s_w2)).to(torch.float32) + if bias is not None and bias.dtype != x.dtype: + bias = bias.to(x.dtype) + + out_fused = torch.ops.auto_deploy.torch_quant_nvfp4_linear( + x, + weight_fp4, + bias=bias, + input_scale=s_in2, + weight_scale=weight_scale_cutlass, + alpha=alpha_fused, + ) + + out_unified = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear( + x, + weight_fp4, + bias, + [s_in2], # input_scale list + [ + weight_scale_cutlass, + alpha_fused, + ], # weight_scale list: [per-block vector, combined alpha] + [], # input_zp + [], # weight_zp + ) + + assert out_unified.shape == out_fused.shape + torch.testing.assert_close(out_unified, out_fused, rtol=1e-3, atol=5e-3) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py index b99862fdc1d..f48f7dc2b10 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_gemm_fusion.py @@ -8,12 +8,12 @@ import torch.nn as nn import torch.nn.functional as F from _graph_test_helpers import count_buffers, run_test_transformed_gm +from _model_test_utils import FakeFP8Linear from _torch_test_utils import all_close, fp8_compatible, reset_parameters -from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8Linear from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer -from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_linear_op, is_op torch.manual_seed(0) @@ -74,7 +74,7 @@ def __init__(self, **kwargs): class FusableModel1_M_FP8(FusableModel1_M): def __init__(self, **kwargs): - super().__init__(**{"cls": FP8Linear, **kwargs}) + super().__init__(**{"cls": FakeFP8Linear, **kwargs}) class FusableModel1_L(FusableModel1): @@ -89,7 +89,7 @@ def __init__(self, **kwargs): class FusableModel1_XL_FP8(FusableModel1_XL): def __init__(self, **kwargs): - super().__init__(**{"cls": FP8Linear, **kwargs}) + super().__init__(**{"cls": FakeFP8Linear, **kwargs}) class FusableModel2(FusableModel): @@ -112,7 +112,7 @@ def num_gemms_after_fusion(self) -> int: class FusableModel2_FP8(FusableModel2): def __init__(self, **kwargs): - super().__init__(**{"cls": FP8Linear, **kwargs}) + super().__init__(**{"cls": FakeFP8Linear, **kwargs}) class FusableModel3(FusableModel): @@ -137,7 +137,7 @@ def num_gemms_after_fusion(self) -> int: class FusableModel3_FP8(FusableModel3): def __init__(self, **kwargs): - super().__init__(**{"cls": FP8Linear, **kwargs}) + super().__init__(**{"cls": FakeFP8Linear, **kwargs}) class FusableModel4(FusableModel): @@ -168,7 +168,7 @@ def num_gemms_after_fusion(self) -> int: class FusableModel4_FP8(FusableModel4): def __init__(self, **kwargs): - super().__init__(**{"cls": FP8Linear, **kwargs}) + super().__init__(**{"cls": FakeFP8Linear, **kwargs}) # TODO: consider adding test cases for classic GQA and MLP layers @@ -262,6 +262,9 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str): "fuse_gemms": { "stage": "post_load_fusion", }, + "fuse_fp8_gemms": { + "stage": "post_load_fusion", + }, }, )(None, gm) @@ -269,7 +272,10 @@ def test_fusion(get_model: Callable[[], TestModel], dtype: str): model, x, gm_transformed, - lambda gm: sum(is_linear_op(n, include_quantization=True) for n in gm.graph.nodes) + lambda gm: sum( + (is_linear_op(n) or is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear)) + for n in gm.graph.nodes + ) == model.num_gemms_after_fusion, lambda num_p_og: num_p_og, # unchanged since fusing doesn't change param count atol=tol, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py index dba864f4f1b..2d5cff42706 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_moe_fusion.py @@ -133,7 +133,7 @@ def __init__(self, ffn_dim, hidden_dim, input_sample, dtype=torch.bfloat16, devi def forward(self, hidden_states): x = hidden_states - w1_out = torch.ops.auto_deploy.torch_quant_fp4_linear( + w1_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( x, self.w1_fp4, bias=None, @@ -141,7 +141,7 @@ def forward(self, hidden_states): weight_scale=self.w1_weight_scale, alpha=self.w1_alpha, ) - w3_out = torch.ops.auto_deploy.torch_quant_fp4_linear( + w3_out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( x, self.w3_fp4, bias=None, @@ -150,7 +150,7 @@ def forward(self, hidden_states): alpha=self.w3_alpha, ) fused = self.act_fn(w1_out) * w3_out - out = torch.ops.auto_deploy.torch_quant_fp4_linear( + out = torch.ops.auto_deploy.torch_quant_nvfp4_linear( fused, self.w2_fp4, bias=None, @@ -274,7 +274,7 @@ def get_input(self, device): ), pytest.param( "NVFP4", - torch.ops.auto_deploy.torch_quant_fp4_moe, + torch.ops.auto_deploy.torch_quant_nvfp4_moe, 0.05, 0.01, marks=[ @@ -308,6 +308,12 @@ def test_moe_matching(quant_type, expected_op, atol, rtol): "match_moe_pattern": { "stage": "pattern_matcher", }, + "match_fp8_moe_pattern": { + "stage": "pattern_matcher", + }, + "match_nvfp4_moe_pattern": { + "stage": "pattern_matcher", + }, }, )(None, gm) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py new file mode 100644 index 00000000000..075373706e6 --- /dev/null +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_fusion.py @@ -0,0 +1,180 @@ +# test_quant_fusion.py +import pytest +import torch +import torch.nn as nn +from _graph_test_helpers import run_test_transformed_gm +from _torch_test_utils import fp4_compatible, fp8_compatible, trtllm_ops_available + +import tensorrt_llm._torch.auto_deploy.custom_ops # noqa: F401 +from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm +from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer +from tensorrt_llm._torch.auto_deploy.utils.node_utils import is_op +from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp4_global_scale, fp8_scale + + +def _has_fused_linear_fp8(gm): + found_fused = any( + is_op(n, torch.ops.auto_deploy.torch_quant_fp8_linear) for n in gm.graph.nodes + ) + found_ref = any( + is_op(n, torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default) for n in gm.graph.nodes + ) + return found_fused and not found_ref + + +def _has_fused_linear_fp4(gm): + found_fused = any( + is_op(n, torch.ops.auto_deploy.torch_quant_nvfp4_linear) for n in gm.graph.nodes + ) + found_ref = any( + is_op(n, torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear) for n in gm.graph.nodes + ) + return found_fused and not found_ref + + +class TinyFP8Ref(nn.Module): + """ + A tiny module whose forward uses the reference FP8 op: + torch_fake_quant_fp8_linear(input, weight_fp8, bias, [in_s], [w_s], [], []) + """ + + def __init__(self, in_features=16, out_features=32, use_bias=True): + super().__init__() + self.use_bias = use_bias + self.weight = nn.Parameter(torch.rand(out_features, in_features, dtype=torch.float16)) + if use_bias: + self.bias = nn.Parameter(torch.rand(out_features, dtype=torch.float16)) + else: + self.register_parameter("bias", None) + + # Precompute FP8 packing + scales as buffers + with torch.no_grad(): + w_s = fp8_scale(self.weight) # per-tensor scale + w_fp8 = (self.weight / w_s).to(torch.float8_e4m3fn) + + self.register_buffer("weight_fp8", w_fp8) + self.register_buffer("weight_scale", w_s) + self.register_buffer( + "input_scale", torch.tensor(1.0, dtype=torch.float32) + ) # simple test scale + + def forward(self, x): + bias = self.bias if self.use_bias else None + return torch.ops.auto_deploy.torch_fake_quant_fp8_linear.default( + x, + self.weight_fp8, + bias, + [self.input_scale], + [self.weight_scale], + [], + [], + ) + + +class TinyFP4Ref(nn.Module): + """ + A tiny module whose forward uses the reference NVFP4 op: + torch_fake_quant_nvfp4_linear(x, w_fp4, bias, [s_in2], [cutlass_vec, alpha], [], []) + """ + + def __init__(self, in_features=64, out_features=32, use_bias=True): + super().__init__() + assert in_features % 16 == 0, "NVFP4 requires K % 16 == 0 for CUTLASS scaling." + device = torch.device("cuda") + + self.use_bias = use_bias + self.weight = nn.Parameter( + torch.rand(out_features, in_features, dtype=torch.half, device=device) + ) + if use_bias: + self.bias = nn.Parameter(torch.rand(out_features, dtype=torch.half, device=device)) + else: + self.register_parameter("bias", None) + + with torch.no_grad(): + s_in2 = fp4_global_scale(torch.rand(1, in_features, dtype=torch.half, device=device)) + s_w2 = fp4_global_scale(self.weight) + w_fp4, cutlass_vec = torch.ops.trtllm.fp4_quantize(self.weight, s_w2, 16, False) + alpha = (1.0 / (s_in2 * s_w2)).to(torch.float32) + + self.register_buffer("weight_fp4", w_fp4) # uint8 packed + self.register_buffer("input_scale_2", s_in2.to(torch.float32)) + self.register_buffer("weight_scale_cutlass", cutlass_vec) # uint8 vec + self.register_buffer("alpha", alpha.to(torch.float32)) + + def forward(self, x): + bias = self.bias if self.use_bias else None + return torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear( + x, + self.weight_fp4, + bias, + [self.input_scale_2], + [self.weight_scale_cutlass, self.alpha], + [], + [], + ) + + +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.skipif(not fp8_compatible(), reason="Requires fp8 support") +def test_fuse_quant_rewrites_fp8_linear(use_bias): + torch.manual_seed(0) + model = TinyFP8Ref(use_bias=use_bias).to("cuda") + x = torch.rand(3, 16, dtype=torch.float16, device="cuda") + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_fp8_linear": {"stage": "post_load_fusion", "backend": "torch"}, + }, + )(None, gm) + gm_transformed.to("cuda") + + run_test_transformed_gm( + model, + x, + gm_transformed, + _has_fused_linear_fp8, + lambda n: n, + 0.1, # atol + 0.05, # rtol + False, # test_load_hook + False, # strict_loading + None, # dynamic_shapes + False, # skip_output_assert + ) + + +@pytest.mark.parametrize("use_bias", [True, False]) +@pytest.mark.skipif( + not (fp4_compatible() and trtllm_ops_available()), + reason="Requires NVFP4 and TRT-LLM ops", +) +def test_fuse_quant_rewrites_fp4_linear(use_bias): + torch.manual_seed(0) + model = TinyFP4Ref(use_bias=use_bias).to("cuda") + x = torch.rand(3, 64, dtype=torch.float16, device="cuda") + + gm = torch_export_to_gm(model, args=(x,), clone=True) + gm_transformed = InferenceOptimizer( + None, + { + "fuse_nvfp4_linear": {"stage": "post_load_fusion", "backend": "trtllm"}, + }, + )(None, gm) + gm_transformed.to("cuda") + + run_test_transformed_gm( + model, + x, + gm_transformed, + _has_fused_linear_fp4, + lambda n: n, + 0.1, # atol + 0.05, # rtol + False, # test_load_hook + False, # strict_loading + None, # dynamic_shapes + False, # skip_output_assert + ) diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py index 0327f01329d..e324323839f 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quant_moe.py @@ -19,7 +19,7 @@ ), pytest.param( "NVFP4", - torch.ops.auto_deploy.torch_quant_fp4_moe, + torch.ops.auto_deploy.torch_quant_nvfp4_moe, marks=pytest.mark.skipif( not (fp4_compatible() and trtllm_ops_available()), reason="Requires FP4 + TRTLLM" ), @@ -67,7 +67,10 @@ def _expected_num_params(n): gm_transformed = InferenceOptimizer( FakeFactory(quant_config=quant_config), { - "quantize_moe": { + "quantize_fp8_moe": { + "stage": "pattern_matcher", + }, + "quantize_nvfp4_moe": { "stage": "pattern_matcher", }, }, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py index 341edae905d..f8677d70adf 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/transformations/library/test_quantization.py @@ -8,7 +8,6 @@ from _model_test_utils import MLP, BMMDynamicModel, BMMModel from _torch_test_utils import fp4_compatible, fp8_compatible -from tensorrt_llm._torch.auto_deploy.custom_ops.quant import QUANT_OPS from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory from tensorrt_llm._torch.auto_deploy.transform.optimizer import InferenceOptimizer @@ -16,10 +15,6 @@ from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import fp8_scale -def check_quantized(gm): - return any(is_op(n, QUANT_OPS) for n in gm.graph.nodes) - - class DummyFactory(ModelFactory): """Dummy factory to pass quant_config for testing.""" @@ -68,12 +63,18 @@ def test_quantization(quant_config, atol, rtol, num_p_og): model.linear2.register_buffer( "input_scale", torch.tensor([1.0], device=model.linear2.weight.device) ) + QUANT_OP = torch.ops.auto_deploy.torch_fake_quant_nvfp4_linear + elif quant_config.get("quant_algo") == "FP8": + QUANT_OP = torch.ops.auto_deploy.torch_fake_quant_fp8_linear # set up sequence+cache objects gm = torch_export_to_gm(model, args=(x,), clone=True) gm_transformed = InferenceOptimizer( DummyFactory(quant_config), { - "quantize_from_config": { + "quantize_fp8_linear_from_config": { + "stage": "pattern_matcher", + }, + "quantize_nvfp4_linear_from_config": { "stage": "pattern_matcher", }, }, @@ -84,7 +85,7 @@ def test_quantization(quant_config, atol, rtol, num_p_og): model, x, gm_transformed, - check_quantized, + lambda gm: any(is_op(n, QUANT_OP) for n in gm.graph.nodes), num_p_og, atol, rtol, @@ -155,18 +156,19 @@ def test_bmm_quantization(quant_config, atol, rtol, num_p_og, model_class): gm_transformed = InferenceOptimizer( DummyFactory(quant_config), { - "quantize_from_config": { + "quantize_fp8_bmm_from_config": { "stage": "pattern_matcher", }, }, )(None, gm) gm_transformed.to("cuda") + QUANT_OP = torch.ops.auto_deploy.torch_quant_fp8_bmm run_test_transformed_gm( model, x, gm_transformed, - check_quantized, + lambda gm: any(is_op(n, QUANT_OP) for n in gm.graph.nodes), num_p_og, atol, rtol, diff --git a/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py b/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py index 2dd7ace087a..005e893af08 100644 --- a/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py +++ b/tests/unittest/_torch/auto_deploy/unit/singlegpu/utils/test_quantization_utils.py @@ -2,12 +2,15 @@ import torch from tensorrt_llm._torch.auto_deploy.custom_ops.quant import FP8_MAX +from tensorrt_llm._torch.auto_deploy.transform.interface import TransformConfig +from tensorrt_llm._torch.auto_deploy.transform.library.quantization import ( + FP8LinearQuantizationFromConfig, +) from tensorrt_llm._torch.auto_deploy.utils.quantization_utils import ( - FP8QuantizationImpl, - _shard_fp4_weight_scale, fp4_global_scale, modelopt_fp4_scale_to_cutlass_fp4_scale, ) +from tensorrt_llm._torch.auto_deploy.utils.sharding_utils import _shard_fp4_weight_scale @pytest.mark.parametrize("dim", [0, 1]) @@ -48,7 +51,8 @@ def test_fp4_global_scale(): @pytest.mark.parametrize("amax, expected_scale", [(FP8_MAX, 1.0), (FP8_MAX / 2.0, 0.5)]) def test_fp8_convert_amax_hook(amax, expected_scale): - fp8_imp = FP8QuantizationImpl() + config = TransformConfig(stage="pattern_matcher") + fp8_imp = FP8LinearQuantizationFromConfig(config) mock_state_dict = {"amax": amax}