Skip to content

Commit ae18da2

Browse files
committed
lora + chunked prefill
1 parent fc6c274 commit ae18da2

File tree

12 files changed

+57
-24
lines changed

12 files changed

+57
-24
lines changed

tests/lora/test_chatglm3.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List
22

3+
import pytest
4+
35
import vllm
46
from vllm.lora.request import LoRARequest
57

@@ -37,13 +39,15 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3739
return generated_texts
3840

3941

40-
def test_chatglm3_lora(chatglm3_lora_files):
42+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
43+
def test_chatglm3_lora(chatglm3_lora_files, enable_chunked_prefill):
4144
llm = vllm.LLM(MODEL_PATH,
4245
max_model_len=1024,
4346
enable_lora=True,
4447
max_loras=4,
4548
max_lora_rank=64,
46-
trust_remote_code=True)
49+
trust_remote_code=True,
50+
enable_chunked_prefill=enable_chunked_prefill)
4751

4852
expected_lora_output = [
4953
"SELECT count(*) FROM singer",

tests/lora/test_gemma.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,13 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3232

3333

3434
@pytest.mark.xfail(is_hip(), reason="There can be output mismatch on ROCm")
35-
def test_gemma_lora(gemma_lora_files):
35+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
36+
def test_gemma_lora(gemma_lora_files, enable_chunked_prefill):
3637
llm = vllm.LLM(MODEL_PATH,
3738
max_model_len=1024,
3839
enable_lora=True,
39-
max_loras=4)
40+
max_loras=4,
41+
enable_chunked_prefill=enable_chunked_prefill)
4042

4143
expected_lora_output = [
4244
"more important than knowledge.\nAuthor: Albert Einstein\n",

tests/lora/test_llama.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,18 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]:
3838

3939

4040
@pytest.mark.parametrize("tp_size", [1, 2, 4])
41-
def test_llama_lora(sql_lora_files, tp_size, num_gpus_available):
41+
@pytest.mark.parametrize("enable_chunked_prefill", [False, True])
42+
def test_llama_lora(sql_lora_files, tp_size, enable_chunked_prefill,
43+
num_gpus_available):
4244
if num_gpus_available < tp_size:
4345
pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}")
4446

4547
llm = vllm.LLM(MODEL_PATH,
4648
enable_lora=True,
4749
max_num_seqs=16,
4850
max_loras=4,
49-
tensor_parallel_size=tp_size)
51+
tensor_parallel_size=tp_size,
52+
enable_chunked_prefill=enable_chunked_prefill)
5053

5154
expected_no_lora_output = [
5255
"\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501
@@ -88,7 +91,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
8891
enable_lora=True,
8992
max_num_seqs=16,
9093
max_loras=4,
91-
tensor_parallel_size=1)
94+
tensor_parallel_size=1,
95+
enable_chunked_prefill=True)
9296
output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1)
9397

9498
del llm_tp1
@@ -98,7 +102,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
98102
enable_lora=True,
99103
max_num_seqs=16,
100104
max_loras=4,
101-
tensor_parallel_size=2)
105+
tensor_parallel_size=2,
106+
enable_chunked_prefill=True)
102107
output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1)
103108

104109
del llm_tp2
@@ -110,7 +115,8 @@ def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available):
110115
enable_lora=True,
111116
max_num_seqs=16,
112117
max_loras=4,
113-
tensor_parallel_size=4)
118+
tensor_parallel_size=4,
119+
enable_chunked_prefill=True)
114120
output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1)
115121

116122
del llm_tp4
@@ -125,13 +131,18 @@ def test_llama_lora_warmup(sql_lora_files):
125131

126132
@ray.remote(num_gpus=1)
127133
def get_num_gpu_blocks_lora():
128-
llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16)
134+
llm = vllm.LLM(MODEL_PATH,
135+
enable_lora=True,
136+
max_num_seqs=16,
137+
enable_chunked_prefill=True)
129138
num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks
130139
return num_gpu_blocks_lora_warmup
131140

132141
@ray.remote(num_gpus=1)
133142
def get_num_gpu_blocks_no_lora():
134-
llm = vllm.LLM(MODEL_PATH, max_num_seqs=16)
143+
llm = vllm.LLM(MODEL_PATH,
144+
max_num_seqs=16,
145+
enable_chunked_prefill=True)
135146
num_gpu_blocks_no_lora_warmup = (
136147
llm.llm_engine.cache_config.num_gpu_blocks)
137148
return num_gpu_blocks_no_lora_warmup

