Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2546,7 +2546,7 @@ def from_pretrained(
logger.warning(
"You passed `quantization_config` to `from_pretrained` but the model you're loading already has a "
"`quantization_config` attribute and has already quantized weights. However, loading attributes"
" (e.g. disable_exllama, use_cuda_fp16) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
" (e.g. disable_exllama, use_cuda_fp16, max_input_length) will be overwritten with the one you passed to `from_pretrained`. The rest will be ignored."
)
if (
quantization_method_from_args == QuantizationMethod.GPTQ
Expand All @@ -2556,7 +2556,11 @@ def from_pretrained(
raise RuntimeError("GPU is required to quantize or run quantize model.")
elif not (is_optimum_available() and is_auto_gptq_available()):
raise ImportError(
"Loading GPTQ quantized model requires optimum library : `pip install optimum` and auto-gptq library 'pip install auto-gptq'"
"Loading a GPTQ quantized model requires optimum (`pip install optimum`) and auto-gptq library (`pip install auto-gptq`)"
)
elif version.parse(importlib.metadata.version("auto_gptq")) < version.parse("0.4.2"):
raise ImportError(
"You need a version of auto_gptq >= 0.4.2 to use GPTQ: `pip install --upgrade auto-gptq`"
)
else:
# Need to protect the import
Expand Down
7 changes: 6 additions & 1 deletion src/transformers/utils/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,9 @@ class GPTQConfig(QuantizationConfigMixin):
The pad token id. Needed to prepare the dataset when `batch_size` > 1.
disable_exllama (`bool`, *optional*, defaults to `False`):
Whether to use exllama backend. Only works with `bits` = 4.
max_input_length (`int`, *optional*)
The maximum input length. This is needed to initialize a buffer that depends on the maximum expected input
length. It is specific to the exllama backend with act-order.
"""

def __init__(
Expand All @@ -365,6 +368,7 @@ def __init__(
batch_size: int = 1,
pad_token_id: Optional[int] = None,
disable_exllama: bool = False,
max_input_length: Optional[int] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.GPTQ
Expand All @@ -383,11 +387,12 @@ def __init__(
self.batch_size = batch_size
self.pad_token_id = pad_token_id
self.disable_exllama = disable_exllama
self.max_input_length = max_input_length
self.post_init()

def get_loading_attributes(self):
attibutes_dict = copy.deepcopy(self.__dict__)
loading_attibutes = ["disable_exllama", "use_cuda_fp16"]
loading_attibutes = ["disable_exllama", "use_cuda_fp16", "max_input_length"]
loading_attibutes_dict = {i: j for i, j in attibutes_dict.items() if i in loading_attibutes}
return loading_attibutes_dict

Expand Down
78 changes: 78 additions & 0 deletions tests/quantization/gptq/test_gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ class GPTQTest(unittest.TestCase):

EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is John and I am a professional photographer. I")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a professional photographer and I")
EXPECTED_OUTPUTS.add("Hello my name is John, I am a student in the University of")
EXPECTED_OUTPUTS.add("Hello my name is John and I am a very good looking man.")
EXPECTED_OUTPUTS.add("Hello my name is Alyson, I am a student in the")
EXPECTED_OUTPUTS.add("Hello my name is Alyson and I am a very sweet,")
Expand Down Expand Up @@ -236,6 +238,82 @@ class GPTQTestDeviceMapExllama(GPTQTest):
disable_exllama = False


@slow
@require_optimum
@require_auto_gptq
@require_torch_gpu
@require_accelerate
class GPTQTestActOrderExllama(unittest.TestCase):
"""
Test GPTQ model with exllama kernel and desc_act=True (also known as act-order).
More information on those arguments here:
https://huggingface.co/docs/transformers/main_classes/quantization#transformers.GPTQConfig
"""

EXPECTED_OUTPUTS = set()
EXPECTED_OUTPUTS.add("Hello my name is Katie and I am a 20 year")
model_name = "hf-internal-testing/Llama-2-7B-GPTQ"
revision = "gptq-4bit-128g-actorder_True"
input_text = "Hello my name is"

@classmethod
def setUpClass(cls):
"""
Setup quantized model
"""

cls.quantization_config = GPTQConfig(bits=4, disable_exllama=False, max_input_length=4028)
cls.quantized_model = AutoModelForCausalLM.from_pretrained(
cls.model_name,
revision=cls.revision,
torch_dtype=torch.float16,
device_map={"": 0},
quantization_config=cls.quantization_config,
)
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name, use_fast=True)

def check_inference_correctness(self, model):
"""
Test the generation quality of the quantized model and see that we are matching the expected output.
Given that we are operating on small numbers + the testing model is relatively small, we might not get
the same output across GPUs. So we'll generate few tokens (5-10) and check their output.
"""

# Check that inference pass works on the model
encoded_input = self.tokenizer(self.input_text, return_tensors="pt")

# Check the exactness of the results
output_sequences = model.generate(input_ids=encoded_input["input_ids"].to(0), max_new_tokens=10)

# Get the generation
self.assertIn(self.tokenizer.decode(output_sequences[0], skip_special_tokens=True), self.EXPECTED_OUTPUTS)

def test_generate_quality(self):
"""
Simple test to check the quality of the model by comapring the the generated tokens with the expected tokens
"""
self.check_inference_correctness(self.quantized_model)

# this test will fail until the next release of optimum
@pytest.mark.skip
def test_max_input_length(self):
"""
Test if the max_input_length works. It modifies the maximum input length that of the model that runs with exllama backend.
"""

prompt = "I am in Paris and" * 1000
inp = self.tokenizer(prompt, return_tensors="pt").to(0)
self.assertTrue(inp["input_ids"].shape[1] > 4028)
with self.assertRaises(RuntimeError) as cm:
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)
self.assertTrue("temp_state buffer is too small" in str(cm.exception))

prompt = "I am in Paris and" * 500
inp = self.tokenizer(prompt, return_tensors="pt").to(0)
self.assertTrue(inp["input_ids"].shape[1] < 4028)
self.quantized_model.generate(**inp, num_beams=1, min_new_tokens=3, max_new_tokens=3)


# fail when run all together
@pytest.mark.skip
@require_accelerate
Expand Down