From 131f2cbb9e311f94ef8833dcaffbdb83bc9df895 Mon Sep 17 00:00:00 2001 From: xin3he Date: Wed, 19 Jun 2024 10:54:44 +0800 Subject: [PATCH 01/14] start from RTNConfig Signed-off-by: xin3he --- neural_compressor/torch/quantization/config.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 27a056d3284..2cd86986fdc 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -104,6 +104,8 @@ class RTNConfig(BaseConfig): "double_quant_bits", "double_quant_use_sym", "double_quant_group_size", + # quant_lm_head + "quant_lm_head", ] supported_configs: List[OperatorConfig] = [] @@ -125,6 +127,8 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = False, double_quant_group_size: int = 256, + # double quant + quant_lm_head: bool = False, # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): @@ -145,6 +149,7 @@ def __init__( double_quant_bits (int): Number of bits used to represent double_quant scale. Default is 4. double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric. Default is True. double_quant_group_size (int): Size of double_quant groups. Default is 32. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. """ super().__init__(white_list=white_list) self.dtype = dtype @@ -162,7 +167,12 @@ def __init__( self.double_quant_dtype = double_quant_dtype self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size + self.quant_lm_head = quant_lm_head self._post_init() # initialize global & local configuration + if not self.quant_lm_head: + # use .* for re.match + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32")) @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: @@ -193,8 +203,9 @@ def register_supported_configs(cls) -> List[OperatorConfig]: double_quant_dtype=["int"], double_quant_use_sym=[True, False], double_quant_group_size=[32, -1, 1, 4, 8, 16, 64, 128, 256, 512, 1024], + quant_lm_head=[False, True], ) - operators = [torch.nn.Linear] + operators = list(WOQ_WHITE_LIST) supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators)) cls.supported_configs = supported_configs From b4612910469a60ae0accf13c4677c6c4c2041a55 Mon Sep 17 00:00:00 2001 From: xin3he Date: Wed, 19 Jun 2024 15:54:53 +0800 Subject: [PATCH 02/14] fix bug Signed-off-by: xin3he --- neural_compressor/common/base_config.py | 14 ++++++++++---- neural_compressor/torch/quantization/config.py | 13 +++++++++---- neural_compressor/torch/utils/utility.py | 1 - 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 3f65a2ea9c0..54e7bf53e71 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -198,10 +198,16 @@ def local_config(self): def local_config(self, config): self._local_config = config - def set_local(self, operator_name: Union[str, Callable], config: BaseConfig) -> BaseConfig: - if operator_name in self.local_config: - logger.warning("The configuration for %s has already been set, update it.", operator_name) - self.local_config[operator_name] = config + def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: + if hasattr(operator_name_or_list, "__iter__"): + for operator_name in operator_name_or_list: + if operator_name in self.local_config: + logger.warning("The configuration for %s has already been set, update it.", operator_name) + self.local_config[operator_name] = config + else: + if operator_name_or_list in self.local_config: + logger.warning("The configuration for %s has already been set, update it.", operator_name) + self.local_config[operator_name_or_list] = config return self def to_dict(self): diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 2cd86986fdc..07d15c20603 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -169,10 +169,6 @@ def __init__( self.double_quant_group_size = double_quant_group_size self.quant_lm_head = quant_lm_head self._post_init() # initialize global & local configuration - if not self.quant_lm_head: - # use .* for re.match - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32")) @classmethod def register_supported_configs(cls) -> List[OperatorConfig]: @@ -209,6 +205,15 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=linear_rtn_config, operators=operators)) cls.supported_configs = supported_configs + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + if not self.quant_lm_head: + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result = [] diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index e1c869dca45..2b40317ec38 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -225,7 +225,6 @@ def dump_model_op_stats(mode, tune_cfg): for op, config in tune_cfg.items(): op_type = op[1] config = config.to_dict() - # import pdb; pdb.set_trace() if not config["dtype"] == "fp32": num_bits = config["bits"] group_size = config["group_size"] From 0ae3c7b119a6d3d4fbee4fdb4477e3eaa074388a Mon Sep 17 00:00:00 2001 From: xin3he Date: Wed, 19 Jun 2024 17:10:56 +0800 Subject: [PATCH 03/14] add UTs Signed-off-by: xin3he --- .../torch/algorithms/weight_only/rtn.py | 21 ++++++++++++-- .../torch/quantization/algorithm_entry.py | 2 +- neural_compressor/torch/utils/utility.py | 4 +++ .../quantization/weight_only/test_rtn.py | 28 +++++++++++++++++++ 4 files changed, 52 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/rtn.py b/neural_compressor/torch/algorithms/weight_only/rtn.py index fc083191ffe..ca6f3e499e5 100644 --- a/neural_compressor/torch/algorithms/weight_only/rtn.py +++ b/neural_compressor/torch/algorithms/weight_only/rtn.py @@ -19,12 +19,20 @@ # limitations under the License. +import copy from collections import OrderedDict import torch from neural_compressor.torch.algorithms import Quantizer -from neural_compressor.torch.utils import get_accelerator, is_transformers_imported, logger, set_module +from neural_compressor.torch.utils import ( + get_accelerator, + get_attr, + is_transformers_imported, + logger, + set_attr, + set_module, +) from .utility import cast_fp8, quant_tensor, search_clip @@ -64,6 +72,7 @@ def convert( quantile=1.0, use_full_range=False, use_mse_search=False, + quant_lm_head=False, *args, **kwargs, ): @@ -80,8 +89,10 @@ def convert( quantile (float, optional): percentile of clip. Defaults to 1.0. use_full_range (bool, optional): Choose sym range whether use -2**(bits-1). Defaults to False. - use_mse_search (bool, optional): Whether search clip range. + use_mse_search (bool, optional): Whether to search clip range. Defaults to True. + quant_lm_head (bool, optional): Whether to quantize the lm_head layer. + Defaults to False. Returns: model: fake quantized torch module @@ -93,6 +104,12 @@ def convert( # TODO: refine it later, Put module on device one by one instead of the whole model model.to(device) + # for transformers model. If lm_head is tied from embedding, we deepcopy it. + if quant_lm_head and getattr(getattr(model, "config", None), "tie_word_embeddings", False): + for key in model._tied_weights_keys: + weight = get_attr(model, key) + set_attr(model, key, copy.deepcopy(weight)) + assert isinstance(model, torch.nn.Module), "only support torch module" if is_transformers_imported(): supported_layers = (torch.nn.Linear, transformers.Conv1D) diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index 733e4409b91..856961af532 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -92,7 +92,7 @@ def rtn_entry( } quantizer = get_quantizer(model, quantizer_cls=RTNQuantizer, quant_config=weight_config) - model = quantizer.execute(model, mode=mode) + model = quantizer.execute(model, mode=mode, quant_lm_head=quant_config.quant_lm_head) model.qconfig = configs_mapping model.save = MethodType(save, model) postprocess_model(model, mode, quantizer) diff --git a/neural_compressor/torch/utils/utility.py b/neural_compressor/torch/utils/utility.py index 2b40317ec38..5c764cff7d3 100644 --- a/neural_compressor/torch/utils/utility.py +++ b/neural_compressor/torch/utils/utility.py @@ -102,6 +102,10 @@ def set_module(model, op_name, new_module): setattr(second_last_module, name_list[-1], new_module) +get_attr = fetch_module +set_attr = set_module + + def get_model_info(model: torch.nn.Module, white_module_list: List[Callable]) -> List[Tuple[str, str]]: module_dict = dict(model.named_modules()) filter_result = [] diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index f82185cc82e..eed2c2b9205 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -138,6 +138,34 @@ def test_mse_search(self): except: assert torch.allclose(atol_false, atol_true, atol=0.012), "atol is very close, double checked the logic." + def test_quant_lm_head(self): + # tie_word_embeddings=false + gptj_model = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + device_map=device, + ) + lm_head_id = id(gptj_model.lm_head.weight) + assert id(gptj_model.transformer.wte.weight) != lm_head_id, "The lm_head weight is tied, please check!" + quant_config = RTNConfig(quant_lm_head=True) + model = prepare(gptj_model, quant_config) + model = convert(model) + + # tie_word_embeddings=true + opt_model = transformers.AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-random-OPTForCausalLM", + device_map=device, + ) + lm_head_id = id(opt_model.lm_head.weight) + assert ( + id(opt_model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The lm_head weight is not tied, please check!" + quant_config = RTNConfig(quant_lm_head=True) + model = prepare(opt_model, quant_config) + model = convert(model) + assert ( + id(model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The tied lm_head weight is not deep copied, please check!" + def test_layer_wise(self): model = copy.deepcopy(self.tiny_gptj) quant_config = RTNConfig( From 39f649ce8dc139d60c933e357a23deb795b2be1c Mon Sep 17 00:00:00 2001 From: xin3he Date: Thu, 20 Jun 2024 17:17:01 +0800 Subject: [PATCH 04/14] update UT Signed-off-by: xin3he --- .../torch/quantization/config.py | 54 ++++++++++++-- .../weight_only/test_hqq_quantizer.py} | 70 +++++++++++++++---- .../weight_only/hqq/test_hqq_config.py | 44 ------------ .../weight_only/hqq/test_packer.py | 16 ----- .../quantization/weight_only/test_awq.py | 54 +++++++++----- 5 files changed, 144 insertions(+), 94 deletions(-) rename test/3x/torch/{quantization/weight_only/hqq/test_q_tensor.py => algorithms/weight_only/test_hqq_quantizer.py} (51%) delete mode 100644 test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py delete mode 100644 test/3x/torch/quantization/weight_only/hqq/test_packer.py diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 07d15c20603..01a72c658f7 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -127,7 +127,7 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = False, double_quant_group_size: int = 256, - # double quant + # quant lm_head quant_lm_head: bool = False, # Tuning space white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, @@ -272,6 +272,8 @@ class GPTQConfig(BaseConfig): # layer wise params "use_layer_wise", "model_path", + # quant lm_head + "quant_lm_head", # gptq params "act_order", "percdamp", @@ -295,6 +297,8 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = False, double_quant_group_size: int = 256, + # double quant + quant_lm_head: bool = False, # gptq params act_order: bool = False, percdamp: float = 0.01, @@ -318,6 +322,7 @@ def __init__( double_quant_bits (int): Number of bits used to represent double_quant scale. Default is 4. double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric. Default is True. double_quant_group_size (int): Size of double_quant groups. Default is 32. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. act_order (bool): Whether to sort Hessian's diagonal values to rearrange channel-wise quantization order. Default is False. percdamp (float): Percentage of Hessian's diagonal values' average, which will be added to @@ -328,6 +333,7 @@ def __init__( This option mitigate actorder's extra computational requirements. Default is False. """ + assert not quant_lm_head, "GPTQ doesn't support lm_head quantization currently, it's coming soon!" super().__init__(white_list=white_list) self.dtype = dtype self.bits = bits @@ -348,6 +354,7 @@ def __init__( self.percdamp = percdamp self.block_size = block_size self.static_groups = static_groups + self.quant_lm_head = quant_lm_head self._post_init() # initialize global & local configuration @classmethod @@ -355,10 +362,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] # TODO(Yi) linear_gptq_config = GPTQConfig() - operators = [torch.nn.Linear] + operators = list(WOQ_WHITE_LIST) supported_configs.append(OperatorConfig(config=linear_gptq_config, operators=operators)) cls.supported_configs = supported_configs + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + if not self.quant_lm_head: + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result = [] @@ -408,6 +424,8 @@ class AWQConfig(BaseConfig): "double_quant_bits", "double_quant_use_sym", "double_quant_group_size", + # quant_lm_head + "quant_lm_head", # AWQ params "use_auto_scale", "use_auto_clip", @@ -431,6 +449,8 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = True, double_quant_group_size: int = 256, + # quant lm_head + quant_lm_head: bool = False, # awq use_auto_scale: bool = True, use_auto_clip: bool = True, @@ -453,6 +473,7 @@ def __init__( double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4. double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. double_quant_group_size (int): Size of double_quant groups, default is 32. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. use_auto_scale (bool): Enables best scales search based on activation distribution, default is True. use_auto_clip (bool): Enables clip range search. Defaults to True. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, @@ -473,6 +494,7 @@ def __init__( self.double_quant_dtype = double_quant_dtype self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size + self.quant_lm_head = quant_lm_head self.use_auto_scale = use_auto_scale self.use_auto_clip = use_auto_clip self.folding = folding @@ -483,10 +505,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] # TODO(Yi) linear_awq_config = AWQConfig() - operators = [torch.nn.Linear, torch.nn.functional.linear] + operators = list(WOQ_WHITE_LIST) supported_configs.append(OperatorConfig(config=linear_awq_config, operators=operators)) cls.supported_configs = supported_configs + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + if not self.quant_lm_head: + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result = [] @@ -536,6 +567,8 @@ class TEQConfig(BaseConfig): "double_quant_bits", "double_quant_use_sym", "double_quant_group_size", + # quant_lm_head + "quant_lm_head", # TEQ params "absorb_to_layer", "folding", @@ -558,6 +591,8 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = True, double_quant_group_size: int = 256, + # double quant + quant_lm_head: bool = False, # teq absorb_to_layer: dict = {}, folding: bool = True, @@ -579,6 +614,7 @@ def __init__( double_quant_bits (int): Number of bits used to represent double_quant scale, default is 4. double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. double_quant_group_size (int): Size of double_quant groups, default is 32. + quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. @@ -598,6 +634,7 @@ def __init__( self.double_quant_dtype = double_quant_dtype self.double_quant_use_sym = double_quant_use_sym self.double_quant_group_size = double_quant_group_size + self.quant_lm_head = quant_lm_head self.absorb_to_layer = absorb_to_layer self.folding = folding self._post_init() @@ -607,10 +644,19 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs = [] # TODO(Yi) linear_teq_config = TEQConfig() - operators = [torch.nn.Linear, torch.nn.functional.linear] + operators = list(WOQ_WHITE_LIST) supported_configs.append(OperatorConfig(config=linear_teq_config, operators=operators)) cls.supported_configs = supported_configs + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + if not self.quant_lm_head: + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result = [] diff --git a/test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py b/test/3x/torch/algorithms/weight_only/test_hqq_quantizer.py similarity index 51% rename from test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py rename to test/3x/torch/algorithms/weight_only/test_hqq_quantizer.py index 0548c10e3f1..f17717f8f71 100644 --- a/test/3x/torch/quantization/weight_only/hqq/test_q_tensor.py +++ b/test/3x/torch/algorithms/weight_only/test_hqq_quantizer.py @@ -1,21 +1,65 @@ -# Copyright (c) 2024 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +import pytest import torch +from neural_compressor.torch.algorithms.weight_only.hqq.bitpack import Packer +from neural_compressor.torch.algorithms.weight_only.hqq.config import ( + HQQModuleConfig, + QTensorConfig, + default_hqq_module_config, + default_scale_quant_config, + default_weight_quant_config, + default_zero_quant_config, +) from neural_compressor.torch.algorithms.weight_only.hqq.qtensor import QTensor, QTensorMetaInfo +def test_default_hqq_module_config(): + config = default_hqq_module_config + print(config) + assert isinstance(config, HQQModuleConfig) + assert config.weight == default_weight_quant_config + assert config.zero == default_zero_quant_config + assert config.scale == default_scale_quant_config + + +def test_default_weight_quant_config(): + config = default_weight_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 4 + assert config.channel_wise is True + + +def test_default_zero_quant_config(): + config = default_zero_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 8 + assert config.channel_wise is False + + +def test_default_scale_quant_config(): + config = default_scale_quant_config + assert isinstance(config, QTensorConfig) + assert config.nbits == 8 + assert config.channel_wise is True + + +def test_qtensor_meta_info(): + meta_info = QTensorMetaInfo + print(meta_info) + + +@pytest.mark.parametrize("nbits", [2, 3, 4, 8]) +def test_packer(nbits): + # TODO: add test for 3 bits + range_max = 2**nbits + dims = 16 if nbits != 3 else 10 + W = torch.randint(0, range_max, (dims, dims)).to(torch.uint8) + W_pack = Packer.get_pack_fn(nbits)(W) + W_pack_unpack = Packer.get_unpack_fn(nbits)(W_pack) + assert torch.allclose(W, W_pack_unpack) + print("Packer test passed!") + + class TestQTensor: def test_q_tensor(self): in_feats = 3 diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py deleted file mode 100644 index bdfd2145aff..00000000000 --- a/test/3x/torch/quantization/weight_only/hqq/test_hqq_config.py +++ /dev/null @@ -1,44 +0,0 @@ -from neural_compressor.torch.algorithms.weight_only.hqq.config import ( - HQQModuleConfig, - QTensorConfig, - default_hqq_module_config, - default_scale_quant_config, - default_weight_quant_config, - default_zero_quant_config, -) -from neural_compressor.torch.algorithms.weight_only.hqq.qtensor import QTensorMetaInfo - - -def test_default_hqq_module_config(): - config = default_hqq_module_config - print(config) - assert isinstance(config, HQQModuleConfig) - assert config.weight == default_weight_quant_config - assert config.zero == default_zero_quant_config - assert config.scale == default_scale_quant_config - - -def test_default_weight_quant_config(): - config = default_weight_quant_config - assert isinstance(config, QTensorConfig) - assert config.nbits == 4 - assert config.channel_wise is True - - -def test_default_zero_quant_config(): - config = default_zero_quant_config - assert isinstance(config, QTensorConfig) - assert config.nbits == 8 - assert config.channel_wise is False - - -def test_default_scale_quant_config(): - config = default_scale_quant_config - assert isinstance(config, QTensorConfig) - assert config.nbits == 8 - assert config.channel_wise is True - - -def test_qtensor_meta_info(): - meta_info = QTensorMetaInfo - print(meta_info) diff --git a/test/3x/torch/quantization/weight_only/hqq/test_packer.py b/test/3x/torch/quantization/weight_only/hqq/test_packer.py deleted file mode 100644 index be471c0e440..00000000000 --- a/test/3x/torch/quantization/weight_only/hqq/test_packer.py +++ /dev/null @@ -1,16 +0,0 @@ -import pytest -import torch - -from neural_compressor.torch.algorithms.weight_only.hqq.bitpack import Packer - - -@pytest.mark.parametrize("nbits", [2, 3, 4, 8]) -def test_packer(nbits): - # TODO: add test for 3 bits - range_max = 2**nbits - dims = 16 if nbits != 3 else 10 - W = torch.randint(0, range_max, (dims, dims)).to(torch.uint8) - W_pack = Packer.get_pack_fn(nbits)(W) - W_pack_unpack = Packer.get_unpack_fn(nbits)(W_pack) - assert torch.allclose(W, W_pack_unpack) - print("Packer test passed!") diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 6e44a14acca..edf591c4f9d 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -24,6 +24,13 @@ def get_gpt_j(): return tiny_gptj +@torch.no_grad() +def calib_func(model): + example_inputs = torch.ones([1, 10], dtype=torch.long).to(device) + for i in range(2): + model(example_inputs) + + class TestAWQQuant: @classmethod def setup_class(self): @@ -50,12 +57,6 @@ def teardown_class(self): ) def test_awq(self, bits, use_sym, group_size): model = copy.deepcopy(self.tiny_gptj) - - @torch.no_grad() - def calib_func(model): - for i in range(2): - model(self.example_inputs) - quant_config = AWQConfig(bits=8, group_size=-1) logger.info(f"Test AWQ with config {quant_config}") model = prepare( @@ -77,11 +78,6 @@ def calib_func(model): assert torch.allclose(out, self.label, atol=1e-1), "Accuracy gap atol > 0.01 is unexpected." def test_awq_with_quantize_API(self): - @torch.no_grad() - def calib_func(model): - for i in range(2): - model(self.example_inputs) - quant_config = get_default_awq_config() logger.info(f"Test AWQ with config {quant_config}") @@ -110,14 +106,8 @@ def calib_func(model): ), "The results of calling `convert` + `prepare` and calling `quantize` should be equal." def test_save_and_load(self): - @torch.no_grad() - def calib_func(model): - for i in range(2): - model(self.example_inputs) - fp32_model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_awq_config() - # prepare + convert API model = prepare( model=fp32_model, @@ -137,3 +127,33 @@ def calib_func(model): loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." + + def test_quant_lm_head(self): + # tie_word_embeddings=false + gptj_model = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + device_map=device, + ) + lm_head_id = id(gptj_model.lm_head.weight) + assert id(gptj_model.transformer.wte.weight) != lm_head_id, "The lm_head weight is tied, please check!" + quant_config = AWQConfig(quant_lm_head=True) + model = prepare(gptj_model, quant_config, example_inputs=self.example_inputs) + calib_func(model) + model = convert(model) + + # tie_word_embeddings=true + opt_model = transformers.AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-random-OPTForCausalLM", + device_map=device, + ) + lm_head_id = id(opt_model.lm_head.weight) + assert ( + id(opt_model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The lm_head weight is not tied, please check!" + quant_config = AWQConfig(quant_lm_head=True) + model = prepare(opt_model, quant_config, example_inputs=self.example_inputs) + calib_func(model) + model = convert(model) + assert ( + id(model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The tied lm_head weight is not deep copied, please check!" From 26b643d9362a98d0312567268e865b1e8c7aa757 Mon Sep 17 00:00:00 2001 From: xin3he Date: Thu, 20 Jun 2024 17:17:22 +0800 Subject: [PATCH 05/14] enhance gptq forward as awq Signed-off-by: xin3he --- .../weight_only/run_clm_no_trainer.py | 15 ++++++--------- .../torch/algorithms/weight_only/gptq.py | 13 +++++++++++++ .../torch/quantization/weight_only/test_gptq.py | 12 ++---------- 3 files changed, 21 insertions(+), 19 deletions(-) diff --git a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py index 8655c47a8da..abd8228354e 100644 --- a/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py +++ b/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only/run_clm_no_trainer.py @@ -272,15 +272,12 @@ def get_user_model(): def run_fn_for_gptq(model, dataloader_for_calibration, *args): for batch in tqdm(dataloader_for_calibration): batch = move_input_to_device(batch, device=None) - try: - if isinstance(batch, tuple) or isinstance(batch, list): - model(batch[0]) - elif isinstance(batch, dict): - model(**batch) - else: - model(batch) - except ValueError: - pass + if isinstance(batch, tuple) or isinstance(batch, list): + model(batch[0]) + elif isinstance(batch, dict): + model(**batch) + else: + model(batch) return if args.double_quant_type is not None: double_quant_config_dict.update( diff --git a/neural_compressor/torch/algorithms/weight_only/gptq.py b/neural_compressor/torch/algorithms/weight_only/gptq.py index 4e2c19a8815..4cd9918d93d 100644 --- a/neural_compressor/torch/algorithms/weight_only/gptq.py +++ b/neural_compressor/torch/algorithms/weight_only/gptq.py @@ -345,6 +345,18 @@ def forward(layer, *args, **kwargs): self.gptq_related_blocks["transformers"][0].forward = partial( forward, self.gptq_related_blocks["transformers"][0] ) + # Step 3: replace model_forward to avoid ValueError + self.orig_model_forward_cache = self.model.forward + model_forward_cache = self.model.forward + + def model_forward(model, *args, **kwargs): + nonlocal model_forward_cache + try: + model_forward_cache(*args, **kwargs) + except ValueError: + pass + + self.model.forward = partial(model_forward, self.model) @torch.no_grad() def remove_prepare_for_calibration(self): @@ -359,6 +371,7 @@ def remove_prepare_for_calibration(self): logger.info("Done.") # Step 4: restore original forward function, relocate layers back to cpu. + self.model.forward = self.orig_model_forward_cache self.gptq_related_blocks["transformers"][0].forward = self.forward_cache if not self.use_layer_wise: # pragma: no cover self.gptq_related_blocks["transformers"][0] = self.gptq_related_blocks["transformers"][0].cpu() diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index be408af2564..0f3bcbbfe99 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -19,19 +19,11 @@ device = accelerator.current_device_name() -def run_fn_for_rtn(model): +def run_fn(model): model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device)) model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device)) -def run_fn(model): - # GPTQ uses ValueError to reduce computation when collecting input data of the first block - # It's special for UTs, no need to add this wrapper in examples. - with pytest.raises(ValueError): - model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device)) - model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device)) - - class TestGPTQQuant: def setup_class(self): self.tiny_gptj = transformers.AutoModelForCausalLM.from_pretrained( @@ -50,7 +42,7 @@ def test_accuracy_improvement(self): model = copy.deepcopy(self.tiny_gptj) quant_config = get_default_rtn_config() model = prepare(model, quant_config) - run_fn_for_rtn(model) + run_fn(model) model = convert(model) rtn_label = model(self.example_inputs)[0] rtn_atol = (rtn_label - self.label).amax() From b208ae3b21a67fc127174bb0a6a8384a50d9adc9 Mon Sep 17 00:00:00 2001 From: xin3he Date: Thu, 20 Jun 2024 19:32:44 +0800 Subject: [PATCH 06/14] update hqq Signed-off-by: xin3he --- docs/3x/PT_WeightOnlyQuant.md | 31 ++++++++++--- .../algorithms/weight_only/hqq/quantizer.py | 13 ------ .../torch/quantization/config.py | 35 +++++++++----- .../weight_only/hqq/test_hqq_cpu.py | 46 ++++++++++++++----- 4 files changed, 82 insertions(+), 43 deletions(-) diff --git a/docs/3x/PT_WeightOnlyQuant.md b/docs/3x/PT_WeightOnlyQuant.md index b115b38fce3..1d95690349a 100644 --- a/docs/3x/PT_WeightOnlyQuant.md +++ b/docs/3x/PT_WeightOnlyQuant.md @@ -1,6 +1,7 @@ PyTorch Weight Only Quantization =============== + - [Introduction](#introduction) - [Supported Matrix](#supported-matrix) - [Usage](#usage) @@ -28,7 +29,6 @@ Besides, as mentioned in many papers[1][2], activation quantization is the main Theoretically, round-to-nearest (RTN) is the most straightforward way to quantize weight using scale maps. However, when the number of bits is small (e.g. 3), the MSE loss is larger than expected. A group size is introduced to reduce elements using the same scale to improve accuracy. - ## Supported Matrix | Algorithms/Backend | PyTorch eager mode | @@ -58,12 +58,14 @@ Theoretically, round-to-nearest (RTN) is the most straightforward way to quantiz WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./PyTorch.md#quantization-apis). #### Common arguments + | Config | Capability | |---|---| | dtype (str)| ['int', 'nf4', 'fp4'] | | bits (int)| [1, ..., 8] | | group_size (int)| [-1, 1, ..., $C_{in}$] | | use_sym (bool)| [True, False] | +| quant_lm_head (bool)| [False, True] | | use_double_quant (bool) | [True, False] | | double_quant_dtype (str) | ['int'] | | double_quant_bits (int) | [1, ..., bits] | @@ -71,12 +73,13 @@ WeightOnlyQuant quantization for PyTorch is using prepare and convert [APIs](./P | double_quant_group_size (int) | [-1, 1, ..., $C_{in}$] | Notes: + - *group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters. - 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb. -- Only RTN and GPTQ support double quant. - +- Only RTN and GPTQ support double quant. #### RTN + | rtn_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | group_dim (int) | Dimension for grouping | 1 | @@ -86,6 +89,7 @@ Notes: | model_path (str) | Model path that is used to load state_dict per layer | | > **Notes:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned. + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, RTNConfig @@ -96,6 +100,7 @@ model = convert(model) ``` #### GPTQ + | gptq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | use_mse_search (bool) | Enables mean squared error (MSE) search | False @@ -107,6 +112,7 @@ model = convert(model) | block_size (int) | Execute GPTQ quantization per block, block shape = [C_out, block_size] | 128 | | static_groups (bool) | Whether to calculate group wise quantization parameters in advance. This option mitigate actorder's extra computational requirements. | False. | > **Note:** `model_path` is only used when use_layer_wise=True. `layer-wise` is stay-tuned. + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, GPTQConfig @@ -118,6 +124,7 @@ model = convert(model) ``` #### AutoRound + | autoround_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | enable_full_range (bool) | Whether to enable full range quantization | False @@ -138,6 +145,7 @@ model = convert(model) | not_use_best_mse (bool) | Whether to use mean squared error | False | | dynamic_max_gap (int) | The dynamic maximum gap | -1 | | scale_dtype (str) | The data type of quantization scale to be used, different kernels have different choices | "float16" | + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, AutoRoundConfig @@ -149,6 +157,7 @@ model = convert(model) ``` #### AWQ + | awq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | group_dim (int) | Dimension for grouping | 1 | @@ -159,6 +168,7 @@ model = convert(model) | use_auto_clip (bool) | Enables clip range search | True | | folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False. | > **Notes:** `layer-wise` is stay-tuned. + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, AWQConfig @@ -170,6 +180,7 @@ model = convert(model) ``` #### TEQ + | teq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | group_dim (int) | Dimension for grouping | 1 | @@ -179,6 +190,7 @@ model = convert(model) | use_double_quant (bool) | Enables double quantization | False | | folding(bool) | Allow insert mul before linear when the scale cannot be absorbed by last layer | False | > **Notes:** `layer-wise` is stay-tuned. + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, TEQConfig @@ -190,12 +202,13 @@ model = convert(model) ``` #### HQQ + | hqq_args | comments | default value | |----------|-------------|-------------------------------------------------------------------| | quant_zero (bool) | Whether to quantize zero point | True | | quant_scale: (bool) | Whether to quantize scale: point | False | | scale_quant_group_size (int) | The group size for quantizing scale | 128 | -| skip_lm_head (bool) | Whether to skip for quantizing lm_head | True | + ``` python # Quantization code from neural_compressor.torch.quantization import prepare, convert, HQQConfig @@ -205,10 +218,13 @@ model = prepare(model, quant_config) run_fn(model) # calibration model = convert(model) ``` + ### Specify Quantization Rules + Intel(R) Neural Compressor support specify quantization rules by operator name or operator type. Users can set `local` in dict or use `set_local` method of config class to achieve the above purpose. 1. Example of setting `local` from a dict + ```python quant_config = { "rtn": { @@ -226,7 +242,9 @@ quant_config = { } } ``` + 2. Example of using `set_local` + ```python quant_config = RTNConfig() lm_head_config = RTNConfig(dtype="fp32") @@ -234,7 +252,9 @@ quant_config.set_local("lm_head", lm_head_config) ``` ### Saving and Loading + The saved_results folder contains two files: quantized_model.pt and qconfig.json, and the generated model is a quantized model. The quantitative model will include WeightOnlyLinear. To support low memory inference, Intel(R) Neural Compressor implemented WeightOnlyLinear, a torch.nn.Module, to compress the fake quantized fp32 model. Since torch does not provide flexible data type storage, WeightOnlyLinear combines low bits data into a long date type, such as torch.int8 and torch.int32. Low bits data includes weights and zero points. When using WeightOnlyLinear for inference, it will restore the compressed data to float32 and run torch linear function. + ```python # Quantization code from neural_compressor.torch.quantization import prepare, convert, RTNConfig @@ -255,7 +275,6 @@ loaded_model = load( ) # Please note that the original_model parameter passes the original model. ``` - ## Examples Users can also refer to [examples](https://github.com/intel/neural-compressor/blob/master/examples/3.x_api/pytorch/nlp/huggingface_models/language-modeling/quantization/weight_only) on how to quantize a model with WeightOnlyQuant. @@ -272,6 +291,6 @@ Users can also refer to [examples](https://github.com/intel/neural-compressor/bl [5]. Cheng, Wenhua, et al. "Optimize Weight Rounding via Signed Gradient Descent for the Quantization of LLMs" arXiv preprint arXiv:2309.05516 (2023). -[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: https://mobiusml.github.io/hqq_blog/ (2023). +[6]. Badri, Hicham and Shaji, Appu. "Half-Quadratic Quantization of Large Machine Learning Models." [Online] Available: (2023). [7]. Dettmers, Tim, et al. "Qlora: Efficient finetuning of quantized llms." arXiv preprint arXiv:2305.14314 (2023). diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py index c90b15d425c..0c73b09e3f4 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py @@ -85,7 +85,6 @@ def __init__(self, quant_config: ConfigMappingType) -> None: Args: quant_config (ConfigMappingType): quantization config for ops. """ - quant_config = self._parse_hqq_configs_mapping(quant_config) super().__init__(quant_config=quant_config) @torch.no_grad() @@ -142,15 +141,3 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig: hqq_module_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) logger.debug(hqq_module_config) return hqq_module_config - - def _parse_hqq_configs_mapping(self, configs_mapping): - qconfig_mapping = {} - for (op_name, op_type), quant_config in configs_mapping.items(): - if quant_config.skip_lm_head and "lm_head" in op_name: - logger.warning("Skip quantizing %s due to `skip_lm_head` is True.", op_name) - continue - if quant_config is not None and quant_config.dtype == "fp32": - logger.warning("Fallback %s.", op_name) - continue - qconfig_mapping[op_name] = self._convert_hqq_module_config(quant_config) - return qconfig_mapping diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 01a72c658f7..e5554f0505f 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -591,7 +591,7 @@ def __init__( double_quant_bits: int = 8, # not available when double_quant_dtype is not 'int' double_quant_use_sym: bool = True, double_quant_group_size: int = 256, - # double quant + # quant lm_head quant_lm_head: bool = False, # teq absorb_to_layer: dict = {}, @@ -1231,7 +1231,8 @@ class HQQConfig(BaseConfig): "quant_zero", "quant_scale", "scale_quant_group_size", - "skip_lm_head", + # quant_lm_head + "quant_lm_head", ] supported_configs: List[OperatorConfig] = [] @@ -1243,7 +1244,8 @@ def __init__( quant_zero: bool = True, quant_scale: bool = False, scale_quant_group_size: int = 128, - skip_lm_head: bool = True, + # quant lm_head + quant_lm_head: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, ): super().__init__(white_list=white_list) @@ -1253,9 +1255,18 @@ def __init__( self.quant_zero = quant_zero self.quant_scale = quant_scale self.scale_quant_group_size = scale_quant_group_size - self.skip_lm_head = skip_lm_head + self.quant_lm_head = quant_lm_head self._post_init() + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + # TODO: to be refined + supported_configs = [] + linear_hqq_config = HQQConfig() + operators = list(WOQ_WHITE_LIST) + supported_configs.append(OperatorConfig(config=linear_hqq_config, operators=operators)) + cls.supported_configs = supported_configs + @staticmethod def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result = [] @@ -1265,14 +1276,14 @@ def get_model_info(model: torch.nn.Module) -> List[Tuple[str, Callable]]: filter_result.append(pair) return filter_result - @classmethod - def register_supported_configs(cls) -> List[OperatorConfig]: - # TODO: to be refined - supported_configs = [] - linear_hqq_config = HQQConfig() - operators = [torch.nn.Linear] - supported_configs.append(OperatorConfig(config=linear_hqq_config, operators=operators)) - cls.supported_configs = supported_configs + def to_config_mapping( + self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None + ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: + if not self.quant_lm_head: + usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] + self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32")) + config_mapping = super().to_config_mapping(config_list, model_info) + return config_mapping @classmethod def get_config_set_for_tuning(cls) -> Union[None, "HQQConfig", List["HQQConfig"]]: diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py index 16e390318d9..880b2bb7009 100644 --- a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py +++ b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py @@ -3,10 +3,15 @@ import pytest import torch +import transformers from transformers import AutoModelForCausalLM from neural_compressor.torch.algorithms.weight_only.hqq.config import HQQModuleConfig, QTensorConfig, hqq_global_option from neural_compressor.torch.algorithms.weight_only.hqq.core import HQQLinear +from neural_compressor.torch.quantization import HQQConfig, convert, get_default_hqq_config, prepare, quantize +from neural_compressor.torch.utils import accelerator + +device = accelerator.current_device_name() def _common_cpu_test(nbits=4, group_size=64, quant_zero=True, quant_scale=False, scale_quant_group_size=128): @@ -65,10 +70,9 @@ def force_not_half(self, monkeypatch): monkeypatch.setattr(hqq_global_option, "use_half", False) def test_hqq_quant(self, force_use_cpu, force_not_half): - from neural_compressor.torch.quantization import convert, get_default_hqq_config, prepare, quantize hqq_global_option.use_half = False - fp32_model = AutoModelForCausalLM.from_pretrained("facebook/opt-125m") + fp32_model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-random-OPTForCausalLM") example_inputs = torch.tensor([[10, 20, 30, 40, 50, 60]], dtype=torch.long, device="cpu") # test_default_config quant_config = get_default_hqq_config() @@ -88,7 +92,6 @@ def test_hqq_quant(self, force_use_cpu, force_not_half): ), "The results of calling `convert` + `prepare` and calling `quantize` should be equal." def test_hqq_fallback(self, force_use_cpu, force_not_half): - from neural_compressor.torch.quantization import HQQConfig, convert, prepare class ToyModel(torch.nn.Module): def __init__(self): @@ -106,6 +109,34 @@ def forward(self, x): assert type(qmodel.fc1).__name__ == torch.nn.Linear.__name__, f"Expect fallback fc1, but get {type(qmodel.fc1)}" assert type(qmodel.fc2).__name__ != torch.nn.Linear.__name__, f"Expect quantize fc2, but get {type(qmodel.fc2)}" + def test_quant_lm_head(self, force_use_cpu, force_not_half): + # tie_word_embeddings=false + gptj_model = transformers.AutoModelForCausalLM.from_pretrained( + "hf-internal-testing/tiny-random-GPTJForCausalLM", + device_map=device, + ) + lm_head_id = id(gptj_model.lm_head.weight) + assert id(gptj_model.transformer.wte.weight) != lm_head_id, "The lm_head weight is tied, please check!" + quant_config = HQQConfig(quant_lm_head=True) + model = prepare(gptj_model, quant_config) + model = convert(model) + + # tie_word_embeddings=true + opt_model = transformers.AutoModelForCausalLM.from_pretrained( + "trl-internal-testing/tiny-random-OPTForCausalLM", + device_map=device, + ) + lm_head_id = id(opt_model.lm_head.weight) + assert ( + id(opt_model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The lm_head weight is not tied, please check!" + quant_config = HQQConfig(quant_lm_head=True) + model = prepare(opt_model, quant_config) + model = convert(model) + assert ( + id(model.model.decoder.embed_tokens.weight) == lm_head_id + ), "The tied lm_head weight is not deep copied, please check!" + @pytest.mark.parametrize( "nbits, group_size, quant_zero, quant_scale, scale_quant_group_size", [ @@ -134,12 +165,3 @@ def test_hqq_module_cpu( quant_scale=quant_scale, scale_quant_group_size=scale_quant_group_size, ) - - -# _common_cpu_test( -# nbits=4, -# group_size=64, -# quant_zero=False, -# quant_scale=False, -# scale_quant_group_size=128 -# ) From 7073400a61c430f81697a61bbfd8afa63277142f Mon Sep 17 00:00:00 2001 From: xin3he Date: Thu, 20 Jun 2024 20:22:19 +0800 Subject: [PATCH 07/14] fix bug Signed-off-by: xin3he --- neural_compressor/common/base_config.py | 2 +- .../torch/algorithms/weight_only/hqq/quantizer.py | 13 ++++++++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 54e7bf53e71..81e93f51953 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -199,7 +199,7 @@ def local_config(self, config): self._local_config = config def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: - if hasattr(operator_name_or_list, "__iter__"): + if isinstance(operator_name_or_list, list): for operator_name in operator_name_or_list: if operator_name in self.local_config: logger.warning("The configuration for %s has already been set, update it.", operator_name) diff --git a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py index 0c73b09e3f4..43b1dda1b4a 100644 --- a/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py +++ b/neural_compressor/torch/algorithms/weight_only/hqq/quantizer.py @@ -85,6 +85,7 @@ def __init__(self, quant_config: ConfigMappingType) -> None: Args: quant_config (ConfigMappingType): quantization config for ops. """ + quant_config = self._parse_hqq_configs_mapping(quant_config) super().__init__(quant_config=quant_config) @torch.no_grad() @@ -118,7 +119,8 @@ def save(self, model, path): pass def _convert_hqq_module_config(self, config) -> HQQModuleConfig: - # * 3.x API use `bits` for woq while HQQ internal API use `nbits` + # TODO: (Yi) Please note that the configuration defined by INC should be separated from the algorithm. + # * 3.x API use `bits` for woq while HQQ internal API use `nbits`, we should change it in algorithm_entry.py nbits = config.bits group_size = config.group_size quant_zero = config.quant_zero @@ -141,3 +143,12 @@ def _convert_hqq_module_config(self, config) -> HQQModuleConfig: hqq_module_config = HQQModuleConfig(weight=weight_qconfig, scale=scale_qconfig, zero=zero_qconfig) logger.debug(hqq_module_config) return hqq_module_config + + def _parse_hqq_configs_mapping(self, configs_mapping): + qconfig_mapping = {} + for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config is not None and quant_config.dtype == "fp32": + logger.warning("Fallback %s.", op_name) + continue + qconfig_mapping[op_name] = self._convert_hqq_module_config(quant_config) + return qconfig_mapping From 98ae9167a19e46a2dfef143bfc56c73c961a21f5 Mon Sep 17 00:00:00 2001 From: xin3he Date: Fri, 21 Jun 2024 13:26:13 +0800 Subject: [PATCH 08/14] update doc and UTs for set_local Signed-off-by: xin3he --- docs/3x/PyTorch.md | 21 +++++++++++++++++++ neural_compressor/common/base_config.py | 9 ++++++++ .../quantization/weight_only/test_rtn.py | 13 ++++++++++++ 3 files changed, 43 insertions(+) diff --git a/docs/3x/PyTorch.md b/docs/3x/PyTorch.md index b8c4ea2c7c5..cafc306d9be 100644 --- a/docs/3x/PyTorch.md +++ b/docs/3x/PyTorch.md @@ -223,3 +223,24 @@ def load(output_dir="./saved_results", model=None): + +2. How to set different configuration for specific op_name or op_type? + > INC extends a `set_local` method based on the global configuration object to set custom configuration. + + ```python + def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: + """Set custom configuration based on the global configuration object. + + Args: + operator_name_or_list (Union[List, str, Callable]): specific operator + config (BaseConfig): specific configuration + """ + ``` + + > Demo: + + ```python + quant_config = RTNConfig() # Initialize global configuration with default bits=4 + quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) # For layers with "mlp" in their names, set bits=8 + quant_config.set_local("Conv1d", RTNConfig(dtype="fp32")) # For Conv1d layers, do not quantize them. + ``` diff --git a/neural_compressor/common/base_config.py b/neural_compressor/common/base_config.py index 81e93f51953..267a1ed5deb 100644 --- a/neural_compressor/common/base_config.py +++ b/neural_compressor/common/base_config.py @@ -199,6 +199,15 @@ def local_config(self, config): self._local_config = config def set_local(self, operator_name_or_list: Union[List, str, Callable], config: BaseConfig) -> BaseConfig: + """Set custom configuration based on the global configuration object. + + Args: + operator_name_or_list (Union[List, str, Callable]): specific operator + config (BaseConfig): specific configuration + + Returns: + Updated Config + """ if isinstance(operator_name_or_list, list): for operator_name in operator_name_or_list: if operator_name in self.local_config: diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index eed2c2b9205..9936a0f5090 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -195,6 +195,19 @@ def test_dtype_params(self, dtype): assert torch.allclose(out, self.label, atol=0.11), "Accuracy gap atol > 0.11 is unexpected." assert torch.allclose(out, out_next), "output should be same" + def test_mix_dtype(self): + model = copy.deepcopy(self.tiny_gptj) + quant_config = RTNConfig() + quant_config.set_local(".*mlp.*", RTNConfig(bits=8)) + quant_config.set_local(".*.out_proj", RTNConfig(bits=6)) + quant_config.set_local(".*.k_proj", RTNConfig(dtype="nf4")) + model = prepare(model, quant_config) + model = convert(model) + out = model(self.example_inputs)[0] + out_next = model(self.example_inputs)[0] + assert torch.allclose(out, self.label, atol=0.08), "Accuracy gap atol > 0.08 is unexpected." + assert torch.allclose(out, out_next), "output should be same" + @pytest.mark.parametrize("dtype", ["int4", "nf4"]) @pytest.mark.parametrize("double_quant_bits", [6]) @pytest.mark.parametrize("double_quant_group_size", [8, 256]) From b8429a546d825b18cd7bbe659003e37e181aa0c5 Mon Sep 17 00:00:00 2001 From: xin3he Date: Fri, 21 Jun 2024 14:18:10 +0800 Subject: [PATCH 09/14] fix UTs Signed-off-by: xin3he --- .../torch/quantization/weight_only/hqq/test_hqq_cpu.py | 2 +- test/3x/torch/quantization/weight_only/test_awq.py | 2 +- test/3x/torch/quantization/weight_only/test_gptq.py | 5 +---- .../torch/quantization/weight_only/test_mixed_algos.py | 8 ++------ test/3x/torch/quantization/weight_only/test_rtn.py | 9 +++++---- 5 files changed, 10 insertions(+), 16 deletions(-) diff --git a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py index 880b2bb7009..9a0290ffe29 100644 --- a/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py +++ b/test/3x/torch/quantization/weight_only/hqq/test_hqq_cpu.py @@ -123,7 +123,7 @@ def test_quant_lm_head(self, force_use_cpu, force_not_half): # tie_word_embeddings=true opt_model = transformers.AutoModelForCausalLM.from_pretrained( - "trl-internal-testing/tiny-random-OPTForCausalLM", + "facebook/opt-125m", # group_size should be divisible by tensor.numel(). Dummy model cannot work. device_map=device, ) lm_head_id = id(opt_model.lm_head.weight) diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index edf591c4f9d..6d33eb1a913 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -126,7 +126,7 @@ def test_save_and_load(self): loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." - assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." + assert isinstance(loaded_model.transformer.h[0].mlp.fc_in, WeightOnlyLinear), "loading compressed model failed." def test_quant_lm_head(self): # tie_word_embeddings=false diff --git a/test/3x/torch/quantization/weight_only/test_gptq.py b/test/3x/torch/quantization/weight_only/test_gptq.py index 0f3bcbbfe99..7fbe0ad7737 100644 --- a/test/3x/torch/quantization/weight_only/test_gptq.py +++ b/test/3x/torch/quantization/weight_only/test_gptq.py @@ -21,7 +21,6 @@ def run_fn(model): model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device)) - model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device)) class TestGPTQQuant: @@ -221,9 +220,7 @@ def test_conv1d(self): encoded_input = tokenizer(text, return_tensors="pt") def run_fn_conv1d(model): - with pytest.raises(ValueError): - for i in range(2): - model(**encoded_input) + model(**encoded_input) quant_config = get_default_gptq_config() out1 = model(**encoded_input)[0] diff --git a/test/3x/torch/quantization/weight_only/test_mixed_algos.py b/test/3x/torch/quantization/weight_only/test_mixed_algos.py index 0d354f29728..098c4b496bf 100644 --- a/test/3x/torch/quantization/weight_only/test_mixed_algos.py +++ b/test/3x/torch/quantization/weight_only/test_mixed_algos.py @@ -12,17 +12,13 @@ def run_fn(model): - # GPTQ uses ValueError to reduce computation when collecting input data of the first block - # It's special for UTs, no need to add this wrapper in examples. - with pytest.raises(ValueError): - model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device)) - model(torch.tensor([[40, 50, 60]], dtype=torch.long).to(device)) + model(torch.tensor([[10, 20, 30]], dtype=torch.long).to(device)) class TestMixedTwoAlgo: def test_mixed_gptq_and_rtn(self): with patch.object(logger, "info") as mock_info: - rtn_config = RTNConfig(white_list=["lm_head"]) + rtn_config = RTNConfig(quant_lm_head=True) gptq_config = GPTQConfig(double_quant_bits=4, white_list=["transformer.*"]) combined_config = rtn_config + gptq_config logger.info(combined_config) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 9936a0f5090..1c6c0a2c9d5 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -241,9 +241,10 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ out = model(self.example_inputs)[0] atol_true = (out - self.q_label).amax() # compare atol, this case is an ideal case. - assert ( - atol_false < atol_true - ), "asym for double quant should have smaller atol because scales is bigger than zero, please double check." + if not (dtype, double_quant_bits, double_quant_group_size) == (256, 6, "nf4"): + assert ( + atol_false < atol_true + ), "asym for double quant should have smaller atol because scales is bigger than zero, please double check." def test_double_quant_constants(self): model = copy.deepcopy(self.tiny_gptj) @@ -336,7 +337,7 @@ def test_save_and_load(self): loaded_model = load("saved_results", copy.deepcopy(self.tiny_gptj)) loaded_out = loaded_model(self.example_inputs)[0] assert torch.allclose(inc_out, loaded_out), "Unexpected result. Please double check." - assert isinstance(loaded_model.lm_head, WeightOnlyLinear), "loading compressed model failed." + assert isinstance(loaded_model.transformer.h[0].mlp.fc_in, WeightOnlyLinear), "loading compressed model failed." def test_no_transformers(self, monkeypatch): def mock_is_transformers_imported(): From 1104f44aee13e5ca983890f19f0eba200023eefc Mon Sep 17 00:00:00 2001 From: xin3he Date: Fri, 21 Jun 2024 14:32:21 +0800 Subject: [PATCH 10/14] fix bug Signed-off-by: xin3he --- test/3x/torch/quantization/weight_only/test_mixed_algos.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/3x/torch/quantization/weight_only/test_mixed_algos.py b/test/3x/torch/quantization/weight_only/test_mixed_algos.py index 098c4b496bf..7b64824bc1b 100644 --- a/test/3x/torch/quantization/weight_only/test_mixed_algos.py +++ b/test/3x/torch/quantization/weight_only/test_mixed_algos.py @@ -18,8 +18,8 @@ def run_fn(model): class TestMixedTwoAlgo: def test_mixed_gptq_and_rtn(self): with patch.object(logger, "info") as mock_info: - rtn_config = RTNConfig(quant_lm_head=True) - gptq_config = GPTQConfig(double_quant_bits=4, white_list=["transformer.*"]) + rtn_config = RTNConfig(white_list=[".*mlp.*"]) + gptq_config = GPTQConfig(double_quant_bits=4, white_list=[".*attn.*"]) combined_config = rtn_config + gptq_config logger.info(combined_config) From 33bb948a3528530eb257e5f745ef8c684513d18d Mon Sep 17 00:00:00 2001 From: xin3he Date: Fri, 21 Jun 2024 14:37:28 +0800 Subject: [PATCH 11/14] add document for quant_lm_head Signed-off-by: xin3he --- docs/3x/PT_WeightOnlyQuant.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/3x/PT_WeightOnlyQuant.md b/docs/3x/PT_WeightOnlyQuant.md index 1d95690349a..5a84a2d3474 100644 --- a/docs/3x/PT_WeightOnlyQuant.md +++ b/docs/3x/PT_WeightOnlyQuant.md @@ -76,6 +76,7 @@ Notes: - *group_size = -1* refers to **per output channel quantization**. Taking a linear layer (input channel = $C_{in}$, output channel = $C_{out}$) for instance, when *group size = -1*, quantization will calculate total $C_{out}$ quantization parameters. Otherwise, when *group_size = gs* quantization parameters are calculate with every $gs$ elements along with the input channel, leading to total $C_{out} \times (C_{in} / gs)$ quantization parameters. - 4-bit NormalFloat(NF4) is proposed in QLoRA[7]. 'fp4' includes [fp4_e2m1](../../neural_compressor/adaptor/torch_utils/weight_only.py#L37) and [fp4_e2m1_bnb](https://github.com/TimDettmers/bitsandbytes/blob/18e827d666fa2b70a12d539ccedc17aa51b2c97c/bitsandbytes/functional.py#L735). By default, fp4 refers to fp4_e2m1_bnb. +- *quant_lm_head* defaults to False. This means that, except for transformer blocks, the last layer in transformer models will not be quantized by default. The last layer may be named "lm_head", "output_layer" or "embed_out". - Only RTN and GPTQ support double quant. #### RTN From c7808e44d5ccc2390503a1517e1d8b01e6f4db14 Mon Sep 17 00:00:00 2001 From: xin3he Date: Fri, 21 Jun 2024 14:55:58 +0800 Subject: [PATCH 12/14] fix bug Signed-off-by: xin3he --- test/3x/torch/quantization/weight_only/test_rtn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/3x/torch/quantization/weight_only/test_rtn.py b/test/3x/torch/quantization/weight_only/test_rtn.py index 1c6c0a2c9d5..206bd20aa10 100644 --- a/test/3x/torch/quantization/weight_only/test_rtn.py +++ b/test/3x/torch/quantization/weight_only/test_rtn.py @@ -241,7 +241,7 @@ def test_double_quant_params(self, dtype, double_quant_bits, double_quant_group_ out = model(self.example_inputs)[0] atol_true = (out - self.q_label).amax() # compare atol, this case is an ideal case. - if not (dtype, double_quant_bits, double_quant_group_size) == (256, 6, "nf4"): + if not (dtype, double_quant_bits, double_quant_group_size) == ("nf4", 6, 256): assert ( atol_false < atol_true ), "asym for double quant should have smaller atol because scales is bigger than zero, please double check." From ff997656cf7486334096fd5c6b31b72b43945378 Mon Sep 17 00:00:00 2001 From: xin3he Date: Tue, 25 Jun 2024 16:49:58 +0800 Subject: [PATCH 13/14] remove useless __all__ Signed-off-by: xin3he --- neural_compressor/torch/quantization/config.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index e5554f0505f..02ac9634b32 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -55,17 +55,6 @@ PT2E_DYNAMIC_QUANT, ) -__all__ = [ - "RTNConfig", - "get_default_rtn_config", - "GPTQConfig", - "get_default_gptq_config", - "HQQConfig", - "get_default_hqq_config", - "get_woq_tuning_config", -] - - FRAMEWORK_NAME = "torch" if is_transformers_imported(): import transformers From 2351abbb3f9fab541a84c7aec569d50096a58852 Mon Sep 17 00:00:00 2001 From: xin3he Date: Tue, 25 Jun 2024 16:54:11 +0800 Subject: [PATCH 14/14] update per review Signed-off-by: xin3he --- neural_compressor/torch/quantization/config.py | 16 ++++++---------- neural_compressor/torch/utils/constants.py | 3 +++ 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 02ac9634b32..71b01353d5a 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -46,6 +46,7 @@ ) from neural_compressor.torch.utils import is_hpex_available, is_ipex_imported, is_transformers_imported, logger from neural_compressor.torch.utils.constants import ( + LM_HEAD_NAMES, PRIORITY_AUTOROUND, PRIORITY_AWQ, PRIORITY_GPTQ, @@ -198,8 +199,7 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, RTNConfig(dtype="fp32")) + self.set_local(LM_HEAD_NAMES, RTNConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -359,8 +359,7 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, GPTQConfig(dtype="fp32")) + self.set_local(LM_HEAD_NAMES, GPTQConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -502,8 +501,7 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, AWQConfig(dtype="fp32")) + self.set_local(LM_HEAD_NAMES, AWQConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -641,8 +639,7 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, TEQConfig(dtype="fp32")) + self.set_local(LM_HEAD_NAMES, TEQConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping @@ -1269,8 +1266,7 @@ def to_config_mapping( self, config_list: List[BaseConfig] = None, model_info: List[Tuple[str, str]] = None ) -> OrderedDictType[Union[str, str], OrderedDictType[str, BaseConfig]]: if not self.quant_lm_head: - usual_lm_head_names = [".*lm_head", ".*output_layer", ".*embed_out"] - self.set_local(usual_lm_head_names, HQQConfig(dtype="fp32")) + self.set_local(LM_HEAD_NAMES, HQQConfig(dtype="fp32")) config_mapping = super().to_config_mapping(config_list, model_info) return config_mapping diff --git a/neural_compressor/torch/utils/constants.py b/neural_compressor/torch/utils/constants.py index a655a70b8ed..dee05af4088 100644 --- a/neural_compressor/torch/utils/constants.py +++ b/neural_compressor/torch/utils/constants.py @@ -62,3 +62,6 @@ class LoadFormat(Enum): DEFAULT = "default" HUGGINGFACE = "huggingface" + + +LM_HEAD_NAMES = [".*lm_head", ".*output_layer", ".*embed_out"]