-
Notifications
You must be signed in to change notification settings - Fork 31k
HFQuantizer implementation for compressed-tensors library #31704
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d695ec3
f468964
41224d3
b61bfb9
ff8f1c5
c1cb55d
ef9d3f1
117d050
1901c3e
3ca270d
9a14b09
ec59052
520ded8
7dec8fc
afb550d
d9b3660
e51ac59
ccb5442
bfd9220
71a80f9
547f9cc
8acbc09
eaa5f20
4ba75fb
94ea0d3
c48840d
ab74d26
2ecf711
e1ae504
ea9e927
1c3ad5c
aa1a4f9
81a13dd
f53d7b9
d8f7073
c4fbf70
1992a88
298a638
3cb4415
a943157
64f475a
fabe8a3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| <!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
|
|
||
| Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
| the License. You may obtain a copy of the License at | ||
|
|
||
| http://www.apache.org/licenses/LICENSE-2.0 | ||
|
|
||
| Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
| an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
| specific language governing permissions and limitations under the License. | ||
|
|
||
| ⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be | ||
| rendered properly in your Markdown viewer. | ||
|
|
||
| --> | ||
| # Compressed Tensors | ||
|
|
||
| The [`compressed-tensors`](https://github.com/neuralmagic/compressed-tensors) library provides a versatile and efficient way to store and manage compressed model checkpoints. This library supports various quantization and sparsity schemes, making it a unified format for handling different model optimizations like GPTQ, AWQ, SmoothQuant, INT8, FP8, SparseGPT, and more. | ||
|
|
||
| Some of the supported formats include: | ||
| 1. `dense` | ||
| 2. `int-quantized`: INT8 quantized models | ||
| - sample [model/config](https://huggingface.co/nm-testing/tinyllama-w8a8-compressed-hf-quantizer) | ||
| 3. `float-quantized`: FP8 quantized models; currently support E4M3 | ||
| - sample [model/config](https://huggingface.co/nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat/tree/main) | ||
| 4. `pack-quantized`: INT4 or INT8 weight-quantized models, packed into INT32. For INT4, the weights have an INT4 range but are stored as INT8 and then packed into INT32. | ||
| - sample [model/config](nm-testing/tinyllama-w4a16-compressed-hf-quantizer) | ||
|
|
||
| Compressed models can be easily created using [llm-compressor](https://github.com/vllm-project/llm-compressor). | ||
| Alternatively models can be created indepedenty and serialized with a compressed tensors config. | ||
|
|
||
| To find existing models on the Hugging Face Model Hub, search for the [`compressed-tensors` tag](https://huggingface.co/models?other=compressed-tensors). | ||
|
|
||
| #### Features: | ||
| - Weight and activation precisions: FP8, INT4, INT8 (for Q/DQ arbitrary precision is allowed for INT) | ||
| - Quantization scales and zero-points strategies: [tensor, channel, group, block, token](https://github.com/neuralmagic/compressed-tensors/blob/83b2e7a969d70606421a76b9a3d112646077c8de/src/compressed_tensors/quantization/quant_args.py#L43-L52) | ||
| - Dynamic per-token activation quantization (or any static strategy) | ||
| - Sparsity can be | ||
| - Supports quantization of arbitrary modules, not just Linear modules | ||
| - Targeted support or ignoring of modules by name or class | ||
|
|
||
| ## Installation | ||
|
|
||
| It is recommended to install stable releases of compressed-tensors from [PyPI](https://pypi.org/project/compressed-tensors): | ||
| ```bash | ||
| pip install compressed-tensors | ||
| ``` | ||
|
|
||
| Developers who want to experiment with the latest features can also install the package from source: | ||
| ```bash | ||
| git clone https://github.com/neuralmagic/compressed-tensors | ||
| cd compressed-tensors | ||
| pip install -e . | ||
| ``` | ||
|
|
||
| ## Quickstart Model Load | ||
| Quantized models can be easily loaded for inference as shown below. Only models that have already been quantized can be loaded at the moment. To quantize a model into the compressed-tensors format see [llm-compressor](https://github.com/vllm-project/llm-compressor). | ||
|
|
||
| ```python | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| # Load the model in compressed-tensors format | ||
| ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf") | ||
|
|
||
| # Measure memory usage | ||
| mem_params = sum([param.nelement()*param.element_size() for param in ct_model.parameters()]) | ||
| print(f"{mem/2**30:.4f} GB") | ||
| # 8.4575 GB | ||
| ``` | ||
|
|
||
| We can see just above that the compressed-tensors FP8 checkpoint of Llama 3.1 8B is able to be loaded for inference using half of the memory of the unquantized reference checkpoint. | ||
|
|
||
| ## Sample Use Cases - Load and run an FP8 model | ||
|
|
||
| ```python | ||
| from transformers import AutoModelForCausalLM, AutoTokenizer | ||
|
|
||
| prompt = [ | ||
| "Hello, my name is", | ||
| "The capital of France is", | ||
| "The future of AI is" | ||
| ] | ||
|
|
||
| model_name = "nm-testing/Meta-Llama-3-8B-Instruct-fp8-hf_compat" | ||
|
|
||
| quantized_model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto") | ||
| tokenizer = AutoTokenizer.from_pretrained(model_name) | ||
|
|
||
| inputs = tokenizer(prompt, return_tensors="pt") | ||
| generated_ids = quantized_model.generate(**inputs, max_length=50, do_sample=False) | ||
| outputs = tokenizer.batch_decode(generated_ids) | ||
|
|
||
| print(outputs) | ||
|
|
||
| """ | ||
| ['<|begin_of_text|>Hello, my name is [Name]. I am a [Your Profession/Student] and I am here to learn about the [Course/Program] at [University/Institution]. I am excited to be here and I am looking forward to', '<|begin_of_text|>The capital of France is Paris, which is located in the north-central part of the country. Paris is the most populous city in France and is known for its stunning architecture, art museums, fashion, and romantic atmosphere. The city is home to', "<|begin_of_text|>The future of AI is here, and it's already changing the way we live and work. From virtual assistants to self-driving cars, AI is transforming industries and revolutionizing the way we interact with technology. But what does the future of AI hold"] | ||
| """ | ||
|
|
||
| ``` | ||
|
|
||
| The above shows a quick example for running generation using a `compressed-tensors` | ||
| model. Currently, once loaded the model cannot be saved. | ||
|
|
||
| ## Deep dive into a compressed-tensors model checkpoint | ||
|
|
||
| In this example we will examine how the compressed-tensors model nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf is defined through its configuration entry and see how this translates to the loaded model representation. | ||
|
|
||
| First, let us look at the [`quantization_config` of the model](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf/blob/main/config.json). At a glance it looks overwhelming with the number of entries but this is because compressed-tensors is a format that allows for flexible expression both during and after model compression. | ||
|
|
||
| In practice for checkpoint loading and inference the configuration can be simplified to not include all the default or empty entries, so we will do that here to focus on what compression is actually represented. | ||
|
|
||
| ```yaml | ||
| "quantization_config": { | ||
| "config_groups": { | ||
| "group_0": { | ||
| "input_activations": { | ||
| "num_bits": 8, | ||
| "strategy": "tensor", | ||
| "type": "float" | ||
| }, | ||
| "targets": ["Linear"], | ||
| "weights": { | ||
| "num_bits": 8, | ||
| "strategy": "tensor", | ||
| "type": "float" | ||
| } | ||
| } | ||
| }, | ||
| "format": "naive-quantized", | ||
| "ignore": ["lm_head"], | ||
| "quant_method": "compressed-tensors", | ||
| "quantization_status": "frozen" | ||
| }, | ||
| ``` | ||
|
|
||
| We can see from the above configuration that it is specifying one config group that includes weight and activation quantization to FP8 with a static per-tensor strategy. It is also worth noting that in the `ignore` list there is an entry to skip quantization of the `lm_head` module, so that module should be untouched in the checkpoint. | ||
|
|
||
| To see the result of the configuration in practice, we can simply use the [safetensors viewer](https://huggingface.co/nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf?show_file_info=model.safetensors.index.json) on the model card to see the quantized weights, input_scale, and weight_scale for all of the Linear modules in the first model layer (and so on for the rest of the layers). | ||
|
|
||
| | Tensors | Shape | Precision | | ||
| | ------- | ----- | --------- | | ||
| model.layers.0.input_layernorm.weight | [4 096] | BF16 | ||
| model.layers.0.mlp.down_proj.input_scale | [1] | BF16 | ||
| model.layers.0.mlp.down_proj.weight | [4 096, 14 336] | F8_E4M3 | ||
| model.layers.0.mlp.down_proj.weight_scale | [1] | BF16 | ||
| model.layers.0.mlp.gate_proj.input_scale | [1] | BF16 | ||
| model.layers.0.mlp.gate_proj.weight | [14 336, 4 096] | F8_E4M3 | ||
| model.layers.0.mlp.gate_proj.weight_scale | [1] | BF16 | ||
| model.layers.0.mlp.up_proj.input_scale| [1] |BF16 | ||
| model.layers.0.mlp.up_proj.weight | [14 336, 4 096] | F8_E4M3 | ||
| model.layers.0.mlp.up_proj.weight_scale | [1] | BF16 | ||
| model.layers.0.post_attention_layernorm.weight | [4 096] |BF16 | ||
| model.layers.0.self_attn.k_proj.input_scale | [1] | BF16 | ||
| model.layers.0.self_attn.k_proj.weight | [1 024, 4 096]| F8_E4M3 | ||
| model.layers.0.self_attn.k_proj.weight_scale |[1] | BF16 | ||
| model.layers.0.self_attn.o_proj.input_scale | [1] | BF16 | ||
| model.layers.0.self_attn.o_proj.weight | [4 096, 4 096] | F8_E4M3 | ||
| model.layers.0.self_attn.o_proj.weight_scale | [1] | BF16 | ||
| model.layers.0.self_attn.q_proj.input_scale | [1] | BF16 | ||
| model.layers.0.self_attn.q_proj.weight | [4 096, 4 096] | F8_E4M3 | ||
| model.layers.0.self_attn.q_proj.weight_scale | [1] | BF16 | ||
| model.layers.0.self_attn.v_proj.input_scale | [1] | BF16 | ||
| model.layers.0.self_attn.v_proj.weight | [1 024, 4 096] | F8_E4M3 | ||
| model.layers.0.self_attn.v_proj.weight_scale | [1] | BF16 | ||
|
|
||
| When we load the model with the compressed-tensors HFQuantizer integration, we can see that all of the Linear modules that are specified within the quantization configuration have been replaced by `CompressedLinear` modules that manage the compressed weights and forward pass for inference. Note that the `lm_head` mentioned before in the ignore list is still kept as an unquantized Linear module. | ||
|
|
||
| ```python | ||
| from transformers import AutoModelForCausalLM | ||
|
|
||
| ct_model = AutoModelForCausalLM.from_pretrained("nm-testing/Meta-Llama-3.1-8B-Instruct-FP8-hf") | ||
| print(ct_model) | ||
| """ | ||
| LlamaForCausalLM( | ||
| (model): LlamaModel( | ||
| (embed_tokens): Embedding(128256, 4096) | ||
| (layers): ModuleList( | ||
| (0-31): 32 x LlamaDecoderLayer( | ||
| (self_attn): LlamaSdpaAttention( | ||
| (q_proj): CompressedLinear( | ||
| in_features=4096, out_features=4096, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (k_proj): CompressedLinear( | ||
| in_features=4096, out_features=1024, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (v_proj): CompressedLinear( | ||
| in_features=4096, out_features=1024, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (o_proj): CompressedLinear( | ||
| in_features=4096, out_features=4096, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (rotary_emb): LlamaRotaryEmbedding() | ||
| ) | ||
| (mlp): LlamaMLP( | ||
| (gate_proj): CompressedLinear( | ||
| in_features=4096, out_features=14336, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (up_proj): CompressedLinear( | ||
| in_features=4096, out_features=14336, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (down_proj): CompressedLinear( | ||
| in_features=14336, out_features=4096, bias=False | ||
| (input_observer): MovingAverageMinMaxObserver() | ||
| (weight_observer): MovingAverageMinMaxObserver() | ||
| ) | ||
| (act_fn): SiLU() | ||
| ) | ||
| (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | ||
| (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05) | ||
| ) | ||
| ) | ||
| (norm): LlamaRMSNorm((4096,), eps=1e-05) | ||
| (rotary_emb): LlamaRotaryEmbedding() | ||
| ) | ||
| (lm_head): Linear(in_features=4096, out_features=128256, bias=False) | ||
| ) | ||
| """ | ||
| ``` |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,77 @@ | ||
| # Copyright 2024 The HuggingFace Inc. team. All rights reserved. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from ..utils import is_compressed_tensors_available, is_torch_available, logging | ||
| from ..utils.quantization_config import QuantizationConfigMixin | ||
| from .base import HfQuantizer | ||
|
|
||
|
|
||
| if is_torch_available(): | ||
| import torch | ||
|
|
||
| logger = logging.get_logger(__name__) | ||
|
|
||
|
|
||
| class CompressedTensorsHfQuantizer(HfQuantizer): | ||
| """ | ||
| Quantizer for the compressed_tensors package. Loads and restores models to | ||
| quantized state with compressed_tensors | ||
| """ | ||
|
|
||
| requires_calibration = True | ||
| required_packages = ["compressed_tensors"] | ||
|
|
||
| def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs): | ||
| super().__init__(quantization_config, **kwargs) | ||
|
|
||
| from compressed_tensors.compressors import ModelCompressor | ||
|
|
||
| self.compressor = ModelCompressor.from_compression_config(quantization_config) | ||
|
|
||
| def validate_environment(self, *args, **kwargs): | ||
| if not is_compressed_tensors_available(): | ||
| raise ImportError( | ||
| "Using `compressed_tensors` quantized models requires the compressed-tensors library: " | ||
| "`pip install compressed-tensors`" | ||
| ) | ||
| if not is_torch_available(): | ||
| # torch already should be installed as part of compressed tensors | ||
| raise ImportError("torch is required for using compressed-tensors quantization") | ||
|
|
||
| def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||
| if torch_dtype is None: | ||
| logger.info("Loading model using torch.float16 for compressed-tensors quantization") | ||
| torch_dtype = torch.float16 | ||
| elif torch_dtype != torch.float16: | ||
| logger.info( | ||
| "We suggest you to set `torch_dtype=torch.float16` for better efficiency with compressed_tensors." | ||
| ) | ||
| return torch_dtype | ||
|
Comment on lines
+53
to
+60
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there an issue with bfloat16? We should try to allow this for llama models There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No issue with bfloat16, we just recommend float16 as a default since that is what vLLM expects for the scale/zp |
||
|
|
||
| def _process_model_before_weight_loading(self, model, **kwargs): | ||
| from compressed_tensors.quantization import apply_quantization_config | ||
|
|
||
| ct_quantization_config = self.compressor.quantization_config | ||
| apply_quantization_config(model, ct_quantization_config, run_compressed=True) | ||
SunMarc marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def _process_model_after_weight_loading(self, model, **kwargs): | ||
| pass | ||
|
|
||
| @property | ||
| def is_trainable(self): | ||
| return False | ||
|
|
||
| @property | ||
| def is_serializable(self): | ||
| return False | ||
Uh oh!
There was an error while loading. Please reload this page.