Skip to content

Commit 8214d6e

Browse files
authored
add exllamav2 arg (#26437)
* add_ xllamav2 arg * add test * style * add check * add doc * replace by use_exllama_v2 * fix tests * fix doc * style * better condition * fix logic * add deprecate msg
1 parent d7cb5e1 commit 8214d6e

File tree

4 files changed

+93
-5
lines changed

4 files changed

+93
-5
lines changed

docs/source/en/main_classes/quantization.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,22 @@ For 4-bit model, you can use the exllama kernels in order to a faster inference
128128

129129
```py
130130
import torch
131-
gptq_config = GPTQConfig(bits=4, disable_exllama=False)
131+
gptq_config = GPTQConfig(bits=4)
132+
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
133+
```
134+
135+
With the release of the exllamav2 kernels, you can get faster inference speed compared to the exllama kernels. You just need to
136+
pass `use_exllama_v2=True` in [`GPTQConfig`] and disable exllama kernels:
137+
138+
```py
139+
import torch
140+
gptq_config = GPTQConfig(bits=4, use_exllama_v2=True)
132141
model = AutoModelForCausalLM.from_pretrained("{your_username}/opt-125m-gptq", device_map="auto", quantization_config = gptq_config)
133142
```
134143

135144
Note that only 4-bit models are supported for now. Furthermore, it is recommended to deactivate the exllama kernels if you are finetuning a quantized model with peft.
136145

146+
You can find the benchmark of these kernels [here](https://github.com/huggingface/optimum/tree/main/tests/benchmark#gptq-benchmark)
137147
#### Fine-tune a quantized model
138148

139149
With the official support of adapters in the Hugging Face ecosystem, you can fine-tune models that have been quantized with GPTQ.

src/transformers/modeling_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2759,7 +2759,7 @@ def from_pretrained(
27592759
logger.warning(
27602760
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
27612761
"`quantization_config` attribute and has already quantized weights. However, loading attributes"
2762-
" (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
2762+
" (e.g. disable_exllama, use_cuda_fp16, max_input_length, use_exllama_v2) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
27632763
)
27642764
if (
27652765
quantization_method_from_args == QuantizationMethod.GPTQ

src/transformers/utils/quantization_config.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ class GPTQConfig(QuantizationConfigMixin):
349349
max_input_length (`int`, *optional*):
350350
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
351351
length. It is specific to the exllama backend with act-order.
352+
use_exllama_v2 (`bool`, *optional*, defaults to `False`):
353+
Whether to use exllamav2 backend. Only works with `bits` = 4.
352354
"""
353355

354356
def __init__(
@@ -369,6 +371,7 @@ def __init__(
369371
pad_token_id: Optional[int] = None,
370372
disable_exllama: bool = False,
371373
max_input_length: Optional[int] = None,
374+
use_exllama_v2: bool = False,
372375
**kwargs,
373376
):
374377
self.quant_method = QuantizationMethod.GPTQ
@@ -388,11 +391,14 @@ def __init__(
388391
self.pad_token_id = pad_token_id
389392
self.disable_exllama = disable_exllama
390393
self.max_input_length = max_input_length
394+
self.use_exllama_v2 = use_exllama_v2
395+
# needed for compatibility with optimum gptq config
396+
self.disable_exllamav2 = not use_exllama_v2
391397
self.post_init()
392398

393399
def get_loading_attributes(self):
394400
attibutes_dict = copy.deepcopy(self.__dict__)
395-
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"]
401+
loading_attibutes = ["disable_exllama", "use_exllama_v2", "use_cuda_fp16", "max_input_length"]
396402
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
397403
return loading_attibutes_dict
398404

@@ -418,3 +424,19 @@ def post_init(self):
418424
f"""dataset needs to be either a list of string or a value in
419425
['wikitext2','c4','c4-new','ptb','ptb-new'], but we found {self.dataset}"""
420426
)
427+
if self.bits == 4:
428+
if self.use_exllama_v2:
429+
optimum_version = version.parse(importlib.metadata.version("optimum"))
430+
autogptq_version = version.parse(importlib.metadata.version("auto_gptq"))
431+
if optimum_version <= version.parse("1.13.2") or autogptq_version <= version.parse("0.4.2"):
432+
raise ValueError(
433+
f"You need optimum > 1.13.2 and auto-gptq > 0.4.2 . Make sure to have that version installed - detected version : optimum {optimum_version} and autogptq {autogptq_version}"
434+
)
435+
self.disable_exllama = True
436+
logger.warning("You have activated exllamav2 kernels. Exllama kernels will be disabled.")
437+
if not self.disable_exllama:
438+
logger.warning(
439+
"""You have activated exllama backend. Note that you can get better inference
440+
speed using exllamav2 kernel by setting `use_exllama_v2=True`.`disable_exllama` will be deprecated
441+
in future version."""
442+
)

tests/quantization/gptq/test_gptq.py

Lines changed: 58 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ def test_quantized_layers_class(self):
178178
group_size=self.group_size,
179179
bits=self.bits,
180180
disable_exllama=self.disable_exllama,
181+
disable_exllamav2=True,
181182
)
182183
self.assertTrue(self.quantized_model.transformer.h[0].mlp.dense_4h_to_h.__class__ == QuantLinear)
183184

@@ -281,8 +282,7 @@ def setUpClass(cls):
281282
"""
282283
Setup quantized model
283284
"""
284-
285-
cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028)
285+
cls.quantization_config = GPTQConfig(bits=4, max_input_length=4028)
286286
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
287287
cls.model_name,
288288
revision=cls.revision,
@@ -334,6 +334,62 @@ def test_max_input_length(self):
334334
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
335335

336336

337+
@slow
338+
@require_optimum
339+
@require_auto_gptq
340+
@require_torch_gpu
341+
@require_accelerate
342+
class GPTQTestExllamaV2(unittest.TestCase):
343+
"""
344+
Test GPTQ model with exllamav2 kernel and desc_act=True (also known as act-order).
345+
More information on those arguments here:
346+
https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
347+
"""
348+
349+
EXPECTED_OUTPUTS = set()
350+
EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year")
351+
model_name = "hf-internal-testing/Llama-2-7B-GPTQ"
352+
revision = "gptq-4bit-128g-actorder_True"
353+
input_text = "Hello my name is"
354+
355+
@classmethod
356+
def setUpClass(cls):
357+
"""
358+
Setup quantized model
359+
"""
360+
cls.quantization_config = GPTQConfig(bits=4, use_exllama_v2=True)
361+
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
362+
cls.model_name,
363+
revision=cls.revision,
364+
torch_dtype=torch.float16,
365+
device_map={"": 0},
366+
quantization_config=cls.quantization_config,
367+
)
368+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)
369+
370+
def check_inference_correctness(self, model):
371+
"""
372+
Test the generation quality of the quantized model and see that we are matching the expected output.
373+
Given that we are operating on small numbers + the testing model is relatively small, we might not get
374+
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
375+
"""
376+
377+
# Check that inference pass works on the model
378+
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")
379+
380+
# Check the exactness of the results
381+
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)
382+
383+
# Get the generation
384+
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)
385+
386+
def test_generate_quality(self):
387+
"""
388+
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
389+
"""
390+
self.check_inference_correctness(self.quantized_model)
391+
392+
337393
# fail when run all together
338394
@pytest.mark.skip
339395
@require_accelerate

0 commit comments

Comments
 (0)