| 
5 | 5 | # LICENSE file in the root directory of this source tree.  | 
6 | 6 | 
 
  | 
7 | 7 | import types  | 
8 |  | -from dataclasses import dataclass  | 
 | 8 | +from dataclasses import dataclass, field  | 
9 | 9 | from typing import Optional  | 
10 | 10 | 
 
  | 
11 | 11 | import torch  | 
12 | 12 | 
 
  | 
13 |  | -import torchao  | 
14 | 13 | from torchao.core.config import AOBaseConfig  | 
15 | 14 | from torchao.prototype.mx_formats import (  | 
16 | 15 |     MXGemmKernelChoice,  | 
 | 
20 | 19 |     _validate_gemm_kernel_choice,  | 
21 | 20 | )  | 
22 | 21 | from torchao.prototype.mx_formats.mx_tensor import MXTensor  | 
 | 22 | +from torchao.prototype.mx_formats.nvfp4_tensor import NVFP4MMConfig, NVFP4Tensor  | 
23 | 23 | from torchao.quantization.quant_api import to_linear_activation_quantized  | 
24 | 24 | from torchao.quantization.transform_module import (  | 
25 | 25 |     register_quantize_module_handler,  | 
26 | 26 | )  | 
27 |  | -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5, is_sm_at_least_100  | 
 | 27 | +from torchao.utils import (  | 
 | 28 | +    TORCH_VERSION_AT_LEAST_2_5,  | 
 | 29 | +    TORCH_VERSION_AT_LEAST_2_8,  | 
 | 30 | +    is_sm_at_least_100,  | 
 | 31 | +)  | 
28 | 32 | 
 
  | 
29 | 33 | 
 
  | 
30 | 34 | # Note: This API is extra prototype and will change in the future  | 
@@ -63,16 +67,13 @@ class MXFPInferenceConfig(AOBaseConfig):  | 
63 | 67 | 
 
  | 
64 | 68 |     block_size: int = 32  | 
65 | 69 | 
 
  | 
66 |  | -    # Dtypes for Input and Weights  | 
 | 70 | +    # Dtypes for Input and Weights, supports Fp8 and Fp4 formats  | 
67 | 71 |     activation_dtype: torch.dtype = torch.float8_e4m3fn  | 
68 | 72 |     weight_dtype: torch.dtype = torch.float8_e4m3fn  | 
69 | 73 | 
 
  | 
70 | 74 |     # Which kernel to run for mm  | 
71 | 75 |     gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS  | 
72 | 76 | 
 
  | 
