Skip to content

Commit de43d85

Browse files
authored
Support absorb dict for awq (#1920)
Signed-off-by: Kaihui-intel <[email protected]>
1 parent e976595 commit de43d85

File tree

4 files changed

+113
-30
lines changed

4 files changed

+113
-30
lines changed

neural_compressor/torch/algorithms/weight_only/awq.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,36 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}
8989
return block_absorb_dict, absorb_layer_dict
9090

9191

92+
def _get_absorb_dict(model, absorb_layer_dict):
93+
"""Get absorbed layer per block from absorbed layer dict.
94+
95+
Args:
96+
model (torch.nn.Module): input model
97+
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
98+
99+
Returns:
100+
block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...}
101+
"""
102+
block_absorb_dict = {}
103+
block_prefix, block_num = get_block_prefix(model)
104+
new_absorb_layer_dict = {}
105+
for i in range(block_num):
106+
block_absorb_dict[i] = []
107+
block_name = block_prefix + "." + str(i) + "."
108+
109+
for k, v in absorb_layer_dict.items():
110+
111+
if isinstance(v, str):
112+
name_list = (block_name + v,)
113+
else:
114+
name_list = tuple(block_name + vv for vv in v)
115+
block_absorb_dict[i].append(name_list)
116+
new_absorb_layer_dict[name_list] = block_name + k
117+
logger.debug(f"The absorbed layers per block: {block_absorb_dict}")
118+
logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}")
119+
return block_absorb_dict, new_absorb_layer_dict
120+
121+
92122
@torch.no_grad()
93123
def _get_weight_scale(weight, q_group_size=-1):
94124
org_shape = weight.shape
@@ -123,6 +153,7 @@ def __init__(
123153
total_block_args=[],
124154
total_block_kwargs=[],
125155
device="auto",
156+
absorb_layer_dict={},
126157
):
127158

128159
self.example_inputs = example_inputs
@@ -140,6 +171,7 @@ def __init__(
140171
self.scheme = scheme
141172
self.use_full_range = use_full_range
142173
self.weight_config = weight_config
174+
self.absorb_layer_dict = absorb_layer_dict
143175

144176
def _move_model_and_data_to_device(self):
145177
# Put the model and example_inputs into target device
@@ -164,13 +196,16 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu
164196
# Step 1: get absorbed module list per block, includes self-absorption
165197
# block_absorb_dict is split per block, includes all absorb relationship.
166198
# absorb_layer_dict is the inverse of block_absorb_dict for all blocks
167-
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
168-
self.model,
169-
self.example_inputs,
170-
# for only use_mse_search, folding is useless.
171-
folding=folding if use_auto_scale else False,
172-
weight_config=self.weight_config,
173-
)
199+
if not self.absorb_layer_dict:
200+
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
201+
self.model,
202+
self.example_inputs,
203+
# for only use_mse_search, folding is useless.
204+
folding=folding if use_auto_scale else False,
205+
weight_config=self.weight_config,
206+
)
207+
else:
208+
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_dict(self.model, self.absorb_layer_dict)
174209
# process per block
175210
for i, module_list in self.block_absorb_dict.items():
176211
logger.info(f"Processing block: {i+1}/{self.block_num}")
@@ -491,13 +526,15 @@ def module_inference(self, model, inputs):
491526

492527

493528
class AWQQuantizer(Quantizer):
494-
def __init__(self, quant_config: OrderedDict = {}):
529+
def __init__(self, quant_config: OrderedDict = {}, absorb_layer_dict: dict = {}):
495530
"""Init an AWQQuantizer object.
496531
497532
Args:
498533
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
534+
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
499535
"""
500536
super().__init__(quant_config)
537+
self.absorb_layer_dict = absorb_layer_dict
501538

