@@ -143,8 +143,6 @@ def create_quantized_param(
143143 parent_module = find_parent (model , layer_name )
144144 node = layer_name .split ("." )[- 1 ]
145145
146- # print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj
147-
148146 # set module state_dict
149147 module_state_dict = {}
150148 for k , v in state_dict .items ():
@@ -154,39 +152,27 @@ def create_quantized_param(
154152 unexpected_keys .remove (k )
155153
156154 if self .pre_quantized :
157- if isinstance (module , HQQLinear ):
158- return
159- else :
155+ if isinstance (module , (torch .nn .Linear , HQQLinear )):
160156 hqq_layer = HQQLinear (
161157 linear_layer = None ,
162- quant_config = None , # module.quant_config
158+ quant_config = None ,
163159 compute_dtype = self .torch_dtype ,
164160 device = target_device ,
165161 )
166162
167- try :
168- hqq_layer .load_state_dict (module_state_dict )
169- except Exception :
170- # TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
171- # Currently setting a fake layer so that loading doesn't break
172- print ("Error loading, setting a fake layer for" , layer_name , module_state_dict .keys ())
173- hqq_layer = HQQLinear (
174- torch .nn .Linear (in_features = module .in_features , out_features = module .out_features , bias = False ),
175- module .quant_config ,
176- compute_dtype = self .torch_dtype ,
177- device = target_device ,
178- del_orig = True ,
179- )
180-
181- if hqq_layer .bias is not None and isinstance (hqq_layer .bias , torch .Tensor ):
182- hqq_layer .bias = torch .nn .Parameter (hqq_layer .bias )
183-
184- if self .using_multi_gpu :
185- hqq_layer = self ._patch_layer_for_multigpu (hqq_layer )
186-
187- setattr (parent_module , node , hqq_layer )
188- torch .cuda .empty_cache ()
189- return
163+ hqq_layer .axis = None
164+ hqq_layer .channel_wise = None
165+ hqq_layer .load_state_dict (module_state_dict )
166+
167+ if hqq_layer .bias is not None and isinstance (hqq_layer .bias , torch .Tensor ):
168+ hqq_layer .bias = torch .nn .Parameter (hqq_layer .bias )
169+
170+ if self .using_multi_gpu :
171+ hqq_layer = self ._patch_layer_for_multigpu (hqq_layer )
172+
173+ setattr (parent_module , node , hqq_layer )
174+ torch .cuda .empty_cache ()
175+ return
190176
191177 # Step 1: populate module with weight/bias from module state dict
192178 for key in module_state_dict :
0 commit comments