From f6f2cdaa6f7d034e6a860d0e1a9ef76e44a216f2 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 15 Jul 2024 11:58:25 +0800 Subject: [PATCH 1/6] support absorb dict for awq Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/awq.py | 45 ++++++++++++++---- .../torch/quantization/algorithm_entry.py | 43 ++++++++--------- .../torch/quantization/config.py | 6 ++- .../quantization/weight_only/test_awq.py | 47 +++++++++++++++++++ 4 files changed, 111 insertions(+), 30 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 940e9826785..01db9748106 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -88,6 +88,27 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={} logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") return block_absorb_dict, absorb_layer_dict +def _get_block_absorb_dict(model, absorb_layer_dict): + """Get absorbed layer per block from absorbed layer dict. + + Args: + model (torch.nn.Module): input model + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. + + Returns: + block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...} + """ + block_absorb_dict = {} + block_prefix, block_num = get_block_prefix(model) + for i in range(block_num): + block_absorb_dict[i] = [] + block_name = block_prefix + "." + str(i) + "." + for k, v in absorb_layer_dict.items(): + if all(block_name in elem for elem in k): + block_absorb_dict[i].append(k) + logger.debug(f"The absorbed layers per block: {block_absorb_dict}") + logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") + return block_absorb_dict @torch.no_grad() def _get_weight_scale(weight, q_group_size=-1): @@ -123,6 +144,7 @@ def __init__( total_block_args=[], total_block_kwargs=[], device="auto", + absorb_layer_dict={}, ): self.example_inputs = example_inputs @@ -140,6 +162,7 @@ def __init__( self.scheme = scheme self.use_full_range = use_full_range self.weight_config = weight_config + self.absorb_layer_dict = absorb_layer_dict def _move_model_and_data_to_device(self): # Put the model and example_inputs into target device @@ -164,13 +187,16 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu # Step 1: get absorbed module list per block, includes self-absorption # block_absorb_dict is split per block, includes all absorb relationship. # absorb_layer_dict is the inverse of block_absorb_dict for all blocks - self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block( - self.model, - self.example_inputs, - # for only use_mse_search, folding is useless. - folding=folding if use_auto_scale else False, - weight_config=self.weight_config, - ) + if not self.absorb_layer_dict: + self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block( + self.model, + self.example_inputs, + # for only use_mse_search, folding is useless. + folding=folding if use_auto_scale else False, + weight_config=self.weight_config, + ) + else: + self.block_absorb_dict = _get_block_absorb_dict(self.model, self.absorb_layer_dict) # process per block for i, module_list in self.block_absorb_dict.items(): logger.info(f"Processing block: {i+1}/{self.block_num}") @@ -491,13 +517,15 @@ def module_inference(self, model, inputs): class AWQQuantizer(Quantizer): - def __init__(self, quant_config: OrderedDict = {}): + def __init__(self, quant_config: OrderedDict = {}, absorb_to_layer: dict = {}): """Init an AWQQuantizer object. Args: quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. + absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(quant_config) + self.absorb_to_layer = absorb_to_layer @torch.no_grad() def prepare(self, model, *args, **kwargs): @@ -566,6 +594,7 @@ def convert( weight_config=self.quant_config, total_block_args=total_block_args, total_block_kwargs=total_block_kwargs, + absorb_layer_dict=self.absorb_to_layer, ) qdq_model = awq.quantize( use_auto_scale=use_auto_scale, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index b8a1e3b9202..ea9b81e53f2 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -317,10 +317,10 @@ def awq_quantize_entry( from neural_compressor.torch.algorithms.weight_only.save_load import save weight_config = {} - for (op_name, op_type), op_config in configs_mapping.items(): - if op_config.name != AWQ: + for (op_name, op_type), quant_config in configs_mapping.items(): + if quant_config.name != AWQ: continue - if op_config.dtype == "fp32": + if quant_config.dtype == "fp32": weight_config[op_name] = { "bits": -1, "dtype": "fp32", # skip quantization @@ -329,31 +329,32 @@ def awq_quantize_entry( } else: weight_config[op_name] = { - "dtype": op_config.dtype, - "bits": op_config.bits, - "group_size": op_config.group_size, - "group_dim": op_config.group_dim, - "scheme": "sym" if op_config.use_sym else "asym", - "use_full_range": op_config.use_full_range, - "use_mse_search": op_config.use_mse_search, - "use_layer_wise": op_config.use_layer_wise, - "use_double_quant": op_config.use_double_quant, - "double_quant_dtype": op_config.double_quant_dtype, - "double_quant_bits": op_config.double_quant_bits, - "double_quant_scheme": op_config.double_quant_use_sym, - "double_quant_group_size": op_config.double_quant_group_size, + "dtype": quant_config.dtype, + "bits": quant_config.bits, + "group_size": quant_config.group_size, + "group_dim": quant_config.group_dim, + "scheme": "sym" if quant_config.use_sym else "asym", + "use_full_range": quant_config.use_full_range, + "use_mse_search": quant_config.use_mse_search, + "use_layer_wise": quant_config.use_layer_wise, + "use_double_quant": quant_config.use_double_quant, + "double_quant_dtype": quant_config.double_quant_dtype, + "double_quant_bits": quant_config.double_quant_bits, + "double_quant_scheme": quant_config.double_quant_use_sym, + "double_quant_group_size": quant_config.double_quant_group_size, } - use_auto_scale = op_config.use_auto_scale - use_mse_search = op_config.use_auto_clip # for awq clip - folding = op_config.folding - use_full_range = op_config.use_full_range + use_auto_scale = quant_config.use_auto_scale + use_mse_search = quant_config.use_auto_clip # for awq clip + folding = quant_config.folding + use_full_range = quant_config.use_full_range + absorb_to_layer = quant_config.absorb_to_layer run_fn = kwargs.get("run_fn", None) run_args = kwargs.get("run_args", None) example_inputs = kwargs.get("example_inputs", None) assert example_inputs is not None, "Please provide example_inputs for AWQ quantization." - quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config) + quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_to_layer=absorb_to_layer) model = quantizer.execute( model, mode=mode, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index 9014f1576a3..a56223411dc 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -425,6 +425,7 @@ class AWQConfig(BaseConfig): "use_auto_scale", "use_auto_clip", "folding", + "absorb_to_layer", ] name = AWQ @@ -451,6 +452,7 @@ def __init__( use_auto_clip: bool = True, folding: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + absorb_to_layer: dict = {}, ): """Init AWQ weight-only quantization config. @@ -473,6 +475,7 @@ def __init__( use_auto_clip (bool): Enables clip range search. Defaults to True. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. + absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(white_list=white_list) self.dtype = dtype @@ -493,6 +496,7 @@ def __init__( self.use_auto_scale = use_auto_scale self.use_auto_clip = use_auto_clip self.folding = folding + self.absorb_to_layer = absorb_to_layer self._post_init() @classmethod @@ -609,7 +613,7 @@ def __init__( double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True. double_quant_group_size (int): Size of double_quant groups, default is 32. quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False. - absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}. + absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. """ diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 6d33eb1a913..40eb5e64f56 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -157,3 +157,50 @@ def test_quant_lm_head(self): assert ( id(model.model.decoder.embed_tokens.weight) == lm_head_id ), "The tied lm_head weight is not deep copied, please check!" + + def test_awq_absorb_to_layer(self): + absorb_to_layer = { + ('transformer.h.0.attn.q_proj', 'transformer.h.0.attn.k_proj', 'transformer.h.0.attn.v_proj', 'transformer.h.0.mlp.fc_in'): 'transformer.h.0.ln_1', + ('transformer.h.0.attn.out_proj',): 'transformer.h.0.attn.out_proj', + ('transformer.h.0.mlp.fc_out',): 'transformer.h.0.mlp.fc_out', + ('transformer.h.1.attn.q_proj', 'transformer.h.1.attn.k_proj', 'transformer.h.1.attn.v_proj', 'transformer.h.1.mlp.fc_in'): 'transformer.h.1.ln_1', + ('transformer.h.1.attn.out_proj',): 'transformer.h.1.attn.out_proj', + ('transformer.h.1.mlp.fc_out',): 'transformer.h.1.mlp.fc_out', + ('transformer.h.2.attn.q_proj', 'transformer.h.2.attn.k_proj', 'transformer.h.2.attn.v_proj', 'transformer.h.2.mlp.fc_in'): 'transformer.h.2.ln_1', + ('transformer.h.2.attn.out_proj',): 'transformer.h.2.attn.out_proj', + ('transformer.h.2.mlp.fc_out',): 'transformer.h.2.mlp.fc_out', + ('transformer.h.3.attn.q_proj', 'transformer.h.3.attn.k_proj', 'transformer.h.3.attn.v_proj', 'transformer.h.3.mlp.fc_in'): 'transformer.h.3.ln_1', + ('transformer.h.3.attn.out_proj',): 'transformer.h.3.attn.out_proj', + ('transformer.h.3.mlp.fc_out',): 'transformer.h.3.mlp.fc_out', + ('transformer.h.4.attn.q_proj', 'transformer.h.4.attn.k_proj', 'transformer.h.4.attn.v_proj', 'transformer.h.4.mlp.fc_in'): 'transformer.h.4.ln_1', + ('transformer.h.4.attn.out_proj',): 'transformer.h.4.attn.out_proj', + ('transformer.h.4.mlp.fc_out',): 'transformer.h.4.mlp.fc_out' + } + quant_config = AWQConfig(absorb_to_layer=absorb_to_layer) + logger.info(f"Test AWQ with config {quant_config}") + # prepare + convert API + model = prepare( + model=copy.deepcopy(self.tiny_gptj), + quant_config=quant_config, + example_inputs=self.example_inputs, + ) + calib_func(model) + model = convert(model) + out1 = model(self.example_inputs) + + quant_config = AWQConfig() + logger.info(f"Test AWQ with config {quant_config}") + + # prepare + convert API + model = prepare( + model=copy.deepcopy(self.tiny_gptj), + quant_config=quant_config, + example_inputs=self.example_inputs, + ) + calib_func(model) + model = convert(model) + out2 = model(self.example_inputs) + + assert torch.all( + out1[0].eq(out2[0]) + ), "The results should be equal." \ No newline at end of file From 1bf8db0d1fb5881bdc535c9d97cdf4fe90ad344d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 05:04:26 +0000 Subject: [PATCH 2/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../torch/algorithms/weight_only/awq.py | 2 + .../torch/quantization/algorithm_entry.py | 4 +- .../quantization/weight_only/test_awq.py | 63 +++++++++++++------ 3 files changed, 48 insertions(+), 21 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index 01db9748106..f16a18893bf 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -88,6 +88,7 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={} logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") return block_absorb_dict, absorb_layer_dict + def _get_block_absorb_dict(model, absorb_layer_dict): """Get absorbed layer per block from absorbed layer dict. @@ -110,6 +111,7 @@ def _get_block_absorb_dict(model, absorb_layer_dict): logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") return block_absorb_dict + @torch.no_grad() def _get_weight_scale(weight, q_group_size=-1): org_shape = weight.shape diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index ea9b81e53f2..d4e4071600f 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -354,7 +354,9 @@ def awq_quantize_entry( example_inputs = kwargs.get("example_inputs", None) assert example_inputs is not None, "Please provide example_inputs for AWQ quantization." - quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_to_layer=absorb_to_layer) + quantizer = get_quantizer( + model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_to_layer=absorb_to_layer + ) model = quantizer.execute( model, mode=mode, diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 40eb5e64f56..96f55d60e8c 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -160,21 +160,46 @@ def test_quant_lm_head(self): def test_awq_absorb_to_layer(self): absorb_to_layer = { - ('transformer.h.0.attn.q_proj', 'transformer.h.0.attn.k_proj', 'transformer.h.0.attn.v_proj', 'transformer.h.0.mlp.fc_in'): 'transformer.h.0.ln_1', - ('transformer.h.0.attn.out_proj',): 'transformer.h.0.attn.out_proj', - ('transformer.h.0.mlp.fc_out',): 'transformer.h.0.mlp.fc_out', - ('transformer.h.1.attn.q_proj', 'transformer.h.1.attn.k_proj', 'transformer.h.1.attn.v_proj', 'transformer.h.1.mlp.fc_in'): 'transformer.h.1.ln_1', - ('transformer.h.1.attn.out_proj',): 'transformer.h.1.attn.out_proj', - ('transformer.h.1.mlp.fc_out',): 'transformer.h.1.mlp.fc_out', - ('transformer.h.2.attn.q_proj', 'transformer.h.2.attn.k_proj', 'transformer.h.2.attn.v_proj', 'transformer.h.2.mlp.fc_in'): 'transformer.h.2.ln_1', - ('transformer.h.2.attn.out_proj',): 'transformer.h.2.attn.out_proj', - ('transformer.h.2.mlp.fc_out',): 'transformer.h.2.mlp.fc_out', - ('transformer.h.3.attn.q_proj', 'transformer.h.3.attn.k_proj', 'transformer.h.3.attn.v_proj', 'transformer.h.3.mlp.fc_in'): 'transformer.h.3.ln_1', - ('transformer.h.3.attn.out_proj',): 'transformer.h.3.attn.out_proj', - ('transformer.h.3.mlp.fc_out',): 'transformer.h.3.mlp.fc_out', - ('transformer.h.4.attn.q_proj', 'transformer.h.4.attn.k_proj', 'transformer.h.4.attn.v_proj', 'transformer.h.4.mlp.fc_in'): 'transformer.h.4.ln_1', - ('transformer.h.4.attn.out_proj',): 'transformer.h.4.attn.out_proj', - ('transformer.h.4.mlp.fc_out',): 'transformer.h.4.mlp.fc_out' + ( + "transformer.h.0.attn.q_proj", + "transformer.h.0.attn.k_proj", + "transformer.h.0.attn.v_proj", + "transformer.h.0.mlp.fc_in", + ): "transformer.h.0.ln_1", + ("transformer.h.0.attn.out_proj",): "transformer.h.0.attn.out_proj", + ("transformer.h.0.mlp.fc_out",): "transformer.h.0.mlp.fc_out", + ( + "transformer.h.1.attn.q_proj", + "transformer.h.1.attn.k_proj", + "transformer.h.1.attn.v_proj", + "transformer.h.1.mlp.fc_in", + ): "transformer.h.1.ln_1", + ("transformer.h.1.attn.out_proj",): "transformer.h.1.attn.out_proj", + ("transformer.h.1.mlp.fc_out",): "transformer.h.1.mlp.fc_out", + ( + "transformer.h.2.attn.q_proj", + "transformer.h.2.attn.k_proj", + "transformer.h.2.attn.v_proj", + "transformer.h.2.mlp.fc_in", + ): "transformer.h.2.ln_1", + ("transformer.h.2.attn.out_proj",): "transformer.h.2.attn.out_proj", + ("transformer.h.2.mlp.fc_out",): "transformer.h.2.mlp.fc_out", + ( + "transformer.h.3.attn.q_proj", + "transformer.h.3.attn.k_proj", + "transformer.h.3.attn.v_proj", + "transformer.h.3.mlp.fc_in", + ): "transformer.h.3.ln_1", + ("transformer.h.3.attn.out_proj",): "transformer.h.3.attn.out_proj", + ("transformer.h.3.mlp.fc_out",): "transformer.h.3.mlp.fc_out", + ( + "transformer.h.4.attn.q_proj", + "transformer.h.4.attn.k_proj", + "transformer.h.4.attn.v_proj", + "transformer.h.4.mlp.fc_in", + ): "transformer.h.4.ln_1", + ("transformer.h.4.attn.out_proj",): "transformer.h.4.attn.out_proj", + ("transformer.h.4.mlp.fc_out",): "transformer.h.4.mlp.fc_out", } quant_config = AWQConfig(absorb_to_layer=absorb_to_layer) logger.info(f"Test AWQ with config {quant_config}") @@ -187,7 +212,7 @@ def test_awq_absorb_to_layer(self): calib_func(model) model = convert(model) out1 = model(self.example_inputs) - + quant_config = AWQConfig() logger.info(f"Test AWQ with config {quant_config}") @@ -200,7 +225,5 @@ def test_awq_absorb_to_layer(self): calib_func(model) model = convert(model) out2 = model(self.example_inputs) - - assert torch.all( - out1[0].eq(out2[0]) - ), "The results should be equal." \ No newline at end of file + + assert torch.all(out1[0].eq(out2[0])), "The results should be equal." From cc61136b8d42a853339f854e3164a8bca3a4824a Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 15 Jul 2024 16:36:30 +0800 Subject: [PATCH 3/6] update entry Signed-off-by: Kaihui-intel --- .../torch/algorithms/weight_only/awq.py | 25 ++++++++++++------- .../torch/quantization/algorithm_entry.py | 4 +-- .../torch/quantization/config.py | 8 +++--- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index f16a18893bf..ef0a82c7142 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -89,7 +89,7 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={} return block_absorb_dict, absorb_layer_dict -def _get_block_absorb_dict(model, absorb_layer_dict): +def _get_absorb_dict(model, absorb_layer_dict): """Get absorbed layer per block from absorbed layer dict. Args: @@ -101,15 +101,22 @@ def _get_block_absorb_dict(model, absorb_layer_dict): """ block_absorb_dict = {} block_prefix, block_num = get_block_prefix(model) + new_absorb_layer_dict = {} for i in range(block_num): block_absorb_dict[i] = [] block_name = block_prefix + "." + str(i) + "." + for k, v in absorb_layer_dict.items(): - if all(block_name in elem for elem in k): - block_absorb_dict[i].append(k) + + if isinstance(v, str): + name_list = (block_name + v, ) + else: + name_list = tuple(block_name + vv for vv in v) + block_absorb_dict[i].append(name_list) + new_absorb_layer_dict[name_list] = block_name + k logger.debug(f"The absorbed layers per block: {block_absorb_dict}") logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}") - return block_absorb_dict + return block_absorb_dict, new_absorb_layer_dict @torch.no_grad() @@ -198,7 +205,7 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu weight_config=self.weight_config, ) else: - self.block_absorb_dict = _get_block_absorb_dict(self.model, self.absorb_layer_dict) + self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_dict(self.model, self.absorb_layer_dict) # process per block for i, module_list in self.block_absorb_dict.items(): logger.info(f"Processing block: {i+1}/{self.block_num}") @@ -519,15 +526,15 @@ def module_inference(self, model, inputs): class AWQQuantizer(Quantizer): - def __init__(self, quant_config: OrderedDict = {}, absorb_to_layer: dict = {}): + def __init__(self, quant_config: OrderedDict = {}, absorb_layer_dict: dict = {}): """Init an AWQQuantizer object. Args: quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}. - absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(quant_config) - self.absorb_to_layer = absorb_to_layer + self.absorb_layer_dict = absorb_layer_dict @torch.no_grad() def prepare(self, model, *args, **kwargs): @@ -596,7 +603,7 @@ def convert( weight_config=self.quant_config, total_block_args=total_block_args, total_block_kwargs=total_block_kwargs, - absorb_layer_dict=self.absorb_to_layer, + absorb_layer_dict=self.absorb_layer_dict, ) qdq_model = awq.quantize( use_auto_scale=use_auto_scale, diff --git a/neural_compressor/torch/quantization/algorithm_entry.py b/neural_compressor/torch/quantization/algorithm_entry.py index d4e4071600f..ea2d53e7353 100644 --- a/neural_compressor/torch/quantization/algorithm_entry.py +++ b/neural_compressor/torch/quantization/algorithm_entry.py @@ -347,7 +347,7 @@ def awq_quantize_entry( use_mse_search = quant_config.use_auto_clip # for awq clip folding = quant_config.folding use_full_range = quant_config.use_full_range - absorb_to_layer = quant_config.absorb_to_layer + absorb_layer_dict = quant_config.absorb_layer_dict run_fn = kwargs.get("run_fn", None) run_args = kwargs.get("run_args", None) @@ -355,7 +355,7 @@ def awq_quantize_entry( assert example_inputs is not None, "Please provide example_inputs for AWQ quantization." quantizer = get_quantizer( - model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_to_layer=absorb_to_layer + model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_layer_dict=absorb_layer_dict ) model = quantizer.execute( model, diff --git a/neural_compressor/torch/quantization/config.py b/neural_compressor/torch/quantization/config.py index a56223411dc..f2b12f89b5f 100644 --- a/neural_compressor/torch/quantization/config.py +++ b/neural_compressor/torch/quantization/config.py @@ -425,7 +425,7 @@ class AWQConfig(BaseConfig): "use_auto_scale", "use_auto_clip", "folding", - "absorb_to_layer", + "absorb_layer_dict", ] name = AWQ @@ -452,7 +452,7 @@ def __init__( use_auto_clip: bool = True, folding: bool = False, white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, - absorb_to_layer: dict = {}, + absorb_layer_dict: dict = {}, ): """Init AWQ weight-only quantization config. @@ -475,7 +475,7 @@ def __init__( use_auto_clip (bool): Enables clip range search. Defaults to True. folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer, default is False. - absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}. + absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}. """ super().__init__(white_list=white_list) self.dtype = dtype @@ -496,7 +496,7 @@ def __init__( self.use_auto_scale = use_auto_scale self.use_auto_clip = use_auto_clip self.folding = folding - self.absorb_to_layer = absorb_to_layer + self.absorb_layer_dict = absorb_layer_dict self._post_init() @classmethod From 6f2342512cd5136b9eb38dab4ca7bf0b0bca99a8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:39:02 +0000 Subject: [PATCH 4/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- neural_compressor/torch/algorithms/weight_only/awq.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/neural_compressor/torch/algorithms/weight_only/awq.py b/neural_compressor/torch/algorithms/weight_only/awq.py index ef0a82c7142..b8c4329de3b 100644 --- a/neural_compressor/torch/algorithms/weight_only/awq.py +++ b/neural_compressor/torch/algorithms/weight_only/awq.py @@ -105,11 +105,11 @@ def _get_absorb_dict(model, absorb_layer_dict): for i in range(block_num): block_absorb_dict[i] = [] block_name = block_prefix + "." + str(i) + "." - + for k, v in absorb_layer_dict.items(): - + if isinstance(v, str): - name_list = (block_name + v, ) + name_list = (block_name + v,) else: name_list = tuple(block_name + vv for vv in v) block_absorb_dict[i].append(name_list) From 089a49c8d1d6e77eb0ec434b5e7f041a2bee90d4 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Mon, 15 Jul 2024 16:45:09 +0800 Subject: [PATCH 5/6] update ut Signed-off-by: Kaihui-intel --- .../quantization/weight_only/test_awq.py | 55 ++++--------------- 1 file changed, 12 insertions(+), 43 deletions(-) diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 96f55d60e8c..97c4c6747ab 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -159,49 +159,19 @@ def test_quant_lm_head(self): ), "The tied lm_head weight is not deep copied, please check!" def test_awq_absorb_to_layer(self): - absorb_to_layer = { - ( - "transformer.h.0.attn.q_proj", - "transformer.h.0.attn.k_proj", - "transformer.h.0.attn.v_proj", - "transformer.h.0.mlp.fc_in", - ): "transformer.h.0.ln_1", - ("transformer.h.0.attn.out_proj",): "transformer.h.0.attn.out_proj", - ("transformer.h.0.mlp.fc_out",): "transformer.h.0.mlp.fc_out", - ( - "transformer.h.1.attn.q_proj", - "transformer.h.1.attn.k_proj", - "transformer.h.1.attn.v_proj", - "transformer.h.1.mlp.fc_in", - ): "transformer.h.1.ln_1", - ("transformer.h.1.attn.out_proj",): "transformer.h.1.attn.out_proj", - ("transformer.h.1.mlp.fc_out",): "transformer.h.1.mlp.fc_out", - ( - "transformer.h.2.attn.q_proj", - "transformer.h.2.attn.k_proj", - "transformer.h.2.attn.v_proj", - "transformer.h.2.mlp.fc_in", - ): "transformer.h.2.ln_1", - ("transformer.h.2.attn.out_proj",): "transformer.h.2.attn.out_proj", - ("transformer.h.2.mlp.fc_out",): "transformer.h.2.mlp.fc_out", - ( - "transformer.h.3.attn.q_proj", - "transformer.h.3.attn.k_proj", - "transformer.h.3.attn.v_proj", - "transformer.h.3.mlp.fc_in", - ): "transformer.h.3.ln_1", - ("transformer.h.3.attn.out_proj",): "transformer.h.3.attn.out_proj", - ("transformer.h.3.mlp.fc_out",): "transformer.h.3.mlp.fc_out", - ( - "transformer.h.4.attn.q_proj", - "transformer.h.4.attn.k_proj", - "transformer.h.4.attn.v_proj", - "transformer.h.4.mlp.fc_in", - ): "transformer.h.4.ln_1", - ("transformer.h.4.attn.out_proj",): "transformer.h.4.attn.out_proj", - ("transformer.h.4.mlp.fc_out",): "transformer.h.4.mlp.fc_out", + absorb_layer_dict = { + "ln_1": ( + "attn.q_proj", + "attn.k_proj", + "attn.v_proj", + "mlp.fc_in", + ), + "attn.out_proj": "attn.out_proj", + "mlp.fc_out": ("mlp.fc_out"), } - quant_config = AWQConfig(absorb_to_layer=absorb_to_layer) + + + quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict) logger.info(f"Test AWQ with config {quant_config}") # prepare + convert API model = prepare( @@ -212,7 +182,6 @@ def test_awq_absorb_to_layer(self): calib_func(model) model = convert(model) out1 = model(self.example_inputs) - quant_config = AWQConfig() logger.info(f"Test AWQ with config {quant_config}") From 88ce9756b006c1e0d56cf0f3e240502311278ff9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 15 Jul 2024 08:49:22 +0000 Subject: [PATCH 6/6] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/3x/torch/quantization/weight_only/test_awq.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/test/3x/torch/quantization/weight_only/test_awq.py b/test/3x/torch/quantization/weight_only/test_awq.py index 97c4c6747ab..c877288f7dc 100644 --- a/test/3x/torch/quantization/weight_only/test_awq.py +++ b/test/3x/torch/quantization/weight_only/test_awq.py @@ -169,8 +169,7 @@ def test_awq_absorb_to_layer(self): "attn.out_proj": "attn.out_proj", "mlp.fc_out": ("mlp.fc_out"), } - - + quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict) logger.info(f"Test AWQ with config {quant_config}") # prepare + convert API