From 6bfcb9d177d9c0d57ed9342a742b350ae2d4491d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 19 Sep 2025 09:10:05 -0700 Subject: [PATCH 01/47] [wip] enable quantize_ for 3d weights Summary: This PR adds in a simple 2d and 3d moe implementation and tests `quantize_` on them to see if we get the same results. Test Plan: ``` pytest test/prototype/test_parameter.py -k test_quantize_parameter ``` Reviewers: Subscribers: Tasks: Tags: --- test/prototype/test_parameter.py | 169 ++++++++++++++++++ .../workflows/float8/float8_tensor.py | 23 +++ 2 files changed, 192 insertions(+) create mode 100644 test/prototype/test_parameter.py diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py new file mode 100644 index 0000000000..d34e4bc111 --- /dev/null +++ b/test/prototype/test_parameter.py @@ -0,0 +1,169 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +import copy +import logging +import unittest + +import torch +from torch import nn +from torch.testing._internal import common_utils + +from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout +from torchao.quantization import ( + Float8DynamicActivationFloat8SemiSparseWeightConfig, + Float8DynamicActivationFloat8WeightConfig, +) +from torchao.quantization.quant_api import ( + Int4WeightOnlyConfig, + Int8DynamicActivationInt8WeightConfig, + PerRow, + PerTensor, + quantize_, +) +from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ +from torchao.utils import is_sm_at_least_90 +import torch.nn.functional as F + +logging.basicConfig( + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO +) + +class TestMoE2d(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + self.experts = nn.ModuleList([nn.Linear(input_size, output_size, bias=None) for _ in range(self.num_experts)]) + + def forward(self, inputs, expert_size): + """ + Forward pass of the JetMoeParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + # return True + input_list = inputs.split(expert_size, dim=0) + output_list = [] + + assert len(input_list) == len(self.experts) + for expert, expert_input in zip(self.experts, input_list): + output_list.append(expert(expert_input)) + results = torch.cat(output_list, dim=0) + return results + + +class TestMoE3d(nn.Module): + def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: + """ + This implementation is taken from: + https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/jetmoe/modeling_jetmoe.py#L141 + + Args: + num_experts (int): + Number of experts. + input_size (int): + Size of the input. + output_size (int): + Size of the output. + """ + super().__init__() + self.weight = nn.Parameter(torch.randn(num_experts, output_size, input_size)) + self.num_experts = num_experts + self.input_size = input_size + self.output_size = output_size + + def forward(self, inputs, expert_size): + """ + Forward pass of the JetMoeParallelExperts module. + + Args: + inputs (Tensor): + Input tensor. + expert_size: + Expert size information. + + Returns: + Tensor: Output tensor. + """ + # return True + input_list = inputs.split(expert_size, dim=0) + output_list = [] + for i in range(self.num_experts): + output_list.append(F.linear(input_list[i], self.weight[i])) + results = torch.cat(output_list, dim=0) + return results + +def print_model_fqn(model): + print("=== Parameters ===") + for name, param in model.named_parameters(): + print(f"{name}: {param.shape}") + + print("\n=== Modules ===") + for name, module in model.named_modules(): + if name: # Skip empty name for root module + print(f"{name}: {type(module).__name__}") + + +class TestQuantizeParameterMoE(common_utils.TestCase): + + def test_2d_3d_moe_equivalent(self): + test_input = torch.randn(1024, 1024).cuda() + + model_2d = TestMoE2d(2, 1024, 1024).cuda() + model_3d = TestMoE3d(2, 1024, 1024).cuda() + + for i, expert in enumerate(model_2d.experts): + model_3d.weight.data[i] = expert.weight.detach().clone() + + output_2d = model_2d(test_input, 512) + output_3d = model_3d(test_input, 512) + + torch.testing.assert_close(output_2d, output_3d, rtol=1e-3, atol=1e-3) + + + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantize_parameter(self): + test_input = torch.randn(1024, 1024).cuda().bfloat16() + + model_2d = TestMoE2d(2, 1024, 1024).cuda().bfloat16() + model_3d = TestMoE3d(2, 1024, 1024).cuda().bfloat16() + + for i, expert in enumerate(model_2d.experts): + model_3d.weight.data[i] = expert.weight.detach().clone() + + # quantize all linears in 2d + quantize_( + model_2d, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + ) + + quantize_( + model_3d, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + filter_fn=lambda mod, fqn: fqn is '' # top level module has no fqn + ) + + output_3d = model_3d(test_input, 512) + output_2d = model_2d(test_input, 512) + + torch.testing.assert_close(output_2d, output_3d, rtol=1e-3, atol=1e-3) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 49c8b1cd24..cf6543980c 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -617,6 +617,29 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) +@implements(aten.select.int) +def _(func, types, args, kwargs): + self, dim, index = args + qdata = self.qdata.select(dim, index) + # TODO we need to handle this case differently based on the scaling config + + scale = self.scale.select(dim, index).t() + """ + Without the transpose here, I run into the following runtime error: + E RuntimeError: Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. For RowWise scaling, scale_a should be (512, 1) and scale_b should be (1, 1024). Got scale_a.size()=(512, 1) and scale_b.size()=(1024, 1) + """ + + new = self.__class__( + qdata, + scale, + self.block_size, + self.mm_config, + self.act_quant_kwargs, + self.kernel_preference, + ) + return return_and_correct_aliasing(func, args, kwargs, new) + + Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` From 95867eec0fe98ca248f2588d3b924657c488619e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 22 Sep 2025 11:11:19 -0700 Subject: [PATCH 02/47] update for user-defined parameter names --- test/prototype/test_parameter.py | 51 ++++++++++++++++++++++++++++--- torchao/quantization/quant_api.py | 13 +++++--- 2 files changed, 55 insertions(+), 9 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index d34e4bc111..8716ce8f6f 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -27,6 +27,16 @@ from torchao.utils import is_sm_at_least_90 import torch.nn.functional as F +import re +import unittest +import warnings +import torch +from torch.testing._internal.common_utils import TestCase, run_tests +from torchao.utils import is_fbcode, is_sm_at_least_90 + +if not is_fbcode(): + from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) @@ -87,7 +97,7 @@ def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: Size of the output. """ super().__init__() - self.weight = nn.Parameter(torch.randn(num_experts, output_size, input_size)) + self.moe_weight = nn.Parameter(torch.randn(num_experts, output_size, input_size)) self.num_experts = num_experts self.input_size = input_size self.output_size = output_size @@ -109,7 +119,7 @@ def forward(self, inputs, expert_size): input_list = inputs.split(expert_size, dim=0) output_list = [] for i in range(self.num_experts): - output_list.append(F.linear(input_list[i], self.weight[i])) + output_list.append(F.linear(input_list[i], self.moe_weight[i])) results = torch.cat(output_list, dim=0) return results @@ -133,7 +143,7 @@ def test_2d_3d_moe_equivalent(self): model_3d = TestMoE3d(2, 1024, 1024).cuda() for i, expert in enumerate(model_2d.experts): - model_3d.weight.data[i] = expert.weight.detach().clone() + model_3d.moe_weight.data[i] = expert.weight.detach().clone() output_2d = model_2d(test_input, 512) output_3d = model_3d(test_input, 512) @@ -149,7 +159,7 @@ def test_quantize_parameter(self): model_3d = TestMoE3d(2, 1024, 1024).cuda().bfloat16() for i, expert in enumerate(model_2d.experts): - model_3d.weight.data[i] = expert.weight.detach().clone() + model_3d.moe_weight.data[i] = expert.weight.detach().clone() # quantize all linears in 2d quantize_( @@ -159,7 +169,7 @@ def test_quantize_parameter(self): quantize_( model_3d, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), param_name="moe_weight"), filter_fn=lambda mod, fqn: fqn is '' # top level module has no fqn ) @@ -167,3 +177,34 @@ def test_quantize_parameter(self): output_2d = model_2d(test_input, 512) torch.testing.assert_close(output_2d, output_3d, rtol=1e-3, atol=1e-3) + +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") +@unittest.skipIf( + is_fbcode(), + "Skipping the test in fbcode for now, not sure how to download from transformers", +) +class TestTorchAOCheckpoint(TestCase): + + def test_comprehensive_checkpoint_loading(self): + from transformers import AutoTokenizer, Llama4ForConditionalGeneration + import torch + + model_id = "RedHatAI/Llama-4-Scout-17B-16E-Instruct" + + tokenizer = AutoTokenizer.from_pretrained(model_id) + + messages = [ + {"role": "user", "content": "Who are you?"}, + ] + inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) + + model = Llama4ForConditionalGeneration.from_pretrained( + model_id, + device_map="auto", + dtype=torch.bfloat16 + ) + + outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) + outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) + print(outputs[0]) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 248b804790..bc582af509 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1742,7 +1742,6 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) """ - activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None @@ -1752,6 +1751,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 + param_name: Optional[str] = None def __post_init__(self): torch._C._log_api_usage_once( @@ -1851,14 +1851,19 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( + + param_name = getattr(config, "param_name") + if param_name is None: + param_name = "weight" + + assert hasattr(module, param_name), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - module.weight, config + getattr(module, param_name), config ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) + setattr(module, param_name, torch.nn.Parameter(quantized_weight, requires_grad=False)) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module From d2ebe99325a156f2858b26d272214a399a3cf0c2 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 23 Sep 2025 11:32:20 -0700 Subject: [PATCH 03/47] update --- test/prototype/test_parameter.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 8716ce8f6f..27810063ab 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -187,24 +187,13 @@ def test_quantize_parameter(self): class TestTorchAOCheckpoint(TestCase): def test_comprehensive_checkpoint_loading(self): - from transformers import AutoTokenizer, Llama4ForConditionalGeneration - import torch + from transformers import AutoConfig, AutoModel + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - model_id = "RedHatAI/Llama-4-Scout-17B-16E-Instruct" - - tokenizer = AutoTokenizer.from_pretrained(model_id) - - messages = [ - {"role": "user", "content": "Who are you?"}, - ] - inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True) - - model = Llama4ForConditionalGeneration.from_pretrained( - model_id, - device_map="auto", - dtype=torch.bfloat16 - ) - - outputs = model.generate(**inputs.to(model.device), max_new_tokens=100) - outputs = tokenizer.batch_decode(outputs[:, inputs["input_ids"].shape[-1]:]) - print(outputs[0]) + config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct") + model = Llama4TextMoe(config.text_config).to(torch.bfloat16).cuda() + quantize_( + model, + Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + print(model) + print("DONE") From b4e280932dc8f799c04dde029041dcf028bbaa01 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 11:00:31 -0700 Subject: [PATCH 04/47] added ParamFQN config --- test/prototype/test_parameter.py | 53 +++++++++++++++-- torchao/quantization/quant_api.py | 58 ++++++++++++++++--- .../workflows/float8/float8_tensor.py | 5 +- 3 files changed, 101 insertions(+), 15 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 27810063ab..da856b5a13 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -23,6 +23,7 @@ PerTensor, quantize_, ) +from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ from torchao.utils import is_sm_at_least_90 import torch.nn.functional as F @@ -190,10 +191,52 @@ def test_comprehensive_checkpoint_loading(self): from transformers import AutoConfig, AutoModel from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct") - model = Llama4TextMoe(config.text_config).to(torch.bfloat16).cuda() + config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct").text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + input_tensor = torch.randn(16, 128, config.hidden_size).cuda().bfloat16() + # print(model.experts) + for name, param in model.named_parameters(): + print(name) + + from torchao.quantization.quant_api import ParamFqnToConfig + + quant_config = ParamFqnToConfig({ + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + }) + + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + + def test_regex(self): + from transformers import AutoConfig, AutoModel + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct").text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + input_tensor = torch.randn(16, 128, config.hidden_size).cuda().bfloat16() + # print(model.experts) + for name, param in model.named_parameters(): + print(name) + + from torchao.quantization.quant_api import ParamFqnToConfig + + quant_config = ParamFqnToConfig({ + ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + }) + + quantize_( model, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) - print(model) - print("DONE") + quant_config, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index bc582af509..4d37b2a48a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,6 +20,7 @@ import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import re import torch import torch.nn as nn @@ -532,18 +533,31 @@ def quantize_( extra_args=(config,), ) return + if isinstance(config, ParamFqnToConfig): + def my_filter_fn(mod, fqn): + for name, _ in mod.named_parameters(): + if "." not in name: + return any(re.match(pattern, f"{fqn}.{name}") for pattern in config.param_fqn_to_config) - if isinstance(config, AOBaseConfig): - handler = _QUANTIZE_CONFIG_HANDLER[type(config)] - # for each linear in the model, apply the transform if filtering passes - _replace_with_custom_fn_if_matches_filter( + _replace_with_custom_fn_if_matches_filter_with_name( model, - handler, - filter_fn, + _param_fqn_to_config_handler, + my_filter_fn, device=device, extra_args=(config,), ) - + return + if isinstance(config, AOBaseConfig): + handler = _QUANTIZE_CONFIG_HANDLER[type(config)] + # for each linear in the model, apply the transform if filtering passes + if isinstance(model, nn.Module): + _replace_with_custom_fn_if_matches_filter( + model, + handler, + filter_fn, + device=device, + extra_args=(config,), + ) else: raise AssertionError( """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead.""" @@ -2356,6 +2370,36 @@ def _module_fqn_to_config_handler( return module +@dataclass +class ParamFqnToConfig(AOBaseConfig): + """Per param configurations for torchao quantize_ API + + Args: + `param_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from + the fully qualified name of the parameter to the AOBaseConfig that we want to apply to that parameter. + """ + + param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( + default_factory=dict + ) + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.ParamFqnToConfig") + +def _param_fqn_to_config_handler( + mod_containg_param: torch.nn.Module, fqn: str, config: ParamFqnToConfig +): + for name, param in list(mod_containg_param.named_parameters()): + # skip if not direct child + print(mod_containg_param, fqn, name) + if "." not in name: + for pattern in config.param_fqn_to_config: + if re.match(pattern, f"{fqn}.{name}"): + print("Matching pattern", pattern) + param_config = config.param_fqn_to_config.get(pattern) + setattr(mod_containg_param, name, nn.Parameter(_float8_dynamic_activation_float8_weight_quantize_tensor(param, param_config))) + + return mod_containg_param torch.serialization.add_safe_globals( [ diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index cf6543980c..f4ddbd6fa2 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -386,7 +386,7 @@ def _(func, types, args, kwargs): a_data = input_tensor.qdata a_scale = input_tensor.scale - b_data = weight_tensor.qdata + b_data = weight_tensor.qdata.transpose(1, 2).contiguous() b_scale = weight_tensor.scale.squeeze(-1) assert b_data.is_contiguous(), "weight for bmm must be contiguous" @@ -400,7 +400,7 @@ def _(func, types, args, kwargs): ), "bmm only works for per row activation quantization" orig_out_features = b_data.shape[-2] - + breakpoint() res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, b_data, @@ -415,7 +415,6 @@ def _(func, types, args, kwargs): return res - @implements(aten.slice.Tensor) def _(func, types, args, kwargs): """Only supports slicing for dim == 1 and dim == 2 From 8a0236473629e430bc288c834a44455b44dda71c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:21:06 -0700 Subject: [PATCH 05/47] added param support to quantize_ --- test/prototype/test_parameter.py | 7 ++++++ torchao/quantization/quant_api.py | 12 +++++++---- torchao/quantization/transform_module.py | 27 ++++++++++++++++++++++++ 3 files changed, 42 insertions(+), 4 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index da856b5a13..86c6d601a3 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -240,3 +240,10 @@ def test_regex(self): ) assert isinstance(model.experts.gate_up_proj, Float8Tensor) + + def test_top_level_param(self): + param = nn.Parameter(torch.randn(1024, 1024).cuda().to(torch.bfloat16)) + + new_param = quantize_(param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + + assert isinstance(new_param, Float8Tensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 4d37b2a48a..f958296d3f 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -86,7 +86,9 @@ ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, + _QUANTIZE_CONFIG_PARAM_HANDLER, register_quantize_module_handler, + register_quantize_param_handler, ) from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, @@ -547,6 +549,9 @@ def my_filter_fn(mod, fqn): extra_args=(config,), ) return + if isinstance(model, nn.Parameter): + handler = _QUANTIZE_CONFIG_PARAM_HANDLER[type(config)] + return nn.Parameter(handler(model, config)) if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -1784,7 +1789,7 @@ def __post_init__(self): "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig ) - +@register_quantize_param_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -2391,13 +2396,12 @@ def _param_fqn_to_config_handler( ): for name, param in list(mod_containg_param.named_parameters()): # skip if not direct child - print(mod_containg_param, fqn, name) if "." not in name: for pattern in config.param_fqn_to_config: if re.match(pattern, f"{fqn}.{name}"): - print("Matching pattern", pattern) param_config = config.param_fqn_to_config.get(pattern) - setattr(mod_containg_param, name, nn.Parameter(_float8_dynamic_activation_float8_weight_quantize_tensor(param, param_config))) + assert param_config is not None + setattr(mod_containg_param, name, quantize_(param, param_config)) return mod_containg_param diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index 52bc721f1f..6a7d218442 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -15,6 +15,10 @@ Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], ] = {} +_QUANTIZE_CONFIG_PARAM_HANDLER: Dict[ + Type[AOBaseConfig], + Callable[[torch.nn.Parameter, AOBaseConfig], torch.nn.Parameter], +] = {} def register_quantize_module_handler(config_type): """ @@ -50,3 +54,26 @@ def decorator(func): return func # needed to make the functions usable externally return decorator + +def register_quantize_param_handler(config_type): + """ + A decorator to register a transform function to map from a workflow + configuration (child of `AOBaseConfig`) to a function that transforms + a `torch.nn.Parameter` according to the specified configuration. + + For example:: + + # user facing code + class WorkflowFooConfig(AOBaseConfig): ... + # configuration for workflow `Foo` is defined here + bar = 'baz' + + # non user facing code + @resgister_quantize + """ + @functools.wraps(config_type) + def decorator(func): + _QUANTIZE_CONFIG_PARAM_HANDLER[config_type] = func + return func # needed to make the functions usable externally + + return decorator From 2f2715aae4f6e7403c0ff3aeca68f4b4877bb048 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:25:55 -0700 Subject: [PATCH 06/47] remove float8 changes --- .../quantization/quantize_/workflows/float8/float8_tensor.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index f4ddbd6fa2..cf6543980c 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -386,7 +386,7 @@ def _(func, types, args, kwargs): a_data = input_tensor.qdata a_scale = input_tensor.scale - b_data = weight_tensor.qdata.transpose(1, 2).contiguous() + b_data = weight_tensor.qdata b_scale = weight_tensor.scale.squeeze(-1) assert b_data.is_contiguous(), "weight for bmm must be contiguous" @@ -400,7 +400,7 @@ def _(func, types, args, kwargs): ), "bmm only works for per row activation quantization" orig_out_features = b_data.shape[-2] - breakpoint() + res = torch.ops.fbgemm.f8f8bf16_rowwise_batched( a_data, b_data, @@ -415,6 +415,7 @@ def _(func, types, args, kwargs): return res + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): """Only supports slicing for dim == 1 and dim == 2 From cd3d1a3ee4e319a7a03ffb05417bdfe1dbe5ab15 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:30:32 -0700 Subject: [PATCH 07/47] update --- torchao/quantization/quant_api.py | 34 +++++++++++++++++-------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f958296d3f..35fbfa8b2b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -16,11 +16,11 @@ """ import logging +import re import types import warnings from dataclasses import dataclass, field from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import re import torch import torch.nn as nn @@ -526,6 +526,9 @@ def quantize_( torch._C._log_api_usage_once("torchao.quantization.quantize_") filter_fn = _is_linear if filter_fn is None else filter_fn + if isinstance(model, nn.Parameter): + handler = _QUANTIZE_CONFIG_PARAM_HANDLER[type(config)] + return nn.Parameter(handler(model, config)) if isinstance(config, ModuleFqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -536,22 +539,19 @@ def quantize_( ) return if isinstance(config, ParamFqnToConfig): - def my_filter_fn(mod, fqn): - for name, _ in mod.named_parameters(): - if "." not in name: - return any(re.match(pattern, f"{fqn}.{name}") for pattern in config.param_fqn_to_config) - _replace_with_custom_fn_if_matches_filter_with_name( model, _param_fqn_to_config_handler, - my_filter_fn, + lambda mod, fqn: any( + re.match(pattern, f"{fqn}.{name}") + for pattern in config.param_fqn_to_config + for name, _ in mod.named_parameters() + if "." not in name + ), device=device, extra_args=(config,), ) return - if isinstance(model, nn.Parameter): - handler = _QUANTIZE_CONFIG_PARAM_HANDLER[type(config)] - return nn.Parameter(handler(model, config)) if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -1761,6 +1761,7 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): version (int): the version of the config, version 1 is using AffineQuantizedTensor that we plan to deprecate/split, version 2 is using Float8Tensor (default) """ + activation_dtype: torch.dtype = e4m3_dtype weight_dtype: torch.dtype = e4m3_dtype granularity: Optional[Union[FP8Granularity, List[FP8Granularity]]] = None @@ -1789,6 +1790,7 @@ def __post_init__(self): "float8_dynamic_activation_float8_weight", Float8DynamicActivationFloat8WeightConfig ) + @register_quantize_param_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype @@ -1870,7 +1872,6 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - param_name = getattr(config, "param_name") if param_name is None: param_name = "weight" @@ -1882,7 +1883,9 @@ def _float8_dynamic_activation_float8_weight_transform( quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( getattr(module, param_name), config ) - setattr(module, param_name, torch.nn.Parameter(quantized_weight, requires_grad=False)) + setattr( + module, param_name, torch.nn.Parameter(quantized_weight, requires_grad=False) + ) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2375,6 +2378,7 @@ def _module_fqn_to_config_handler( return module + @dataclass class ParamFqnToConfig(AOBaseConfig): """Per param configurations for torchao quantize_ API @@ -2384,13 +2388,12 @@ class ParamFqnToConfig(AOBaseConfig): the fully qualified name of the parameter to the AOBaseConfig that we want to apply to that parameter. """ - param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( - default_factory=dict - ) + param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(default_factory=dict) def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.ParamFqnToConfig") + def _param_fqn_to_config_handler( mod_containg_param: torch.nn.Module, fqn: str, config: ParamFqnToConfig ): @@ -2405,6 +2408,7 @@ def _param_fqn_to_config_handler( return mod_containg_param + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant, From a22e781eb1316015033aad5559f0b99d552111af Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:36:36 -0700 Subject: [PATCH 08/47] ruff format --- torchao/quantization/quant_api.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 35fbfa8b2b..68516dd841 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -1872,20 +1872,14 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - param_name = getattr(config, "param_name") - if param_name is None: - param_name = "weight" - - assert hasattr(module, param_name), ( + assert hasattr(module, "weight"), ( "applying float8 dynamic activation quant requires module to have weight attribute" + f"but {module} does not have one" ) quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - getattr(module, param_name), config - ) - setattr( - module, param_name, torch.nn.Parameter(quantized_weight, requires_grad=False) + getattr(module, "weight"), config ) + setattr(module, "weight", torch.nn.Parameter(quantized_weight, requires_grad=False)) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module From 36775c8648efc70fcc25130341b65a7e7be0bc1e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:44:57 -0700 Subject: [PATCH 09/47] update --- torchao/quantization/quant_api.py | 31 +++++++++++++++---------------- 1 file changed, 15 insertions(+), 16 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 68516dd841..8cd25cb0c0 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -529,15 +529,6 @@ def quantize_( if isinstance(model, nn.Parameter): handler = _QUANTIZE_CONFIG_PARAM_HANDLER[type(config)] return nn.Parameter(handler(model, config)) - if isinstance(config, ModuleFqnToConfig): - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _module_fqn_to_config_handler, - filter_fn, - device=device, - extra_args=(config,), - ) - return if isinstance(config, ParamFqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -552,6 +543,15 @@ def quantize_( extra_args=(config,), ) return + if isinstance(config, ModuleFqnToConfig): + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _module_fqn_to_config_handler, + filter_fn, + device=device, + extra_args=(config,), + ) + return if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -1771,7 +1771,6 @@ class Float8DynamicActivationFloat8WeightConfig(AOBaseConfig): kernel_preference: KernelPreference = KernelPreference.AUTO set_inductor_config: bool = True version: int = 2 - param_name: Optional[str] = None def __post_init__(self): torch._C._log_api_usage_once( @@ -1877,9 +1876,9 @@ def _float8_dynamic_activation_float8_weight_transform( + f"but {module} does not have one" ) quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - getattr(module, "weight"), config + module.weight, config ) - setattr(module, "weight", torch.nn.Parameter(quantized_weight, requires_grad=False)) + module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2389,18 +2388,18 @@ def __post_init__(self): def _param_fqn_to_config_handler( - mod_containg_param: torch.nn.Module, fqn: str, config: ParamFqnToConfig + mod_containing_param: torch.nn.Module, fqn: str, config: ParamFqnToConfig ): - for name, param in list(mod_containg_param.named_parameters()): + for name, param in list(mod_containing_param.named_parameters()): # skip if not direct child if "." not in name: for pattern in config.param_fqn_to_config: if re.match(pattern, f"{fqn}.{name}"): param_config = config.param_fqn_to_config.get(pattern) assert param_config is not None - setattr(mod_containg_param, name, quantize_(param, param_config)) + setattr(mod_containing_param, name, quantize_(param, param_config)) - return mod_containg_param + return mod_containing_param torch.serialization.add_safe_globals( From b71e77fc588dd04d88228f90baf0ba2445897b3f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:47:02 -0700 Subject: [PATCH 10/47] update --- torchao/quantization/quant_api.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8cd25cb0c0..b93374ca80 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -555,14 +555,13 @@ def quantize_( if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes - if isinstance(model, nn.Module): - _replace_with_custom_fn_if_matches_filter( - model, - handler, - filter_fn, - device=device, - extra_args=(config,), - ) + _replace_with_custom_fn_if_matches_filter( + model, + handler, + filter_fn, + device=device, + extra_args=(config,), + ) else: raise AssertionError( """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead.""" From 6c8bc804580029a496cd851f9caf79a42e9db3cc Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:51:26 -0700 Subject: [PATCH 11/47] remove old changes --- .../workflows/float8/float8_tensor.py | 37 +++++++++---------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index cf6543980c..46d4ca5426 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -23,7 +23,6 @@ preprocess_scale, ) from torchao.quantization.granularity import PerRow, PerTensor -from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _choose_scale_float8, _dequantize_affine_float8, @@ -34,6 +33,7 @@ QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, ) +from torchao.quantization.utils import get_block_size from torchao.utils import ( TorchAOBaseTensor, _is_fbgemm_genai_gpu_available, @@ -619,25 +619,24 @@ def _(func, types, args, kwargs): @implements(aten.select.int) def _(func, types, args, kwargs): - self, dim, index = args - qdata = self.qdata.select(dim, index) - # TODO we need to handle this case differently based on the scaling config - - scale = self.scale.select(dim, index).t() - """ - Without the transpose here, I run into the following runtime error: - E RuntimeError: Invalid scaling configuration. For TensorWise scaling, both scales should be scalar. For RowWise scaling, scale_a should be (512, 1) and scale_b should be (1, 1024). Got scale_a.size()=(512, 1) and scale_b.size()=(1024, 1) - """ - - new = self.__class__( - qdata, - scale, - self.block_size, - self.mm_config, - self.act_quant_kwargs, - self.kernel_preference, + old_float8_tensor, dim, index = args + assert dim == 0, f"Float8Tensor aten.select.int with {dim=} is not yet supported" + assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.scale.shape), ( + "unsupported" ) - return return_and_correct_aliasing(func, args, kwargs, new) + assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.block_size), ( + "unsupported" + ) + new_float8_tensor = old_float8_tensor.__class__( + old_float8_tensor.qdata[index], + old_float8_tensor.scale[index], + old_float8_tensor.block_size[1:], + old_float8_tensor.mm_config, + old_float8_tensor.act_quant_kwargs, + old_float8_tensor.kernel_preference, + old_float8_tensor.dtype, + ) + return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor) Float8Tensor.__module__ = "torchao.quantization" From 62e0c5cb5c6c7f9bcd41962b87c742273d8add0c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 12:54:28 -0700 Subject: [PATCH 12/47] update --- test/prototype/test_parameter.py | 152 +------------------------------ 1 file changed, 5 insertions(+), 147 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 86c6d601a3..77d50aea7e 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -34,6 +34,7 @@ import torch from torch.testing._internal.common_utils import TestCase, run_tests from torchao.utils import is_fbcode, is_sm_at_least_90 +from torchao.quantization.quant_api import ParamFqnToConfig if not is_fbcode(): from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig @@ -42,163 +43,22 @@ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) -class TestMoE2d(nn.Module): - def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: - """ - Args: - num_experts (int): - Number of experts. - input_size (int): - Size of the input. - output_size (int): - Size of the output. - """ - super().__init__() - self.num_experts = num_experts - self.input_size = input_size - self.output_size = output_size - self.experts = nn.ModuleList([nn.Linear(input_size, output_size, bias=None) for _ in range(self.num_experts)]) - - def forward(self, inputs, expert_size): - """ - Forward pass of the JetMoeParallelExperts module. - - Args: - inputs (Tensor): - Input tensor. - expert_size: - Expert size information. - - Returns: - Tensor: Output tensor. - """ - # return True - input_list = inputs.split(expert_size, dim=0) - output_list = [] - - assert len(input_list) == len(self.experts) - for expert, expert_input in zip(self.experts, input_list): - output_list.append(expert(expert_input)) - results = torch.cat(output_list, dim=0) - return results - - -class TestMoE3d(nn.Module): - def __init__(self, num_experts: int, input_size: int, output_size: int) -> None: - """ - This implementation is taken from: - https://github.com/huggingface/transformers/blob/6cade29278c4aee3f174f8950f97a3873bdb212f/src/transformers/models/jetmoe/modeling_jetmoe.py#L141 - - Args: - num_experts (int): - Number of experts. - input_size (int): - Size of the input. - output_size (int): - Size of the output. - """ - super().__init__() - self.moe_weight = nn.Parameter(torch.randn(num_experts, output_size, input_size)) - self.num_experts = num_experts - self.input_size = input_size - self.output_size = output_size - - def forward(self, inputs, expert_size): - """ - Forward pass of the JetMoeParallelExperts module. - - Args: - inputs (Tensor): - Input tensor. - expert_size: - Expert size information. - - Returns: - Tensor: Output tensor. - """ - # return True - input_list = inputs.split(expert_size, dim=0) - output_list = [] - for i in range(self.num_experts): - output_list.append(F.linear(input_list[i], self.moe_weight[i])) - results = torch.cat(output_list, dim=0) - return results - -def print_model_fqn(model): - print("=== Parameters ===") - for name, param in model.named_parameters(): - print(f"{name}: {param.shape}") - - print("\n=== Modules ===") - for name, module in model.named_modules(): - if name: # Skip empty name for root module - print(f"{name}: {type(module).__name__}") - - -class TestQuantizeParameterMoE(common_utils.TestCase): - - def test_2d_3d_moe_equivalent(self): - test_input = torch.randn(1024, 1024).cuda() - - model_2d = TestMoE2d(2, 1024, 1024).cuda() - model_3d = TestMoE3d(2, 1024, 1024).cuda() - - for i, expert in enumerate(model_2d.experts): - model_3d.moe_weight.data[i] = expert.weight.detach().clone() - - output_2d = model_2d(test_input, 512) - output_3d = model_3d(test_input, 512) - - torch.testing.assert_close(output_2d, output_3d, rtol=1e-3, atol=1e-3) - - - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") - def test_quantize_parameter(self): - test_input = torch.randn(1024, 1024).cuda().bfloat16() - - model_2d = TestMoE2d(2, 1024, 1024).cuda().bfloat16() - model_3d = TestMoE3d(2, 1024, 1024).cuda().bfloat16() - - for i, expert in enumerate(model_2d.experts): - model_3d.moe_weight.data[i] = expert.weight.detach().clone() - - # quantize all linears in 2d - quantize_( - model_2d, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), - ) - - quantize_( - model_3d, - Float8DynamicActivationFloat8WeightConfig(granularity=PerRow(), param_name="moe_weight"), - filter_fn=lambda mod, fqn: fqn is '' # top level module has no fqn - ) - - output_3d = model_3d(test_input, 512) - output_2d = model_2d(test_input, 512) - - torch.testing.assert_close(output_2d, output_3d, rtol=1e-3, atol=1e-3) - @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") @unittest.skipIf( is_fbcode(), "Skipping the test in fbcode for now, not sure how to download from transformers", ) -class TestTorchAOCheckpoint(TestCase): +class TestQuantizeFQNParam (TestCase): - def test_comprehensive_checkpoint_loading(self): + def test_quantize_param_fqn_exact(self): from transformers import AutoConfig, AutoModel from transformers.models.llama4.modeling_llama4 import Llama4TextMoe config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct").text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() input_tensor = torch.randn(16, 128, config.hidden_size).cuda().bfloat16() - # print(model.experts) - for name, param in model.named_parameters(): - print(name) - from torchao.quantization.quant_api import ParamFqnToConfig quant_config = ParamFqnToConfig({ "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( @@ -214,7 +74,7 @@ def test_comprehensive_checkpoint_loading(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) - def test_regex(self): + def test_quantize_param_fqn_regex(self): from transformers import AutoConfig, AutoModel from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -241,9 +101,7 @@ def test_regex(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) - def test_top_level_param(self): + def test_quantize_param_root(self): param = nn.Parameter(torch.randn(1024, 1024).cuda().to(torch.bfloat16)) - new_param = quantize_(param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) - assert isinstance(new_param, Float8Tensor) From cb760161db06851dfdd74b02edd6595c36ef8932 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 13:28:57 -0700 Subject: [PATCH 13/47] update --- .../workflows/float8/float8_tensor.py | 24 +------------------ 1 file changed, 1 insertion(+), 23 deletions(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index 46d4ca5426..ffff2cac18 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -23,6 +23,7 @@ preprocess_scale, ) from torchao.quantization.granularity import PerRow, PerTensor +from torchao.quantization.utils import get_block_size from torchao.quantization.quant_primitives import ( _choose_scale_float8, _dequantize_affine_float8, @@ -33,7 +34,6 @@ QuantizeTensorKwargs, _choose_quant_func_and_quantize_tensor, ) -from torchao.quantization.utils import get_block_size from torchao.utils import ( TorchAOBaseTensor, _is_fbgemm_genai_gpu_available, @@ -617,28 +617,6 @@ def _(func, types, args, kwargs): return return_and_correct_aliasing(func, args, kwargs, new) -@implements(aten.select.int) -def _(func, types, args, kwargs): - old_float8_tensor, dim, index = args - assert dim == 0, f"Float8Tensor aten.select.int with {dim=} is not yet supported" - assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.scale.shape), ( - "unsupported" - ) - assert len(old_float8_tensor.qdata.shape) == len(old_float8_tensor.block_size), ( - "unsupported" - ) - new_float8_tensor = old_float8_tensor.__class__( - old_float8_tensor.qdata[index], - old_float8_tensor.scale[index], - old_float8_tensor.block_size[1:], - old_float8_tensor.mm_config, - old_float8_tensor.act_quant_kwargs, - old_float8_tensor.kernel_preference, - old_float8_tensor.dtype, - ) - return return_and_correct_aliasing(func, args, kwargs, new_float8_tensor) - - Float8Tensor.__module__ = "torchao.quantization" # Allow a model with Float8Tensor weights to be loaded with `weights_only=True` From ee36c40045df7cfc7722e353247831184184b923 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 13:30:41 -0700 Subject: [PATCH 14/47] undo --- .../quantization/quantize_/workflows/float8/float8_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py index ffff2cac18..49c8b1cd24 100644 --- a/torchao/quantization/quantize_/workflows/float8/float8_tensor.py +++ b/torchao/quantization/quantize_/workflows/float8/float8_tensor.py @@ -23,7 +23,7 @@ preprocess_scale, ) from torchao.quantization.granularity import PerRow, PerTensor -from torchao.quantization.utils import get_block_size +from torchao.quantization.observer import get_block_size from torchao.quantization.quant_primitives import ( _choose_scale_float8, _dequantize_affine_float8, From 7c5ab0417b36393e2ecdd419ba97fa38a97bf614 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 26 Sep 2025 13:53:42 -0700 Subject: [PATCH 15/47] update --- test/prototype/test_parameter.py | 76 ++++++++++++++------------------ 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 77d50aea7e..5195104764 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -3,69 +3,56 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. -import copy import logging import unittest import torch from torch import nn -from torch.testing._internal import common_utils +from torch.testing._internal.common_utils import TestCase -from torchao.dtypes import MarlinSparseLayout, SemiSparseLayout from torchao.quantization import ( - Float8DynamicActivationFloat8SemiSparseWeightConfig, Float8DynamicActivationFloat8WeightConfig, ) from torchao.quantization.quant_api import ( - Int4WeightOnlyConfig, - Int8DynamicActivationInt8WeightConfig, + ParamFqnToConfig, PerRow, - PerTensor, quantize_, ) from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor -from torchao.sparsity import apply_fake_sparsity, semi_sparse_weight, sparsify_ -from torchao.utils import is_sm_at_least_90 -import torch.nn.functional as F - -import re -import unittest -import warnings -import torch -from torch.testing._internal.common_utils import TestCase, run_tests from torchao.utils import is_fbcode, is_sm_at_least_90 -from torchao.quantization.quant_api import ParamFqnToConfig if not is_fbcode(): - from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig + pass logging.basicConfig( format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO ) + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") @unittest.skipIf( is_fbcode(), "Skipping the test in fbcode for now, not sure how to download from transformers", ) -class TestQuantizeFQNParam (TestCase): - +class TestQuantizeFQNParam(TestCase): def test_quantize_param_fqn_exact(self): - from transformers import AutoConfig, AutoModel + from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct").text_config + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - input_tensor = torch.randn(16, 128, config.hidden_size).cuda().bfloat16() - - - quant_config = ParamFqnToConfig({ - "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - }) - + torch.randn(16, 128, config.hidden_size).cuda().bfloat16() + + quant_config = ParamFqnToConfig( + { + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) quantize_( model, @@ -75,24 +62,27 @@ def test_quantize_param_fqn_exact(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) def test_quantize_param_fqn_regex(self): - from transformers import AutoConfig, AutoModel + from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - config = AutoConfig.from_pretrained("unsloth/Llama-4-Scout-17B-16E-Instruct").text_config + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - input_tensor = torch.randn(16, 128, config.hidden_size).cuda().bfloat16() + torch.randn(16, 128, config.hidden_size).cuda().bfloat16() # print(model.experts) for name, param in model.named_parameters(): print(name) - from torchao.quantization.quant_api import ParamFqnToConfig - - quant_config = ParamFqnToConfig({ - ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - }) + from torchao.quantization.quant_api import ParamFqnToConfig + quant_config = ParamFqnToConfig( + { + ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) quantize_( model, @@ -103,5 +93,7 @@ def test_quantize_param_fqn_regex(self): def test_quantize_param_root(self): param = nn.Parameter(torch.randn(1024, 1024).cuda().to(torch.bfloat16)) - new_param = quantize_(param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())) + new_param = quantize_( + param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + ) assert isinstance(new_param, Float8Tensor) From 57d2f21a41a3b45d2b051e200bea21123e8e65ec Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 2 Oct 2025 18:44:10 -0700 Subject: [PATCH 16/47] cr feedback --- test/prototype/test_parameter.py | 25 +++++---- torchao/quantization/quant_api.py | 67 ++++++++++++++---------- torchao/quantization/transform_module.py | 8 ++- 3 files changed, 58 insertions(+), 42 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 5195104764..1bad54696a 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -14,7 +14,7 @@ Float8DynamicActivationFloat8WeightConfig, ) from torchao.quantization.quant_api import ( - ParamFqnToConfig, + ModuleOrParamFqnToConfig, PerRow, quantize_, ) @@ -46,7 +46,7 @@ def test_quantize_param_fqn_exact(self): model = Llama4TextMoe(config).to(torch.bfloat16).cuda() torch.randn(16, 128, config.hidden_size).cuda().bfloat16() - quant_config = ParamFqnToConfig( + quant_config = ModuleOrParamFqnToConfig( { "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), @@ -61,26 +61,24 @@ def test_quantize_param_fqn_exact(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) - def test_quantize_param_fqn_regex(self): + def test_quantize_param_and_module_fqn(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + from torchao.quantization import PerTensor config = AutoConfig.from_pretrained( "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() torch.randn(16, 128, config.hidden_size).cuda().bfloat16() - # print(model.experts) - for name, param in model.named_parameters(): - print(name) - - from torchao.quantization.quant_api import ParamFqnToConfig - - quant_config = ParamFqnToConfig( + quant_config = ModuleOrParamFqnToConfig( { - ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), } ) @@ -90,10 +88,11 @@ def test_quantize_param_fqn_regex(self): ) assert isinstance(model.experts.gate_up_proj, Float8Tensor) + assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) def test_quantize_param_root(self): param = nn.Parameter(torch.randn(1024, 1024).cuda().to(torch.bfloat16)) - new_param = quantize_( + quantize_( param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) ) - assert isinstance(new_param, Float8Tensor) + assert isinstance(param, Float8Tensor) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b93374ca80..a0187b6d72 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -20,6 +20,7 @@ import types import warnings from dataclasses import dataclass, field +from functools import partial from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch @@ -94,6 +95,7 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( + TorchAOBaseTensor, _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, @@ -526,19 +528,18 @@ def quantize_( torch._C._log_api_usage_once("torchao.quantization.quantize_") filter_fn = _is_linear if filter_fn is None else filter_fn - if isinstance(model, nn.Parameter): - handler = _QUANTIZE_CONFIG_PARAM_HANDLER[type(config)] - return nn.Parameter(handler(model, config)) - if isinstance(config, ParamFqnToConfig): + if isinstance(config, ModuleOrParamFqnToConfig): + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _module_fqn_to_config_handler, + filter_fn, + device=device, + extra_args=(config,), + ) _replace_with_custom_fn_if_matches_filter_with_name( model, _param_fqn_to_config_handler, - lambda mod, fqn: any( - re.match(pattern, f"{fqn}.{name}") - for pattern in config.param_fqn_to_config - for name, _ in mod.named_parameters() - if "." not in name - ), + partial(select_module_if_fqn_in_pattern, config=config), device=device, extra_args=(config,), ) @@ -2372,35 +2373,45 @@ def _module_fqn_to_config_handler( @dataclass -class ParamFqnToConfig(AOBaseConfig): - """Per param configurations for torchao quantize_ API - - Args: - `param_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from - the fully qualified name of the parameter to the AOBaseConfig that we want to apply to that parameter. - """ - - param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field(default_factory=dict) +class ModuleOrParamFqnToConfig(AOBaseConfig): + module_or_param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( + default_factory=dict + ) def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.ParamFqnToConfig") + torch._C._log_api_usage_once("torchao.quantization.ModuleOrParamFqnToConfig") + + @property + def module_fqn_to_config(self): + return self.module_or_param_fqn_to_config def _param_fqn_to_config_handler( - mod_containing_param: torch.nn.Module, fqn: str, config: ParamFqnToConfig + mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig ): for name, param in list(mod_containing_param.named_parameters()): - # skip if not direct child - if "." not in name: - for pattern in config.param_fqn_to_config: - if re.match(pattern, f"{fqn}.{name}"): - param_config = config.param_fqn_to_config.get(pattern) - assert param_config is not None - setattr(mod_containing_param, name, quantize_(param, param_config)) + # check to see if top level param + if name in dir(mod_containing_param): + for pattern, param_config in config.module_or_param_fqn_to_config.items(): + if pattern == f"{fqn}.{name}" and not isinstance(param, TorchAOBaseTensor): + param_config_type = type(param_config) + if param_config_type in _QUANTIZE_CONFIG_PARAM_HANDLER: + handler = _QUANTIZE_CONFIG_PARAM_HANDLER[param_config_type] + new_param = handler(param, param_config) + setattr(mod_containing_param, name, new_param) + else: + raise NotImplementedError(f"Parameter quantization for {param_config_type} not supported currently!") return mod_containing_param +def select_module_if_fqn_in_pattern(mod, fqn, config): + for name, _ in mod.named_parameters(): + if "." not in name: # only want attribute parameters + if f"{fqn}.{name}" in config.module_or_param_fqn_to_config: + return True + + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant, diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index 6a7d218442..d59790d3ad 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -73,7 +73,13 @@ class WorkflowFooConfig(AOBaseConfig): ... """ @functools.wraps(config_type) def decorator(func): - _QUANTIZE_CONFIG_PARAM_HANDLER[config_type] = func + + def func_supporting_param(tensor_or_param, config): + if type(tensor_or_param) is torch.nn.Parameter: + return torch.nn.Parameter(func(tensor_or_param, config)) + return func(tensor_or_param, config) + + _QUANTIZE_CONFIG_PARAM_HANDLER[config_type] = func_supporting_param return func # needed to make the functions usable externally return decorator From 8c43f45523884eb4b28accc0c3b7fb68a8e3bd47 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 2 Oct 2025 19:27:18 -0700 Subject: [PATCH 17/47] add test --- test/prototype/test_parameter.py | 62 +++++- torchao/quantization/quant_api.py | 316 +++++++++++++++++++----------- 2 files changed, 257 insertions(+), 121 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 1bad54696a..db8550ba63 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -18,6 +18,7 @@ PerRow, quantize_, ) +from torchao.core.config import AOBaseConfig from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.utils import is_fbcode, is_sm_at_least_90 @@ -87,12 +88,63 @@ def test_quantize_param_and_module_fqn(self): quant_config, ) + def test_quantize_param_and_module_fqn_regex(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + from torchao.quantization import PerTensor + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + torch.randn(16, 128, config.hidden_size).cuda().bfloat16() + quant_config = ModuleOrParamFqnToConfig( + { + ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + assert isinstance(model.experts.gate_up_proj, Float8Tensor) assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) - def test_quantize_param_root(self): - param = nn.Parameter(torch.randn(1024, 1024).cuda().to(torch.bfloat16)) - quantize_( - param, Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()) + def test_unsupported_param_config_raises_not_implemented_error(self): + """Test that using an unsupported parameter config raises NotImplementedError.""" + from dataclasses import dataclass + + # Create a custom config that doesn't have a registered parameter handler + @dataclass + class UnsupportedParamConfig(AOBaseConfig): + some_value: int = 42 + + # Create a simple model + model = nn.Linear(10, 5).cuda().bfloat16() + + # Create config with unsupported parameter handler + quant_config = ModuleOrParamFqnToConfig( + { + "weight": UnsupportedParamConfig(), + } ) - assert isinstance(param, Float8Tensor) + + # This should raise NotImplementedError + with self.assertRaises(NotImplementedError) as context: + quantize_(model, quant_config) + + # Check that the error message contains the expected text + self.assertIn("Parameter quantization for", str(context.exception)) + self.assertIn("not supported currently", str(context.exception)) + self.assertIn("UnsupportedParamConfig", str(context.exception)) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a0187b6d72..27e6b970cd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -19,9 +19,19 @@ import re import types import warnings +from collections import OrderedDict from dataclasses import dataclass, field from functools import partial -from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + OrderedDict as OrderedDictType, + Tuple, + Union, +) import torch import torch.nn as nn @@ -44,32 +54,30 @@ QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, - UintxLayout, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, to_marlinqqq_quantized_intx, + UintxLayout, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - Target, make_packed_linear_int8_dynamic_activation_intx_weight_tensor, + Target, ) from torchao.dtypes.utils import Layout from torchao.float8.config import e4m3_dtype, e5m2_dtype from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import ( - Float8MMConfig, - FP8Granularity, _check_hardware_support, _normalize_granularity, + Float8MMConfig, + FP8Granularity, ) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, ) from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size -from torchao.quantization.quantize_.common import ( - KernelPreference, -) +from torchao.quantization.quantize_.common import KernelPreference from torchao.quantization.quantize_.workflows import ( Float8Tensor, Int4ChooseQParamsAlgorithm, @@ -95,44 +103,29 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( - TorchAOBaseTensor, _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, + TorchAOBaseTensor, ) -from .autoquant import AutoQuantizableLinearWeight, autoquant -from .GPTQ import ( - Int4WeightOnlyGPTQQuantizer, -) -from .granularity import ( - Granularity, - PerAxis, - PerGroup, - PerRow, - PerTensor, -) +from .autoquant import autoquant, AutoQuantizableLinearWeight +from .GPTQ import Int4WeightOnlyGPTQQuantizer +from .granularity import Granularity, PerAxis, PerGroup, PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .linear_quant_modules import ( - Int4WeightOnlyQuantizer, - Int8DynActInt4WeightQuantizer, -) -from .qat import ( - intx_quantization_aware_training, -) +from .linear_quant_modules import Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer +from .qat import intx_quantization_aware_training from .quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, - ZeroPointDomain, quantize_affine, + ZeroPointDomain, ) -from .subclass import ( - QuantizedLinearWeightBase, -) +from .subclass import QuantizedLinearWeightBase from .unified import Quantizer, TwoStepQuantizer from .utils import _get_per_token_block_size @@ -783,34 +776,28 @@ def __post_init__(self): torch._C._log_api_usage_once( "torchao.quantization.Int8DynamicActivationIntxWeightConfig" ) - assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( - f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" - ) - assert isinstance(self.weight_granularity, (PerAxis, PerGroup)), ( - f"weight_granularity must be PerAxis or PerGroup, but got {self.weight_granularity}" - ) + assert self.weight_dtype in [ + getattr(torch, f"int{b}") for b in range(1, 9) + ], f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + assert isinstance( + self.weight_granularity, (PerAxis, PerGroup) + ), f"weight_granularity must be PerAxis or PerGroup, but got {self.weight_granularity}" if isinstance(self.weight_granularity, PerAxis): - assert self.weight_granularity.axis == 0, ( - f"axis must be 0, but got {self.weight_granularity.axis}" - ) + assert ( + self.weight_granularity.axis == 0 + ), f"axis must be 0, but got {self.weight_granularity.axis}" assert self.weight_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR, - ], ( - f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" - ) + ], f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" assert self.act_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, - ], ( - f"act_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.act_mapping_type}" - ) + ], f"act_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.act_mapping_type}" assert isinstance( self.layout, (PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout) - ), ( - f"layout must be PackedLinearInt8DynamicActivationIntxWeightLayout or QDQLayout, but got {self.layout}" - ) + ), f"layout must be PackedLinearInt8DynamicActivationIntxWeightLayout or QDQLayout, but got {self.layout}" if isinstance(self.layout, PackedLinearInt8DynamicActivationIntxWeightLayout): if self.layout.target in [Target.AUTO, Target.KLEIDIAI, Target.ATEN]: @@ -834,15 +821,15 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): layout = config.layout intx_packing_format = config.intx_packing_format - assert weight.dim() == 2, ( - f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" - ) + assert ( + weight.dim() == 2 + ), f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" if isinstance(weight_granularity, PerGroup): group_size = weight_granularity.group_size elif isinstance(weight_granularity, PerAxis): - assert weight_granularity.axis == 0, ( - f"axis must be 0 with PerAxis, but got {weight_granularity.axis}" - ) + assert ( + weight_granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {weight_granularity.axis}" group_size = weight.shape[-1] else: raise ValueError( @@ -928,9 +915,9 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): elif isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): # PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization # fused with the kernel and it should not be applied separately - assert act_mapping_type == MappingType.ASYMMETRIC, ( - "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" - ) + assert ( + act_mapping_type == MappingType.ASYMMETRIC + ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" data, scale, zero_point = weight.tensor_impl.get_plain() groups_per_row = weight.shape[-1] // group_size scale = scale.reshape(-1, groups_per_row) @@ -1231,16 +1218,16 @@ def _int4_weight_only_quantize_tensor(weight, config): ) # nonlocal zero_point_domain - assert type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys(), ( - f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" - ) + assert ( + type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() + ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: - assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)], ( - f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" - ) + assert ( + zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] + ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" if zero_point_domain == ZeroPointDomain.INT and isinstance(layout, Int4XPULayout): zero_point_dtype = torch.int32 @@ -1255,9 +1242,9 @@ def _int4_weight_only_quantize_tensor(weight, config): # we should consider moving this logic somewhere else. if isinstance(layout, MarlinSparseLayout): mapping_type = MappingType.SYMMETRIC - assert group_size == 128 or group_size == weight.shape[-1], ( - f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" - ) + assert ( + group_size == 128 or group_size == weight.shape[-1] + ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" new_weight = to_affine_quantized_intx( weight, @@ -1317,9 +1304,9 @@ def _float8_dynamic_activation_int4_weight_transform( ) int4_packing_format = config.int4_packing_format - assert int4_packing_format == "preshuffled", ( - f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" - ) + assert ( + int4_packing_format == "preshuffled" + ), f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" weight = module.weight group_size = 128 block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) @@ -1682,13 +1669,13 @@ def _input_activation_quant_func_fp8( """This function is used to quantize the input activation tensor for an aqt_float variant. If scale is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. """ - assert zero_point is None, ( - "Zero point is not supported for dynamic FP8 quantization" - ) + assert ( + zero_point is None + ), "Zero point is not supported for dynamic FP8 quantization" if isinstance(activation_granularity, PerRow): - assert x.dtype == torch.bfloat16, ( - "PerRow quantization only works for bfloat16 precision input activation" - ) + assert ( + x.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input activation" block_size = get_block_size(x.shape, activation_granularity) if scale is None: @@ -1700,9 +1687,9 @@ def _input_activation_quant_func_fp8( _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: - assert isinstance(activation_granularity, PerTensor), ( - "Static quantization only supports PerTensor granularity" - ) + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" activation = to_affine_quantized_floatx_static( input_float=x, block_size=block_size, @@ -1810,9 +1797,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): return weight if isinstance(weight_granularity, PerRow): - assert weight.dtype == torch.bfloat16, ( - "PerRow quantization only works for bfloat16 precision input weight" - ) + assert ( + weight.dtype == torch.bfloat16 + ), "PerRow quantization only works for bfloat16 precision input weight" if config.version == 1: warnings.warn( @@ -1865,9 +1852,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" - ) + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1970,9 +1957,9 @@ def __post_init__(self): def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig ): - assert is_sm_at_least_89() or is_MI300(), ( - "Float8 static activation quantization is only supported on CUDA 8.9 and above" - ) + assert ( + is_sm_at_least_89() or is_MI300() + ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" scale = config.scale activation_dtype = config.activation_dtype @@ -1984,9 +1971,9 @@ def _float8_static_activation_float8_weight_transform( weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) - assert isinstance(activation_granularity, PerTensor), ( - "Static quantization only supports PerTensor granularity" - ) + assert isinstance( + activation_granularity, PerTensor + ), "Static quantization only supports PerTensor granularity" if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently @@ -2181,22 +2168,20 @@ class IntxWeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig") - assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( - f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" - ) - assert isinstance(self.granularity, (PerAxis, PerGroup)), ( - f"granularity must be PerAxis or PerGroup, but got {self.granularity}" - ) + assert self.weight_dtype in [ + getattr(torch, f"int{b}") for b in range(1, 9) + ], f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + assert isinstance( + self.granularity, (PerAxis, PerGroup) + ), f"granularity must be PerAxis or PerGroup, but got {self.granularity}" if isinstance(self.granularity, PerAxis): - assert self.granularity.axis == 0, ( - f"axis must be 0 with PerAxis, but got {self.granularity.axis}" - ) + assert ( + self.granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {self.granularity.axis}" assert self.mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, - ], ( - f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" - ) + ], f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" def _intx_weight_only_quantize_tensor(weight, config): @@ -2207,15 +2192,15 @@ def _intx_weight_only_quantize_tensor(weight, config): layout = config.layout intx_packing_format = config.intx_packing_format - assert weight.dim() == 2, ( - f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" - ) + assert ( + weight.dim() == 2 + ), f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" if isinstance(granularity, PerGroup): group_size = granularity.group_size elif isinstance(granularity, PerAxis): - assert granularity.axis == 0, ( - f"axis must be 0 with PerAxis, but got {granularity.axis}" - ) + assert ( + granularity.axis == 0 + ), f"axis must be 0 with PerAxis, but got {granularity.axis}" group_size = weight.shape[-1] else: raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") @@ -2374,8 +2359,51 @@ def _module_fqn_to_config_handler( @dataclass class ModuleOrParamFqnToConfig(AOBaseConfig): - module_or_param_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( - default_factory=dict + """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). + + This extends the functionality of ModuleFqnToConfig to support parameter-level quantization configurations + in addition to module-level configurations. It allows for fine-grained control over quantization by + specifying configurations for individual parameters or modules using regex pattern matching on their FQNs. + + Args: + module_or_param_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping + regex patterns (as strings) to quantization configurations. The patterns are matched against the + fully qualified names of modules and parameters. If a pattern matches multiple items, the + configuration is applied to all matches. Use None as the config value to skip quantization for + matching items. + + Example:: + + import torch.nn as nn + from collections import OrderedDict + from torchao.quantization.quant_api import ModuleOrParamFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig + + # Create a model + model = nn.Sequential( + nn.Linear(10, 20, bias=True), # Will be "0.weight" and "0.bias" + nn.Linear(20, 5, bias=True), # Will be "1.weight" and "1.bias" + ) + + # Configure different quantization for different parameters + config = ModuleOrParamFqnToConfig( + module_or_param_fqn_to_config=OrderedDict([ + (r"0\.weight", Int4WeightOnlyConfig()), # 4-bit for first layer weight + (r"1\.weight", Int8WeightOnlyConfig()), # 8-bit for second layer weight + (r".*\.bias", None), # Skip bias quantization + ]) + ) + + # Apply quantization + quantize_(model, config) + + Note: + - The order of patterns in the OrderedDict matters as the first matching pattern is applied + - Regex patterns allow for flexible matching (e.g., r".*\.weight" matches all weight parameters) + - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization + """ + + module_or_param_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( + default_factory=OrderedDict ) def __post_init__(self): @@ -2383,33 +2411,89 @@ def __post_init__(self): @property def module_fqn_to_config(self): + """Compatibility property to maintain interface consistency with ModuleFqnToConfig.""" return self.module_or_param_fqn_to_config def _param_fqn_to_config_handler( mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig ): + """Apply parameter-specific quantization configurations based on fully qualified name pattern matching. + + This function processes parameters within a module and applies quantization configurations + when the parameter's fully qualified name matches patterns defined in the config. + + Args: + mod_containing_param (torch.nn.Module): The module containing parameters to be processed. + fqn (str): The fully qualified name of the module containing the parameters. + config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns mapped + to quantization configurations. + + Returns: + torch.nn.Module: The modified module with quantized parameters. + + Note: + - Only processes top-level parameters (those directly accessible as module attributes) + - Skips parameters that are already TorchAOBaseTensor instances to avoid double quantization + - Uses the first matching pattern for each parameter + - Sets quantized parameters as non-differentiable (requires_grad=False) + + Raises: + NotImplementedError: If a configuration type doesn't have a registered parameter handler. + """ for name, param in list(mod_containing_param.named_parameters()): # check to see if top level param if name in dir(mod_containing_param): for pattern, param_config in config.module_or_param_fqn_to_config.items(): - if pattern == f"{fqn}.{name}" and not isinstance(param, TorchAOBaseTensor): + if re.search(pattern, f"{fqn}.{name}") and not isinstance( + param, TorchAOBaseTensor + ): param_config_type = type(param_config) if param_config_type in _QUANTIZE_CONFIG_PARAM_HANDLER: handler = _QUANTIZE_CONFIG_PARAM_HANDLER[param_config_type] new_param = handler(param, param_config) setattr(mod_containing_param, name, new_param) else: - raise NotImplementedError(f"Parameter quantization for {param_config_type} not supported currently!") + raise NotImplementedError( + f"Parameter quantization for {param_config_type} not supported currently!" + ) return mod_containing_param def select_module_if_fqn_in_pattern(mod, fqn, config): + """Check if a module should be selected for quantization based on parameter FQN pattern matching. + + This function determines whether a module should be processed for parameter-level quantization + by checking if any of its top-level parameters match the patterns defined in the configuration. + + Args: + mod (torch.nn.Module): The module to check for parameter pattern matches. + fqn (str): The fully qualified name of the module. + config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns for + parameter quantization. + + Returns: + bool: True if any of the module's parameters match patterns in the configuration, + False otherwise. + + Note: + - Only checks top-level parameters (those directly accessible as module attributes) + - Uses the first pattern match found to determine selection + - The function returns immediately upon finding the first match + + Example:: + + # Given a module with parameters "weight" and "bias" and FQN "layer1" + # and config patterns [".*\.weight", ".*\.bias"] + # This would return True because "layer1.weight" matches ".*\.weight" + """ for name, _ in mod.named_parameters(): - if "." not in name: # only want attribute parameters - if f"{fqn}.{name}" in config.module_or_param_fqn_to_config: - return True + if name in dir(mod): + for pattern in config.module_or_param_fqn_to_config: + if re.search(pattern, f"{fqn}.{name}"): + return True + return False torch.serialization.add_safe_globals( From f04e1a07a0a070ee7b9df9295d4278b773500afb Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 2 Oct 2025 19:51:28 -0700 Subject: [PATCH 18/47] ruff --- torchao/quantization/quant_api.py | 166 ++++++++++++++++-------------- 1 file changed, 88 insertions(+), 78 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 27e6b970cd..87d42e6171 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -28,10 +28,12 @@ Dict, List, Optional, - OrderedDict as OrderedDictType, Tuple, Union, ) +from typing import ( + OrderedDict as OrderedDictType, +) import torch import torch.nn as nn @@ -54,24 +56,24 @@ QDQLayout, SemiSparseLayout, TensorCoreTiledLayout, + UintxLayout, to_affine_quantized_floatx, to_affine_quantized_floatx_static, to_affine_quantized_intx, to_marlinqqq_quantized_intx, - UintxLayout, ) from torchao.dtypes.uintx.packed_linear_int8_dynamic_activation_intx_weight_layout import ( - make_packed_linear_int8_dynamic_activation_intx_weight_tensor, Target, + make_packed_linear_int8_dynamic_activation_intx_weight_tensor, ) from torchao.dtypes.utils import Layout from torchao.float8.config import e4m3_dtype, e5m2_dtype from torchao.float8.float8_linear import Float8Linear from torchao.float8.inference import ( - _check_hardware_support, - _normalize_granularity, Float8MMConfig, FP8Granularity, + _check_hardware_support, + _normalize_granularity, ) from torchao.quantization.linear_activation_weight_observed_tensor import ( LinearActivationWeightObservedTensor, @@ -103,14 +105,14 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( + TorchAOBaseTensor, _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, - TorchAOBaseTensor, ) -from .autoquant import autoquant, AutoQuantizableLinearWeight +from .autoquant import AutoQuantizableLinearWeight, autoquant from .GPTQ import Int4WeightOnlyGPTQQuantizer from .granularity import Granularity, PerAxis, PerGroup, PerRow, PerTensor from .linear_activation_quantized_tensor import ( @@ -122,8 +124,8 @@ from .quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, - quantize_affine, ZeroPointDomain, + quantize_affine, ) from .subclass import QuantizedLinearWeightBase from .unified import Quantizer, TwoStepQuantizer @@ -776,28 +778,34 @@ def __post_init__(self): torch._C._log_api_usage_once( "torchao.quantization.Int8DynamicActivationIntxWeightConfig" ) - assert self.weight_dtype in [ - getattr(torch, f"int{b}") for b in range(1, 9) - ], f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" - assert isinstance( - self.weight_granularity, (PerAxis, PerGroup) - ), f"weight_granularity must be PerAxis or PerGroup, but got {self.weight_granularity}" + assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( + f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + ) + assert isinstance(self.weight_granularity, (PerAxis, PerGroup)), ( + f"weight_granularity must be PerAxis or PerGroup, but got {self.weight_granularity}" + ) if isinstance(self.weight_granularity, PerAxis): - assert ( - self.weight_granularity.axis == 0 - ), f"axis must be 0, but got {self.weight_granularity.axis}" + assert self.weight_granularity.axis == 0, ( + f"axis must be 0, but got {self.weight_granularity.axis}" + ) assert self.weight_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, MappingType.SYMMETRIC_NO_CLIPPING_ERR, - ], f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" + ], ( + f"weight_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC or MappingType.SYMMETRIC_NO_CLIPPING_ERR, but got {self.weight_mapping_type}" + ) assert self.act_mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, - ], f"act_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.act_mapping_type}" + ], ( + f"act_mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.act_mapping_type}" + ) assert isinstance( self.layout, (PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout) - ), f"layout must be PackedLinearInt8DynamicActivationIntxWeightLayout or QDQLayout, but got {self.layout}" + ), ( + f"layout must be PackedLinearInt8DynamicActivationIntxWeightLayout or QDQLayout, but got {self.layout}" + ) if isinstance(self.layout, PackedLinearInt8DynamicActivationIntxWeightLayout): if self.layout.target in [Target.AUTO, Target.KLEIDIAI, Target.ATEN]: @@ -821,15 +829,15 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): layout = config.layout intx_packing_format = config.intx_packing_format - assert ( - weight.dim() == 2 - ), f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" + assert weight.dim() == 2, ( + f"Int8DynamicActivationIntxWeightConfig only works for 2-d Tensor, got: {weight.dim()}" + ) if isinstance(weight_granularity, PerGroup): group_size = weight_granularity.group_size elif isinstance(weight_granularity, PerAxis): - assert ( - weight_granularity.axis == 0 - ), f"axis must be 0 with PerAxis, but got {weight_granularity.axis}" + assert weight_granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {weight_granularity.axis}" + ) group_size = weight.shape[-1] else: raise ValueError( @@ -915,9 +923,9 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config): elif isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout): # PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization # fused with the kernel and it should not be applied separately - assert ( - act_mapping_type == MappingType.ASYMMETRIC - ), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" + assert act_mapping_type == MappingType.ASYMMETRIC, ( + "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC" + ) data, scale, zero_point = weight.tensor_impl.get_plain() groups_per_row = weight.shape[-1] // group_size scale = scale.reshape(-1, groups_per_row) @@ -1218,16 +1226,16 @@ def _int4_weight_only_quantize_tensor(weight, config): ) # nonlocal zero_point_domain - assert ( - type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys() - ), f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + assert type(layout) in LAYOUT_TO_ZERO_POINT_DOMAIN.keys(), ( + f"Only support layout: {LAYOUT_TO_ZERO_POINT_DOMAIN.keys()}" + ) if zero_point_domain == ZeroPointDomain.NONE: # the first value is the default one zero_point_domain = LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)][0] else: - assert ( - zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)] - ), f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" + assert zero_point_domain in LAYOUT_TO_ZERO_POINT_DOMAIN[type(layout)], ( + f"Layout only support {LAYOUT_TO_ZERO_POINT_DOMAIN[layout]}" + ) if zero_point_domain == ZeroPointDomain.INT and isinstance(layout, Int4XPULayout): zero_point_dtype = torch.int32 @@ -1242,9 +1250,9 @@ def _int4_weight_only_quantize_tensor(weight, config): # we should consider moving this logic somewhere else. if isinstance(layout, MarlinSparseLayout): mapping_type = MappingType.SYMMETRIC - assert ( - group_size == 128 or group_size == weight.shape[-1] - ), f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + assert group_size == 128 or group_size == weight.shape[-1], ( + f"MarlinSparseLayout only supports 128 group size or per channel quantization, got {group_size}" + ) new_weight = to_affine_quantized_intx( weight, @@ -1304,9 +1312,9 @@ def _float8_dynamic_activation_int4_weight_transform( ) int4_packing_format = config.int4_packing_format - assert ( - int4_packing_format == "preshuffled" - ), f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" + assert int4_packing_format == "preshuffled", ( + f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}" + ) weight = module.weight group_size = 128 block_size = tuple([1 for _ in range(weight.ndim - 1)] + [group_size]) @@ -1669,13 +1677,13 @@ def _input_activation_quant_func_fp8( """This function is used to quantize the input activation tensor for an aqt_float variant. If scale is not provided it will be dynamically calculate the scales otherwise it will use the provided scale. """ - assert ( - zero_point is None - ), "Zero point is not supported for dynamic FP8 quantization" + assert zero_point is None, ( + "Zero point is not supported for dynamic FP8 quantization" + ) if isinstance(activation_granularity, PerRow): - assert ( - x.dtype == torch.bfloat16 - ), "PerRow quantization only works for bfloat16 precision input activation" + assert x.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input activation" + ) block_size = get_block_size(x.shape, activation_granularity) if scale is None: @@ -1687,9 +1695,9 @@ def _input_activation_quant_func_fp8( _layout=Float8Layout(mm_config=None), # Config is stored on weight ) else: - assert isinstance( - activation_granularity, PerTensor - ), "Static quantization only supports PerTensor granularity" + assert isinstance(activation_granularity, PerTensor), ( + "Static quantization only supports PerTensor granularity" + ) activation = to_affine_quantized_floatx_static( input_float=x, block_size=block_size, @@ -1797,9 +1805,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): return weight if isinstance(weight_granularity, PerRow): - assert ( - weight.dtype == torch.bfloat16 - ), "PerRow quantization only works for bfloat16 precision input weight" + assert weight.dtype == torch.bfloat16, ( + "PerRow quantization only works for bfloat16 precision input weight" + ) if config.version == 1: warnings.warn( @@ -1852,9 +1860,9 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig ): - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" + ) if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() @@ -1957,9 +1965,9 @@ def __post_init__(self): def _float8_static_activation_float8_weight_transform( module: torch.nn.Module, config: Float8StaticActivationFloat8WeightConfig ): - assert ( - is_sm_at_least_89() or is_MI300() - ), "Float8 static activation quantization is only supported on CUDA 8.9 and above" + assert is_sm_at_least_89() or is_MI300(), ( + "Float8 static activation quantization is only supported on CUDA 8.9 and above" + ) scale = config.scale activation_dtype = config.activation_dtype @@ -1971,9 +1979,9 @@ def _float8_static_activation_float8_weight_transform( weight = module.weight activation_granularity, weight_granularity = _normalize_granularity(granularity) - assert isinstance( - activation_granularity, PerTensor - ), "Static quantization only supports PerTensor granularity" + assert isinstance(activation_granularity, PerTensor), ( + "Static quantization only supports PerTensor granularity" + ) if not _fp8_mm_compat(weight): # TODO(future PR): this should really throw an exception instead of silently @@ -2168,20 +2176,22 @@ class IntxWeightOnlyConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.IntxWeightOnlyConfig") - assert self.weight_dtype in [ - getattr(torch, f"int{b}") for b in range(1, 9) - ], f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" - assert isinstance( - self.granularity, (PerAxis, PerGroup) - ), f"granularity must be PerAxis or PerGroup, but got {self.granularity}" + assert self.weight_dtype in [getattr(torch, f"int{b}") for b in range(1, 9)], ( + f"weight_dtype must be torch.intx, where 1 <= x <= 8, but got {self.weight_dtype}" + ) + assert isinstance(self.granularity, (PerAxis, PerGroup)), ( + f"granularity must be PerAxis or PerGroup, but got {self.granularity}" + ) if isinstance(self.granularity, PerAxis): - assert ( - self.granularity.axis == 0 - ), f"axis must be 0 with PerAxis, but got {self.granularity.axis}" + assert self.granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {self.granularity.axis}" + ) assert self.mapping_type in [ MappingType.ASYMMETRIC, MappingType.SYMMETRIC, - ], f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" + ], ( + f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}" + ) def _intx_weight_only_quantize_tensor(weight, config): @@ -2192,15 +2202,15 @@ def _intx_weight_only_quantize_tensor(weight, config): layout = config.layout intx_packing_format = config.intx_packing_format - assert ( - weight.dim() == 2 - ), f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + assert weight.dim() == 2, ( + f"IntxWeightOnlyConfig only works for 2-d Tensor, got: {weight.dim()}" + ) if isinstance(granularity, PerGroup): group_size = granularity.group_size elif isinstance(granularity, PerAxis): - assert ( - granularity.axis == 0 - ), f"axis must be 0 with PerAxis, but got {granularity.axis}" + assert granularity.axis == 0, ( + f"axis must be 0 with PerAxis, but got {granularity.axis}" + ) group_size = weight.shape[-1] else: raise ValueError(f"granularity must be PerGroup or PerAxis, got {granularity}") From d706e29e0892ebfa7bdce7c59eac0ca8920fe4f6 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 06:49:53 -0700 Subject: [PATCH 19/47] update --- test/prototype/test_parameter.py | 53 +++++++++++++++++++++++++++----- 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index db8550ba63..0526f245f7 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -10,6 +10,7 @@ from torch import nn from torch.testing._internal.common_utils import TestCase +from torchao.core.config import AOBaseConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, ) @@ -18,7 +19,6 @@ PerRow, quantize_, ) -from torchao.core.config import AOBaseConfig from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor from torchao.utils import is_fbcode, is_sm_at_least_90 @@ -65,6 +65,7 @@ def test_quantize_param_fqn_exact(self): def test_quantize_param_and_module_fqn(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + from torchao.quantization import PerTensor config = AutoConfig.from_pretrained( @@ -91,6 +92,7 @@ def test_quantize_param_and_module_fqn(self): def test_quantize_param_and_module_fqn_regex(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + from torchao.quantization import PerTensor config = AutoConfig.from_pretrained( @@ -100,7 +102,7 @@ def test_quantize_param_and_module_fqn_regex(self): torch.randn(16, 128, config.hidden_size).cuda().bfloat16() quant_config = ModuleOrParamFqnToConfig( { - ".*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( @@ -116,30 +118,67 @@ def test_quantize_param_and_module_fqn_regex(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) + assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + + def test_quantize_modle_param_double_specified(self): + from transformers import AutoConfig + + from torchao.quantization import PerTensor + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = ( + nn.Sequential( + nn.Linear(128, 128), + ) + .to(torch.bfloat16) + .cuda() + ) + input_tensor = torch.randn(16, 128).cuda().bfloat16() + quant_config = ModuleOrParamFqnToConfig( + { + "0.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + "0": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + model(input_tensor) + + assert isinstance(model[0].weight, Float8Tensor) + assert model[0].weight.scale.numel() == 1 def test_unsupported_param_config_raises_not_implemented_error(self): """Test that using an unsupported parameter config raises NotImplementedError.""" from dataclasses import dataclass - + # Create a custom config that doesn't have a registered parameter handler @dataclass class UnsupportedParamConfig(AOBaseConfig): some_value: int = 42 - + # Create a simple model model = nn.Linear(10, 5).cuda().bfloat16() - + # Create config with unsupported parameter handler quant_config = ModuleOrParamFqnToConfig( { "weight": UnsupportedParamConfig(), } ) - + # This should raise NotImplementedError with self.assertRaises(NotImplementedError) as context: quantize_(model, quant_config) - + # Check that the error message contains the expected text self.assertIn("Parameter quantization for", str(context.exception)) self.assertIn("not supported currently", str(context.exception)) From 9d3cfe48794b4e88260bfc6c700d72aebdb5a37b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 07:04:44 -0700 Subject: [PATCH 20/47] update --- torchao/quantization/quant_api.py | 68 ++++++++++++------------ torchao/quantization/transform_module.py | 22 +++----- 2 files changed, 42 insertions(+), 48 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 87d42e6171..ab07bda689 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -97,7 +97,7 @@ ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, - _QUANTIZE_CONFIG_PARAM_HANDLER, + _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER, register_quantize_module_handler, register_quantize_param_handler, ) @@ -524,13 +524,13 @@ def quantize_( filter_fn = _is_linear if filter_fn is None else filter_fn if isinstance(config, ModuleOrParamFqnToConfig): - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _module_fqn_to_config_handler, - filter_fn, - device=device, - extra_args=(config,), - ) + # _replace_with_custom_fn_if_matches_filter_with_name( + # model, + # _module_fqn_to_config_handler, + # filter_fn, + # device=device, + # extra_args=(config,), + # ) _replace_with_custom_fn_if_matches_filter_with_name( model, _param_fqn_to_config_handler, @@ -2398,8 +2398,7 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): config = ModuleOrParamFqnToConfig( module_or_param_fqn_to_config=OrderedDict([ (r"0\.weight", Int4WeightOnlyConfig()), # 4-bit for first layer weight - (r"1\.weight", Int8WeightOnlyConfig()), # 8-bit for second layer weight - (r".*\.bias", None), # Skip bias quantization + (r"re:1\.weight", Int8WeightOnlyConfig()), # 8-bit for second layer weight, matching using regex ]) ) @@ -2452,15 +2451,20 @@ def _param_fqn_to_config_handler( NotImplementedError: If a configuration type doesn't have a registered parameter handler. """ for name, param in list(mod_containing_param.named_parameters()): - # check to see if top level param - if name in dir(mod_containing_param): + # check to see if top level param and hasn't been modified previously by module flow + if name in dir(mod_containing_param) and not isinstance( + param, TorchAOBaseTensor + ): for pattern, param_config in config.module_or_param_fqn_to_config.items(): - if re.search(pattern, f"{fqn}.{name}") and not isinstance( - param, TorchAOBaseTensor + full_param_fqn = f"{fqn}.{name}" + if (pattern == full_param_fqn) or ( + pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") ): param_config_type = type(param_config) - if param_config_type in _QUANTIZE_CONFIG_PARAM_HANDLER: - handler = _QUANTIZE_CONFIG_PARAM_HANDLER[param_config_type] + if param_config_type in _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: + handler = _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[ + param_config_type + ] new_param = handler(param, param_config) setattr(mod_containing_param, name, new_param) else: @@ -2471,37 +2475,33 @@ def _param_fqn_to_config_handler( return mod_containing_param -def select_module_if_fqn_in_pattern(mod, fqn, config): - """Check if a module should be selected for quantization based on parameter FQN pattern matching. +def select_module_if_fqn_in_pattern( + mod: nn.Module, fqn: str, config: ModuleOrParamFqnToConfig +): + """Check if a module should be selected for quantization. This function determines whether a module should be processed for parameter-level quantization - by checking if any of its top-level parameters match the patterns defined in the configuration. + by checking if any of its top-level parameters match the patterns defined in ModuleOrParamFqnToConfig. + + We only check top-level parameters (those directly accessible as module attributes). Args: mod (torch.nn.Module): The module to check for parameter pattern matches. fqn (str): The fully qualified name of the module. - config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns for + config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns or raw FQNs for parameter quantization. Returns: bool: True if any of the module's parameters match patterns in the configuration, False otherwise. - - Note: - - Only checks top-level parameters (those directly accessible as module attributes) - - Uses the first pattern match found to determine selection - - The function returns immediately upon finding the first match - - Example:: - - # Given a module with parameters "weight" and "bias" and FQN "layer1" - # and config patterns [".*\.weight", ".*\.bias"] - # This would return True because "layer1.weight" matches ".*\.weight" """ - for name, _ in mod.named_parameters(): - if name in dir(mod): + for name, param in mod.named_parameters(): + if name in dir(mod) and not isinstance(param, TorchAOBaseTensor): for pattern in config.module_or_param_fqn_to_config: - if re.search(pattern, f"{fqn}.{name}"): + full_param_fqn = f"{fqn}.{name}" + if (pattern == full_param_fqn) or ( + pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") + ): return True return False diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index d59790d3ad..d06610f54c 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -15,11 +15,12 @@ Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], ] = {} -_QUANTIZE_CONFIG_PARAM_HANDLER: Dict[ +_QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: Dict[ Type[AOBaseConfig], Callable[[torch.nn.Parameter, AOBaseConfig], torch.nn.Parameter], ] = {} + def register_quantize_module_handler(config_type): """ A decorator to register a transform function to map from a workflow @@ -55,31 +56,24 @@ def decorator(func): return decorator -def register_quantize_param_handler(config_type): + +def register_quantize_tensor_handler(config_type): """ A decorator to register a transform function to map from a workflow configuration (child of `AOBaseConfig`) to a function that transforms - a `torch.nn.Parameter` according to the specified configuration. - - For example:: - - # user facing code - class WorkflowFooConfig(AOBaseConfig): ... - # configuration for workflow `Foo` is defined here - bar = 'baz' + a `torch.Tensor` according to the specified configuration. - # non user facing code - @resgister_quantize + The wrapped function will be extended to support `torch.nn.Parameter` as well. """ + @functools.wraps(config_type) def decorator(func): - def func_supporting_param(tensor_or_param, config): if type(tensor_or_param) is torch.nn.Parameter: return torch.nn.Parameter(func(tensor_or_param, config)) return func(tensor_or_param, config) - _QUANTIZE_CONFIG_PARAM_HANDLER[config_type] = func_supporting_param + _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[config_type] = func_supporting_param return func # needed to make the functions usable externally return decorator From 768eb603f5feede51ce410866d20117be5018561 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 07:14:10 -0700 Subject: [PATCH 21/47] update --- test/prototype/test_parameter.py | 30 ++++++++---------------------- torchao/quantization/quant_api.py | 18 +++++++++--------- 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py index 0526f245f7..b4fc5e808f 100644 --- a/test/prototype/test_parameter.py +++ b/test/prototype/test_parameter.py @@ -13,6 +13,7 @@ from torchao.core.config import AOBaseConfig from torchao.quantization import ( Float8DynamicActivationFloat8WeightConfig, + PerTensor, ) from torchao.quantization.quant_api import ( ModuleOrParamFqnToConfig, @@ -45,7 +46,6 @@ def test_quantize_param_fqn_exact(self): "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - torch.randn(16, 128, config.hidden_size).cuda().bfloat16() quant_config = ModuleOrParamFqnToConfig( { @@ -66,13 +66,10 @@ def test_quantize_param_and_module_fqn(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - from torchao.quantization import PerTensor - config = AutoConfig.from_pretrained( "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - torch.randn(16, 128, config.hidden_size).cuda().bfloat16() quant_config = ModuleOrParamFqnToConfig( { "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( @@ -93,13 +90,10 @@ def test_quantize_param_and_module_fqn_regex(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - from torchao.quantization import PerTensor - config = AutoConfig.from_pretrained( "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - torch.randn(16, 128, config.hidden_size).cuda().bfloat16() quant_config = ModuleOrParamFqnToConfig( { "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( @@ -121,13 +115,6 @@ def test_quantize_param_and_module_fqn_regex(self): assert model.shared_expert.gate_proj.weight.scale.numel() == 1 def test_quantize_modle_param_double_specified(self): - from transformers import AutoConfig - - from torchao.quantization import PerTensor - - config = AutoConfig.from_pretrained( - "unsloth/Llama-4-Scout-17B-16E-Instruct" - ).text_config model = ( nn.Sequential( nn.Linear(128, 128), @@ -135,15 +122,15 @@ def test_quantize_modle_param_double_specified(self): .to(torch.bfloat16) .cuda() ) - input_tensor = torch.randn(16, 128).cuda().bfloat16() quant_config = ModuleOrParamFqnToConfig( { - "0.weight": Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), - ), + # only this config should be applied, as module fqn takes precedence "0": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), + "0.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), } ) @@ -151,10 +138,9 @@ def test_quantize_modle_param_double_specified(self): model, quant_config, ) - model(input_tensor) assert isinstance(model[0].weight, Float8Tensor) - assert model[0].weight.scale.numel() == 1 + assert model[0].weight.scale.numel() == 128 def test_unsupported_param_config_raises_not_implemented_error(self): """Test that using an unsupported parameter config raises NotImplementedError.""" @@ -166,12 +152,12 @@ class UnsupportedParamConfig(AOBaseConfig): some_value: int = 42 # Create a simple model - model = nn.Linear(10, 5).cuda().bfloat16() + model = nn.Sequential(nn.Linear(10, 5).cuda().bfloat16()) # Create config with unsupported parameter handler quant_config = ModuleOrParamFqnToConfig( { - "weight": UnsupportedParamConfig(), + "0.weight": UnsupportedParamConfig(), } ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index ab07bda689..87716f1339 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -99,7 +99,7 @@ _QUANTIZE_CONFIG_HANDLER, _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER, register_quantize_module_handler, - register_quantize_param_handler, + register_quantize_tensor_handler, ) from torchao.quantization.weight_tensor_linear_activation_quantization import ( to_weight_tensor_with_linear_activation_quantization_metadata, @@ -524,13 +524,13 @@ def quantize_( filter_fn = _is_linear if filter_fn is None else filter_fn if isinstance(config, ModuleOrParamFqnToConfig): - # _replace_with_custom_fn_if_matches_filter_with_name( - # model, - # _module_fqn_to_config_handler, - # filter_fn, - # device=device, - # extra_args=(config,), - # ) + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _module_fqn_to_config_handler, + filter_fn, + device=device, + extra_args=(config,), + ) _replace_with_custom_fn_if_matches_filter_with_name( model, _param_fqn_to_config_handler, @@ -1785,7 +1785,7 @@ def __post_init__(self): ) -@register_quantize_param_handler(Float8DynamicActivationFloat8WeightConfig) +@register_quantize_tensor_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype From d55081a9b9dc6e3b25cbeb9906fc8e6b2cead767 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 07:20:25 -0700 Subject: [PATCH 22/47] update type signature --- torchao/quantization/quant_api.py | 78 ++++++++++++++++--------------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index c3d93a9bc5..adbe3f8c40 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2398,44 +2398,6 @@ def _fpx_weight_only_transform( return module -@dataclass -class ModuleFqnToConfig(AOBaseConfig): - """Per module configurations for torchao quantize_ API - - Args: - `module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from - the fully qualified name of module to the AOBaseConfig that we want to apply to the module. - Also has a special key: "_default", if "_default" is present in the dictionary, - the config for "_default" will be applied to all the remaining modules that does not have - per module configuration specified. - """ - - module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( - default_factory=dict - ) - - def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig") - - -def _module_fqn_to_config_handler( - module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig -): - c = None - if module_fqn in config.module_fqn_to_config: - # Maybe: we can add module type specific config in the future, in needed - c = config.module_fqn_to_config[module_fqn] - else: - # fallback to use default if no module specific config is provided - c = config.module_fqn_to_config.get("_default", None) - - if c is not None: - handler = _QUANTIZE_CONFIG_HANDLER[type(c)] - return handler(module, c) - - return module - - @dataclass class ModuleOrParamFqnToConfig(AOBaseConfig): """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). @@ -2575,6 +2537,46 @@ def select_module_if_fqn_in_pattern( return False +@dataclass +class ModuleFqnToConfig(AOBaseConfig): + """Per module configurations for torchao quantize_ API + + Args: + `module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from + the fully qualified name of module to the AOBaseConfig that we want to apply to the module. + Also has a special key: "_default", if "_default" is present in the dictionary, + the config for "_default" will be applied to all the remaining modules that does not have + per module configuration specified. + """ + + module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( + default_factory=dict + ) + + def __post_init__(self): + torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig") + + +def _module_fqn_to_config_handler( + module: torch.nn.Module, + module_fqn: str, + config: Union[ModuleFqnToConfig, ModuleOrParamFqnToConfig], +): + c = None + if module_fqn in config.module_fqn_to_config: + # Maybe: we can add module type specific config in the future, in needed + c = config.module_fqn_to_config[module_fqn] + else: + # fallback to use default if no module specific config is provided + c = config.module_fqn_to_config.get("_default", None) + + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + + return module + + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant, From 41fdbe9f0a2e2723b9cb88222c1dbadbd5f2795c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 07:29:51 -0700 Subject: [PATCH 23/47] move tests to test_quant_api --- test/prototype/test_parameter.py | 175 ---------------------------- test/quantization/test_quant_api.py | 154 ++++++++++++++++++++++++ 2 files changed, 154 insertions(+), 175 deletions(-) delete mode 100644 test/prototype/test_parameter.py diff --git a/test/prototype/test_parameter.py b/test/prototype/test_parameter.py deleted file mode 100644 index b4fc5e808f..0000000000 --- a/test/prototype/test_parameter.py +++ /dev/null @@ -1,175 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD 3-Clause license found in the -# LICENSE file in the root directory of this source tree. -import logging -import unittest - -import torch -from torch import nn -from torch.testing._internal.common_utils import TestCase - -from torchao.core.config import AOBaseConfig -from torchao.quantization import ( - Float8DynamicActivationFloat8WeightConfig, - PerTensor, -) -from torchao.quantization.quant_api import ( - ModuleOrParamFqnToConfig, - PerRow, - quantize_, -) -from torchao.quantization.quantize_.workflows.float8.float8_tensor import Float8Tensor -from torchao.utils import is_fbcode, is_sm_at_least_90 - -if not is_fbcode(): - pass - -logging.basicConfig( - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", level=logging.INFO -) - - -@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") -@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") -@unittest.skipIf( - is_fbcode(), - "Skipping the test in fbcode for now, not sure how to download from transformers", -) -class TestQuantizeFQNParam(TestCase): - def test_quantize_param_fqn_exact(self): - from transformers import AutoConfig - from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - - config = AutoConfig.from_pretrained( - "unsloth/Llama-4-Scout-17B-16E-Instruct" - ).text_config - model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - - quant_config = ModuleOrParamFqnToConfig( - { - "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - } - ) - - quantize_( - model, - quant_config, - ) - - assert isinstance(model.experts.gate_up_proj, Float8Tensor) - - def test_quantize_param_and_module_fqn(self): - from transformers import AutoConfig - from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - - config = AutoConfig.from_pretrained( - "unsloth/Llama-4-Scout-17B-16E-Instruct" - ).text_config - model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - quant_config = ModuleOrParamFqnToConfig( - { - "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), - ), - } - ) - - quantize_( - model, - quant_config, - ) - - def test_quantize_param_and_module_fqn_regex(self): - from transformers import AutoConfig - from transformers.models.llama4.modeling_llama4 import Llama4TextMoe - - config = AutoConfig.from_pretrained( - "unsloth/Llama-4-Scout-17B-16E-Instruct" - ).text_config - model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - quant_config = ModuleOrParamFqnToConfig( - { - "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), - ), - } - ) - - quantize_( - model, - quant_config, - ) - - assert isinstance(model.experts.gate_up_proj, Float8Tensor) - assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) - assert model.shared_expert.gate_proj.weight.scale.numel() == 1 - - def test_quantize_modle_param_double_specified(self): - model = ( - nn.Sequential( - nn.Linear(128, 128), - ) - .to(torch.bfloat16) - .cuda() - ) - quant_config = ModuleOrParamFqnToConfig( - { - # only this config should be applied, as module fqn takes precedence - "0": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - "0.weight": Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), - ), - } - ) - - quantize_( - model, - quant_config, - ) - - assert isinstance(model[0].weight, Float8Tensor) - assert model[0].weight.scale.numel() == 128 - - def test_unsupported_param_config_raises_not_implemented_error(self): - """Test that using an unsupported parameter config raises NotImplementedError.""" - from dataclasses import dataclass - - # Create a custom config that doesn't have a registered parameter handler - @dataclass - class UnsupportedParamConfig(AOBaseConfig): - some_value: int = 42 - - # Create a simple model - model = nn.Sequential(nn.Linear(10, 5).cuda().bfloat16()) - - # Create config with unsupported parameter handler - quant_config = ModuleOrParamFqnToConfig( - { - "0.weight": UnsupportedParamConfig(), - } - ) - - # This should raise NotImplementedError - with self.assertRaises(NotImplementedError) as context: - quantize_(model, quant_config) - - # Check that the error message contains the expected text - self.assertIn("Parameter quantization for", str(context.exception)) - self.assertIn("not supported currently", str(context.exception)) - self.assertIn("UnsupportedParamConfig", str(context.exception)) - - -if __name__ == "__main__": - unittest.main() diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index b5ea7bf09a..e2ff8403a2 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -51,6 +51,9 @@ Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, + ModuleOrParamFqnToConfig, + PerRow, + PerTensor, Quantizer, TwoStepQuantizer, UIntXWeightOnlyConfig, @@ -808,5 +811,156 @@ def test_config_deprecation(self): common_utils.instantiate_parametrized_tests(TestQuantFlow) +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") +class TestModuleOrParamFqnToConfig(TestCase): + def test_quantize_param_fqn_exact(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + + quant_config = ModuleOrParamFqnToConfig( + { + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + # Note: Need to import Float8Tensor from the correct location + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + + def test_quantize_param_and_module_fqn(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + quant_config = ModuleOrParamFqnToConfig( + { + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + def test_quantize_param_and_module_fqn_regex(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + quant_config = ModuleOrParamFqnToConfig( + { + "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) + assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + + def test_quantize_modle_param_double_specified(self): + model = ( + torch.nn.Sequential( + torch.nn.Linear(128, 128), + ) + .to(torch.bfloat16) + .cuda() + ) + quant_config = ModuleOrParamFqnToConfig( + { + # only this config should be applied, as module fqn takes precedence + "0": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "0.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( + Float8Tensor, + ) + + assert isinstance(model[0].weight, Float8Tensor) + assert model[0].weight.scale.numel() == 128 + + def test_unsupported_param_config_raises_not_implemented_error(self): + """Test that using an unsupported parameter config raises NotImplementedError.""" + from dataclasses import dataclass + + from torchao.core.config import AOBaseConfig + + # Create a custom config that doesn't have a registered parameter handler + @dataclass + class UnsupportedParamConfig(AOBaseConfig): + some_value: int = 42 + + # Create a simple model + model = torch.nn.Sequential(torch.nn.Linear(10, 5).cuda().bfloat16()) + + # Create config with unsupported parameter handler + quant_config = ModuleOrParamFqnToConfig( + { + "0.weight": UnsupportedParamConfig(), + } + ) + + # This should raise NotImplementedError + with self.assertRaises(NotImplementedError) as context: + quantize_(model, quant_config) + + # Check that the error message contains the expected text + self.assertIn("Parameter quantization for", str(context.exception)) + self.assertIn("not supported currently", str(context.exception)) + self.assertIn("UnsupportedParamConfig", str(context.exception)) + + if __name__ == "__main__": unittest.main() From aa036b9110808d0dacf71a7a6e020cde8e81b446 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 08:07:55 -0700 Subject: [PATCH 24/47] only apply to non-transformed modules --- test/quantization/test_quant_api.py | 3 +- torchao/quantization/quant_api.py | 53 +++++++++++++----------- torchao/quantization/transform_module.py | 4 +- 3 files changed, 34 insertions(+), 26 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index e2ff8403a2..442a64dd54 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -903,6 +903,7 @@ def test_quantize_modle_param_double_specified(self): model = ( torch.nn.Sequential( torch.nn.Linear(128, 128), + torch.nn.Linear(128, 128), ) .to(torch.bfloat16) .cuda() @@ -913,7 +914,7 @@ def test_quantize_modle_param_double_specified(self): "0": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), - "0.weight": Float8DynamicActivationFloat8WeightConfig( + "re:.*weight": Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ), } diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index adbe3f8c40..8c207f903a 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2458,9 +2458,7 @@ def module_fqn_to_config(self): def _param_fqn_to_config_handler( mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig ): - """Apply parameter-specific quantization configurations based on fully qualified name pattern matching. - - This function processes parameters within a module and applies quantization configurations + """This function processes parameters within a module and applies quantization configurations when the parameter's fully qualified name matches patterns defined in the config. Args: @@ -2481,27 +2479,34 @@ def _param_fqn_to_config_handler( Raises: NotImplementedError: If a configuration type doesn't have a registered parameter handler. """ - for name, param in list(mod_containing_param.named_parameters()): - # check to see if top level param and hasn't been modified previously by module flow - if name in dir(mod_containing_param) and not isinstance( - param, TorchAOBaseTensor - ): - for pattern, param_config in config.module_or_param_fqn_to_config.items(): - full_param_fqn = f"{fqn}.{name}" - if (pattern == full_param_fqn) or ( - pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") - ): - param_config_type = type(param_config) - if param_config_type in _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: - handler = _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[ - param_config_type - ] - new_param = handler(param, param_config) - setattr(mod_containing_param, name, new_param) - else: - raise NotImplementedError( - f"Parameter quantization for {param_config_type} not supported currently!" - ) + top_level_named_parameters_list = [ + (name, param) + for name, param in mod_containing_param.named_parameters() + if name in dir(mod_containing_param) + ] + + # return if modified previously by module flow + if any( + isinstance(param, TorchAOBaseTensor) + for _, param in top_level_named_parameters_list + ): + return mod_containing_param + + for name, param in top_level_named_parameters_list: + for pattern, param_config in config.module_or_param_fqn_to_config.items(): + full_param_fqn = f"{fqn}.{name}" + if (pattern == full_param_fqn) or ( + pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") + ): + param_config_type = type(param_config) + if param_config_type in _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: + handler = _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[param_config_type] + new_param = handler(param, param_config) + setattr(mod_containing_param, name, new_param) + else: + raise NotImplementedError( + f"Parameter quantization for {param_config_type} not supported currently!" + ) return mod_containing_param diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index d06610f54c..cb27b4c753 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -70,7 +70,9 @@ def register_quantize_tensor_handler(config_type): def decorator(func): def func_supporting_param(tensor_or_param, config): if type(tensor_or_param) is torch.nn.Parameter: - return torch.nn.Parameter(func(tensor_or_param, config)) + return torch.nn.Parameter( + func(tensor_or_param, config), requires_grad=False + ) return func(tensor_or_param, config) _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[config_type] = func_supporting_param From 9836c31e442abeb69daa3ed1248ca01b90644be3 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 08:20:20 -0700 Subject: [PATCH 25/47] update docstring --- torchao/quantization/quant_api.py | 33 ++++++++----------------------- 1 file changed, 8 insertions(+), 25 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 8c207f903a..96839151c9 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2408,37 +2408,20 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): Args: module_or_param_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping - regex patterns (as strings) to quantization configurations. The patterns are matched against the - fully qualified names of modules and parameters. If a pattern matches multiple items, the - configuration is applied to all matches. Use None as the config value to skip quantization for - matching items. + regex patterns (as strings) to quantization configurations. - Example:: - - import torch.nn as nn - from collections import OrderedDict - from torchao.quantization.quant_api import ModuleOrParamFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig + The patterns can be one of the follows: + (1). fully qualified name (fqn) of module or param + (2). regex of fully qualified name (in python `re` module regex format) or + (3). "_default" - # Create a model - model = nn.Sequential( - nn.Linear(10, 20, bias=True), # Will be "0.weight" and "0.bias" - nn.Linear(20, 5, bias=True), # Will be "1.weight" and "1.bias" - ) - - # Configure different quantization for different parameters - config = ModuleOrParamFqnToConfig( - module_or_param_fqn_to_config=OrderedDict([ - (r"0\.weight", Int4WeightOnlyConfig()), # 4-bit for first layer weight - (r"re:1\.weight", Int8WeightOnlyConfig()), # 8-bit for second layer weight, matching using regex - ]) - ) + When passed this config, `quantize_` will first try to replace all modules in the model, matching the logic of ModuleFqnToConfig. + Then, quantize_ will attempt to replace any parameters specified by the fqn or that match regexs, ignoring modules that have already been transformed by the previous flow (Modules with an existing AOBaseTensor attached) - # Apply quantization - quantize_(model, config) + "_default" is ignored for parameter replacement. Note: - The order of patterns in the OrderedDict matters as the first matching pattern is applied - - Regex patterns allow for flexible matching (e.g., r".*\.weight" matches all weight parameters) - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization """ From 617b1f05dcf87e4cd437d2f67342db102b9c3e39 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 08:22:23 -0700 Subject: [PATCH 26/47] rename func --- torchao/quantization/quant_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 96839151c9..e4c7ec0fb3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -538,7 +538,7 @@ def quantize_( _replace_with_custom_fn_if_matches_filter_with_name( model, _param_fqn_to_config_handler, - partial(select_module_if_fqn_in_pattern, config=config), + partial(select_module_if_top_level_params_match_pattern, config=config), device=device, extra_args=(config,), ) @@ -2494,7 +2494,7 @@ def _param_fqn_to_config_handler( return mod_containing_param -def select_module_if_fqn_in_pattern( +def select_module_if_top_level_params_match_pattern( mod: nn.Module, fqn: str, config: ModuleOrParamFqnToConfig ): """Check if a module should be selected for quantization. From c887989a1d4ef1f75df4eaeffa5f4dd64faf7b2d Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 08:37:56 -0700 Subject: [PATCH 27/47] add to top level API --- torchao/quantization/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index c8774e9426..e5154ad54e 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -61,6 +61,7 @@ Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, + ModuleOrParamFqnToConfig, PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, @@ -162,6 +163,7 @@ "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", "ModuleFqnToConfig", + "ModuleOrParamFqnToConfig", # tensor subclasses "Int4Tensor", "Int4PlainInt32Tensor", From 20dbf11747863a6151a10dd674073bed86ca9e7a Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 09:28:17 -0700 Subject: [PATCH 28/47] remove ModuleFqnToConfig --- torchao/quantization/quant_api.py | 32 +++---------------------------- 1 file changed, 3 insertions(+), 29 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e4c7ec0fb3..1a798f5a99 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -25,7 +25,6 @@ from typing import ( Any, Callable, - Dict, List, Optional, Tuple, @@ -543,15 +542,6 @@ def quantize_( extra_args=(config,), ) return - if isinstance(config, ModuleFqnToConfig): - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _module_fqn_to_config_handler, - filter_fn, - device=device, - extra_args=(config,), - ) - return if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -2525,30 +2515,14 @@ def select_module_if_top_level_params_match_pattern( return False -@dataclass -class ModuleFqnToConfig(AOBaseConfig): - """Per module configurations for torchao quantize_ API - - Args: - `module_fqn_to_config`: Dict[str, Optional[AOBaseConfig]]: a dictionary from - the fully qualified name of module to the AOBaseConfig that we want to apply to the module. - Also has a special key: "_default", if "_default" is present in the dictionary, - the config for "_default" will be applied to all the remaining modules that does not have - per module configuration specified. - """ - - module_fqn_to_config: Dict[str, Optional[AOBaseConfig]] = field( - default_factory=dict - ) - - def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig") +# to maintain BC +ModuleFqnToConfig = ModuleOrParamFqnToConfig def _module_fqn_to_config_handler( module: torch.nn.Module, module_fqn: str, - config: Union[ModuleFqnToConfig, ModuleOrParamFqnToConfig], + config: ModuleOrParamFqnToConfig, ): c = None if module_fqn in config.module_fqn_to_config: From a90a8826a77593849a1cdd4bae104559fcaba47e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Fri, 3 Oct 2025 09:29:46 -0700 Subject: [PATCH 29/47] update --- torchao/quantization/quant_api.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 1a798f5a99..a2e1e12843 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2428,6 +2428,10 @@ def module_fqn_to_config(self): return self.module_or_param_fqn_to_config +# maintain BC +ModuleFqnToConfig = ModuleOrParamFqnToConfig + + def _param_fqn_to_config_handler( mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig ): @@ -2515,10 +2519,6 @@ def select_module_if_top_level_params_match_pattern( return False -# to maintain BC -ModuleFqnToConfig = ModuleOrParamFqnToConfig - - def _module_fqn_to_config_handler( module: torch.nn.Module, module_fqn: str, From ba53e66ac1820d8d1579e84ff5dd74bbcffd8511 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 6 Oct 2025 09:31:30 -0700 Subject: [PATCH 30/47] update logic --- torchao/quantization/quant_api.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index db9302a227..88d6413d36 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2475,7 +2475,7 @@ def _param_fqn_to_config_handler( for pattern, param_config in config.module_or_param_fqn_to_config.items(): full_param_fqn = f"{fqn}.{name}" if (pattern == full_param_fqn) or ( - pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") + pattern.startswith("re:") and re.fullmatch(pattern[3:], f"{fqn}.{name}") ): param_config_type = type(param_config) if param_config_type in _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: @@ -2515,7 +2515,7 @@ def select_module_if_top_level_params_match_pattern( for pattern in config.module_or_param_fqn_to_config: full_param_fqn = f"{fqn}.{name}" if (pattern == full_param_fqn) or ( - pattern[:3] == "re:" and re.search(pattern[3:], f"{fqn}.{name}") + pattern.startswith("re:") and re.fullmatch(pattern[3:], f"{fqn}.{name}") ): return True return False @@ -2527,7 +2527,7 @@ def _module_fqn_to_config_handler( config: ModuleOrParamFqnToConfig, ): c = None - if module_fqn in config.module_fqn_to_config: + if module_fqn in config.module_fqn_to_config and not maybe_module_fqn_pattern.startswith("re:"): # Maybe: we can add module type specific config in the future, in needed c = config.module_fqn_to_config[module_fqn] else: From 45ba181976832c3473efee3ebaf93d7e26dd8455 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 6 Oct 2025 09:36:52 -0700 Subject: [PATCH 31/47] fix ruff --- torchao/quantization/quant_api.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 88d6413d36..01de2fb65b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,6 +21,7 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field +from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union from typing import OrderedDict as OrderedDictType @@ -2417,7 +2418,7 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization """ - module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( + module_or_param_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) @@ -2515,7 +2516,8 @@ def select_module_if_top_level_params_match_pattern( for pattern in config.module_or_param_fqn_to_config: full_param_fqn = f"{fqn}.{name}" if (pattern == full_param_fqn) or ( - pattern.startswith("re:") and re.fullmatch(pattern[3:], f"{fqn}.{name}") + pattern.startswith("re:") + and re.fullmatch(pattern[3:], f"{fqn}.{name}") ): return True return False @@ -2527,7 +2529,8 @@ def _module_fqn_to_config_handler( config: ModuleOrParamFqnToConfig, ): c = None - if module_fqn in config.module_fqn_to_config and not maybe_module_fqn_pattern.startswith("re:"): + # check to see if module_fqn is exact match + if module_fqn in config.module_fqn_to_config: # Maybe: we can add module type specific config in the future, in needed c = config.module_fqn_to_config[module_fqn] else: From 0dc84dfaf792f226e26188831db3e78760666b66 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 6 Oct 2025 11:53:46 -0700 Subject: [PATCH 32/47] fix loading --- torchao/quantization/quant_api.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 01de2fb65b..a38088ef01 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2387,8 +2387,14 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): in addition to module-level configurations. It allows for fine-grained control over quantization by specifying configurations for individual parameters or modules using regex pattern matching on their FQNs. + When passed this config, `quantize_` will first try to replace all modules in the model, matching the logic of ModuleFqnToConfig. + Then, quantize_ will attempt to replace any parameters specified by the fqn or that match regexs, ignoring modules that have already been transformed by the previous flow (Modules with an existing AOBaseTensor attached) + + "_default" is ignored for parameter replacement. + + Args: - module_or_param_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping + module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping regex patterns (as strings) to quantization configurations. The patterns can be one of the follows: @@ -2408,17 +2414,12 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): the modules that we don't want to quantize before hand and configure them to None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) - When passed this config, `quantize_` will first try to replace all modules in the model, matching the logic of ModuleFqnToConfig. - Then, quantize_ will attempt to replace any parameters specified by the fqn or that match regexs, ignoring modules that have already been transformed by the previous flow (Modules with an existing AOBaseTensor attached) - - "_default" is ignored for parameter replacement. - Note: - The order of patterns in the OrderedDict matters as the first matching pattern is applied - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization """ - module_or_param_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( + module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) @@ -2426,9 +2427,9 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.ModuleOrParamFqnToConfig") @property - def module_fqn_to_config(self): + def module_or_param_fqn_to_config(self): """Compatibility property to maintain interface consistency with ModuleFqnToConfig.""" - return self.module_or_param_fqn_to_config + return self.module_fqn_to_config # maintain BC From 42bdf687e3adb6ee7dfcf483de39c83c4968c87b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 7 Oct 2025 09:24:35 -0700 Subject: [PATCH 33/47] update --- torchao/quantization/quant_api.py | 33 ++++++++++++++---------- torchao/quantization/transform_module.py | 29 --------------------- 2 files changed, 20 insertions(+), 42 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a38088ef01..0348fe924d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -90,9 +90,7 @@ ) from torchao.quantization.transform_module import ( _QUANTIZE_CONFIG_HANDLER, - _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER, register_quantize_module_handler, - register_quantize_tensor_handler, ) from torchao.quantization.utils import get_block_size from torchao.quantization.weight_tensor_linear_activation_quantization import ( @@ -1806,7 +1804,6 @@ def __post_init__(self): ) -@register_quantize_tensor_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): activation_dtype = config.activation_dtype weight_dtype = config.weight_dtype @@ -1879,7 +1876,10 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): @register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig + module: torch.nn.Module, + config: Float8DynamicActivationFloat8WeightConfig, + *, + param_name: str = "weight", ): assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" @@ -1887,14 +1887,16 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( - "applying float8 dynamic activation quant requires module to have weight attribute" + assert hasattr(module, param_name), ( + "applying float8 dynamic activation quant requires module to have parameter {param_name} attribute" + f"but {module} does not have one" ) - quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - module.weight, config + quantized_tensor = _float8_dynamic_activation_float8_weight_quantize_tensor( + getattr(module, param_name), config + ) + setattr( + module, param_name, torch.nn.Parameter(quantized_tensor, requires_grad=False) ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2435,6 +2437,12 @@ def module_or_param_fqn_to_config(self): # maintain BC ModuleFqnToConfig = ModuleOrParamFqnToConfig +# for now, we need to keep track of what configs support custom param quantization. +# Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check +CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS = { + Float8DynamicActivationFloat8WeightConfig, +} + def _param_fqn_to_config_handler( mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig @@ -2480,10 +2488,9 @@ def _param_fqn_to_config_handler( pattern.startswith("re:") and re.fullmatch(pattern[3:], f"{fqn}.{name}") ): param_config_type = type(param_config) - if param_config_type in _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: - handler = _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[param_config_type] - new_param = handler(param, param_config) - setattr(mod_containing_param, name, new_param) + if param_config_type in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: + handler = _QUANTIZE_CONFIG_HANDLER[param_config_type] + handler(mod_containing_param, param_config, param_name=name) else: raise NotImplementedError( f"Parameter quantization for {param_config_type} not supported currently!" diff --git a/torchao/quantization/transform_module.py b/torchao/quantization/transform_module.py index cb27b4c753..52bc721f1f 100644 --- a/torchao/quantization/transform_module.py +++ b/torchao/quantization/transform_module.py @@ -15,11 +15,6 @@ Callable[[torch.nn.Module, AOBaseConfig], torch.nn.Module], ] = {} -_QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER: Dict[ - Type[AOBaseConfig], - Callable[[torch.nn.Parameter, AOBaseConfig], torch.nn.Parameter], -] = {} - def register_quantize_module_handler(config_type): """ @@ -55,27 +50,3 @@ def decorator(func): return func # needed to make the functions usable externally return decorator - - -def register_quantize_tensor_handler(config_type): - """ - A decorator to register a transform function to map from a workflow - configuration (child of `AOBaseConfig`) to a function that transforms - a `torch.Tensor` according to the specified configuration. - - The wrapped function will be extended to support `torch.nn.Parameter` as well. - """ - - @functools.wraps(config_type) - def decorator(func): - def func_supporting_param(tensor_or_param, config): - if type(tensor_or_param) is torch.nn.Parameter: - return torch.nn.Parameter( - func(tensor_or_param, config), requires_grad=False - ) - return func(tensor_or_param, config) - - _QUANTIZE_CONFIG_TENSOR_PARAM_HANDLER[config_type] = func_supporting_param - return func # needed to make the functions usable externally - - return decorator From 829d31f6640c0d62b97350500334c3b874ab958b Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Tue, 7 Oct 2025 11:35:04 -0700 Subject: [PATCH 34/47] update comment --- torchao/quantization/quant_api.py | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 0348fe924d..2d6e0b2487 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2385,29 +2385,24 @@ def _fpx_weight_only_transform( class ModuleOrParamFqnToConfig(AOBaseConfig): """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). - This extends the functionality of ModuleFqnToConfig to support parameter-level quantization configurations - in addition to module-level configurations. It allows for fine-grained control over quantization by - specifying configurations for individual parameters or modules using regex pattern matching on their FQNs. - - When passed this config, `quantize_` will first try to replace all modules in the model, matching the logic of ModuleFqnToConfig. - Then, quantize_ will attempt to replace any parameters specified by the fqn or that match regexs, ignoring modules that have already been transformed by the previous flow (Modules with an existing AOBaseTensor attached) - - "_default" is ignored for parameter replacement. + Users can either explicitly specify a specific FQN or pass in a regex pattern to match FQNs. + `quantize_` will first try to replace all modules matching the keys of ModuleOrParamFqnToConfig.module_or_param_fqn_to_config, + It will then will try to replace any parameters that match the keys, ignoring modules that have already been transformed by the previous flow (modules that contain AOBaseTensor parameters): Args: module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping regex patterns (as strings) to quantization configurations. The patterns can be one of the follows: - (1). fully qualified name (fqn) of module or + (1). fully qualified name (fqn) of module or paramter or (2). regex of fully qualified name (in python `re` module regex format), should start with prefix "re:" or (3). "_default" Config key ordered by precedence: - * fully qualified module name, e.g. `language.layers.0.q_proj` - * regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, + * fully qualified module or paramteter name, e.g. `language.layers.0.q_proj` + * regex for module or parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, whiever regex fully matches the module fqn first will be applied (order of keys for dictionary are kept consistent since we are using OrderedDict) * "_default", fallback for **all modules** if no match for all previous keys @@ -2417,10 +2412,12 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) Note: - - The order of patterns in the OrderedDict matters as the first matching pattern is applied + - The order of patterns in the OrderedDict may matter as only the first matching pattern is applied - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization + - "_default" is ignored for parameter replacement. """ + # to maintain BC, we keep the same name as ModuleFqnToConfig before module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) From b40a03ad368e4aab26ba89aec72713f901e0484a Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 06:19:13 -0700 Subject: [PATCH 35/47] consolidate handler logic --- docs/source/api_ref_quantization.rst | 1 + docs/source/serving.rst | 4 +- docs/source/torchao_vllm_integration.md | 11 +- test/quantization/test_quant_api.py | 63 ++++++++---- torchao/quantization/__init__.py | 4 +- torchao/quantization/quant_api.py | 130 ++++++++++++------------ 6 files changed, 118 insertions(+), 95 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index c163a4b06a..226de3f98c 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -34,6 +34,7 @@ Inference APIs for quantize\_ Int8DynamicActivationInt8WeightConfig UIntXWeightOnlyConfig FPXWeightOnlyConfig + ModuleOrParamFqnToConfig .. currentmodule:: torchao.quantization diff --git a/docs/source/serving.rst b/docs/source/serving.rst index d95132ded7..58c435ecb0 100644 --- a/docs/source/serving.rst +++ b/docs/source/serving.rst @@ -175,7 +175,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati from torchao.quantization.quant_api import ( IntxWeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, - ModuleFqnToConfig, + FqnToConfig, quantize_, ) from torchao.quantization.granularity import PerGroup, PerAxis @@ -198,7 +198,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati weight_granularity=PerGroup(32), weight_scale_dtype=torch.bfloat16, ) - quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) + quant_config = FqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[]) # either use `untied_model_id` or `untied_model_local_path` diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md index 1ca027a124..ef3719cb9f 100644 --- a/docs/source/torchao_vllm_integration.md +++ b/docs/source/torchao_vllm_integration.md @@ -54,22 +54,21 @@ assert isinstance(config, AOBaseConfig) All quantization configurations inherit from {class}`torchao.core.config.AOBaseConfig`, which provides serialization and validation capabilities. ``` -(module-level-configuration)= -### 3. Module-Level Configuration +(fqn-configuration)= +### 3. FQN Configuration -For granular control, use `ModuleFqnToConfig`: +For granular control, use `FqnToConfig`: ```python -from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig +from torchao.quantization import FqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig -config = ModuleFqnToConfig({ +config = FqnToConfig({ "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(), "_default": Int4WeightOnlyConfig(group_size=128, version=1) # Default for other modules }) ``` - (usage-examples)= ## Usage Examples diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 0e877976eb..d87b12cb7b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -34,6 +34,7 @@ TensorCoreTiledLayout, ) from torchao.quantization import ( + Float8Tensor, Int4TilePackedTo4dTensor, IntxUnpackedToInt8Tensor, LinearActivationQuantizedTensor, @@ -44,6 +45,7 @@ Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + FqnToConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, @@ -53,7 +55,6 @@ Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, - ModuleOrParamFqnToConfig, PerRow, PerTensor, Quantizer, @@ -906,7 +907,7 @@ def test_config_deprecation(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") @unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") -class TestModuleOrParamFqnToConfig(TestCase): +class TestFqnToConfig(TestCase): def test_quantize_param_fqn_exact(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -916,7 +917,7 @@ def test_quantize_param_fqn_exact(self): ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - quant_config = ModuleOrParamFqnToConfig( + quant_config = FqnToConfig( { "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), @@ -929,11 +930,6 @@ def test_quantize_param_fqn_exact(self): quant_config, ) - # Note: Need to import Float8Tensor from the correct location - from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( - Float8Tensor, - ) - assert isinstance(model.experts.gate_up_proj, Float8Tensor) def test_quantize_param_and_module_fqn(self): @@ -944,7 +940,7 @@ def test_quantize_param_and_module_fqn(self): "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - quant_config = ModuleOrParamFqnToConfig( + quant_config = FqnToConfig( { "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), @@ -960,6 +956,10 @@ def test_quantize_param_and_module_fqn(self): quant_config, ) + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) + assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + def test_quantize_param_and_module_fqn_regex(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -968,7 +968,7 @@ def test_quantize_param_and_module_fqn_regex(self): "unsloth/Llama-4-Scout-17B-16E-Instruct" ).text_config model = Llama4TextMoe(config).to(torch.bfloat16).cuda() - quant_config = ModuleOrParamFqnToConfig( + quant_config = FqnToConfig( { "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), @@ -984,14 +984,41 @@ def test_quantize_param_and_module_fqn_regex(self): quant_config, ) - from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( - Float8Tensor, - ) - assert isinstance(model.experts.gate_up_proj, Float8Tensor) assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + def test_quantize_modle_exact_match_preference(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + # only this config should be applied, as module fqn takes precedence + "linear": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + # "re:linear": Float8DynamicActivationFloat8WeightConfig( + # granularity=PerTensor(), + # ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.linear.weight, Float8Tensor) + assert model.linear.weight.scale.numel() == 128 + def test_quantize_modle_param_double_specified(self): model = ( torch.nn.Sequential( @@ -1001,7 +1028,7 @@ def test_quantize_modle_param_double_specified(self): .to(torch.bfloat16) .cuda() ) - quant_config = ModuleOrParamFqnToConfig( + quant_config = FqnToConfig( { # only this config should be applied, as module fqn takes precedence "0": Float8DynamicActivationFloat8WeightConfig( @@ -1018,10 +1045,6 @@ def test_quantize_modle_param_double_specified(self): quant_config, ) - from torchao.quantization.quantize_.workflows.float8.float8_tensor import ( - Float8Tensor, - ) - assert isinstance(model[0].weight, Float8Tensor) assert model[0].weight.scale.numel() == 128 @@ -1040,7 +1063,7 @@ class UnsupportedParamConfig(AOBaseConfig): model = torch.nn.Sequential(torch.nn.Linear(10, 5).cuda().bfloat16()) # Create config with unsupported parameter handler - quant_config = ModuleOrParamFqnToConfig( + quant_config = FqnToConfig( { "0.weight": UnsupportedParamConfig(), } diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index e5154ad54e..44ee3226a4 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -52,6 +52,7 @@ Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + FqnToConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, @@ -61,7 +62,6 @@ Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, - ModuleOrParamFqnToConfig, PlainLayout, TensorCoreTiledLayout, UIntXWeightOnlyConfig, @@ -162,8 +162,8 @@ "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", + "FqnToConfig", "ModuleFqnToConfig", - "ModuleOrParamFqnToConfig", # tensor subclasses "Int4Tensor", "Int4PlainInt32Tensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2d6e0b2487..2b9bbe8ceb 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -515,18 +515,16 @@ def quantize_( torch._C._log_api_usage_once("torchao.quantization.quantize_") filter_fn = _is_linear if filter_fn is None else filter_fn - if isinstance(config, ModuleOrParamFqnToConfig): + if isinstance(config, FqnToConfig): _replace_with_custom_fn_if_matches_filter_with_name( model, - _module_fqn_to_config_handler, - filter_fn, - device=device, - extra_args=(config,), - ) - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _param_fqn_to_config_handler, - partial(select_module_if_top_level_params_match_pattern, config=config), + _fqn_to_config_handler, + # filter_fn, + partial( + select_module_if_top_level_params_match_pattern, + config=config, + filter_fn=filter_fn, + ), device=device, extra_args=(config,), ) @@ -2382,12 +2380,12 @@ def _fpx_weight_only_transform( @dataclass -class ModuleOrParamFqnToConfig(AOBaseConfig): +class FqnToConfig(AOBaseConfig): """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). Users can either explicitly specify a specific FQN or pass in a regex pattern to match FQNs. - `quantize_` will first try to replace all modules matching the keys of ModuleOrParamFqnToConfig.module_or_param_fqn_to_config, + `quantize_` will first try to replace all modules matching the keys of FqnToConfig.fqn_to_config, It will then will try to replace any parameters that match the keys, ignoring modules that have already been transformed by the previous flow (modules that contain AOBaseTensor parameters): Args: @@ -2418,21 +2416,24 @@ class ModuleOrParamFqnToConfig(AOBaseConfig): """ # to maintain BC, we keep the same name as ModuleFqnToConfig before + fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( + default_factory=OrderedDict + ) module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) + version: int = 1 def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.ModuleOrParamFqnToConfig") - - @property - def module_or_param_fqn_to_config(self): - """Compatibility property to maintain interface consistency with ModuleFqnToConfig.""" - return self.module_fqn_to_config + torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") + if self.version == 1: + warnings.warn( + "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig, see https://github.com/pytorch/ao/issues/2967 for more details" + ) # maintain BC -ModuleFqnToConfig = ModuleOrParamFqnToConfig +ModuleFqnToConfig = FqnToConfig # for now, we need to keep track of what configs support custom param quantization. # Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check @@ -2441,8 +2442,8 @@ def module_or_param_fqn_to_config(self): } -def _param_fqn_to_config_handler( - mod_containing_param: torch.nn.Module, fqn: str, config: ModuleOrParamFqnToConfig +def _fqn_to_config_handler( + mod_containing_param: torch.nn.Module, fqn: str, config: FqnToConfig ): """This function processes parameters within a module and applies quantization configurations when the parameter's fully qualified name matches patterns defined in the config. @@ -2450,7 +2451,7 @@ def _param_fqn_to_config_handler( Args: mod_containing_param (torch.nn.Module): The module containing parameters to be processed. fqn (str): The fully qualified name of the module containing the parameters. - config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns mapped + config (FqnToConfig): Configuration object containing regex patterns mapped to quantization configurations. Returns: @@ -2465,6 +2466,30 @@ def _param_fqn_to_config_handler( Raises: NotImplementedError: If a configuration type doesn't have a registered parameter handler. """ + print(mod_containing_param, fqn, config) + # breakpoint() + module = mod_containing_param + # c = None + # # check to see if module_fqn is exact match + if fqn in config.fqn_to_config: + # Maybe: we can add module type specific config in the future, if needed + c = config.fqn_to_config[fqn] + else: + for maybe_module_fqn_pattern in config.fqn_to_config: + if not maybe_module_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_fqn_pattern[3:], fqn): + # we'll apply the config for first fully matched pattern + c = config.fqn_to_config[maybe_module_fqn_pattern] + break + else: + # fallback to use default if no module specific config is provided + c = config.fqn_to_config.get("_default", None) + + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + top_level_named_parameters_list = [ (name, param) for name, param in mod_containing_param.named_parameters() @@ -2479,10 +2504,11 @@ def _param_fqn_to_config_handler( return mod_containing_param for name, param in top_level_named_parameters_list: - for pattern, param_config in config.module_or_param_fqn_to_config.items(): + for pattern, param_config in config.fqn_to_config.items(): full_param_fqn = f"{fqn}.{name}" + # Exact match takes precedence if (pattern == full_param_fqn) or ( - pattern.startswith("re:") and re.fullmatch(pattern[3:], f"{fqn}.{name}") + pattern.startswith("re:") and re.fullmatch(pattern[3:], full_param_fqn) ): param_config_type = type(param_config) if param_config_type in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: @@ -2497,64 +2523,38 @@ def _param_fqn_to_config_handler( def select_module_if_top_level_params_match_pattern( - mod: nn.Module, fqn: str, config: ModuleOrParamFqnToConfig + mod: nn.Module, fqn: str, config: FqnToConfig, filter_fn=None ): """Check if a module should be selected for quantization. This function determines whether a module should be processed for parameter-level quantization - by checking if any of its top-level parameters match the patterns defined in ModuleOrParamFqnToConfig. + by checking if any of its top-level parameters match the patterns defined in FqnToConfig. We only check top-level parameters (those directly accessible as module attributes). Args: mod (torch.nn.Module): The module to check for parameter pattern matches. fqn (str): The fully qualified name of the module. - config (ModuleOrParamFqnToConfig): Configuration object containing regex patterns or raw FQNs for + config (FqnToConfig): Configuration object containing regex patterns or raw FQNs for parameter quantization. Returns: bool: True if any of the module's parameters match patterns in the configuration, False otherwise. """ - for name, param in mod.named_parameters(): - if name in dir(mod) and not isinstance(param, TorchAOBaseTensor): - for pattern in config.module_or_param_fqn_to_config: - full_param_fqn = f"{fqn}.{name}" - if (pattern == full_param_fqn) or ( - pattern.startswith("re:") - and re.fullmatch(pattern[3:], f"{fqn}.{name}") - ): - return True - return False - - -def _module_fqn_to_config_handler( - module: torch.nn.Module, - module_fqn: str, - config: ModuleOrParamFqnToConfig, -): - c = None - # check to see if module_fqn is exact match - if module_fqn in config.module_fqn_to_config: - # Maybe: we can add module type specific config in the future, in needed - c = config.module_fqn_to_config[module_fqn] + if filter_fn(mod, fqn): + return True else: - for maybe_module_fqn_pattern in config.module_fqn_to_config: - if not maybe_module_fqn_pattern.startswith("re:"): - continue - elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): - # we'll apply the config for first fully matched pattern - c = config.module_fqn_to_config[maybe_module_fqn_pattern] - break - else: - # fallback to use default if no module specific config is provided - c = config.module_fqn_to_config.get("_default", None) - - if c is not None: - handler = _QUANTIZE_CONFIG_HANDLER[type(c)] - return handler(module, c) - - return module + for name, param in mod.named_parameters(): + if name in dir(mod) and not isinstance(param, TorchAOBaseTensor): + for pattern in config.fqn_to_config: + full_param_fqn = f"{fqn}.{name}" + if (pattern == full_param_fqn) or ( + pattern.startswith("re:") + and re.fullmatch(pattern[3:], f"{fqn}.{name}") + ): + return True + return False torch.serialization.add_safe_globals( From 3566b6f0c9acf03dc486c2898b09b0f556858580 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 07:28:56 -0700 Subject: [PATCH 36/47] refactor --- docs/source/api_ref_quantization.rst | 2 +- test/quantization/test_quant_api.py | 40 +++++++-- torchao/quantization/quant_api.py | 121 ++++++++++++--------------- 3 files changed, 90 insertions(+), 73 deletions(-) diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index 226de3f98c..d4a661c6d8 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -34,7 +34,7 @@ Inference APIs for quantize\_ Int8DynamicActivationInt8WeightConfig UIntXWeightOnlyConfig FPXWeightOnlyConfig - ModuleOrParamFqnToConfig + FqnToConfig .. currentmodule:: torchao.quantization diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index d87b12cb7b..f23454620d 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1005,9 +1005,37 @@ def forward(self, x): "linear": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), - # "re:linear": Float8DynamicActivationFloat8WeightConfig( - # granularity=PerTensor(), - # ), + "re:linear": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.linear.weight, Float8Tensor) + assert model.linear.weight.scale.numel() == 128 + + def test_quantize_module_default_bc(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + # only this config should be applied, as module fqn takes precedence + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), } ) @@ -1030,13 +1058,13 @@ def test_quantize_modle_param_double_specified(self): ) quant_config = FqnToConfig( { + "re:.*weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), # only this config should be applied, as module fqn takes precedence "0": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), - "re:.*weight": Float8DynamicActivationFloat8WeightConfig( - granularity=PerTensor(), - ), } ) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 2b9bbe8ceb..f8ad08e2fd 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -519,9 +519,8 @@ def quantize_( _replace_with_custom_fn_if_matches_filter_with_name( model, _fqn_to_config_handler, - # filter_fn, partial( - select_module_if_top_level_params_match_pattern, + select_module_if_filter_fn_or_contains_params_matching_pattern, config=config, filter_fn=filter_fn, ), @@ -1877,7 +1876,7 @@ def _float8_dynamic_activation_float8_weight_transform( module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig, *, - param_name: str = "weight", + parameter_name: str = "weight", ): assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" @@ -1885,15 +1884,17 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, param_name), ( - "applying float8 dynamic activation quant requires module to have parameter {param_name} attribute" + assert hasattr(module, parameter_name), ( + "applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute" + f"but {module} does not have one" ) quantized_tensor = _float8_dynamic_activation_float8_weight_quantize_tensor( - getattr(module, param_name), config + getattr(module, parameter_name), config ) setattr( - module, param_name, torch.nn.Parameter(quantized_tensor, requires_grad=False) + module, + parameter_name, + torch.nn.Parameter(quantized_tensor, requires_grad=False), ) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2415,10 +2416,10 @@ class FqnToConfig(AOBaseConfig): - "_default" is ignored for parameter replacement. """ - # to maintain BC, we keep the same name as ModuleFqnToConfig before fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) + # to maintain BC, we keep the same name as ModuleFqnToConfig before module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) @@ -2442,9 +2443,7 @@ def __post_init__(self): } -def _fqn_to_config_handler( - mod_containing_param: torch.nn.Module, fqn: str, config: FqnToConfig -): +def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfig): """This function processes parameters within a module and applies quantization configurations when the parameter's fully qualified name matches patterns defined in the config. @@ -2466,11 +2465,35 @@ def _fqn_to_config_handler( Raises: NotImplementedError: If a configuration type doesn't have a registered parameter handler. """ - print(mod_containing_param, fqn, config) - # breakpoint() - module = mod_containing_param - # c = None - # # check to see if module_fqn is exact match + + # 1) module swap + c = _get_config_for_fqn( + fqn, config, default=config.fqn_to_config.get("_default", None) + ) + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + + # 2) handle custom parameter flow + for parameter_name, param in list(module.named_parameters()): + if parameter_name in dir(module): + c = _get_config_for_fqn(f"{fqn}.{parameter_name}", config) + if c is not None: + if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c, parameter_name=parameter_name) + else: + raise NotImplementedError( + f"Parameter quantization for {type(c)} not supported currently!" + ) + + return module + + +def _get_config_for_fqn(fqn, config: FqnToConfig, default=None): + """Helper function to get the config for a given fqn from an FqnToConfig object.""" + c = None + # check to see if module_fqn is exact match if fqn in config.fqn_to_config: # Maybe: we can add module type specific config in the future, if needed c = config.fqn_to_config[fqn] @@ -2484,69 +2507,35 @@ def _fqn_to_config_handler( break else: # fallback to use default if no module specific config is provided - c = config.fqn_to_config.get("_default", None) - - if c is not None: - handler = _QUANTIZE_CONFIG_HANDLER[type(c)] - return handler(module, c) + if default is not None: + c = default + return c - top_level_named_parameters_list = [ - (name, param) - for name, param in mod_containing_param.named_parameters() - if name in dir(mod_containing_param) - ] - # return if modified previously by module flow - if any( - isinstance(param, TorchAOBaseTensor) - for _, param in top_level_named_parameters_list - ): - return mod_containing_param - - for name, param in top_level_named_parameters_list: - for pattern, param_config in config.fqn_to_config.items(): - full_param_fqn = f"{fqn}.{name}" - # Exact match takes precedence - if (pattern == full_param_fqn) or ( - pattern.startswith("re:") and re.fullmatch(pattern[3:], full_param_fqn) - ): - param_config_type = type(param_config) - if param_config_type in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: - handler = _QUANTIZE_CONFIG_HANDLER[param_config_type] - handler(mod_containing_param, param_config, param_name=name) - else: - raise NotImplementedError( - f"Parameter quantization for {param_config_type} not supported currently!" - ) - - return mod_containing_param - - -def select_module_if_top_level_params_match_pattern( - mod: nn.Module, fqn: str, config: FqnToConfig, filter_fn=None +def select_module_if_filter_fn_or_contains_params_matching_pattern( + module: nn.Module, + fqn: str, + config: FqnToConfig, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ): - """Check if a module should be selected for quantization. - - This function determines whether a module should be processed for parameter-level quantization - by checking if any of its top-level parameters match the patterns defined in FqnToConfig. - - We only check top-level parameters (those directly accessible as module attributes). + """Check if a module should be selected for quantization. We only check top-level parameters (those directly accessible as module attributes). Args: - mod (torch.nn.Module): The module to check for parameter pattern matches. + module (torch.nn.Module): The module to check for parameter pattern matches. fqn (str): The fully qualified name of the module. config (FqnToConfig): Configuration object containing regex patterns or raw FQNs for parameter quantization. + filter_fn (Optional[Callable[[torch.nn.Module], bool]]): A function that takes a module and fqn and return whether to quantize the module. Returns: - bool: True if any of the module's parameters match patterns in the configuration, - False otherwise. + bool: True if filter_fn is passed and filter_fn(module, fqn) is True, or if any of the top-level parameters match the patterns in config.fqn_to_config + False otherwise. """ - if filter_fn(mod, fqn): + if filter_fn is not None and filter_fn(module, fqn): return True else: - for name, param in mod.named_parameters(): - if name in dir(mod) and not isinstance(param, TorchAOBaseTensor): + for name, param in module.named_parameters(): + if name in dir(module) and not isinstance(param, TorchAOBaseTensor): for pattern in config.fqn_to_config: full_param_fqn = f"{fqn}.{name}" if (pattern == full_param_fqn) or ( From 3fe447406ace39815b828af97e55c3bac8668d06 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 08:23:50 -0700 Subject: [PATCH 37/47] add test for same-level param and module --- test/quantization/test_quant_api.py | 35 ++++++++++++++ torchao/quantization/quant_api.py | 73 +++++++++++++++++------------ 2 files changed, 77 insertions(+), 31 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index f23454620d..7379b79892 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1047,6 +1047,41 @@ def forward(self, x): assert isinstance(model.linear.weight, Float8Tensor) assert model.linear.weight.scale.numel() == 128 + def test_quantize_module_default_param_quant(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + self.param = torch.nn.Parameter(torch.randn(128, 128)) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + # only this config should be applied, as module fqn takes precedence + # if we have both a linear and param at the same level, + "linear": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "param": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.linear.weight, Float8Tensor) + assert model.linear.weight.scale.numel() == 128 + + assert isinstance(model.param, Float8Tensor) + def test_quantize_modle_param_double_specified(self): model = ( torch.nn.Sequential( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f8ad08e2fd..f45300fb9b 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -246,23 +246,22 @@ def _replace_with_custom_fn_if_matches_filter_with_name( if device is not None: model.to(device=device) # move to device before quantization model = replacement_fn(model, cur_fqn[:-1], *extra_args) - return model - else: - named_children_list = list(model.named_children()) - for name, child in named_children_list: - new_child = _replace_with_custom_fn_if_matches_filter_with_name( - child, - replacement_fn, - filter_fn, - f"{cur_fqn}{name}.", - device, - extra_args, - ) - if new_child is not child: - setattr(model, name, new_child) - if device is not None: - model.to(device=device) # move parent module to device - return model + # For parameter quantization, filter_fn(model, cur_fqn) no longer is terminal, as a module may contain both a parameter we want to quantize and subsequent submodules. + named_children_list = list(model.named_children()) + for name, child in named_children_list: + new_child = _replace_with_custom_fn_if_matches_filter_with_name( + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, + ) + if new_child is not child: + setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device + return model def _is_linear(mod, *args): @@ -2419,7 +2418,7 @@ class FqnToConfig(AOBaseConfig): fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) - # to maintain BC, we keep the same name as ModuleFqnToConfig before + # to maintain BC, we keep the previous module_fqn_to_config field module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) @@ -2427,10 +2426,19 @@ class FqnToConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") - if self.version == 1: + if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: + warnings.warn( + "Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used" + ) + if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: + self.fqn_to_config = self.module_fqn_to_config warnings.warn( "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig, see https://github.com/pytorch/ao/issues/2967 for more details" ) + elif len(self.fqn_to_config) > 0 and len(self.module_fqn_to_config) == 0: + self.module_fqn_to_config = self.fqn_to_config + else: + self.module_fqn_to_config = self.fqn_to_config # maintain BC @@ -2465,7 +2473,7 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi Raises: NotImplementedError: If a configuration type doesn't have a registered parameter handler. """ - + # breakpoint() # 1) module swap c = _get_config_for_fqn( fqn, config, default=config.fqn_to_config.get("_default", None) @@ -2477,7 +2485,8 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi # 2) handle custom parameter flow for parameter_name, param in list(module.named_parameters()): if parameter_name in dir(module): - c = _get_config_for_fqn(f"{fqn}.{parameter_name}", config) + full_param_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name + c = _get_config_for_fqn(full_param_fqn, config) if c is not None: if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: handler = _QUANTIZE_CONFIG_HANDLER[type(c)] @@ -2531,18 +2540,20 @@ def select_module_if_filter_fn_or_contains_params_matching_pattern( bool: True if filter_fn is passed and filter_fn(module, fqn) is True, or if any of the top-level parameters match the patterns in config.fqn_to_config False otherwise. """ + # breakpoint() if filter_fn is not None and filter_fn(module, fqn): + print(f"Selected module {fqn} for quantization") return True - else: - for name, param in module.named_parameters(): - if name in dir(module) and not isinstance(param, TorchAOBaseTensor): - for pattern in config.fqn_to_config: - full_param_fqn = f"{fqn}.{name}" - if (pattern == full_param_fqn) or ( - pattern.startswith("re:") - and re.fullmatch(pattern[3:], f"{fqn}.{name}") - ): - return True + for name, param in module.named_parameters(): + if name in dir(module) and not isinstance(param, TorchAOBaseTensor): + full_param_fqn = f"{fqn}.{name}" if fqn != "" else name + for pattern in config.fqn_to_config: + if (pattern == full_param_fqn) or ( + pattern.startswith("re:") + and re.fullmatch(pattern[3:], full_param_fqn) + ): + print(f"Found matching pattern for {full_param_fqn}") + return True return False From 84fa0a5908f78caf73deccc08ed8b5f1419ac378 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 08:57:39 -0700 Subject: [PATCH 38/47] update docstring --- torchao/quantization/quant_api.py | 64 ++++++++++++++++--------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f45300fb9b..545710965d 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2439,6 +2439,10 @@ def __post_init__(self): self.module_fqn_to_config = self.fqn_to_config else: self.module_fqn_to_config = self.fqn_to_config + if "_default" in self.fqn_to_config: + warnings.warn( + "Config Deprecation: _default is deprecated and will no longer be supported in a future release" + ) # maintain BC @@ -2452,29 +2456,21 @@ def __post_init__(self): def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfig): - """This function processes parameters within a module and applies quantization configurations - when the parameter's fully qualified name matches patterns defined in the config. + """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig. Args: - mod_containing_param (torch.nn.Module): The module containing parameters to be processed. + module (torch.nn.Module): The module to be processed. fqn (str): The fully qualified name of the module containing the parameters. - config (FqnToConfig): Configuration object containing regex patterns mapped + config (FqnToConfig): Configuration object containing regex patterns / fqn mapped to quantization configurations. Returns: torch.nn.Module: The modified module with quantized parameters. - Note: - - Only processes top-level parameters (those directly accessible as module attributes) - - Skips parameters that are already TorchAOBaseTensor instances to avoid double quantization - - Uses the first matching pattern for each parameter - - Sets quantized parameters as non-differentiable (requires_grad=False) - Raises: - NotImplementedError: If a configuration type doesn't have a registered parameter handler. + NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ - # breakpoint() - # 1) module swap + # First we attempt to apply the module config c = _get_config_for_fqn( fqn, config, default=config.fqn_to_config.get("_default", None) ) @@ -2482,11 +2478,11 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi handler = _QUANTIZE_CONFIG_HANDLER[type(c)] return handler(module, c) - # 2) handle custom parameter flow + # If there is no module config to apply, we attempt to match our parameters for parameter_name, param in list(module.named_parameters()): if parameter_name in dir(module): - full_param_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name - c = _get_config_for_fqn(full_param_fqn, config) + parameter_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name + c = _get_config_for_fqn(parameter_fqn, config) if c is not None: if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: handler = _QUANTIZE_CONFIG_HANDLER[type(c)] @@ -2495,24 +2491,33 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi raise NotImplementedError( f"Parameter quantization for {type(c)} not supported currently!" ) - return module -def _get_config_for_fqn(fqn, config: FqnToConfig, default=None): - """Helper function to get the config for a given fqn from an FqnToConfig object.""" +def _get_config_for_fqn( + fqn: str, config: FqnToConfig, default: Optional[AOBaseConfig] = None +): + """Helper function to get the config for a given fqn from an FqnToConfig object. + + In order of precednece it will try to match + 1) the fqn exactly + 2) any regex that matches the fqn + 3) default, if specified (for mainitainig BC) + """ c = None - # check to see if module_fqn is exact match if fqn in config.fqn_to_config: + assert not fqn.startswith("re:"), ( + f"Error: Exact match but regex {fqn} specified." + ) # Maybe: we can add module type specific config in the future, if needed c = config.fqn_to_config[fqn] else: - for maybe_module_fqn_pattern in config.fqn_to_config: - if not maybe_module_fqn_pattern.startswith("re:"): + for maybe_module_or_param_fqn_pattern in config.fqn_to_config: + if not maybe_module_or_param_fqn_pattern.startswith("re:"): continue - elif re.fullmatch(maybe_module_fqn_pattern[3:], fqn): + elif re.fullmatch(maybe_module_or_param_fqn_pattern[3:], fqn): # we'll apply the config for first fully matched pattern - c = config.fqn_to_config[maybe_module_fqn_pattern] + c = config.fqn_to_config[maybe_module_or_param_fqn_pattern] break else: # fallback to use default if no module specific config is provided @@ -2527,7 +2532,7 @@ def select_module_if_filter_fn_or_contains_params_matching_pattern( config: FqnToConfig, filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ): - """Check if a module should be selected for quantization. We only check top-level parameters (those directly accessible as module attributes). + """Check if a module should be selected for quantization, if filter_fn(module, fqn) is True or if module contains any top-level parameters that match the fqns/regexs in FqnToConfig. Args: module (torch.nn.Module): The module to check for parameter pattern matches. @@ -2540,19 +2545,16 @@ def select_module_if_filter_fn_or_contains_params_matching_pattern( bool: True if filter_fn is passed and filter_fn(module, fqn) is True, or if any of the top-level parameters match the patterns in config.fqn_to_config False otherwise. """ - # breakpoint() if filter_fn is not None and filter_fn(module, fqn): - print(f"Selected module {fqn} for quantization") return True for name, param in module.named_parameters(): if name in dir(module) and not isinstance(param, TorchAOBaseTensor): - full_param_fqn = f"{fqn}.{name}" if fqn != "" else name + parameter_fqn = f"{fqn}.{name}" if fqn != "" else name for pattern in config.fqn_to_config: - if (pattern == full_param_fqn) or ( + if (pattern == parameter_fqn) or ( pattern.startswith("re:") - and re.fullmatch(pattern[3:], full_param_fqn) + and re.fullmatch(pattern[3:], parameter_fqn) ): - print(f"Found matching pattern for {full_param_fqn}") return True return False From 8b9db036b0e13e785edfadcb08d507f1c31e7140 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 09:16:18 -0700 Subject: [PATCH 39/47] more docstring updates --- torchao/quantization/quant_api.py | 54 ++++++++++++++++--------------- 1 file changed, 28 insertions(+), 26 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 545710965d..a6b18a69e4 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2389,29 +2389,21 @@ class FqnToConfig(AOBaseConfig): It will then will try to replace any parameters that match the keys, ignoring modules that have already been transformed by the previous flow (modules that contain AOBaseTensor parameters): Args: - module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): An ordered dictionary mapping - regex patterns (as strings) to quantization configurations. - + fqn_to_config The patterns can be one of the follows: - (1). fully qualified name (fqn) of module or paramter or - (2). regex of fully qualified name (in python `re` module regex format), should - start with prefix "re:" or - (3). "_default" - - Config key ordered by precedence: - * fully qualified module or paramteter name, e.g. `language.layers.0.q_proj` - * regex for module or parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, - whiever regex fully matches the module fqn first will be applied - (order of keys for dictionary are kept consistent since we are using OrderedDict) - * "_default", fallback for **all modules** if no match for all previous keys - (Note, when using `_default`, the config is applied to all modules, to apply - it to only a subset of modules, e.g. with some types, it's better to filter - the modules that we don't want to quantize before hand and configure them to - None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) + * fully qualified module or paramteter name, e.g. `language.layers.0.q_proj` + * regex for module or parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, + whiever regex fully matches the module fqn first will be applied + (order of keys for dictionary are kept consistent since we are using OrderedDict) + * "_default", fallback for **all modules** if no match for all previous keys + (Note, when using `_default`, the config is applied to all modules, to apply + it to only a subset of modules, e.g. with some types, it's better to filter + the modules that we don't want to quantize before hand and configure them to + None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) + module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): For BC Note: - The order of patterns in the OrderedDict may matter as only the first matching pattern is applied - - Parameters that are already TorchAOBaseTensor instances are skipped to avoid double quantization - "_default" is ignored for parameter replacement. """ @@ -2426,22 +2418,27 @@ class FqnToConfig(AOBaseConfig): def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") + + # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: warnings.warn( "Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used" ) + self.module_fqn_to_config = self.fqn_to_config if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: self.fqn_to_config = self.module_fqn_to_config warnings.warn( - "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig, see https://github.com/pytorch/ao/issues/2967 for more details" + "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" ) elif len(self.fqn_to_config) > 0 and len(self.module_fqn_to_config) == 0: self.module_fqn_to_config = self.fqn_to_config else: self.module_fqn_to_config = self.fqn_to_config + + # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: warnings.warn( - "Config Deprecation: _default is deprecated and will no longer be supported in a future release" + "Config Deprecation: _default is deprecated and will no longer be supported in a future release." ) @@ -2499,10 +2496,15 @@ def _get_config_for_fqn( ): """Helper function to get the config for a given fqn from an FqnToConfig object. - In order of precednece it will try to match - 1) the fqn exactly - 2) any regex that matches the fqn - 3) default, if specified (for mainitainig BC) + Args: + fqn (str): The fully qualified name to match against the config patterns. + config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. + default (Optional[AOBaseConfig]): The default config to use if no match is found. Defaults to None. + + Returns: + c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. + Otherwise we will return the config of the first matching regex pattern in FqnToConfig. + Finally, we will return the default config if it is specified. """ c = None if fqn in config.fqn_to_config: @@ -2532,7 +2534,7 @@ def select_module_if_filter_fn_or_contains_params_matching_pattern( config: FqnToConfig, filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ): - """Check if a module should be selected for quantization, if filter_fn(module, fqn) is True or if module contains any top-level parameters that match the fqns/regexs in FqnToConfig. + """Check if a module should be selected for quantization to be applied Args: module (torch.nn.Module): The module to check for parameter pattern matches. From 098327b5f4f8dedb87381323130bd8d01758a85f Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 10:04:19 -0700 Subject: [PATCH 40/47] update --- test/quantization/test_quant_api.py | 13 +++--- torchao/quantization/quant_api.py | 69 ++++++++++++++++------------- 2 files changed, 44 insertions(+), 38 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 7379b79892..235839a60c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1057,16 +1057,16 @@ def __init__(self): def forward(self, x): return self.linear(x) - model = TestModule().to(torch.bfloat16).cuda() + model = torch.nn.Sequential(TestModule()).to(torch.bfloat16).cuda() quant_config = FqnToConfig( { # only this config should be applied, as module fqn takes precedence # if we have both a linear and param at the same level, - "linear": Float8DynamicActivationFloat8WeightConfig( + "_default": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), - "param": Float8DynamicActivationFloat8WeightConfig( + "0.param": Float8DynamicActivationFloat8WeightConfig( granularity=PerTensor(), ), } @@ -1077,10 +1077,11 @@ def forward(self, x): quant_config, ) - assert isinstance(model.linear.weight, Float8Tensor) - assert model.linear.weight.scale.numel() == 128 + assert isinstance(model[0].linear.weight, Float8Tensor) + assert model[0].linear.weight.scale.numel() == 128 - assert isinstance(model.param, Float8Tensor) + assert isinstance(model[0].param, Float8Tensor) + assert model[0].param.scale.numel() == 1 def test_quantize_modle_param_double_specified(self): model = ( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index a6b18a69e4..f7b9e9cef7 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -519,7 +519,7 @@ def quantize_( model, _fqn_to_config_handler, partial( - select_module_if_filter_fn_or_contains_params_matching_pattern, + _select_module_if_filter_fn_or_contains_params_matching_pattern, config=config, filter_fn=filter_fn, ), @@ -2389,18 +2389,26 @@ class FqnToConfig(AOBaseConfig): It will then will try to replace any parameters that match the keys, ignoring modules that have already been transformed by the previous flow (modules that contain AOBaseTensor parameters): Args: - fqn_to_config - The patterns can be one of the follows: - * fully qualified module or paramteter name, e.g. `language.layers.0.q_proj` - * regex for module or parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, - whiever regex fully matches the module fqn first will be applied - (order of keys for dictionary are kept consistent since we are using OrderedDict) - * "_default", fallback for **all modules** if no match for all previous keys - (Note, when using `_default`, the config is applied to all modules, to apply - it to only a subset of modules, e.g. with some types, it's better to filter - the modules that we don't want to quantize before hand and configure them to - None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) - module_fqn_to_config (OrderedDict[str, Optional[AOBaseConfig]]): For BC + `fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an + ordered dictionary from + (1). fully qualified name (fqn) of module or + (2). regex of fully qualified name (in python `re` module regex format), should + start with prefix "re:" or + (3). "_default" + to the config that we want to apply to the module or None + + Config key ordered by precedence: + * fully qualified module name, e.g. `language.layers.0.q_proj` + * regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, + whiever regex fully matches the module fqn first will be applied + (order of keys for dictionary are kept consistent since we are using OrderedDict) + * fully qualified parameter name + * regex for parameter names + * "_default", fallback for **all modules** if no match for all previous keys + (Note, when using `_default`, the config is applied to all modules, to apply + it to only a subset of modules, e.g. with some types, it's better to filter + the modules that we don't want to quantize before hand and configure them to + None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) Note: - The order of patterns in the OrderedDict may matter as only the first matching pattern is applied @@ -2421,10 +2429,9 @@ def __post_init__(self): # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: - warnings.warn( - "Both module_fqn_to_config and fqn_to_config are specified, only fqn_to_config will be used" + raise ValueError( + "Both module_fqn_to_config and fqn_to_config are non-empty, expected one to be" ) - self.module_fqn_to_config = self.fqn_to_config if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: self.fqn_to_config = self.module_fqn_to_config warnings.warn( @@ -2467,15 +2474,13 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi Raises: NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ - # First we attempt to apply the module config - c = _get_config_for_fqn( - fqn, config, default=config.fqn_to_config.get("_default", None) - ) + # First we see if our module fqn matches with FqnToConfig, if so, we apply the appropriate transform + c = _get_config_for_fqn(fqn, config) if c is not None: handler = _QUANTIZE_CONFIG_HANDLER[type(c)] return handler(module, c) - # If there is no module config to apply, we attempt to match our parameters + # If no config is found, we see if any of our top-level parameter FQNs matches with FqnToConfig for parameter_name, param in list(module.named_parameters()): if parameter_name in dir(module): parameter_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name @@ -2488,23 +2493,27 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi raise NotImplementedError( f"Parameter quantization for {type(c)} not supported currently!" ) + + # If no module_fqn or parameter_fqn matches, then we apply _default + c = config.fqn_to_config.get("_default", None) + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + + # Else return unmodified module return module -def _get_config_for_fqn( - fqn: str, config: FqnToConfig, default: Optional[AOBaseConfig] = None -): +def _get_config_for_fqn(fqn: str, config: FqnToConfig): """Helper function to get the config for a given fqn from an FqnToConfig object. Args: fqn (str): The fully qualified name to match against the config patterns. config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. - default (Optional[AOBaseConfig]): The default config to use if no match is found. Defaults to None. Returns: c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. - Otherwise we will return the config of the first matching regex pattern in FqnToConfig. - Finally, we will return the default config if it is specified. + Otherwise we will return the config of the first matching regex pattern in FqnToConfig. """ c = None if fqn in config.fqn_to_config: @@ -2521,14 +2530,10 @@ def _get_config_for_fqn( # we'll apply the config for first fully matched pattern c = config.fqn_to_config[maybe_module_or_param_fqn_pattern] break - else: - # fallback to use default if no module specific config is provided - if default is not None: - c = default return c -def select_module_if_filter_fn_or_contains_params_matching_pattern( +def _select_module_if_filter_fn_or_contains_params_matching_pattern( module: nn.Module, fqn: str, config: FqnToConfig, From 45b8133b3acd7057e36016c189daf7e4eeff4d15 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 10:35:52 -0700 Subject: [PATCH 41/47] cleanup --- test/quantization/test_quant_api.py | 3 +- torchao/quantization/quant_api.py | 43 ++++++++++++----------------- 2 files changed, 19 insertions(+), 27 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 235839a60c..509e2f3c85 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1081,9 +1081,8 @@ def forward(self, x): assert model[0].linear.weight.scale.numel() == 128 assert isinstance(model[0].param, Float8Tensor) - assert model[0].param.scale.numel() == 1 - def test_quantize_modle_param_double_specified(self): + def test_quantize_model_param_double_specified(self): model = ( torch.nn.Sequential( torch.nn.Linear(128, 128), diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index f7b9e9cef7..079316a1ff 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2383,28 +2383,24 @@ def _fpx_weight_only_transform( class FqnToConfig(AOBaseConfig): """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). - Users can either explicitly specify a specific FQN or pass in a regex pattern to match FQNs. - - `quantize_` will first try to replace all modules matching the keys of FqnToConfig.fqn_to_config, - It will then will try to replace any parameters that match the keys, ignoring modules that have already been transformed by the previous flow (modules that contain AOBaseTensor parameters): - Args: `fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an ordered dictionary from - (1). fully qualified name (fqn) of module or + (1). fully qualified name (fqn) of module or parameter (2). regex of fully qualified name (in python `re` module regex format), should start with prefix "re:" or (3). "_default" - to the config that we want to apply to the module or None + to the config that we want to apply to the module/param or None Config key ordered by precedence: * fully qualified module name, e.g. `language.layers.0.q_proj` * regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, - whiever regex fully matches the module fqn first will be applied + whichever regex fully matches the module fqn first will be applied (order of keys for dictionary are kept consistent since we are using OrderedDict) - * fully qualified parameter name - * regex for parameter names - * "_default", fallback for **all modules** if no match for all previous keys + * fully qualified parameter name, e.g. `language.layers.0.q_proj.weight` + * regex for parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj.weight`. + The first regex that matches will be applied. + * "_default", fallback if no match for all previous keys (Note, when using `_default`, the config is applied to all modules, to apply it to only a subset of modules, e.g. with some types, it's better to filter the modules that we don't want to quantize before hand and configure them to @@ -2430,17 +2426,14 @@ def __post_init__(self): # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: raise ValueError( - "Both module_fqn_to_config and fqn_to_config are non-empty, expected one to be" + "Both module_fqn_to_config and fqn_to_config are non-empty, expected one to be empty!" ) - if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) == 0: - self.fqn_to_config = self.module_fqn_to_config + if len(self.module_fqn_to_config) > 0: + assert len(self.fqn_to_config) == 0 warnings.warn( "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" ) - elif len(self.fqn_to_config) > 0 and len(self.module_fqn_to_config) == 0: - self.module_fqn_to_config = self.fqn_to_config - else: - self.module_fqn_to_config = self.fqn_to_config + self.fqn_to_config = self.module_fqn_to_config # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: @@ -2507,13 +2500,13 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi def _get_config_for_fqn(fqn: str, config: FqnToConfig): """Helper function to get the config for a given fqn from an FqnToConfig object. - Args: - fqn (str): The fully qualified name to match against the config patterns. - config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. - - Returns: - c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. - Otherwise we will return the config of the first matching regex pattern in FqnToConfig. + Args: + fqn (str): The fully qualified name to match against the config patterns. + config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. + torchao/quantization/quant_api.py + Returns: + c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. + Otherwise we will return the config of the first matching regex pattern in FqnToConfig. """ c = None if fqn in config.fqn_to_config: From d1805efe28a1af1f08913f284478450a194fec96 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 10:45:28 -0700 Subject: [PATCH 42/47] update docstring --- test/quantization/test_quant_api.py | 3 +-- torchao/quantization/quant_api.py | 3 +++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 509e2f3c85..7464c7e427 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -1061,8 +1061,6 @@ def forward(self, x): quant_config = FqnToConfig( { - # only this config should be applied, as module fqn takes precedence - # if we have both a linear and param at the same level, "_default": Float8DynamicActivationFloat8WeightConfig( granularity=PerRow(), ), @@ -1081,6 +1079,7 @@ def forward(self, x): assert model[0].linear.weight.scale.numel() == 128 assert isinstance(model[0].param, Float8Tensor) + assert model[0].param.scale.numel() == 1 def test_quantize_model_param_double_specified(self): model = ( diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 079316a1ff..b47ce46e07 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2405,6 +2405,8 @@ class FqnToConfig(AOBaseConfig): it to only a subset of modules, e.g. with some types, it's better to filter the modules that we don't want to quantize before hand and configure them to None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) + `module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: To maintain BC with ModuleFqnToConfig, to be deprecated later + `version`: int: Version of config to use. Note: - The order of patterns in the OrderedDict may matter as only the first matching pattern is applied @@ -2434,6 +2436,7 @@ def __post_init__(self): "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" ) self.fqn_to_config = self.module_fqn_to_config + self.module_fqn_to_config = OrderedDict() # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: From f68f572024425efce6095541beb547939775224c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Thu, 9 Oct 2025 12:21:11 -0700 Subject: [PATCH 43/47] update tests --- test/quantization/test_quant_api.py | 143 ++++++++++++++++++++++------ torchao/quantization/quant_api.py | 59 +++++++----- 2 files changed, 148 insertions(+), 54 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 7464c7e427..9eb40e031c 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -932,6 +932,121 @@ def test_quantize_param_fqn_exact(self): assert isinstance(model.experts.gate_up_proj, Float8Tensor) + def test_non_specified_unaffected(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 128) + self.linear2 = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 1 + + # ensure linear2 is not quantized + assert not isinstance(model.linear2.weight, Float8Tensor) + + def test_precedence_pattern_order(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 128) + self.linear2 = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + def test_quantize_param_and_module_fqn(self): from transformers import AutoConfig from transformers.models.llama4.modeling_llama4 import Llama4TextMoe @@ -1019,34 +1134,6 @@ def forward(self, x): assert isinstance(model.linear.weight, Float8Tensor) assert model.linear.weight.scale.numel() == 128 - def test_quantize_module_default_bc(self): - class TestModule(torch.nn.Module): - def __init__(self): - super().__init__() - self.linear = torch.nn.Linear(128, 128) - - def forward(self, x): - return self.linear(x) - - model = TestModule().to(torch.bfloat16).cuda() - - quant_config = FqnToConfig( - { - # only this config should be applied, as module fqn takes precedence - "_default": Float8DynamicActivationFloat8WeightConfig( - granularity=PerRow(), - ), - } - ) - - quantize_( - model, - quant_config, - ) - - assert isinstance(model.linear.weight, Float8Tensor) - assert model.linear.weight.scale.numel() == 128 - def test_quantize_module_default_param_quant(self): class TestModule(torch.nn.Module): def __init__(self): diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index b47ce46e07..064f3f0d46 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2426,17 +2426,15 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. - if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: - raise ValueError( - "Both module_fqn_to_config and fqn_to_config are non-empty, expected one to be empty!" - ) if len(self.module_fqn_to_config) > 0: assert len(self.fqn_to_config) == 0 warnings.warn( "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" ) self.fqn_to_config = self.module_fqn_to_config - self.module_fqn_to_config = OrderedDict() + if len(self.fqn_to_config) > 0: + assert len(self.module_fqn_to_config) == 0 + self.module_fqn_to_config = self.fqn_to_config # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: @@ -2471,24 +2469,31 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. """ # First we see if our module fqn matches with FqnToConfig, if so, we apply the appropriate transform - c = _get_config_for_fqn(fqn, config) - if c is not None: - handler = _QUANTIZE_CONFIG_HANDLER[type(c)] - return handler(module, c) + config_contains_match, c = _get_config_for_fqn(fqn, config) + if config_contains_match: + # special case to handle None in config + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + else: + return module # If no config is found, we see if any of our top-level parameter FQNs matches with FqnToConfig for parameter_name, param in list(module.named_parameters()): if parameter_name in dir(module): parameter_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name - c = _get_config_for_fqn(parameter_fqn, config) - if c is not None: - if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: - handler = _QUANTIZE_CONFIG_HANDLER[type(c)] - return handler(module, c, parameter_name=parameter_name) + config_contains_match, c = _get_config_for_fqn(parameter_fqn, config) + if config_contains_match: + if c is not None: + if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c, parameter_name=parameter_name) + else: + raise NotImplementedError( + f"Parameter quantization for {type(c)} not supported currently!" + ) else: - raise NotImplementedError( - f"Parameter quantization for {type(c)} not supported currently!" - ) + return module # If no module_fqn or parameter_fqn matches, then we apply _default c = config.fqn_to_config.get("_default", None) @@ -2503,21 +2508,22 @@ def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfi def _get_config_for_fqn(fqn: str, config: FqnToConfig): """Helper function to get the config for a given fqn from an FqnToConfig object. - Args: - fqn (str): The fully qualified name to match against the config patterns. - config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. - torchao/quantization/quant_api.py - Returns: - c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. - Otherwise we will return the config of the first matching regex pattern in FqnToConfig. + Args: + fqn (str): The fully qualified name to match against the config patterns. + config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. + + Returns: + c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. + Otherwise we will return the config of the first matching regex pattern in FqnToConfig. """ - c = None + found, c = False, None if fqn in config.fqn_to_config: assert not fqn.startswith("re:"), ( f"Error: Exact match but regex {fqn} specified." ) # Maybe: we can add module type specific config in the future, if needed c = config.fqn_to_config[fqn] + found = True else: for maybe_module_or_param_fqn_pattern in config.fqn_to_config: if not maybe_module_or_param_fqn_pattern.startswith("re:"): @@ -2525,8 +2531,9 @@ def _get_config_for_fqn(fqn: str, config: FqnToConfig): elif re.fullmatch(maybe_module_or_param_fqn_pattern[3:], fqn): # we'll apply the config for first fully matched pattern c = config.fqn_to_config[maybe_module_or_param_fqn_pattern] + found = True break - return c + return found, c def _select_module_if_filter_fn_or_contains_params_matching_pattern( From 05c9518644a591d0990c1ca75c0305d5c19bf84e Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 13 Oct 2025 10:56:47 -0700 Subject: [PATCH 44/47] update CR feedback --- torchao/quantization/quant_api.py | 51 ++++++++++++++++++------------- 1 file changed, 29 insertions(+), 22 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 064f3f0d46..c7f7d265ea 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -21,7 +21,6 @@ import warnings from collections import OrderedDict from dataclasses import dataclass, field -from functools import partial from typing import Any, Callable, List, Optional, Tuple, Union from typing import OrderedDict as OrderedDictType @@ -478,7 +477,7 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, config: AOBaseConfig, - filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = _is_linear, device: Optional[torch.types.Device] = None, ): """Convert the weight of linear modules in the model with `config`, model is modified inplace @@ -513,19 +512,19 @@ def quantize_( """ torch._C._log_api_usage_once("torchao.quantization.quantize_") - filter_fn = _is_linear if filter_fn is None else filter_fn if isinstance(config, FqnToConfig): - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _fqn_to_config_handler, - partial( - _select_module_if_filter_fn_or_contains_params_matching_pattern, - config=config, - filter_fn=filter_fn, - ), - device=device, - extra_args=(config,), - ) + if filter_fn is None or filter_fn is _is_linear: + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _fqn_to_config_handler, + _filter_fn_and_param_in_fqn_config, + device=device, + extra_args=(config,), + ) + else: + raise ValueError( + "Only filter_fn= `is_linear` or `None` is supported for FqnToConfig!" + ) return if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] @@ -2426,15 +2425,17 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. + if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: + raise ValueError( + "Both module_fqn_to_config and fqn_to_config are set. Only one should be set." + ) if len(self.module_fqn_to_config) > 0: assert len(self.fqn_to_config) == 0 warnings.warn( "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" ) self.fqn_to_config = self.module_fqn_to_config - if len(self.fqn_to_config) > 0: - assert len(self.module_fqn_to_config) == 0 - self.module_fqn_to_config = self.fqn_to_config + self.module_fqn_to_config = None # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: @@ -2536,11 +2537,10 @@ def _get_config_for_fqn(fqn: str, config: FqnToConfig): return found, c -def _select_module_if_filter_fn_or_contains_params_matching_pattern( +def _select_module_if_contains_params_matching_pattern( module: nn.Module, fqn: str, config: FqnToConfig, - filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, ): """Check if a module should be selected for quantization to be applied @@ -2549,14 +2549,11 @@ def _select_module_if_filter_fn_or_contains_params_matching_pattern( fqn (str): The fully qualified name of the module. config (FqnToConfig): Configuration object containing regex patterns or raw FQNs for parameter quantization. - filter_fn (Optional[Callable[[torch.nn.Module], bool]]): A function that takes a module and fqn and return whether to quantize the module. Returns: bool: True if filter_fn is passed and filter_fn(module, fqn) is True, or if any of the top-level parameters match the patterns in config.fqn_to_config False otherwise. """ - if filter_fn is not None and filter_fn(module, fqn): - return True for name, param in module.named_parameters(): if name in dir(module) and not isinstance(param, TorchAOBaseTensor): parameter_fqn = f"{fqn}.{name}" if fqn != "" else name @@ -2569,6 +2566,16 @@ def _select_module_if_filter_fn_or_contains_params_matching_pattern( return False +def _filter_fn_and_param_in_fqn_config(mod, fqn, config, filter_fn): + param_in_fqn_config = _select_module_if_contains_params_matching_pattern( + mod, fqn, config=config + ) + if filter_fn is None: + return param_in_fqn_config + else: + return filter_fn(mod, fqn) and param_in_fqn_config + + torch.serialization.add_safe_globals( [ _int8_asymm_per_token_quant, From addfc11772546dcc525ee03f8129f63a14acbd2c Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 13 Oct 2025 11:02:55 -0700 Subject: [PATCH 45/47] fix import --- torchao/quantization/quant_api.py | 1 - 1 file changed, 1 deletion(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 9c9ff13f3e..5cf96302a3 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -97,7 +97,6 @@ ) from torchao.utils import ( TorchAOBaseTensor, - _ConfigDeprecationWrapper, is_MI300, is_sm_at_least_89, is_sm_at_least_90, From 041e00b583fd0abd083cb4fe1e8c2fe77326b119 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 13 Oct 2025 11:36:45 -0700 Subject: [PATCH 46/47] fix parq --- test/prototype/test_parq.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 7d8716204c..119bfb8d05 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -588,11 +588,9 @@ def test_int8_dynamic_activation_intx_e2e( n for n, m in model.named_modules() if isinstance(m, nn.Embedding) } reg_param_names.add("_default") - module_fqn_to_config = ( - model.config.quantization_config.quant_type.module_fqn_to_config - ) - self.assertEqual(set(module_fqn_to_config.keys()), reg_param_names) - for torchao_config in module_fqn_to_config.values(): + fqn_to_config = model.config.quantization_config.quant_type.fqn_to_config + self.assertEqual(set(fqn_to_config.keys()), reg_param_names) + for torchao_config in fqn_to_config.values(): self.assertTrue(isinstance(torchao_config, config.__class__)) From ef20a863430d37679f1c769fcc1341360d5b3293 Mon Sep 17 00:00:00 2001 From: Jesse Cai Date: Mon, 13 Oct 2025 12:05:59 -0700 Subject: [PATCH 47/47] update --- torchao/quantization/quant_api.py | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 5cf96302a3..c4e2904130 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -2355,17 +2355,7 @@ def __post_init__(self): torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. - if len(self.module_fqn_to_config) > 0 and len(self.fqn_to_config) > 0: - raise ValueError( - "Both module_fqn_to_config and fqn_to_config are set. Only one should be set." - ) - if len(self.module_fqn_to_config) > 0: - assert len(self.fqn_to_config) == 0 - warnings.warn( - "Config Deprecation: ModuleFqnToConfig is deprecated and will no longer be supported in a future release, please use FqnToConfig" - ) - self.fqn_to_config = self.module_fqn_to_config - self.module_fqn_to_config = None + self.module_fqn_to_config = self.fqn_to_config # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` if "_default" in self.fqn_to_config: