diff --git a/docs/source/api_ref_quantization.rst b/docs/source/api_ref_quantization.rst index c163a4b06a..d4a661c6d8 100644 --- a/docs/source/api_ref_quantization.rst +++ b/docs/source/api_ref_quantization.rst @@ -34,6 +34,7 @@ Inference APIs for quantize\_ Int8DynamicActivationInt8WeightConfig UIntXWeightOnlyConfig FPXWeightOnlyConfig + FqnToConfig .. currentmodule:: torchao.quantization diff --git a/docs/source/serving.rst b/docs/source/serving.rst index d95132ded7..58c435ecb0 100644 --- a/docs/source/serving.rst +++ b/docs/source/serving.rst @@ -175,7 +175,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati from torchao.quantization.quant_api import ( IntxWeightOnlyConfig, Int8DynamicActivationIntxWeightConfig, - ModuleFqnToConfig, + FqnToConfig, quantize_, ) from torchao.quantization.granularity import PerGroup, PerAxis @@ -198,7 +198,7 @@ Quantizing the model for mobile deployment using TorchAO's ``Int8DynamicActivati weight_granularity=PerGroup(32), weight_scale_dtype=torch.bfloat16, ) - quant_config = ModuleFqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) + quant_config = FqnToConfig({"_default": linear_config, "model.embed_tokens": embedding_config}) quantization_config = TorchAoConfig(quant_type=quant_config, include_embedding=True, untie_embedding_weights=True, modules_to_not_convert=[]) # either use `untied_model_id` or `untied_model_local_path` diff --git a/docs/source/torchao_vllm_integration.md b/docs/source/torchao_vllm_integration.md index 1ca027a124..ef3719cb9f 100644 --- a/docs/source/torchao_vllm_integration.md +++ b/docs/source/torchao_vllm_integration.md @@ -54,22 +54,21 @@ assert isinstance(config, AOBaseConfig) All quantization configurations inherit from {class}`torchao.core.config.AOBaseConfig`, which provides serialization and validation capabilities. ``` -(module-level-configuration)= -### 3. Module-Level Configuration +(fqn-configuration)= +### 3. FQN Configuration -For granular control, use `ModuleFqnToConfig`: +For granular control, use `FqnToConfig`: ```python -from torchao.quantization import ModuleFqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig +from torchao.quantization import FqnToConfig, Int4WeightOnlyConfig, Int8WeightOnlyConfig -config = ModuleFqnToConfig({ +config = FqnToConfig({ "model.layers.0.self_attn.q_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.self_attn.k_proj": Int4WeightOnlyConfig(group_size=64), "model.layers.0.mlp.gate_proj": Int8WeightOnlyConfig(), "_default": Int4WeightOnlyConfig(group_size=128, version=1) # Default for other modules }) ``` - (usage-examples)= ## Usage Examples diff --git a/test/prototype/test_parq.py b/test/prototype/test_parq.py index 7d8716204c..119bfb8d05 100644 --- a/test/prototype/test_parq.py +++ b/test/prototype/test_parq.py @@ -588,11 +588,9 @@ def test_int8_dynamic_activation_intx_e2e( n for n, m in model.named_modules() if isinstance(m, nn.Embedding) } reg_param_names.add("_default") - module_fqn_to_config = ( - model.config.quantization_config.quant_type.module_fqn_to_config - ) - self.assertEqual(set(module_fqn_to_config.keys()), reg_param_names) - for torchao_config in module_fqn_to_config.values(): + fqn_to_config = model.config.quantization_config.quant_type.fqn_to_config + self.assertEqual(set(fqn_to_config.keys()), reg_param_names) + for torchao_config in fqn_to_config.values(): self.assertTrue(isinstance(torchao_config, config.__class__)) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 2b3538195e..ef377567a0 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -33,6 +33,7 @@ TensorCoreTiledLayout, ) from torchao.quantization import ( + Float8Tensor, Int4TilePackedTo4dTensor, IntxUnpackedToInt8Tensor, LinearActivationQuantizedTensor, @@ -43,6 +44,7 @@ Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + FqnToConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, @@ -52,6 +54,8 @@ Int8WeightOnlyConfig, IntxWeightOnlyConfig, ModuleFqnToConfig, + PerRow, + PerTensor, Quantizer, TwoStepQuantizer, UIntXWeightOnlyConfig, @@ -850,5 +854,328 @@ def test_int4wo_cuda_serialization(self): common_utils.instantiate_parametrized_tests(TestQuantFlow) +@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") +@unittest.skipIf(not is_sm_at_least_90(), "Checkpoints are produced in SM90+") +class TestFqnToConfig(TestCase): + def test_quantize_param_fqn_exact(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + + def test_non_specified_unaffected(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 128) + self.linear2 = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 1 + + # ensure linear2 is not quantized + assert not isinstance(model.linear2.weight, Float8Tensor) + + def test_precedence_pattern_order(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear1 = torch.nn.Linear(128, 128) + self.linear2 = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "linear1.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + model = TestModule().to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor() + ), + "re:linear.*.weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow() + ), + } + ) + quantize_(model, quant_config) + assert isinstance(model.linear1.weight, Float8Tensor) + assert model.linear1.weight.scale.numel() == 128 + + def test_quantize_param_and_module_fqn(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "experts.gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) + assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + + def test_quantize_param_and_module_fqn_regex(self): + from transformers import AutoConfig + from transformers.models.llama4.modeling_llama4 import Llama4TextMoe + + config = AutoConfig.from_pretrained( + "unsloth/Llama-4-Scout-17B-16E-Instruct" + ).text_config + model = Llama4TextMoe(config).to(torch.bfloat16).cuda() + quant_config = FqnToConfig( + { + "re:.*gate_up_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "shared_expert.gate_proj": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.experts.gate_up_proj, Float8Tensor) + assert isinstance(model.shared_expert.gate_proj.weight, Float8Tensor) + assert model.shared_expert.gate_proj.weight.scale.numel() == 1 + + def test_quantize_modle_exact_match_preference(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + + def forward(self, x): + return self.linear(x) + + model = TestModule().to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + # only this config should be applied, as module fqn takes precedence + "linear": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "re:linear": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model.linear.weight, Float8Tensor) + assert model.linear.weight.scale.numel() == 128 + + def test_quantize_module_default_param_quant(self): + class TestModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.linear = torch.nn.Linear(128, 128) + self.param = torch.nn.Parameter(torch.randn(128, 128)) + + def forward(self, x): + return self.linear(x) + + model = torch.nn.Sequential(TestModule()).to(torch.bfloat16).cuda() + + quant_config = FqnToConfig( + { + "_default": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + "0.param": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model[0].linear.weight, Float8Tensor) + assert model[0].linear.weight.scale.numel() == 128 + + assert isinstance(model[0].param, Float8Tensor) + assert model[0].param.scale.numel() == 1 + + def test_quantize_model_param_double_specified(self): + model = ( + torch.nn.Sequential( + torch.nn.Linear(128, 128), + torch.nn.Linear(128, 128), + ) + .to(torch.bfloat16) + .cuda() + ) + quant_config = FqnToConfig( + { + "re:.*weight": Float8DynamicActivationFloat8WeightConfig( + granularity=PerTensor(), + ), + # only this config should be applied, as module fqn takes precedence + "0": Float8DynamicActivationFloat8WeightConfig( + granularity=PerRow(), + ), + } + ) + + quantize_( + model, + quant_config, + ) + + assert isinstance(model[0].weight, Float8Tensor) + assert model[0].weight.scale.numel() == 128 + + def test_unsupported_param_config_raises_not_implemented_error(self): + """Test that using an unsupported parameter config raises NotImplementedError.""" + from dataclasses import dataclass + + from torchao.core.config import AOBaseConfig + + # Create a custom config that doesn't have a registered parameter handler + @dataclass + class UnsupportedParamConfig(AOBaseConfig): + some_value: int = 42 + + # Create a simple model + model = torch.nn.Sequential(torch.nn.Linear(10, 5).cuda().bfloat16()) + + # Create config with unsupported parameter handler + quant_config = FqnToConfig( + { + "0.weight": UnsupportedParamConfig(), + } + ) + + # This should raise NotImplementedError + with self.assertRaises(NotImplementedError) as context: + quantize_(model, quant_config) + + # Check that the error message contains the expected text + self.assertIn("Parameter quantization for", str(context.exception)) + self.assertIn("not supported currently", str(context.exception)) + self.assertIn("UnsupportedParamConfig", str(context.exception)) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index aa19aa1890..106328721f 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -52,6 +52,7 @@ Float8StaticActivationFloat8WeightConfig, Float8WeightOnlyConfig, FPXWeightOnlyConfig, + FqnToConfig, GemliteUIntXWeightOnlyConfig, Int4DynamicActivationInt4WeightConfig, Int4WeightOnlyConfig, @@ -137,6 +138,7 @@ "FPXWeightOnlyConfig", "GemliteUIntXWeightOnlyConfig", "AOPerModuleConfig", + "FqnToConfig", "ModuleFqnToConfig", # tensor subclasses "Int4Tensor", diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 139b14cf3f..c4e2904130 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -96,42 +96,28 @@ to_weight_tensor_with_linear_activation_quantization_metadata, ) from torchao.utils import ( + TorchAOBaseTensor, is_MI300, is_sm_at_least_89, is_sm_at_least_90, ) from .autoquant import AutoQuantizableLinearWeight, autoquant -from .GPTQ import ( - Int4WeightOnlyGPTQQuantizer, -) -from .granularity import ( - Granularity, - PerAxis, - PerGroup, - PerRow, - PerTensor, -) +from .GPTQ import Int4WeightOnlyGPTQQuantizer +from .granularity import Granularity, PerAxis, PerGroup, PerRow, PerTensor from .linear_activation_quantized_tensor import ( LinearActivationQuantizedTensor, to_linear_activation_quantized, ) -from .linear_quant_modules import ( - Int4WeightOnlyQuantizer, - Int8DynActInt4WeightQuantizer, -) -from .qat import ( - intx_quantization_aware_training, -) +from .linear_quant_modules import Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer +from .qat import intx_quantization_aware_training from .quant_primitives import ( _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain, quantize_affine, ) -from .subclass import ( - QuantizedLinearWeightBase, -) +from .subclass import QuantizedLinearWeightBase from .unified import Quantizer, TwoStepQuantizer from .utils import _get_per_token_block_size @@ -235,23 +221,22 @@ def _replace_with_custom_fn_if_matches_filter_with_name( if device is not None: model.to(device=device) # move to device before quantization model = replacement_fn(model, cur_fqn[:-1], *extra_args) - return model - else: - named_children_list = list(model.named_children()) - for name, child in named_children_list: - new_child = _replace_with_custom_fn_if_matches_filter_with_name( - child, - replacement_fn, - filter_fn, - f"{cur_fqn}{name}.", - device, - extra_args, - ) - if new_child is not child: - setattr(model, name, new_child) - if device is not None: - model.to(device=device) # move parent module to device - return model + # For parameter quantization, filter_fn(model, cur_fqn) no longer is terminal, as a module may contain both a parameter we want to quantize and subsequent submodules. + named_children_list = list(model.named_children()) + for name, child in named_children_list: + new_child = _replace_with_custom_fn_if_matches_filter_with_name( + child, + replacement_fn, + filter_fn, + f"{cur_fqn}{name}.", + device, + extra_args, + ) + if new_child is not child: + setattr(model, name, new_child) + if device is not None: + model.to(device=device) # move parent module to device + return model def _is_linear(mod, *args): @@ -468,7 +453,7 @@ def insert_subclass(lin): def quantize_( model: torch.nn.Module, config: AOBaseConfig, - filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None, + filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = _is_linear, device: Optional[torch.types.Device] = None, ): """Convert the weight of linear modules in the model with `config`, model is modified inplace @@ -503,17 +488,20 @@ def quantize_( """ torch._C._log_api_usage_once("torchao.quantization.quantize_") - filter_fn = _is_linear if filter_fn is None else filter_fn - if isinstance(config, ModuleFqnToConfig): - _replace_with_custom_fn_if_matches_filter_with_name( - model, - _module_fqn_to_config_handler, - filter_fn, - device=device, - extra_args=(config,), - ) + if isinstance(config, FqnToConfig): + if filter_fn is None or filter_fn is _is_linear: + _replace_with_custom_fn_if_matches_filter_with_name( + model, + _fqn_to_config_handler, + _filter_fn_and_param_in_fqn_config, + device=device, + extra_args=(config,), + ) + else: + raise ValueError( + "Only filter_fn= `is_linear` or `None` is supported for FqnToConfig!" + ) return - if isinstance(config, AOBaseConfig): handler = _QUANTIZE_CONFIG_HANDLER[type(config)] # for each linear in the model, apply the transform if filtering passes @@ -524,7 +512,6 @@ def quantize_( device=device, extra_args=(config,), ) - else: raise AssertionError( """Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead.""" @@ -1819,7 +1806,10 @@ def _float8_dynamic_activation_float8_weight_quantize_tensor(weight, config): @register_quantize_module_handler(Float8DynamicActivationFloat8WeightConfig) def _float8_dynamic_activation_float8_weight_transform( - module: torch.nn.Module, config: Float8DynamicActivationFloat8WeightConfig + module: torch.nn.Module, + config: Float8DynamicActivationFloat8WeightConfig, + *, + parameter_name: str = "weight", ): assert is_sm_at_least_89() or is_MI300(), ( "Float8 dynamic activation quantization is only supported on CUDA>=8.9 and MI300+" @@ -1827,17 +1817,20 @@ def _float8_dynamic_activation_float8_weight_transform( if config.set_inductor_config: torchao.quantization.utils.recommended_inductor_config_setter() - assert hasattr(module, "weight"), ( - "applying float8 dynamic activation quant requires module to have weight attribute" + assert hasattr(module, parameter_name), ( + "applying float8 dynamic activation quant requires module to have parameter {parameter_name} attribute" + f"but {module} does not have one" ) if isinstance(module, Float8Linear): module = _unwrap_float8_linear(module) - - quantized_weight = _float8_dynamic_activation_float8_weight_quantize_tensor( - module.weight, config + quantized_tensor = _float8_dynamic_activation_float8_weight_quantize_tensor( + getattr(module, parameter_name), config + ) + setattr( + module, + parameter_name, + torch.nn.Parameter(quantized_tensor, requires_grad=False), ) - module.weight = torch.nn.Parameter(quantized_weight, requires_grad=False) module.extra_repr = types.MethodType(_linear_extra_repr, module) return module @@ -2316,64 +2309,193 @@ def _fpx_weight_only_transform( @dataclass -class ModuleFqnToConfig(AOBaseConfig): - """Per module configurations for torchao quantize_ API +class FqnToConfig(AOBaseConfig): + """Configuration class for applying different quantization configs to modules or parameters based on their fully qualified names (FQNs). Args: - `module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an + `fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: an ordered dictionary from - (1). fully qualified name (fqn) of module or + (1). fully qualified name (fqn) of module or parameter (2). regex of fully qualified name (in python `re` module regex format), should start with prefix "re:" or (3). "_default" - to the config that we want to apply to the module or None + to the config that we want to apply to the module/param or None Config key ordered by precedence: * fully qualified module name, e.g. `language.layers.0.q_proj` * regex for module names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj`, - whiever regex fully matches the module fqn first will be applied + whichever regex fully matches the module fqn first will be applied (order of keys for dictionary are kept consistent since we are using OrderedDict) - * "_default", fallback for **all modules** if no match for all previous keys + * fully qualified parameter name, e.g. `language.layers.0.q_proj.weight` + * regex for parameter names, must start with `re:`, e.g. `re:language\.layers\..+\.q_proj.weight`. + The first regex that matches will be applied. + * "_default", fallback if no match for all previous keys (Note, when using `_default`, the config is applied to all modules, to apply it to only a subset of modules, e.g. with some types, it's better to filter the modules that we don't want to quantize before hand and configure them to None, e.g. `{"re:.+norm.+": None, "_default": linear_config}`) + `module_fqn_to_config`: typing.OrderedDict[str, Optional[AOBaseConfig]]: To maintain BC with ModuleFqnToConfig, to be deprecated later + `version`: int: Version of config to use. + + Note: + - The order of patterns in the OrderedDict may matter as only the first matching pattern is applied + - "_default" is ignored for parameter replacement. """ + fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( + default_factory=OrderedDict + ) + # to maintain BC, we keep the previous module_fqn_to_config field module_fqn_to_config: OrderedDictType[str, Optional[AOBaseConfig]] = field( default_factory=OrderedDict ) + version: int = 1 def __post_init__(self): - torch._C._log_api_usage_once("torchao.quantization.ModuleFqnToConfig") + torch._C._log_api_usage_once("torchao.quantization.FqnToConfig") + # This code handles BC compatibility with `ModuleFqnToConfig`. It ensures that `self.module_fqn_to_config` and `self.fqn_to_config` share the same object. + self.module_fqn_to_config = self.fqn_to_config + + # TODO we plan to deprecate `_default later, so raise a warning if we find it passed in` + if "_default" in self.fqn_to_config: + warnings.warn( + "Config Deprecation: _default is deprecated and will no longer be supported in a future release." + ) + + +# maintain BC +ModuleFqnToConfig = FqnToConfig + +# for now, we need to keep track of what configs support custom param quantization. +# Once we've updated all the transform functions to take in a custom_param kwarg, we can delete this object and the subsequent check +CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS = { + Float8DynamicActivationFloat8WeightConfig, +} -def _module_fqn_to_config_handler( - module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig -): - c = None - if module_fqn in config.module_fqn_to_config: - # Maybe: we can add module type specific config in the future, in needed - c = config.module_fqn_to_config[module_fqn] - else: - for maybe_module_fqn_pattern in config.module_fqn_to_config: - if not maybe_module_fqn_pattern.startswith("re:"): - continue - elif re.fullmatch(maybe_module_fqn_pattern[3:], module_fqn): - # we'll apply the config for first fully matched pattern - c = config.module_fqn_to_config[maybe_module_fqn_pattern] - break - else: - # fallback to use default if no module specific config is provided - c = config.module_fqn_to_config.get("_default", None) +def _fqn_to_config_handler(module: torch.nn.Module, fqn: str, config: FqnToConfig): + """This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig. + + Args: + module (torch.nn.Module): The module to be processed. + fqn (str): The fully qualified name of the module containing the parameters. + config (FqnToConfig): Configuration object containing regex patterns / fqn mapped + to quantization configurations. + + Returns: + torch.nn.Module: The modified module with quantized parameters. + + Raises: + NotImplementedError: If the quantization configuration is not yet supported for parameter quantization. + """ + # First we see if our module fqn matches with FqnToConfig, if so, we apply the appropriate transform + config_contains_match, c = _get_config_for_fqn(fqn, config) + if config_contains_match: + # special case to handle None in config + if c is not None: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c) + else: + return module + + # If no config is found, we see if any of our top-level parameter FQNs matches with FqnToConfig + for parameter_name, param in list(module.named_parameters()): + if parameter_name in dir(module): + parameter_fqn = f"{fqn}.{parameter_name}" if fqn != "" else parameter_name + config_contains_match, c = _get_config_for_fqn(parameter_fqn, config) + if config_contains_match: + if c is not None: + if type(c) in CUSTOM_PARAM_QUANTIZATION_SUPPOTED_CONFIGS: + handler = _QUANTIZE_CONFIG_HANDLER[type(c)] + return handler(module, c, parameter_name=parameter_name) + else: + raise NotImplementedError( + f"Parameter quantization for {type(c)} not supported currently!" + ) + else: + return module + + # If no module_fqn or parameter_fqn matches, then we apply _default + c = config.fqn_to_config.get("_default", None) if c is not None: handler = _QUANTIZE_CONFIG_HANDLER[type(c)] return handler(module, c) + # Else return unmodified module return module +def _get_config_for_fqn(fqn: str, config: FqnToConfig): + """Helper function to get the config for a given fqn from an FqnToConfig object. + + Args: + fqn (str): The fully qualified name to match against the config patterns. + config (FqnToConfig): The FqnToConfig object containing mapping of FQNs or regex patterns to quantization configs. + + Returns: + c (AOBaseConfig): If fqn is specified exactly in FqnToConfig, then fqn_to_config[fqn] will be returned. + Otherwise we will return the config of the first matching regex pattern in FqnToConfig. + """ + found, c = False, None + if fqn in config.fqn_to_config: + assert not fqn.startswith("re:"), ( + f"Error: Exact match but regex {fqn} specified." + ) + # Maybe: we can add module type specific config in the future, if needed + c = config.fqn_to_config[fqn] + found = True + else: + for maybe_module_or_param_fqn_pattern in config.fqn_to_config: + if not maybe_module_or_param_fqn_pattern.startswith("re:"): + continue + elif re.fullmatch(maybe_module_or_param_fqn_pattern[3:], fqn): + # we'll apply the config for first fully matched pattern + c = config.fqn_to_config[maybe_module_or_param_fqn_pattern] + found = True + break + return found, c + + +def _select_module_if_contains_params_matching_pattern( + module: nn.Module, + fqn: str, + config: FqnToConfig, +): + """Check if a module should be selected for quantization to be applied + + Args: + module (torch.nn.Module): The module to check for parameter pattern matches. + fqn (str): The fully qualified name of the module. + config (FqnToConfig): Configuration object containing regex patterns or raw FQNs for + parameter quantization. + + Returns: + bool: True if filter_fn is passed and filter_fn(module, fqn) is True, or if any of the top-level parameters match the patterns in config.fqn_to_config + False otherwise. + """ + for name, param in module.named_parameters(): + if name in dir(module) and not isinstance(param, TorchAOBaseTensor): + parameter_fqn = f"{fqn}.{name}" if fqn != "" else name + for pattern in config.fqn_to_config: + if (pattern == parameter_fqn) or ( + pattern.startswith("re:") + and re.fullmatch(pattern[3:], parameter_fqn) + ): + return True + return False + + +def _filter_fn_and_param_in_fqn_config(mod, fqn, config, filter_fn): + param_in_fqn_config = _select_module_if_contains_params_matching_pattern( + mod, fqn, config=config + ) + if filter_fn is None: + return param_in_fqn_config + else: + return filter_fn(mod, fqn) and param_in_fqn_config + + def _unwrap_float8_linear(module: Float8Linear) -> nn.Linear: """ Unwrap a torchao Float8Linear by returning a nn.Linear with the same weights and bias.