tests/lora/test_long_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,8 @@ def lora_llm(long_context_infos):
124124
tensor_parallel_size=4,
125125
# FIXME enable async output processor
126126
disable_async_output_proc=True,
127-
distributed_executor_backend="mp")
127+
distributed_executor_backend="mp",
128+
enable_chunked_prefill=True)
128129
yield llm
129130
del llm
130131

tests/lora/test_minicpmv.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,8 @@ def test_minicpmv_lora(minicpmv_lora_files):
6161
max_loras=4,
6262
max_lora_rank=64,
6363
trust_remote_code=True,
64-
gpu_memory_utilization=0.97 # This model is pretty big for CI gpus
64+
gpu_memory_utilization=0.97, # This model is pretty big for CI gpus
65+
enable_chunked_prefill=True,
6566
)
6667

6768
output1 = do_sample(llm, minicpmv_lora_files, lora_id=1)

tests/lora/test_minicpmv_tp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded):
6969
tensor_parallel_size=2,
7070
trust_remote_code=True,
7171
fully_sharded_loras=fully_sharded,
72+
enable_chunked_prefill=True,
7273
)
7374

7475
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
@@ -89,6 +90,7 @@ def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded):
8990
tensor_parallel_size=4,
9091
trust_remote_code=True,
9192
fully_sharded_loras=fully_sharded,
93+
enable_chunked_prefill=True,
9294
)
9395
output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1)
9496
for i in range(len(EXPECTED_OUTPUT)):

tests/lora/test_mixtral.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def test_mixtral_lora(mixtral_lora_files, tp_size):
4747
max_loras=4,
4848
distributed_executor_backend="ray",
4949
tensor_parallel_size=tp_size,
50+
enable_chunked_prefill=True,
5051
)
5152

5253
expected_lora_output = [

tests/lora/test_phi.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ def test_phi2_lora(phi2_lora_files):
5353
max_model_len=1024,
5454
enable_lora=True,
5555
max_loras=2,
56-
enforce_eager=True)
56+
enforce_eager=True,
57+
enable_chunked_prefill=True)
5758

5859
expected_lora_output = [
5960
"SELECT catalog_publisher, COUNT(*) as num_catalogs FROM catalogs GROUP BY catalog_publisher ORDER BY num_catalogs DESC LIMIT 1;", # noqa: E501

tests/lora/test_quant_model.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,8 @@ def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model,
8484
tensor_parallel_size=tp_size,
8585
gpu_memory_utilization=0.2, #avoid OOM
8686
quantization=model.quantization,
87-
trust_remote_code=True)
87+
trust_remote_code=True,
88+
enable_chunked_prefill=True)
8889

8990
if model.quantization is None:
9091
expected_no_lora_output = [
@@ -176,7 +177,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
176177
tensor_parallel_size=1,
177178
gpu_memory_utilization=0.2, #avoid OOM
178179
quantization=model.quantization,
179-
trust_remote_code=True)
180+
trust_remote_code=True,
181+
enable_chunked_prefill=True)
180182
output_tp1 = do_sample(llm_tp1, tinyllama_lora_files, lora_id=1)
181183

182184
del llm_tp1
@@ -189,7 +191,8 @@ def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available,
189191
max_loras=4,
190192
tensor_parallel_size=2,
191193
gpu_memory_utilization=0.2, #avoid OOM
192-
quantization=model.quantization)
194+
quantization=model.quantization,
195+
enable_chunked_prefill=True)
193196
output_tp2 = do_sample(llm_tp2, tinyllama_lora_files, lora_id=1)
194197

195198
del llm_tp2

vllm/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1603,7 +1603,8 @@ def verify_with_scheduler_config(self, scheduler_config: SchedulerConfig):
16031603
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
16041604
# If the feature combo become valid
16051605
if scheduler_config.chunked_prefill_enabled:
1606-
raise ValueError("LoRA is not supported with chunked prefill yet.")
1606+
logger.warning("LoRA with chunked prefill is "
1607+
"experimental and may be unstable.")
16071608

16081609

16091610
@dataclass

0 commit comments

Comments
 (0)