Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit cb6b7a0

Browse files
chenqianfzhRobert Shaw
authored andcommitted
[Feature][Kernel] Support bitsandbytes quantization and QLoRA (vllm-project#4776)
1 parent 5b5c2b9 commit cb6b7a0

File tree

11 files changed

+752
-8
lines changed

11 files changed

+752
-8
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
This example shows how to use LoRA with different quantization techniques
3+
for offline inference.
4+
5+
Requires HuggingFace credentials for access.
6+
"""
7+
8+
import gc
9+
from typing import List, Optional, Tuple
10+
11+
import torch
12+
from huggingface_hub import snapshot_download
13+
14+
from vllm import EngineArgs, LLMEngine, RequestOutput, SamplingParams
15+
from vllm.lora.request import LoRARequest
16+
17+
18+
def create_test_prompts(
19+
lora_path: str
20+
) -> List[Tuple[str, SamplingParams, Optional[LoRARequest]]]:
21+
return [
22+
# this is an example of using quantization without LoRA
23+
("My name is",
24+
SamplingParams(temperature=0.0,
25+
logprobs=1,
26+
prompt_logprobs=1,
27+
max_tokens=128), None),
28+
# the next three examples use quantization with LoRA
29+
("my name is",
30+
SamplingParams(temperature=0.0,
31+
logprobs=1,
32+
prompt_logprobs=1,
33+
max_tokens=128),
34+
LoRARequest("lora-test-1", 1, lora_path)),
35+
("The capital of USA is",
36+
SamplingParams(temperature=0.0,
37+
logprobs=1,
38+
prompt_logprobs=1,
39+
max_tokens=128),
40+
LoRARequest("lora-test-2", 1, lora_path)),
41+
("The capital of France is",
42+
SamplingParams(temperature=0.0,
43+
logprobs=1,
44+
prompt_logprobs=1,
45+
max_tokens=128),
46+
LoRARequest("lora-test-3", 1, lora_path)),
47+
]
48+
49+
50+
def process_requests(engine: LLMEngine,
51+
test_prompts: List[Tuple[str, SamplingParams,
52+
Optional[LoRARequest]]]):
53+
"""Continuously process a list of prompts and handle the outputs."""
54+
request_id = 0
55+
56+
while test_prompts or engine.has_unfinished_requests():
57+
if test_prompts:
58+
prompt, sampling_params, lora_request = test_prompts.pop(0)
59+
engine.add_request(str(request_id),
60+
prompt,
61+
sampling_params,
62+
lora_request=lora_request)
63+
request_id += 1
64+
65+
request_outputs: List[RequestOutput] = engine.step()
66+
for request_output in request_outputs:
67+
if request_output.finished:
68+
print("----------------------------------------------------")
69+
print(f"Prompt: {request_output.prompt}")
70+
print(f"Output: {request_output.outputs[0].text}")
71+
72+
73+
def initialize_engine(model: str, quantization: str,
74+
lora_repo: Optional[str]) -> LLMEngine:
75+
"""Initialize the LLMEngine."""
76+
77+
if quantization == "bitsandbytes":
78+
# QLoRA (https://arxiv.org/abs/2305.14314) is a quantization technique.
79+
# It quantizes the model when loading, with some config info from the
80+
# LoRA adapter repo. So need to set the parameter of load_format and
81+
# qlora_adapter_name_or_path as below.
82+
engine_args = EngineArgs(
83+
model=model,
84+
quantization=quantization,
85+
qlora_adapter_name_or_path=lora_repo,
86+
load_format="bitsandbytes",
87+
enable_lora=True,
88+
max_lora_rank=64,
89+
# set it only in GPUs of limited memory
90+
enforce_eager=True)
91+
else:
92+
engine_args = EngineArgs(
93+
model=model,
94+
quantization=quantization,
95+
enable_lora=True,
96+
max_loras=4,
97+
# set it only in GPUs of limited memory
98+
enforce_eager=True)
99+
return LLMEngine.from_engine_args(engine_args)
100+
101+
102+
def main():
103+
"""Main function that sets up and runs the prompt processing."""
104+
105+
test_configs = [{
106+
"name": "qlora_inference_example",
107+
'model': "huggyllama/llama-7b",
108+
'quantization': "bitsandbytes",
109+
'lora_repo': 'timdettmers/qlora-flan-7b'
110+
}, {
111+
"name": "AWQ_inference_with_lora_example",
112+
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ',
113+
'quantization': "awq",
114+
'lora_repo': 'jashing/tinyllama-colorist-lora'
115+
}, {
116+
"name": "GPTQ_inference_with_lora_example",
117+
'model': 'TheBloke/TinyLlama-1.1B-Chat-v0.3-GPTQ',
118+
'quantization': "gptq",
119+
'lora_repo': 'jashing/tinyllama-colorist-lora'
120+
}]
121+
122+
for test_config in test_configs:
123+
print(
124+
f"~~~~~~~~~~~~~~~~ Running: {test_config['name']} ~~~~~~~~~~~~~~~~"
125+
)
126+
engine = initialize_engine(test_config['model'],
127+
test_config['quantization'],
128+
test_config['lora_repo'])
129+
lora_path = snapshot_download(repo_id=test_config['lora_repo'])
130+
test_prompts = create_test_prompts(lora_path)
131+
process_requests(engine, test_prompts)
132+
133+
# Clean up the GPU memory for the next test
134+
del engine
135+
gc.collect()
136+
torch.cuda.empty_cache()
137+
138+
139+
if __name__ == '__main__':
140+
main()

requirements-dev.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,6 @@ aiohttp
3737

3838
# Multimodal
3939
pillow
40+
41+
# quantization
42+
bitsandbytes==0.42.0
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
'''Tests whether bitsandbytes computation is enabled correctly.
2+
3+
Run `pytest tests/quantization/test_bitsandbytes.py`.
4+
'''
5+
import pytest
6+
import torch
7+
8+
from vllm import SamplingParams
9+
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
10+
11+
capability = torch.cuda.get_device_capability()
12+
capability = capability[0] * 10 + capability[1]
13+
14+
15+
@pytest.mark.skipif(
16+
capability < QUANTIZATION_METHODS['bitsandbytes'].get_min_capability(),
17+
reason='bitsandbytes is not supported on this GPU type.')
18+
def test_load_bnb_model(vllm_runner) -> None:
19+
llm = vllm_runner('huggyllama/llama-7b',
20+
quantization='bitsandbytes',
21+
load_format='bitsandbytes',
22+
enforce_eager=True)
23+
24+
model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model
25+
26+
# check the weights in MLP & SelfAttention are quantized to torch.uint8
27+
qweight = model.model.layers[0].mlp.gate_up_proj.qweight
28+
assert qweight.dtype == torch.uint8, (
29+
f'Expected gate_up_proj dtype torch.uint8 but got {qweight.dtype}')
30+
31+
qweight = model.model.layers[0].mlp.down_proj.qweight
32+
assert qweight.dtype == torch.uint8, (
33+
f'Expected down_proj dtype torch.uint8 but got {qweight.dtype}')
34+
35+
qweight = model.model.layers[0].self_attn.o_proj.qweight
36+
assert qweight.dtype == torch.uint8, (
37+
f'Expected o_proj dtype torch.uint8 but got {qweight.dtype}')
38+
39+
qweight = model.model.layers[0].self_attn.qkv_proj.qweight
40+
assert qweight.dtype == torch.uint8, (
41+
f'Expected qkv_proj dtype torch.uint8 but got {qweight.dtype}')
42+
43+
# some weights should not be quantized
44+
weight = model.lm_head.weight
45+
assert weight.dtype != torch.uint8, (
46+
'lm_head weight dtype should not be torch.uint8')
47+
48+
weight = model.model.embed_tokens.weight
49+
assert weight.dtype != torch.uint8, (
50+
'embed_tokens weight dtype should not be torch.uint8')
51+
52+
weight = model.model.layers[0].input_layernorm.weight
53+
assert weight.dtype != torch.uint8, (
54+
'input_layernorm weight dtype should not be torch.uint8')
55+
56+
weight = model.model.layers[0].post_attention_layernorm.weight
57+
assert weight.dtype != torch.uint8, (
58+
'input_layernorm weight dtype should not be torch.uint8')
59+
60+
# check the output of the model is expected
61+
sampling_params = SamplingParams(temperature=0.0,
62+
logprobs=1,
63+
prompt_logprobs=1,
64+
max_tokens=8)
65+
66+
prompts = ['That which does not kill us', 'To be or not to be,']
67+
expected_outputs = [
68+
'That which does not kill us makes us stronger.',
69+
'To be or not to be, that is the question.'
70+
]
71+
outputs = llm.generate(prompts, sampling_params=sampling_params)
72+
73+
assert len(outputs) == len(prompts)
74+
75+
for index in range(len(outputs)):
76+
# compare the first line of the output
77+
actual_output = outputs[index][1][0].split('\n', 1)[0]
78+
expected_output = expected_outputs[index].split('\n', 1)[0]
79+
assert actual_output == expected_output, (
80+
f'Expected: {expected_output}, but got: {actual_output}')

vllm/config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,12 @@ def verify_with_parallel_config(
273273
"must be divisible by pipeline parallel size "
274274
f"({pipeline_parallel_size}).")
275275

276+
if self.quantization == "bitsandbytes" and (
277+
parallel_config.tensor_parallel_size > 1
278+
or parallel_config.pipeline_parallel_size > 1):
279+
raise ValueError(
280+
"BitAndBytes quantization with TP or PP is not supported yet.")
281+
276282
def get_hf_config_sliding_window(self) -> Optional[int]:
277283
"""Get the sliding window size, or None if disabled.
278284
"""
@@ -359,7 +365,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int:
359365
def get_num_attention_heads(self,
360366
parallel_config: "ParallelConfig") -> int:
361367
return self.hf_text_config.num_attention_heads // \
362-
parallel_config.tensor_parallel_size
368+
parallel_config.tensor_parallel_size
363369

364370
def get_num_layers(self, parallel_config: "ParallelConfig") -> int:
365371
total_num_hidden_layers = self.hf_text_config.num_hidden_layers
@@ -519,6 +525,7 @@ class LoadFormat(str, enum.Enum):
519525
DUMMY = "dummy"
520526
TENSORIZER = "tensorizer"
521527
SHARDED_STATE = "sharded_state"
528+
BITSANDBYTES = "bitsandbytes"
522529

523530

524531
@dataclass

vllm/engine/arg_utils.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class EngineArgs:
9696
ngram_prompt_lookup_max: Optional[int] = None
9797
ngram_prompt_lookup_min: Optional[int] = None
9898

99+
qlora_adapter_name_or_path: Optional[str] = None
100+
99101
def __post_init__(self):
100102
if self.tokenizer is None:
101103
self.tokenizer = self.model
@@ -163,7 +165,8 @@ def add_cli_args(
163165
type=str,
164166
default=EngineArgs.load_format,
165167
choices=[
166-
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer'
168+
'auto', 'pt', 'safetensors', 'npcache', 'dummy', 'tensorizer',
169+
'bitsandbytes'
167170
],
168171
help='The format of the model weights to load.\n\n'
169172
'* "auto" will try to load the weights in the safetensors format '
@@ -177,7 +180,9 @@ def add_cli_args(
177180
'which is mainly for profiling.\n'
178181
'* "tensorizer" will load the weights using tensorizer from '
179182
'CoreWeave. See the Tensorize vLLM Model script in the Examples'
180-
'section for more information.\n')
183+
'section for more information.\n'
184+
'* "bitsandbytes" will load the weights using bitsandbytes '
185+
'quantization.\n')
181186
parser.add_argument(
182187
'--dtype',
183188
type=str,
@@ -558,7 +563,10 @@ def add_cli_args(
558563
"will also be used in `model_name` tag content of "
559564
"prometheus metrics, if multiple names provided, metrics"
560565
"tag will take the first one.")
561-
566+
parser.add_argument('--qlora-adapter-name-or-path',
567+
type=str,
568+
default=None,
569+
help='Name or path of the QLoRA adapter.')
562570
return parser
563571

564572
@classmethod
@@ -570,6 +578,23 @@ def from_cli_args(cls, args: argparse.Namespace):
570578
return engine_args
571579

572580
def create_engine_config(self, ) -> EngineConfig:
581+
582+
# bitsandbytes quantization needs a specific model loader
583+
# so we make sure the quant method and the load format are consistent
584+
if (self.quantization == "bitsandbytes" or
585+
self.qlora_adapter_name_or_path is not None) and \
586+
self.load_format != "bitsandbytes":
587+
raise ValueError(
588+
"BitsAndBytes quantization and QLoRA adapter only support "
589+
f"'bitsandbytes' load format, but got {self.load_format}")
590+
591+
if (self.load_format == "bitsandbytes" or
592+
self.qlora_adapter_name_or_path is not None) and \
593+
self.quantization != "bitsandbytes":
594+
raise ValueError(
595+
"BitsAndBytes load format and QLoRA adapter only support "
596+
f"'bitsandbytes' quantization, but got {self.quantization}")
597+
573598
device_config = DeviceConfig(self.device)
574599
model_config = ModelConfig(
575600
self.model, self.tokenizer, self.tokenizer_mode,
@@ -637,6 +662,13 @@ def create_engine_config(self, ) -> EngineConfig:
637662
max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras
638663
and self.max_cpu_loras > 0 else None) if self.enable_lora else None
639664

665+
if self.qlora_adapter_name_or_path is not None and \
666+
self.qlora_adapter_name_or_path != "":
667+
if self.model_loader_extra_config is None:
668+
self.model_loader_extra_config = {}
669+
self.model_loader_extra_config[
670+
"qlora_adapter_name_or_path"] = self.qlora_adapter_name_or_path
671+
640672
load_config = LoadConfig(
641673
load_format=self.load_format,
642674
download_dir=self.download_dir,

0 commit comments

Comments
 (0)