Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions neural_compressor/torch/algorithms/weight_only/awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,36 @@ def _get_absorb_per_block(model, example_inputs, folding=False, weight_config={}
return block_absorb_dict, absorb_layer_dict


def _get_absorb_dict(model, absorb_layer_dict):
"""Get absorbed layer per block from absorbed layer dict.

Args:
model (torch.nn.Module): input model
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.

Returns:
block_absorb_dict: dict of absorbed layer per block. eg. {0, [[absorbed_1, xx], [xx]], ...}
"""
block_absorb_dict = {}
block_prefix, block_num = get_block_prefix(model)
new_absorb_layer_dict = {}
for i in range(block_num):
block_absorb_dict[i] = []
block_name = block_prefix + "." + str(i) + "."

for k, v in absorb_layer_dict.items():

if isinstance(v, str):
name_list = (block_name + v,)
else:
name_list = tuple(block_name + vv for vv in v)
block_absorb_dict[i].append(name_list)
new_absorb_layer_dict[name_list] = block_name + k
logger.debug(f"The absorbed layers per block: {block_absorb_dict}")
logger.debug(f"The absorb_layer_dict: {absorb_layer_dict}")
return block_absorb_dict, new_absorb_layer_dict


@torch.no_grad()
def _get_weight_scale(weight, q_group_size=-1):
org_shape = weight.shape
Expand Down Expand Up @@ -123,6 +153,7 @@ def __init__(
total_block_args=[],
total_block_kwargs=[],
device="auto",
absorb_layer_dict={},
):

self.example_inputs = example_inputs
Expand All @@ -140,6 +171,7 @@ def __init__(
self.scheme = scheme
self.use_full_range = use_full_range
self.weight_config = weight_config
self.absorb_layer_dict = absorb_layer_dict

def _move_model_and_data_to_device(self):
# Put the model and example_inputs into target device
Expand All @@ -164,13 +196,16 @@ def quantize(self, use_auto_scale=True, use_mse_search=True, folding=False, retu
# Step 1: get absorbed module list per block, includes self-absorption
# block_absorb_dict is split per block, includes all absorb relationship.
# absorb_layer_dict is the inverse of block_absorb_dict for all blocks
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
self.model,
self.example_inputs,
# for only use_mse_search, folding is useless.
folding=folding if use_auto_scale else False,
weight_config=self.weight_config,
)
if not self.absorb_layer_dict:
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_per_block(
self.model,
self.example_inputs,
# for only use_mse_search, folding is useless.
folding=folding if use_auto_scale else False,
weight_config=self.weight_config,
)
else:
self.block_absorb_dict, self.absorb_layer_dict = _get_absorb_dict(self.model, self.absorb_layer_dict)
# process per block
for i, module_list in self.block_absorb_dict.items():
logger.info(f"Processing block: {i+1}/{self.block_num}")
Expand Down Expand Up @@ -491,13 +526,15 @@ def module_inference(self, model, inputs):


class AWQQuantizer(Quantizer):
def __init__(self, quant_config: OrderedDict = {}):
def __init__(self, quant_config: OrderedDict = {}, absorb_layer_dict: dict = {}):
"""Init an AWQQuantizer object.

Args:
quant_config (OrderedDict, optional): quantization config for ops. Defaults to {}.
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
"""
super().__init__(quant_config)
self.absorb_layer_dict = absorb_layer_dict

@torch.no_grad()
def prepare(self, model, *args, **kwargs):
Expand Down Expand Up @@ -566,6 +603,7 @@ def convert(
weight_config=self.quant_config,
total_block_args=total_block_args,
total_block_kwargs=total_block_kwargs,
absorb_layer_dict=self.absorb_layer_dict,
)
qdq_model = awq.quantize(
use_auto_scale=use_auto_scale,
Expand Down
45 changes: 24 additions & 21 deletions neural_compressor/torch/quantization/algorithm_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,10 +317,10 @@ def awq_quantize_entry(
from neural_compressor.torch.algorithms.weight_only.save_load import save

weight_config = {}
for (op_name, op_type), op_config in configs_mapping.items():
if op_config.name != AWQ:
for (op_name, op_type), quant_config in configs_mapping.items():
if quant_config.name != AWQ:
continue
if op_config.dtype == "fp32":
if quant_config.dtype == "fp32":
weight_config[op_name] = {
"bits": -1,
"dtype": "fp32", # skip quantization
Expand All @@ -329,31 +329,34 @@ def awq_quantize_entry(
}
else:
weight_config[op_name] = {
"dtype": op_config.dtype,
"bits": op_config.bits,
"group_size": op_config.group_size,
"group_dim": op_config.group_dim,
"scheme": "sym" if op_config.use_sym else "asym",
"use_full_range": op_config.use_full_range,
"use_mse_search": op_config.use_mse_search,
"use_layer_wise": op_config.use_layer_wise,
"use_double_quant": op_config.use_double_quant,
"double_quant_dtype": op_config.double_quant_dtype,
"double_quant_bits": op_config.double_quant_bits,
"double_quant_scheme": op_config.double_quant_use_sym,
"double_quant_group_size": op_config.double_quant_group_size,
"dtype": quant_config.dtype,
"bits": quant_config.bits,
"group_size": quant_config.group_size,
"group_dim": quant_config.group_dim,
"scheme": "sym" if quant_config.use_sym else "asym",
"use_full_range": quant_config.use_full_range,
"use_mse_search": quant_config.use_mse_search,
"use_layer_wise": quant_config.use_layer_wise,
"use_double_quant": quant_config.use_double_quant,
"double_quant_dtype": quant_config.double_quant_dtype,
"double_quant_bits": quant_config.double_quant_bits,
"double_quant_scheme": quant_config.double_quant_use_sym,
"double_quant_group_size": quant_config.double_quant_group_size,
}
use_auto_scale = op_config.use_auto_scale
use_mse_search = op_config.use_auto_clip # for awq clip
folding = op_config.folding
use_full_range = op_config.use_full_range
use_auto_scale = quant_config.use_auto_scale
use_mse_search = quant_config.use_auto_clip # for awq clip
folding = quant_config.folding
use_full_range = quant_config.use_full_range
absorb_layer_dict = quant_config.absorb_layer_dict

run_fn = kwargs.get("run_fn", None)
run_args = kwargs.get("run_args", None)
example_inputs = kwargs.get("example_inputs", None)
assert example_inputs is not None, "Please provide example_inputs for AWQ quantization."

quantizer = get_quantizer(model, quantizer_cls=AWQQuantizer, quant_config=weight_config)
quantizer = get_quantizer(
model, quantizer_cls=AWQQuantizer, quant_config=weight_config, absorb_layer_dict=absorb_layer_dict
)
model = quantizer.execute(
model,
mode=mode,
Expand Down
6 changes: 5 additions & 1 deletion neural_compressor/torch/quantization/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,7 @@ class AWQConfig(BaseConfig):
"use_auto_scale",
"use_auto_clip",
"folding",
"absorb_layer_dict",
]
name = AWQ

Expand All @@ -451,6 +452,7 @@ def __init__(
use_auto_clip: bool = True,
folding: bool = False,
white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST,
absorb_layer_dict: dict = {},
):
"""Init AWQ weight-only quantization config.

Expand All @@ -473,6 +475,7 @@ def __init__(
use_auto_clip (bool): Enables clip range search. Defaults to True.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
default is False.
absorb_layer_dict (dict): The layer dict that scale can be absorbed, default is {}.
"""
super().__init__(white_list=white_list)
self.dtype = dtype
Expand All @@ -493,6 +496,7 @@ def __init__(
self.use_auto_scale = use_auto_scale
self.use_auto_clip = use_auto_clip
self.folding = folding
self.absorb_layer_dict = absorb_layer_dict
self._post_init()

@classmethod
Expand Down Expand Up @@ -609,7 +613,7 @@ def __init__(
double_quant_use_sym (bool): Indicates whether double_quant scale are symmetric, default is True.
double_quant_group_size (int): Size of double_quant groups, default is 32.
quant_lm_head (bool): Indicates whether quantize the lm_head layer in transformers。 Default is False.
absorb_to_layer (bool): The layer dict that scale can be absorbed, default is {}.
absorb_to_layer (dict): The layer dict that scale can be absorbed, default is {}.
folding(bool): Allow insert mul before linear when the scale cannot be absorbed by last layer,
default is False.
"""
Expand Down
38 changes: 38 additions & 0 deletions test/3x/torch/quantization/weight_only/test_awq.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,41 @@ def test_quant_lm_head(self):
assert (
id(model.model.decoder.embed_tokens.weight) == lm_head_id
), "The tied lm_head weight is not deep copied, please check!"

def test_awq_absorb_to_layer(self):
absorb_layer_dict = {
"ln_1": (
"attn.q_proj",
"attn.k_proj",
"attn.v_proj",
"mlp.fc_in",
),
"attn.out_proj": "attn.out_proj",
"mlp.fc_out": ("mlp.fc_out"),
}

quant_config = AWQConfig(absorb_layer_dict=absorb_layer_dict)
logger.info(f"Test AWQ with config {quant_config}")
# prepare + convert API
model = prepare(
model=copy.deepcopy(self.tiny_gptj),
quant_config=quant_config,
example_inputs=self.example_inputs,
)
calib_func(model)
model = convert(model)
out1 = model(self.example_inputs)
quant_config = AWQConfig()
logger.info(f"Test AWQ with config {quant_config}")

# prepare + convert API
model = prepare(
model=copy.deepcopy(self.tiny_gptj),
quant_config=quant_config,
example_inputs=self.example_inputs,
)
calib_func(model)
model = convert(model)
out2 = model(self.example_inputs)

assert torch.all(out1[0].eq(out2[0])), "The results should be equal."