Skip to content

Commit 6afe796

Browse files
chenqianfzhAlvant
authored andcommitted
support bitsandbytes 8-bit and FP4 quantized models (vllm-project#7445)
Signed-off-by: Alvant <[email protected]>
1 parent ae39746 commit 6afe796

File tree

6 files changed

+437
-191
lines changed

6 files changed

+437
-191
lines changed

tests/conftest.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,14 @@ class HfRunner:
209209

210210
def wrap_device(self, input: _T) -> _T:
211211
if not is_cpu():
212+
# Check if the input is already on the GPU
213+
if hasattr(input, 'device') and input.device.type == "cuda":
214+
return input # Already on GPU, no need to move
212215
return input.to("cuda")
213216
else:
217+
# Check if the input is already on the CPU
218+
if hasattr(input, 'device') and input.device.type == "cpu":
219+
return input # Already on CPU, no need to move
214220
return input.to("cpu")
215221

216222
def __init__(

tests/quantization/test_bitsandbytes.py

Lines changed: 98 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -2,85 +2,115 @@
22
33
Run `pytest tests/quantization/test_bitsandbytes.py`.
44
'''
5+
6+
import gc
7+
58
import pytest
69
import torch
710

811
from tests.quantization.utils import is_quant_method_supported
9-
from vllm import SamplingParams
1012

11-
models_to_test = [
13+
models_4bit_to_test = [
1214
('huggyllama/llama-7b', 'quantize model inflight'),
13-
('lllyasviel/omost-llama-3-8b-4bits', 'read pre-quantized model'),
1415
]
1516

17+
models_pre_qaunt_4bit_to_test = [
18+
('lllyasviel/omost-llama-3-8b-4bits',
19+
'read pre-quantized 4-bit NF4 model'),
20+
('PrunaAI/Einstein-v6.1-Llama3-8B-bnb-4bit-smashed',
21+
'read pre-quantized 4-bit FP4 model'),
22+
]
23+
24+
models_pre_quant_8bit_to_test = [
25+
('meta-llama/Llama-Guard-3-8B-INT8', 'read pre-quantized 8-bit model'),
26+
]
27+
28+
29+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
30+
reason='bitsandbytes is not supported on this GPU type.')
31+
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
32+
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
33+
model_name, description) -> None:
34+
35+
hf_model_kwargs = {"load_in_4bit": True}
36+
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
37+
model_name, hf_model_kwargs)
38+
39+
40+
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
41+
reason='bitsandbytes is not supported on this GPU type.')
42+
@pytest.mark.parametrize("model_name, description",
43+
models_pre_qaunt_4bit_to_test)
44+
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
45+
model_name, description) -> None:
46+
47+
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
48+
model_name)
49+
1650

1751
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
1852
reason='bitsandbytes is not supported on this GPU type.')
19-
@pytest.mark.parametrize("model_name, description", models_to_test)
20-
def test_load_bnb_model(vllm_runner, model_name, description) -> None:
53+
@pytest.mark.parametrize("model_name, description",
54+
models_pre_quant_8bit_to_test)
55+
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
56+
model_name, description) -> None:
57+
58+
validate_generated_texts(hf_runner, vllm_runner, example_prompts[:1],
59+
model_name)
60+
61+
62+
def log_generated_texts(prompts, outputs, runner_name):
63+
logged_texts = []
64+
for i, (_, generated_text) in enumerate(outputs):
65+
log_entry = {
66+
"prompt": prompts[i],
67+
"runner_name": runner_name,
68+
"generated_text": generated_text,
69+
}
70+
logged_texts.append(log_entry)
71+
return logged_texts
72+
73+
74+
def validate_generated_texts(hf_runner,
75+
vllm_runner,
76+
prompts,
77+
model_name,
78+
hf_model_kwargs=None):
79+
80+
if hf_model_kwargs is None:
81+
hf_model_kwargs = {}
82+
83+
# Run with HF runner
84+
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
85+
hf_outputs = llm.generate_greedy(prompts, 8)
86+
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
87+
88+
# Clean up the GPU memory for the next test
89+
torch.cuda.synchronize()
90+
gc.collect()
91+
torch.cuda.empty_cache()
92+
93+
#Run with vLLM runner
2194
with vllm_runner(model_name,
2295
quantization='bitsandbytes',
2396
load_format='bitsandbytes',
24-
enforce_eager=True) as llm:
25-
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model # noqa: E501
26-
27-
# check the weights in MLP & SelfAttention are quantized to torch.uint8
28-
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
29-
assert qweight.dtype == torch.uint8, (
30-
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
31-
32-
qweight = model.model.layers[0].mlp.down_proj.qweight
33-
assert qweight.dtype == torch.uint8, (
34-
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
35-
36-
qweight = model.model.layers[0].self_attn.o_proj.qweight
37-
assert qweight.dtype == torch.uint8, (
38-
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
39-
40-
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
41-
assert qweight.dtype == torch.uint8, (
42-
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
43-
44-
# some weights should not be quantized
45-
weight = model.lm_head.weight
46-
assert weight.dtype != torch.uint8, (
47-
'lm_head weight dtype should not be torch.uint8')
48-
49-
weight = model.model.embed_tokens.weight
50-
assert weight.dtype != torch.uint8, (
51-
'embed_tokens weight dtype should not be torch.uint8')
52-
53-
weight = model.model.layers[0].input_layernorm.weight
54-
assert weight.dtype != torch.uint8, (
55-
'input_layernorm weight dtype should not be torch.uint8')
56-
57-
weight = model.model.layers[0].post_attention_layernorm.weight
58-
assert weight.dtype != torch.uint8, (
59-
'input_layernorm weight dtype should not be torch.uint8')
60-
61-
# check the output of the model is expected
62-
sampling_params = SamplingParams(temperature=0.0,
63-
logprobs=1,
64-
prompt_logprobs=1,
65-
max_tokens=8)
66-
67-
prompts = ['That which does not kill us', 'To be or not to be,']
68-
expected_outputs = [
69-
'That which does not kill us makes us stronger.',
70-
'To be or not to be, that is the question.'
71-
]
72-
outputs = llm.generate(prompts, sampling_params=sampling_params)
73-
assert len(outputs) == len(prompts)
74-
75-
for index in range(len(outputs)):
76-
# compare the first line of the output
77-
actual_output = outputs[index][1][0].split('\n', 1)[0]
78-
expected_output = expected_outputs[index].split('\n', 1)[0]
79-
80-
assert len(actual_output) >= len(expected_output), (
81-
f'Actual {actual_output} should be larger than or equal to '
82-
f'expected {expected_output}')
83-
actual_output = actual_output[:len(expected_output)]
84-
85-
assert actual_output == expected_output, (
86-
f'Expected: {expected_output}, but got: {actual_output}')
97+
enforce_eager=True,
98+
gpu_memory_utilization=0.8) as llm:
99+
vllm_outputs = llm.generate_greedy(prompts, 8)
100+
vllm_logs = log_generated_texts(prompts, vllm_outputs, "VllmRunner")
101+
102+
# Clean up the GPU memory for the next test
103+
torch.cuda.synchronize()
104+
gc.collect()
105+
torch.cuda.empty_cache()
106+
107+
# Compare the generated strings
108+
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
109+
hf_str = hf_log["generated_text"]
110+
vllm_str = vllm_log["generated_text"]
111+
prompt = hf_log["prompt"]
112+
assert hf_str == vllm_str, (f"Model: {model_name}"
113+
f"Mismatch between HF and vLLM outputs:\n"
114+
f"Prompt: {prompt}\n"
115+
f"HF Output: '{hf_str}'\n"
116+
f"vLLM Output: '{vllm_str}'")

vllm/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,8 @@ def verify_with_parallel_config(
405405
raise ValueError(
406406
"BitAndBytes quantization with TP or PP is not supported yet.")
407407

408+
# Remove the constraint after the bitsandbytes issue is fixed:
409+
# https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1308
408410
if self.quantization == "bitsandbytes" and self.enforce_eager is False:
409411
logger.warning("CUDA graph is not supported on BitAndBytes yet, "
410412
"fallback to the eager mode.")

vllm/model_executor/layers/linear.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,9 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
3636
return shard_size * marlin_tile_size, shard_offset * marlin_tile_size
3737

3838

39-
def adjust_bitsandbytes_shard(param: Parameter,
40-
qkv_offsets: Dict[str, Tuple[int, int]],
41-
loaded_shard_id: str) -> Tuple[int, int]:
39+
def adjust_bitsandbytes_4bit_shard(param: Parameter,
40+
qkv_offsets: Dict[str, Tuple[int, int]],
41+
loaded_shard_id: str) -> Tuple[int, int]:
4242
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
4343

4444
total, _ = qkv_offsets["total"]
@@ -505,8 +505,9 @@ def weight_loader(self,
505505
shard_size, shard_offset = adjust_marlin_shard(
506506
param, shard_size, shard_offset)
507507

508-
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
509-
if use_bitsandbytes:
508+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
509+
False)
510+
if use_bitsandbytes_4bit:
510511
shard_size = loaded_weight.shape[output_dim]
511512
shard_offset = loaded_weight.shape[output_dim] * \
512513
loaded_shard_id
@@ -858,8 +859,9 @@ def weight_loader(self,
858859
shard_size, shard_offset = adjust_marlin_shard(
859860
param, shard_size, shard_offset)
860861

861-
use_bitsandbytes = getattr(param, "use_bitsandbytes", False)
862-
if use_bitsandbytes:
862+
use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit",
863+
False)
864+
if use_bitsandbytes_4bit:
863865
orig_qkv_offsets = {
864866
"q": (0, self.num_heads * self.head_size),
865867
"k": (self.num_heads * self.head_size,
@@ -871,7 +873,7 @@ def weight_loader(self,
871873
((self.num_heads + 2 * self.num_kv_heads) * self.head_size,
872874
0)
873875
}
874-
shard_size, shard_offset = adjust_bitsandbytes_shard(
876+
shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
875877
param, orig_qkv_offsets, loaded_shard_id)
876878

877879
if is_gguf_weight:

0 commit comments

Comments
 (0)