Skip to content

Conversation

@fahadh4ilyas
Copy link
Contributor

Because quant_config is gone when you load model using from_quantized. I tried to re-add the quant_config here so then when we call prepare_for_inference for loaded quantized model, it will not crash because quant_config not found.

@mobicham
Copy link
Collaborator

Thanks a lot for the effort @fahadh4ilyas !

That is correct, as a temporary solution, there's this patching functions that adds a quant_config: https://github.com/mobiusml/hqq/blob/master/hqq/utils/patching.py#L29

There's an easy way to do this, without needing a separate json:

  • Add self.quant_config in HQQLinear.state_dict()
  • In load_state_dict, you simply do self.quant_config = state_dict['quant_config']

However, this is going in a different direction, will explain below.

Current Direction

I am currently refactoring the whole serialization logic to make it compatible with safetensors. The goal is to be able to directly save/load HQQ-quantized nodels with HF transformers.
Safetensors has many limitations: we can only put torch.Tensor as value, and nested dictionaries are not allowed. So we can't just put it directly in state_dict.

For the moment, I added support for quant_config loading from a safetensors-compatible state_dict, but it doesn't support quantized scale/zero just yet: https://github.com/mobiusml/hqq/blob/master/hqq/core/quantize.py#L569-L601

The way how it works right now is that state_dict supports both the old format, and a new encoded format that encodes anything that is not torch.Tensor into a torch.Tensor, controlled by the flag self.encoded_state_dict , which is by default True now.

import torch
compute_dtype = torch.float16
device = 'cuda:0'

from transformers import AutoModelForCausalLM
model = AutoModelForCausalLM.from_pretrained('meta-llama/Meta-Llama-3-8B', torch_dtype=compute_dtype, cache_dir='/nas/hicham/tmp/')

#Quantize
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
from hqq.models.hf.base import AutoHQQHFModel
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_scale=False, quant_zero=False, axis=1, offload_meta=False) 
model = AutoHQQHFModel.quantize_model(model, quant_config=quant_config, compute_dtype=compute_dtype, device=device)

##########################################
#Safetensors save/load layer check
from safetensors import safe_open
from safetensors.torch import save_file

 _state_dict = model.model.layers[0].self_attn.q_proj.state_dict()
 save_file(_state_dict, "layer.safetensors")

 _state_dict_loaded= {}
 with safe_open("layer.safetensors", framework="pt") as f:
    for key in f.keys():
        _state_dict_loaded[key] = f.get_tensor(key)
#######################################
#Model save/load check (with hqq lib)

AutoHQQHFModel.save_quantized(model, 'llama3-hqq')
model_loaded = AutoHQQHFModel.from_quantized("llama3-hqq")

#quant_config loaded
print(model_loaded.model.layers[0].self_attn.q_proj.quant_config)

Next step is to use this logic to save/load HQQ-quantized model with HF transformers. Then we can get back to supporting quantized scale/zero.

Happy to hear suggestions from you regarding this!

@fahadh4ilyas
Copy link
Contributor Author

Doesn't safetensors support metadata? How about the meta and quant_config is put inside the metadata?

@mobicham
Copy link
Collaborator

Yeah I thought about it, but it will make things even more complicated, since it will require more work on the transformers lib side. Putting everything in state_dict simplifies the process since iterating is much quicker and we have more freedom.

@fahadh4ilyas
Copy link
Contributor Author

Yeah I thought about it, but it will make things even more complicated, since it will require more work on the transformers lib side. Putting everything in state_dict simplifies the process since iterating is much quicker and we have more freedom.

What do you mean by "will require more work on the transformers"? Because the current way to save_quantize doesn't require to change transformers.

By using metadata, we only split current state_dict into two parts which are "tensors" and "non-tensors". Later, when we want to load_weight, we just combine splitted dictionary, right?

@mobicham
Copy link
Collaborator

hqq's save_quantized wouldn't require changes in transformers that's correct, but the goal is to have official serialization support with HF transformers directly, so we would be able to save models via save_pretrained, not just via the hqq lib.

I am trying to see what is the right way of doing this with @SunMarc

For now, the logic is working just fine with state_dict: we can save the whole model as a safetensor, and hqq lib save_quantized is also working with the same state_dict format, which is good. The only limitation is that meta-data offloading/quantized scale-zero is not supported. Meta-data offloading / quant scale-zero is anyway not supported by the fast backends, so I was even thinking of completely dropping off support for it since we can't even use it for fast inference.
So from 0.2.0 on, we keep things simple and only have floating-point scale/zero on the same device.

@mobicham
Copy link
Collaborator

I also tried loading a model saved with the previous version (https://huggingface.co/mobiuslabsgmbh/Llama-2-7b-chat-hf_4bitnogs_hqq) and it worked without any issue, which is good news for backward compatibility.
Now we just need to see how HQQLinear.load_state_dict behaves when used inside HF transformers.

@mobicham
Copy link
Collaborator

Draft pull request here: huggingface/transformers#32056

@mobicham
Copy link
Collaborator

Closing this since we are very close to full transformers serialization support: huggingface/transformers#33141

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants