Skip to content

Commit f5247ac

Browse files
mobichamSunMarcArthurZucker
authored
Hqq serialization (#33141)
* HQQ model serialization attempt * fix hqq dispatch and unexpected keys * style * remove check_old_param * revert to check HQQLinear in quantizer_hqq.py * revert to check HQQLinear in quantizer_hqq.py * update HqqConfig default params * make ci happy * make ci happy * revert to HQQLinear check in quantizer_hqq.py * check hqq_min version 0.2.0 * set axis=1 as default in quantization_config.py * validate_env with hqq>=0.2.0 version message * deprecated hqq kwargs message * make ci happy * remove run_expected_keys_check hack + bump to 0.2.1 min hqq version * fix unexpected_keys hqq update * add pre_quantized check * add update_expected_keys to base quantizerr * ci base.py fix? * ci base.py fix? * fix "quantization typo" src/transformers/utils/quantization_config.py Co-authored-by: Arthur <[email protected]> * fix post merge --------- Co-authored-by: Marc Sun <[email protected]> Co-authored-by: Arthur <[email protected]>
1 parent 4d5b458 commit f5247ac

File tree

8 files changed

+214
-60
lines changed

8 files changed

+214
-60
lines changed

docs/source/en/quantization/hqq.md

100644100755
Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ To quantize a model, you need to create an [`HqqConfig`]. There are two ways of
3030
from transformers import AutoModelForCausalLM, AutoTokenizer, HqqConfig
3131

3232
# Method 1: all linear layers will use the same quantization config
33-
quant_config = HqqConfig(nbits=8, group_size=64, quant_zero=False, quant_scale=False, axis=0) #axis=0 is used by default
33+
quant_config = HqqConfig(nbits=8, group_size=64)
3434
```
3535

3636
``` Python
3737
# Method 2: each linear layer with the same tag will use a dedicated quantization config
38-
q4_config = {'nbits':4, 'group_size':64, 'quant_zero':False, 'quant_scale':False}
39-
q3_config = {'nbits':3, 'group_size':32, 'quant_zero':False, 'quant_scale':False}
38+
q4_config = {'nbits':4, 'group_size':64}
39+
q3_config = {'nbits':3, 'group_size':32}
4040
quant_config = HqqConfig(dynamic_config={
4141
'self_attn.q_proj':q4_config,
4242
'self_attn.k_proj':q4_config,

src/transformers/integrations/hqq.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_
6666

6767
has_been_replaced = True
6868

69+
# Add these fake parameters to avoid loading fail
70+
for att in ["W_q", "meta"]:
71+
setattr(module, att, None)
72+
6973
if len(list(module.children())) > 0:
7074
_, has_been_replaced = _prepare_for_hqq_linear(
7175
module,
@@ -97,7 +101,7 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
97101

98102
# Convert quantization_config to layer-wise config
99103
skip_modules = quantization_config.skip_modules
100-
quant_config = quantization_config.to_dict()
104+
quant_config = quantization_config.quant_config
101105
linear_tags = list(set(linear_tags) - set(skip_modules) - set(modules_to_not_convert))
102106

103107
if any(key in linear_tags for key in quant_config.keys()):
@@ -113,7 +117,11 @@ def prepare_for_hqq_linear(model, quantization_config=None, modules_to_not_conve
113117
)
114118

115119
# We store quantization config as linear_tag -> hqq quant config
116-
model.config.quantization_config = patch_params
120+
model.config.quantization_config = {
121+
"quant_config": quant_config,
122+
"quant_method": quantization_config.quant_method,
123+
"skip_modules": skip_modules,
124+
}
117125

118126
if not has_been_replaced:
119127
logger.warning("No linear modules were found in your model for quantization.")

src/transformers/modeling_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -934,12 +934,17 @@ def _load_state_dict_into_meta_model(
934934
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
935935
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
936936
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
937+
937938
old_param = model
938939
splits = param_name.split(".")
939940
for split in splits:
940941
old_param = getattr(old_param, split)
942+
# Not all the attributes of a module are Parameters/Tensor
943+
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
944+
old_param = None
941945
if old_param is None:
942946
break
947+
943948
if old_param is not None:
944949
if dtype is None:
945950
param = param.to(old_param.dtype)
@@ -3819,6 +3824,7 @@ def from_pretrained(
38193824
from_pt = not (from_tf | from_flax)
38203825

38213826
# load pt weights early so that we know which dtype to init the model under
3827+
38223828
if from_pt:
38233829
if not is_sharded and state_dict is None:
38243830
# Time to load the checkpoint
@@ -4176,6 +4182,9 @@ def _load_pretrained_model(
41764182
expected_keys = list(model_state_dict.keys())
41774183
prefix = model.base_model_prefix
41784184

4185+
if hf_quantizer is not None:
4186+
expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, loaded_keys)
4187+
41794188
def _fix_key(key):
41804189
if "beta" in key:
41814190
return key.replace("beta", "bias")
@@ -4290,7 +4299,7 @@ def _fix_key(key):
42904299
value = torch.empty(*param.size(), dtype=target_dtype)
42914300
if (
42924301
not is_quantized
4293-
or getattr(hf_quantizer, "requires_parameters_quantization", False)
4302+
or (getattr(hf_quantizer, "requires_parameters_quantization", False))
42944303
or not hf_quantizer.check_quantized_param(
42954304
model, param_value=value, param_name=key, state_dict={}
42964305
)

src/transformers/quantizers/base.py

100644100755
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,18 @@ def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> Li
109109
"""
110110
return missing_keys
111111

112+
def update_expected_keys(self, model, expected_keys: List[str], loaded_keys: List[str]) -> List[str]:
113+
"""
114+
Override this method if you want to adjust the `update_expected_keys`.
115+
116+
Args:
117+
expected_keys (`List[str]`, *optional*):
118+
The list of the expected keys in the initialized model.
119+
loaded_keys (`List[str]`, *optional*):
120+
The list of the loaded keys in the checkpoint.
121+
"""
122+
return expected_keys
123+
112124
def get_special_dtypes_update(self, model, torch_dtype: "torch.dtype") -> Dict[str, "torch.dtype"]:
113125
"""
114126
returns dtypes for modules that are not quantized - used for the computation of the device_map in case

src/transformers/quantizers/quantizer_hqq.py

Lines changed: 105 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(self, quantization_config, **kwargs):
6262
def validate_environment(self, *args, **kwargs):
6363
if not (is_hqq_available()):
6464
raise ImportError(
65-
"HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
65+
"A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`."
6666
)
6767

6868
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
@@ -91,6 +91,65 @@ def validate_environment(self, *args, **kwargs):
9191
else:
9292
self.using_multi_gpu = len(set(device_map.values())) > 1
9393

94+
def update_missing_keys(
95+
self, model: "PreTrainedModel", missing_keys: List[str], prefix: str, **kwargs
96+
) -> List[str]:
97+
if self.pre_quantized:
98+
return [key for key in missing_keys if ("weight" not in key)]
99+
else:
100+
return missing_keys
101+
102+
# Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
103+
def update_expected_keys(
104+
self, model: "PreTrainedModel", expected_keys: List[str], loaded_keys: List[str]
105+
) -> List[str]:
106+
if not self.pre_quantized:
107+
return expected_keys
108+
109+
# Collects all quantizable (linear) layers
110+
def _find_hqq_quantizable_layers(model, layers):
111+
for name, module in model.named_children():
112+
if isinstance(module, (torch.nn.Linear)):
113+
layers.add(module.name)
114+
_find_hqq_quantizable_layers(module, layers)
115+
116+
new_keys = set(expected_keys)
117+
if is_hqq_available():
118+
from hqq.core.quantize import HQQLinear
119+
120+
# Name modules
121+
for name, module in model.named_modules():
122+
module.name = name
123+
124+
# valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
125+
_valid_modules = set()
126+
_find_hqq_quantizable_layers(model, _valid_modules)
127+
_valid_modules -= set(model.config.quantization_config["skip_modules"])
128+
129+
# Append new expected layers based on _ref_keys
130+
_ref_keys = HQQLinear(
131+
linear_layer=None, quant_config=None, compute_dtype=torch.float16, device="cpu"
132+
).state_dict_keys() - {"bias"}
133+
134+
# Clean-up
135+
_rm_keys = set()
136+
for key in new_keys:
137+
if any(_module in key for _module in _valid_modules):
138+
_rm_keys.add(key)
139+
new_keys -= _rm_keys
140+
# At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
141+
142+
# Re-populate Linear/HQQLinear
143+
for _module in _valid_modules:
144+
if _module + ".weight" in loaded_keys:
145+
new_keys.add(_module + ".weight")
146+
else:
147+
new_keys.update({_module + "." + _ref_key for _ref_key in _ref_keys})
148+
if _module + ".bias" in loaded_keys:
149+
new_keys.add(_module + ".bias")
150+
151+
return list(new_keys)
152+
94153
def check_quantized_param(
95154
self,
96155
model: "PreTrainedModel",
@@ -99,9 +158,18 @@ def check_quantized_param(
99158
state_dict: Dict[str, Any],
100159
**kwargs,
101160
) -> bool:
161+
if is_hqq_available():
162+
from hqq.core.quantize import HQQLinear
102163
module, tensor_name = get_module_from_name(model, param_name)
103164

104-
return isinstance(module, torch.nn.Linear) and (tensor_name == "weight")
165+
if self.pre_quantized:
166+
return (
167+
(isinstance(module, torch.nn.Linear) or isinstance(module, HQQLinear))
168+
and tensor_name != "weight"
169+
and tensor_name != "bias"
170+
)
171+
else:
172+
return isinstance(module, torch.nn.Linear) and tensor_name == "weight"
105173

106174
def create_quantized_param(
107175
self,
@@ -122,21 +190,50 @@ def create_quantized_param(
122190
from hqq.core.quantize import HQQLinear
123191

124192
module, tensor_name = get_module_from_name(model, param_name)
125-
126-
layer_name = param_name.replace(".weight", "").replace(".bias", "")
193+
layer_name = ".".join(param_name.split(".")[:-1])
127194
parent_module = find_parent(model, layer_name)
128195
node = layer_name.split(".")[-1]
129196

130-
# Step 0: set module state_dict
131-
module_state_dict = {key.split(".")[-1]: state_dict[key] for key in state_dict if layer_name in key}
197+
# set module state_dict
198+
module_state_dict = {}
199+
for k, v in state_dict.items():
200+
if layer_name + "." in k:
201+
module_state_dict[k.split(".")[-1]] = v
202+
if unexpected_keys is not None and k in unexpected_keys:
203+
unexpected_keys.remove(k)
204+
205+
if self.pre_quantized:
206+
if isinstance(module, HQQLinear):
207+
return
208+
else:
209+
hqq_layer = HQQLinear(
210+
linear_layer=None,
211+
quant_config=None,
212+
compute_dtype=self.torch_dtype,
213+
device=target_device,
214+
)
215+
216+
hqq_layer.load_state_dict(module_state_dict)
217+
218+
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
219+
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
220+
221+
if self.using_multi_gpu:
222+
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
223+
224+
setattr(parent_module, node, hqq_layer)
225+
226+
# cleanup
227+
del module.__dict__, module
228+
torch.cuda.empty_cache()
229+
return
132230

133231
# Step 1: populate module with weight/bias from module state dict
134232
for key in module_state_dict:
135233
setattr(module, key, torch.nn.Parameter(module_state_dict[key]))
136234

137235
# Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
138236
# directly doesn't work.
139-
140237
if hasattr(module, "quant_config"):
141238
hqq_layer = HQQLinear(
142239
module,
@@ -192,7 +289,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
192289
return model
193290

194291
def is_serializable(self, safe_serialization=None):
195-
return False
292+
return True
196293

197294
@property
198295
def is_trainable(self) -> bool:

src/transformers/utils/import_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
9292
FSDP_MIN_VERSION = "1.12.0"
9393
GGUF_MIN_VERSION = "0.10.0"
9494
XLA_FSDPV2_MIN_VERSION = "2.2.0"
95+
HQQ_MIN_VERSION = "0.2.1"
9596

9697

9798
_accelerate_available, _accelerate_version = _is_package_available("accelerate", return_version=True)
@@ -181,7 +182,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
181182
_torchdistx_available = _is_package_available("torchdistx")
182183
_torchvision_available = _is_package_available("torchvision")
183184
_mlx_available = _is_package_available("mlx")
184-
_hqq_available = _is_package_available("hqq")
185+
_hqq_available, _hqq_version = _is_package_available("hqq", return_version=True)
185186
_tiktoken_available = _is_package_available("tiktoken")
186187
_blobfile_available = _is_package_available("blobfile")
187188
_liger_kernel_available = _is_package_available("liger_kernel")
@@ -323,8 +324,8 @@ def is_torch_deterministic():
323324
return True
324325

325326

326-
def is_hqq_available():
327-
return _hqq_available
327+
def is_hqq_available(min_version: str = HQQ_MIN_VERSION):
328+
return _hqq_available and version.parse(_hqq_version) >= version.parse(min_version)
328329

329330

330331
def is_pygments_available():

src/transformers/utils/quantization_config.py

Lines changed: 27 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -193,15 +193,9 @@ class HqqConfig(QuantizationConfigMixin):
193193
Number of bits. Supported values are (8, 4, 3, 2, 1).
194194
group_size (`int`, *optional*, defaults to 64):
195195
Group-size value. Supported values are any value that is divisble by weight.shape[axis]).
196-
quant_zero (`bool`, *optional*, defaults to `True`):
197-
Quantize the zero-point if set to `True`.
198-
quant_scale (`bool`, *optional*, defaults to `False`):
199-
Quantize the scaling if set to `True`.
200-
offload_meta (`bool`, *optional*, defaults to `False`):
201-
Offload the meta-data to the CPU if set to `True`.
202196
view_as_float (`bool`, *optional*, defaults to `False`):
203197
View the quantized weight as float (used in distributed training) if set to `True`.
204-
axis (`int`, *optional*, defaults to 0):
198+
axis (`Optional[int]`, *optional*):
205199
Axis along which grouping is performed. Supported values are 0 or 1.
206200
dynamic_config (dict, *optional*):
207201
Parameters for dynamic configuration. The key is the name tag of the layer and the value is a quantization config.
@@ -216,18 +210,25 @@ def __init__(
216210
self,
217211
nbits: int = 4,
218212
group_size: int = 64,
219-
quant_zero: bool = True,
220-
quant_scale: bool = False,
221-
offload_meta: bool = False,
222213
view_as_float: bool = False,
223-
axis: int = 0,
214+
axis: Optional[int] = None,
224215
dynamic_config: Optional[dict] = None,
225216
skip_modules: List[str] = ["lm_head"],
226217
**kwargs,
227218
):
228219
if is_hqq_available():
229220
from hqq.core.quantize import BaseQuantizeConfig as HQQBaseQuantizeConfig
230221

222+
for deprecated_key in ["quant_zero", "quant_scale", "offload_meta"]:
223+
if deprecated_key in kwargs:
224+
logger.info(
225+
deprecated_key + " is deprecated. This parameter will be ignored in quantization settings."
226+
)
227+
228+
if axis is None:
229+
axis = 1
230+
logger.info("Setting axis=1 as faster backends such as TorchAO or BitBlas are only compatible with it.")
231+
231232
if axis not in [0, 1]:
232233
raise ValueError("Invalid axis value. Only 0 and 1 are allowed.")
233234

@@ -240,9 +241,6 @@ def __init__(
240241
**{
241242
"nbits": nbits,
242243
"group_size": group_size,
243-
"quant_zero": quant_zero,
244-
"quant_scale": quant_scale,
245-
"offload_meta": offload_meta,
246244
"view_as_float": view_as_float,
247245
"axis": axis,
248246
}
@@ -259,12 +257,26 @@ def post_init(self):
259257
"""
260258
pass
261259

260+
@classmethod
261+
def from_dict(cls, config: Dict[str, Any]):
262+
"""
263+
Override from_dict, used in AutoQuantizationConfig.from_dict in quantizers/auto.py
264+
"""
265+
instance = cls()
266+
instance.quant_config = config["quant_config"]
267+
instance.skip_modules = config["skip_modules"]
268+
return instance
269+
262270
def to_dict(self) -> Dict[str, Any]:
263271
"""
264272
Serializes this instance to a Python dictionary. Returns:
265273
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
266274
"""
267-
return self.quant_config
275+
return {
276+
"quant_config": self.quant_config,
277+
"quant_method": self.quant_method,
278+
"skip_modules": self.skip_modules,
279+
}
268280

269281
def __repr__(self):
270282
config_dict = self.to_dict()

0 commit comments

Comments
 (0)