502539
@torch.no_grad()
503540
def prepare(self, model, *args, **kwargs):
@@ -566,6 +603,7 @@ def convert(
566603
weight_config=self.quant_config,
567604
total_block_args=total_block_args,
568605
total_block_kwargs=total_block_kwargs,
606+
absorb_layer_dict=self.absorb_layer_dict,
569607
)
570608
qdq_model = awq.quantize(
571609
use_auto_scale=use_auto_scale,

neural_compressor/torch/quantization/algorithm_entry.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,10 @@ def awq_quantize_entry(
323323
from neural_compressor.torch.algorithms.weight_only.save_load import save
324324

325325
weight_config = {}
326-
for (op_name, op_type), op_config in configs_mapping.items():
327-
if op_config.name != AWQ:
326+
for (op_name, op_type), quant_config in configs_mapping.items():
327+
if quant_config.name != AWQ:
328328
continue
329-
if op_config.dtype == "fp32":
329+
if quant_config.dtype == "fp32":
330330
weight_config[op_name] = {
331331
"bits": -1,
332332
"dtype": "fp32", # skip quantization
@@ -335,31 +335,34 @@ def awq_quantize_entry(
335335
}
336336
else:
337337
weight_config[op_name] = {
338-
"dtype": op_config.dtype,
339-
"bits": op_config.bits,
340-
"group_size": op_config.group_size,
341-
"group_dim": op_config.group_dim,
342-
"scheme": "sym" if op_config.use_sym else "asym",
343-
"use_full_range": op_config.use_full_range,
344-
"use_mse_search": op_config.use_mse_search,
345-
"use_layer_wise": op_config.use_layer_wise,
346-
"use_double_quant": op_config.use_double_quant,
347-
"double_quant_dtype": op_config.double_quant_dtype,
348-
"double_quant_bits": op_config.double_quant_bits,
349-
"double_quant_scheme": op_config.double_quant_use_sym,
350-
"double_quant_group_size": op_config.double_quant_group_size,
338+
"dtype": quant_config.dtype,
339+
"bits": quant_config.bits,
340+
"group_size": quant_config.group_size,
341+
"group_dim": quant_config.group_dim,
342+
"scheme": "sym" if quant_config.use_sym else "asym",
343+
"use_full_range": quant_config.use_full_range,
344+
"use_mse_search": quant_config.use_mse_search,
345+
"use_layer_wise": quant_config.use_layer_wise,
346+
"use_double_quant": quant_config.use_double_quant,
347+
"double_quant_dtype": quant_config.double_quant_dtype,
348+
"double_quant_bits": quant_config.double_quant_bits,
349+
"double_quant_scheme": quant_config.double_quant_use_sym,
350+
"double_quant_group_size": quant_config.double_quant_group_size,
351351
}
352-
use_auto_scale = op_config.use_auto_scale
353-
use_mse_search = op_config.use_auto_clip # for awq clip
354-
folding = op_config.folding
355-
use_full_range = op_config.use_full_range
352+
use_auto_scale = quant_config.use_auto_scale
353+
use_mse_search = quant_config.use_auto_clip # for awq clip
354+
folding = quant_config.folding
355+
use_full_range = quant_config.use_full_range
356+
absorb_layer_dict = quant_config.absorb_layer_dict
356357

357358
run_fn = kwargs.get("run_fn", None)
358359
run_args = kwargs.get("run_args", None)
359360
example_inputs = kwargs.get("example_inputs", None)
360361
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."
361362

362-
quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
363+
quantizer = get_quantizer(
364+
model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_layer_dict=absorb_layer_dict
365+
)
363366
model = quantizer.execute(
364367
model,
365368
mode=mode,

neural_compressor/torch/quantization/config.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ class AWQConfig(TorchBaseConfig):
442442
"use_auto_scale",
443443
"use_auto_clip",
444444
"folding",
445+
"absorb_layer_dict",
445446
]
446447
name = AWQ
447448

@@ -468,6 +469,7 @@ def __init__(
468469
use_auto_clip: bool = True,
469470
folding: bool = False,
470471
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
472+
absorb_layer_dict: dict = {},
471473
):
472474
"""Init AWQ weight-only quantization config.
473475
@@ -490,6 +492,7 @@ def __init__(
490492
use_auto_clip (bool): Enables clip range search. Defaults to True.
491493
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
492494
default is False.
495+
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
493496
"""
494497
super().__init__(white_list=white_list)
495498
self.dtype = dtype
@@ -510,6 +513,7 @@ def __init__(
510513
self.use_auto_scale = use_auto_scale
511514
self.use_auto_clip = use_auto_clip
512515
self.folding = folding
516+
self.absorb_layer_dict = absorb_layer_dict
513517
self._post_init()
514518

515519
@classmethod
@@ -626,7 +630,7 @@ def __init__(
626630
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
627631
double_quant_group_size (int): Size of double_quant groups, default is 32.
628632
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
629-
absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}.
633+
absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}.
630634
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
631635
default is False.
632636
"""

test/3x/torch/quantization/weight_only/test_awq.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,3 +157,41 @@ def test_quant_lm_head(self):
157157
assert (
158158
id(model.model.decoder.embed_tokens.weight) == lm_head_id
159159
), "The tied lm_head weight is not deep copied, please check!"
160+
161+
def test_awq_absorb_to_layer(self):
162+
absorb_layer_dict = {
163+
"ln_1": (
164+
"attn.q_proj",
165+
"attn.k_proj",
166+
"attn.v_proj",
167+
"mlp.fc_in",
168+
),
169+
"attn.out_proj": "attn.out_proj",
170+
"mlp.fc_out": ("mlp.fc_out"),
171+
}
172+
173+
quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict)
174+
logger.info(f"Test AWQ with config {quant_config}")
175+
# prepare + convert API
176+
model = prepare(
177+
model=copy.deepcopy(self.tiny_gptj),
178+
quant_config=quant_config,
179+
example_inputs=self.example_inputs,
180+
)
181+
calib_func(model)
182+
model = convert(model)
183+
out1 = model(self.example_inputs)
184+
quant_config = AWQConfig()
185+
logger.info(f"Test AWQ with config {quant_config}")
186+
187+
# prepare + convert API
188+
model = prepare(
189+
model=copy.deepcopy(self.tiny_gptj),
190+
quant_config=quant_config,
191+
example_inputs=self.example_inputs,
192+
)
193+
calib_func(model)
194+
model = convert(model)
195+
out2 = model(self.example_inputs)
196+
197+
assert torch.all(out1[0].eq(out2[0])), "The results should be equal."

0 commit comments

Comments
 (0)