@@ -62,7 +62,7 @@ def __init__(self, quantization_config, **kwargs):
6262 def validate_environment (self , * args , ** kwargs ):
6363 if not (is_hqq_available ()):
6464 raise ImportError (
65- "HQQ is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`"
65+ "A valid HQQ version (>=0.2.1) is not available. Please follow the instructions to install it: `https://github.com/mobiusml/hqq/`. "
6666 )
6767
6868 if kwargs .get ("from_tf" , False ) or kwargs .get ("from_flax" , False ):
@@ -91,6 +91,65 @@ def validate_environment(self, *args, **kwargs):
9191 else :
9292 self .using_multi_gpu = len (set (device_map .values ())) > 1
9393
94+ def update_missing_keys (
95+ self , model : "PreTrainedModel" , missing_keys : List [str ], prefix : str , ** kwargs
96+ ) -> List [str ]:
97+ if self .pre_quantized :
98+ return [key for key in missing_keys if ("weight" not in key )]
99+ else :
100+ return missing_keys
101+
102+ # Adds missing keys for HQQLinear modules that are loaded but the model with initialized with torch.nn.Linear
103+ def update_expected_keys (
104+ self , model : "PreTrainedModel" , expected_keys : List [str ], loaded_keys : List [str ]
105+ ) -> List [str ]:
106+ if not self .pre_quantized :
107+ return expected_keys
108+
109+ # Collects all quantizable (linear) layers
110+ def _find_hqq_quantizable_layers (model , layers ):
111+ for name , module in model .named_children ():
112+ if isinstance (module , (torch .nn .Linear )):
113+ layers .add (module .name )
114+ _find_hqq_quantizable_layers (module , layers )
115+
116+ new_keys = set (expected_keys )
117+ if is_hqq_available ():
118+ from hqq .core .quantize import HQQLinear
119+
120+ # Name modules
121+ for name , module in model .named_modules ():
122+ module .name = name
123+
124+ # valid modules are Linear layers that have HQQLinear state_dict. We ignore skip_modules and any layers with Linear state_dict() params
125+ _valid_modules = set ()
126+ _find_hqq_quantizable_layers (model , _valid_modules )
127+ _valid_modules -= set (model .config .quantization_config ["skip_modules" ])
128+
129+ # Append new expected layers based on _ref_keys
130+ _ref_keys = HQQLinear (
131+ linear_layer = None , quant_config = None , compute_dtype = torch .float16 , device = "cpu"
132+ ).state_dict_keys () - {"bias" }
133+
134+ # Clean-up
135+ _rm_keys = set ()
136+ for key in new_keys :
137+ if any (_module in key for _module in _valid_modules ):
138+ _rm_keys .add (key )
139+ new_keys -= _rm_keys
140+ # At this point, new_keys contains all the keys of the layers that are NOT HQQLinear or torch.nn.Linear
141+
142+ # Re-populate Linear/HQQLinear
143+ for _module in _valid_modules :
144+ if _module + ".weight" in loaded_keys :
145+ new_keys .add (_module + ".weight" )
146+ else :
147+ new_keys .update ({_module + "." + _ref_key for _ref_key in _ref_keys })
148+ if _module + ".bias" in loaded_keys :
149+ new_keys .add (_module + ".bias" )
150+
151+ return list (new_keys )
152+
94153 def check_quantized_param (
95154 self ,
96155 model : "PreTrainedModel" ,
@@ -99,9 +158,18 @@ def check_quantized_param(
99158 state_dict : Dict [str , Any ],
100159 ** kwargs ,
101160 ) -> bool :
161+ if is_hqq_available ():
162+ from hqq .core .quantize import HQQLinear
102163 module , tensor_name = get_module_from_name (model , param_name )
103164
104- return isinstance (module , torch .nn .Linear ) and (tensor_name == "weight" )
165+ if self .pre_quantized :
166+ return (
167+ (isinstance (module , torch .nn .Linear ) or isinstance (module , HQQLinear ))
168+ and tensor_name != "weight"
169+ and tensor_name != "bias"
170+ )
171+ else :
172+ return isinstance (module , torch .nn .Linear ) and tensor_name == "weight"
105173
106174 def create_quantized_param (
107175 self ,
@@ -122,21 +190,50 @@ def create_quantized_param(
122190 from hqq .core .quantize import HQQLinear
123191
124192 module , tensor_name = get_module_from_name (model , param_name )
125-
126- layer_name = param_name .replace (".weight" , "" ).replace (".bias" , "" )
193+ layer_name = "." .join (param_name .split ("." )[:- 1 ])
127194 parent_module = find_parent (model , layer_name )
128195 node = layer_name .split ("." )[- 1 ]
129196
130- # Step 0: set module state_dict
131- module_state_dict = {key .split ("." )[- 1 ]: state_dict [key ] for key in state_dict if layer_name in key }
197+ # set module state_dict
198+ module_state_dict = {}
199+ for k , v in state_dict .items ():
200+ if layer_name + "." in k :
201+ module_state_dict [k .split ("." )[- 1 ]] = v
202+ if unexpected_keys is not None and k in unexpected_keys :
203+ unexpected_keys .remove (k )
204+
205+ if self .pre_quantized :
206+ if isinstance (module , HQQLinear ):
207+ return
208+ else :
209+ hqq_layer = HQQLinear (
210+ linear_layer = None ,
211+ quant_config = None ,
212+ compute_dtype = self .torch_dtype ,
213+ device = target_device ,
214+ )
215+
216+ hqq_layer .load_state_dict (module_state_dict )
217+
218+ if hqq_layer .bias is not None and isinstance (hqq_layer .bias , torch .Tensor ):
219+ hqq_layer .bias = torch .nn .Parameter (hqq_layer .bias )
220+
221+ if self .using_multi_gpu :
222+ hqq_layer = self ._patch_layer_for_multigpu (hqq_layer )
223+
224+ setattr (parent_module , node , hqq_layer )
225+
226+ # cleanup
227+ del module .__dict__ , module
228+ torch .cuda .empty_cache ()
229+ return
132230
133231 # Step 1: populate module with weight/bias from module state dict
134232 for key in module_state_dict :
135233 setattr (module , key , torch .nn .Parameter (module_state_dict [key ]))
136234
137235 # Step 2: Replace module with either HQQLinear or move it to device. We do this via setattr on the parent as doing on it on the module
138236 # directly doesn't work.
139-
140237 if hasattr (module , "quant_config" ):
141238 hqq_layer = HQQLinear (
142239 module ,
@@ -192,7 +289,7 @@ def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs
192289 return model
193290
194291 def is_serializable (self , safe_serialization = None ):
195- return False
292+ return True
196293
197294 @property
198295 def is_trainable (self ) -> bool :
0 commit comments