Skip to content

Commit 5cb7d81

Browse files
committed
remove check_old_param
1 parent a8704d2 commit 5cb7d81

File tree

3 files changed

+34
-44
lines changed

3 files changed

+34
-44
lines changed

src/transformers/integrations/hqq.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,10 @@ def _prepare_for_hqq_linear(model, patch_params, has_been_replaced, current_key_
6666

6767
has_been_replaced = True
6868

69+
# Add these fake parameters to avoid loading fail
70+
for att in ["W_q", "meta"]:
71+
setattr(module, att, None)
72+
6973
if len(list(module.children())) > 0:
7074
_, has_been_replaced = _prepare_for_hqq_linear(
7175
module,

src/transformers/modeling_utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -858,10 +858,14 @@ def _load_state_dict_into_meta_model(
858858

859859
is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
860860

861+
# We add this because HQQLinear dict has a very large state_dict (19 params/per module), which makes loading extremely slow
862+
run_expected_keys_check = True
863+
if isinstance(hf_quantizer, HqqHfQuantizer):
864+
run_expected_keys_check = False
865+
861866
for param_name, param in state_dict.items():
862-
# print('param_name', param_name, param_name in loaded_state_dict_keys, param_name in expected_keys)
863867
# First part of the test is always true as load_state_dict_keys always contains state_dict keys.
864-
if param_name not in loaded_state_dict_keys: # or param_name not in expected_keys: #TODO @mobicham
868+
if param_name not in loaded_state_dict_keys or ((param_name not in expected_keys) and run_expected_keys_check):
865869
continue
866870

867871
if param_name.startswith(start_prefix):
@@ -894,19 +898,15 @@ def _load_state_dict_into_meta_model(
894898
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
895899
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
896900

897-
# TODO @mobicham: We need this for Hqq Quantizer otherwise it would break because state_dict fields (W_q, etc.) are not in nn.Linear
898-
check_old_param = True
899-
if is_quantized:
900-
if isinstance(hf_quantizer, HqqHfQuantizer):
901-
check_old_param, old_param = False, None
902-
903-
if check_old_param:
904-
old_param = model
905-
splits = param_name.split(".")
906-
for split in splits:
907-
old_param = getattr(old_param, split)
908-
if old_param is None:
909-
break
901+
old_param = model
902+
splits = param_name.split(".")
903+
for split in splits:
904+
old_param = getattr(old_param, split)
905+
# Not all the attributes of a module are Parameters/Tensor
906+
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
907+
old_param = None
908+
if old_param is None:
909+
break
910910

911911
if old_param is not None:
912912
if dtype is None:

src/transformers/quantizers/quantizer_hqq.py

Lines changed: 15 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,6 @@ def create_quantized_param(
143143
parent_module = find_parent(model, layer_name)
144144
node = layer_name.split(".")[-1]
145145

146-
# print("create_quantized_param | ", 'layer_name', layer_name, type(module), hasattr(module, "quant_config")) #model.layers.0.mlp.down_proj
147-
148146
# set module state_dict
149147
module_state_dict = {}
150148
for k, v in state_dict.items():
@@ -154,39 +152,27 @@ def create_quantized_param(
154152
unexpected_keys.remove(k)
155153

156154
if self.pre_quantized:
157-
if isinstance(module, HQQLinear):
158-
return
159-
else:
155+
if isinstance(module, (torch.nn.Linear, HQQLinear)):
160156
hqq_layer = HQQLinear(
161157
linear_layer=None,
162-
quant_config=None, # module.quant_config
158+
quant_config=None,
163159
compute_dtype=self.torch_dtype,
164160
device=target_device,
165161
)
166162

167-
try:
168-
hqq_layer.load_state_dict(module_state_dict)
169-
except Exception:
170-
# TODO @mobicham: Llama3 break with model.layers.28.mlp.down_proj because its parameters are split across 2 safetensors. How to fix this?
171-
# Currently setting a fake layer so that loading doesn't break
172-
print("Error loading, setting a fake layer for", layer_name, module_state_dict.keys())
173-
hqq_layer = HQQLinear(
174-
torch.nn.Linear(in_features=module.in_features, out_features=module.out_features, bias=False),
175-
module.quant_config,
176-
compute_dtype=self.torch_dtype,
177-
device=target_device,
178-
del_orig=True,
179-
)
180-
181-
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
182-
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
183-
184-
if self.using_multi_gpu:
185-
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
186-
187-
setattr(parent_module, node, hqq_layer)
188-
torch.cuda.empty_cache()
189-
return
163+
hqq_layer.axis = None
164+
hqq_layer.channel_wise = None
165+
hqq_layer.load_state_dict(module_state_dict)
166+
167+
if hqq_layer.bias is not None and isinstance(hqq_layer.bias, torch.Tensor):
168+
hqq_layer.bias = torch.nn.Parameter(hqq_layer.bias)
169+
170+
if self.using_multi_gpu:
171+
hqq_layer = self._patch_layer_for_multigpu(hqq_layer)
172+
173+
setattr(parent_module, node, hqq_layer)
174+
torch.cuda.empty_cache()
175+
return
190176

191177
# Step 1: populate module with weight/bias from module state dict
192178
for key in module_state_dict:

0 commit comments

Comments
 (0)