73 |  | -    # Set some magic perf settings  | 
74 |  | -    set_inductor_config: bool = False  | 
75 |  | - | 
76 | 77 |     def __post_init__(self):  | 
77 | 78 |         assert self.activation_dtype == self.weight_dtype, (  | 
78 | 79 |             "For now - we only support matching input/weight dtypes."  | 
@@ -115,8 +116,6 @@ def _mx_inference_linear_transform(  | 
115 | 116 |     # TODO Sm120 has slightly more restrictive reqs  | 
116 | 117 |     # TODO handle AMD  | 
117 | 118 |     assert is_sm_at_least_100(), "MXFP is only supported on sm100 machiens for now"  | 
118 |  | -    if config.set_inductor_config:  | 
119 |  | -        torchao.quantization.utils.recommended_inductor_config_setter()  | 
120 | 119 | 
 
  | 
121 | 120 |     activation_dtype = config.activation_dtype  | 
122 | 121 |     weight_dtype = config.weight_dtype  | 
@@ -151,7 +150,90 @@ def _mx_inference_linear_transform(  | 
151 | 150 |     return module  | 
152 | 151 | 
 
  | 
153 | 152 | 
 
  | 
 | 153 | +def _get_nvfp4_dtype():  | 
 | 154 | +    """Factory function for NVFP4 dtype defaults."""  | 
 | 155 | +    if not TORCH_VERSION_AT_LEAST_2_8:  | 
 | 156 | +        raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")  | 
 | 157 | +    return torch.float4_e2m1fn_x2  | 
 | 158 | + | 
 | 159 | + | 
 | 160 | +@dataclass  | 
 | 161 | +class NVFP4InferenceConfig(AOBaseConfig):  | 
 | 162 | +    """  | 
 | 163 | +    NVIDIA FP4 (NVFP4) Inference Quantization Configuration  | 
 | 164 | +
  | 
 | 165 | +    This is a specialized configuration for NVIDIA's FP4 format with UE4M3 scales.  | 
 | 166 | +    It provides defaults optimized for NVFP4:  | 
 | 167 | +    - Data: float4_e2m1fn_x2  | 
 | 168 | +    - Scales: float8_e4m3fn (UE4M3)  | 
 | 169 | +    - Block size: 16 (required for NVFP4)  | 
 | 170 | +    - CUBLAS kernel (optimized for VEC16_UE4M3)  | 
 | 171 | +    """  | 
 | 172 | + | 
 | 173 | +    block_size: int = 16  # NVFP4 requires block size 16  | 
 | 174 | + | 
 | 175 | +    # NVFP4 uses FP4 data  | 
 | 176 | +    activation_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)  | 
 | 177 | +    weight_dtype: torch.dtype = field(default_factory=_get_nvfp4_dtype)  | 
 | 178 | + | 
 | 179 | +    # NVFP4 uses E4M3 scales  | 
 | 180 | +    scale_dtype: torch.dtype = torch.float8_e4m3fn  | 
 | 181 | + | 
 | 182 | +    # CUBLAS is preferred for NVFP4 with VEC16_UE4M3 support  | 
 | 183 | +    gemm_kernel_choice: MXGemmKernelChoice = MXGemmKernelChoice.CUBLAS  | 
 | 184 | + | 
 | 185 | +    # Matrix multiplication configuration  | 
 | 186 | +    mm_config: NVFP4MMConfig = NVFP4MMConfig.DYNAMIC  | 
 | 187 | + | 
 | 188 | +    def __post_init__(self):  | 
 | 189 | +        # Validate NVFP4 constraints  | 
 | 190 | +        if not TORCH_VERSION_AT_LEAST_2_8:  | 
 | 191 | +            raise RuntimeError("NVFP4InferenceConfig requires PyTorch 2.8 or later")  | 
 | 192 | + | 
 | 193 | +        assert self.activation_dtype == torch.float4_e2m1fn_x2, (  | 
 | 194 | +            f"NVFP4 requires activation_dtype=float4_e2m1fn_x2, got {self.activation_dtype}"  | 
 | 195 | +        )  | 
 | 196 | +        assert self.weight_dtype == torch.float4_e2m1fn_x2, (  | 
 | 197 | +            f"NVFP4 requires weight_dtype=float4_e2m1fn_x2, got {self.weight_dtype}"  | 
 | 198 | +        )  | 
 | 199 | +        assert self.scale_dtype == torch.float8_e4m3fn, (  | 
 | 200 | +            f"NVFP4 requires scale_dtype=float8_e4m3fn, got {self.scale_dtype}"  | 
 | 201 | +        )  | 
 | 202 | +        assert self.block_size == 16, (  | 
 | 203 | +            f"NVFP4 requires block_size=16, got {self.block_size}"  | 
 | 204 | +        )  | 
 | 205 | + | 
 | 206 | + | 
 | 207 | +@register_quantize_module_handler(NVFP4InferenceConfig)  | 
 | 208 | +def _nvfp4_inference_linear_transform(  | 
 | 209 | +    module: torch.nn.Module, config: NVFP4InferenceConfig  | 
 | 210 | +):  | 
 | 211 | +    """Quantization handler for NVFP4InferenceConfig"""  | 
 | 212 | +    assert is_sm_at_least_100(), "NVFP4 is only supported on sm100+ machines"  | 
 | 213 | + | 
 | 214 | +    weight = module.weight  | 
 | 215 | +    assert weight.dtype == torch.bfloat16, (  | 
 | 216 | +        f"Only supporting bf16 out dtype for now, got {weight.dtype}"  | 
 | 217 | +    )  | 
 | 218 | + | 
 | 219 | +    quantized_weight = NVFP4Tensor.to_nvfp4(  | 
 | 220 | +        weight,  | 
 | 221 | +        block_size=config.block_size,  | 
 | 222 | +        mm_config=config.mm_config,  | 
 | 223 | +    )  | 
 | 224 | + | 
 | 225 | +    module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False)  | 
 | 226 | +    module.extra_repr = types.MethodType(_linear_extra_repr, module)  | 
 | 227 | +    return module  | 
 | 228 | + | 
 | 229 | + | 
154 | 230 | if TORCH_VERSION_AT_LEAST_2_5:  | 
155 | 231 |     torch.serialization.add_safe_globals(  | 
156 |  | -        [MXTensor, MXGemmKernelChoice, _input_activation_quant_func_mxfp]  | 
 | 232 | +        [  | 
 | 233 | +            MXTensor,  | 
 | 234 | +            NVFP4Tensor,  | 
 | 235 | +            NVFP4MMConfig,  | 
 | 236 | +            MXGemmKernelChoice,  | 
 | 237 | +            _input_activation_quant_func_mxfp,  | 
 | 238 | +        ]  | 
157 | 239 |     )  | 
0 commit comments