Skip to content

Commit 8a8b30e

Browse files
varun-sundar-rabindranathVarun Sundar Rabindranath
andauthored
[Bugfix] LoRA V0 - Fix case where max_num_seqs is between cudagraph capture sizes (#15308)
Signed-off-by: Varun Sundar Rabindranath <[email protected]> Co-authored-by: Varun Sundar Rabindranath <[email protected]>
1 parent 2fa0e13 commit 8a8b30e

File tree

2 files changed

+17
-7
lines changed

2 files changed

+17
-7
lines changed

tests/lora/test_llama_tp.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -84,12 +84,14 @@ def v1(run_with_both_engines_lora):
8484
@create_new_process_for_each_test()
8585
def test_llama_lora(sql_lora_files):
8686

87-
llm = vllm.LLM(MODEL_PATH,
88-
enable_lora=True,
89-
max_num_seqs=16,
90-
max_loras=4,
91-
tensor_parallel_size=1,
92-
enable_chunked_prefill=True)
87+
llm = vllm.LLM(
88+
MODEL_PATH,
89+
enable_lora=True,
90+
# also test odd max_num_seqs
91+
max_num_seqs=13,
92+
max_loras=4,
93+
tensor_parallel_size=1,
94+
enable_chunked_prefill=True)
9395
generate_and_test(llm, sql_lora_files)
9496

9597

vllm/lora/punica_wrapper/punica_gpu.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
import torch
1212

13+
import vllm.envs as envs
1314
from vllm.lora.layers import LoRAMapping
1415
from vllm.triton_utils import HAS_TRITON
1516

@@ -42,8 +43,15 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int,
4243
self.token_mapping_meta = LoRAKernelMeta.make(self.max_loras,
4344
max_num_batched_tokens,
4445
device=device)
46+
47+
# When cudagraph capture size is greater than max_num_seqs (max_batches,
48+
# here), V0 captures the graph as if max_num_seqs is set to
49+
# the capture size.
50+
# V1 doesn't have this problem and always respects max_num_seqs.
51+
max_num_prompts = (max_batches
52+
if envs.VLLM_USE_V1 else max_num_batched_tokens)
4553
self.prompt_mapping_meta = LoRAKernelMeta.make(self.max_loras,
46-
max_batches,
54+
max_num_prompts,
4755
device=device)
4856

4957
def update_metadata(

0 commit comments

Comments
 (0)