|
2 | 2 |
|
3 | 3 | Run `pytest tests/quantization/test_bitsandbytes.py`. |
4 | 4 | ''' |
| 5 | + |
| 6 | +import gc |
| 7 | + |
5 | 8 | import pytest |
6 | 9 | import torch |
7 | 10 |
|
8 | 11 | from tests.quantization.utils import is_quant_method_supported |
9 | | -from vllm import SamplingParams |
10 | 12 |
|
11 | | -models_to_test = [ |
| 13 | +models_4bit_to_test = [ |
12 | 14 | ('huggyllama/llama-7b', 'quantize model inflight'), |
13 | | - ('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'), |
14 | 15 | ] |
15 | 16 |
|
| 17 | +models_pre_qaunt_4bit_to_test = [ |
| 18 | + ('lllyasviel/omost-llama-3-8b-4bits', |
| 19 | + 'read pre-quantized 4-bit NF4 model'), |
| 20 | + ('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed', |
| 21 | + 'read pre-quantized 4-bit FP4 model'), |
| 22 | +] |
| 23 | + |
| 24 | +models_pre_quant_8bit_to_test = [ |
| 25 | + ('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'), |
| 26 | +] |
| 27 | + |
| 28 | + |
| 29 | +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), |
| 30 | + reason='bitsandbytes is not supported on this GPU type.') |
| 31 | +@pytest.mark.parametrize("model_name, description", models_4bit_to_test) |
| 32 | +def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, |
| 33 | + model_name, description) -> None: |
| 34 | + |
| 35 | + hf_model_kwargs = {"load_in_4bit": True} |
| 36 | + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], |
| 37 | + model_name, hf_model_kwargs) |
| 38 | + |
| 39 | + |
| 40 | +@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), |
| 41 | + reason='bitsandbytes is not supported on this GPU type.') |
| 42 | +@pytest.mark.parametrize("model_name, description", |
| 43 | + models_pre_qaunt_4bit_to_test) |
| 44 | +def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts, |
| 45 | + model_name, description) -> None: |
| 46 | + |
| 47 | + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], |
| 48 | + model_name) |
| 49 | + |
16 | 50 |
|
17 | 51 | @pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"), |
18 | 52 | reason='bitsandbytes is not supported on this GPU type.') |
19 | | -@pytest.mark.parametrize("model_name, description", models_to_test) |
20 | | -def test_load_bnb_model(vllm_runner, model_name, description) -> None: |
| 53 | +@pytest.mark.parametrize("model_name, description", |
| 54 | + models_pre_quant_8bit_to_test) |
| 55 | +def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts, |
| 56 | + model_name, description) -> None: |
| 57 | + |
| 58 | + validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1], |
| 59 | + model_name) |
| 60 | + |
| 61 | + |
| 62 | +def log_generated_texts(prompts, outputs, runner_name): |
| 63 | + logged_texts = [] |
| 64 | + for i, (_, generated_text) in enumerate(outputs): |
| 65 | + log_entry = { |
| 66 | + "prompt": prompts[i], |
| 67 | + "runner_name": runner_name, |
| 68 | + "generated_text": generated_text, |
| 69 | + } |
| 70 | + logged_texts.append(log_entry) |
| 71 | + return logged_texts |
| 72 | + |
| 73 | + |
| 74 | +def validate_generated_texts(hf_runner, |
| 75 | + vllm_runner, |
| 76 | + prompts, |
| 77 | + model_name, |
| 78 | + hf_model_kwargs=None): |
| 79 | + |
| 80 | + if hf_model_kwargs is None: |
| 81 | + hf_model_kwargs = {} |
| 82 | + |
| 83 | + # Run with HF runner |
| 84 | + with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm: |
| 85 | + hf_outputs = llm.generate_greedy(prompts, 8) |
| 86 | + hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner") |
| 87 | + |
| 88 | + # Clean up the GPU memory for the next test |
| 89 | + torch.cuda.synchronize() |
| 90 | + gc.collect() |
| 91 | + torch.cuda.empty_cache() |
| 92 | + |
| 93 | + #Run with vLLM runner |
21 | 94 | with vllm_runner(model_name, |
22 | 95 | quantization='bitsandbytes', |
23 | 96 | load_format='bitsandbytes', |
24 | | - enforce_eager=True) as llm: |
25 | | - model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501 |
26 | | - |
27 | | - # check the weights in MLP & SelfAttention are quantized to torch.uint8 |
28 | | - qweight = model.model.layers[0].mlp.gate_up_proj.qweight |
29 | | - assert qweight.dtype == torch.uint8, ( |
30 | | - f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}') |
31 | | - |
32 | | - qweight = model.model.layers[0].mlp.down_proj.qweight |
33 | | - assert qweight.dtype == torch.uint8, ( |
34 | | - f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}') |
35 | | - |
36 | | - qweight = model.model.layers[0].self_attn.o_proj.qweight |
37 | | - assert qweight.dtype == torch.uint8, ( |
38 | | - f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}') |
39 | | - |
40 | | - qweight = model.model.layers[0].self_attn.qkv_proj.qweight |
41 | | - assert qweight.dtype == torch.uint8, ( |
42 | | - f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}') |
43 | | - |
44 | | - # some weights should not be quantized |
45 | | - weight = model.lm_head.weight |
46 | | - assert weight.dtype != torch.uint8, ( |
47 | | - 'lm_head weight dtype should not be torch.uint8') |
48 | | - |
49 | | - weight = model.model.embed_tokens.weight |
50 | | - assert weight.dtype != torch.uint8, ( |
51 | | - 'embed_tokens weight dtype should not be torch.uint8') |
52 | | - |
53 | | - weight = model.model.layers[0].input_layernorm.weight |
54 | | - assert weight.dtype != torch.uint8, ( |
55 | | - 'input_layernorm weight dtype should not be torch.uint8') |
56 | | - |
57 | | - weight = model.model.layers[0].post_attention_layernorm.weight |
58 | | - assert weight.dtype != torch.uint8, ( |
59 | | - 'input_layernorm weight dtype should not be torch.uint8') |
60 | | - |
61 | | - # check the output of the model is expected |
62 | | - sampling_params = SamplingParams(temperature=0.0, |
63 | | - logprobs=1, |
64 | | - prompt_logprobs=1, |
65 | | - max_tokens=8) |
66 | | - |
67 | | - prompts = ['That which does not kill us', 'To be or not to be,'] |
68 | | - expected_outputs = [ |
69 | | - 'That which does not kill us makes us stronger.', |
70 | | - 'To be or not to be, that is the question.' |
71 | | - ] |
72 | | - outputs = llm.generate(prompts, sampling_params=sampling_params) |
73 | | - assert len(outputs) == len(prompts) |
74 | | - |
75 | | - for index in range(len(outputs)): |
76 | | - # compare the first line of the output |
77 | | - actual_output = outputs[index][1][0].split('\n', 1)[0] |
78 | | - expected_output = expected_outputs[index].split('\n', 1)[0] |
79 | | - |
80 | | - assert len(actual_output) >= len(expected_output), ( |
81 | | - f'Actual {actual_output} should be larger than or equal to ' |
82 | | - f'expected {expected_output}') |
83 | | - actual_output = actual_output[:len(expected_output)] |
84 | | - |
85 | | - assert actual_output == expected_output, ( |
86 | | - f'Expected: {expected_output}, but got: {actual_output}') |
| 97 | + enforce_eager=True, |
| 98 | + gpu_memory_utilization=0.8) as llm: |
| 99 | + vllm_outputs = llm.generate_greedy(prompts, 8) |
| 100 | + vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner") |
| 101 | + |
| 102 | + # Clean up the GPU memory for the next test |
| 103 | + torch.cuda.synchronize() |
| 104 | + gc.collect() |
| 105 | + torch.cuda.empty_cache() |
| 106 | + |
| 107 | + # Compare the generated strings |
| 108 | + for hf_log, vllm_log in zip(hf_logs, vllm_logs): |
| 109 | + hf_str = hf_log["generated_text"] |
| 110 | + vllm_str = vllm_log["generated_text"] |
| 111 | + prompt = hf_log["prompt"] |
| 112 | + assert hf_str == vllm_str, (f"Model: {model_name}" |
| 113 | + f"Mismatch between HF and vLLM outputs:\n" |
| 114 | + f"Prompt: {prompt}\n" |
| 115 | + f"HF Output: '{hf_str}'\n" |
| 116 | + f"vLLM Output: '{vllm_str}'") |
0 commit comments