Skip to content

Commit ccc16cc

Browse files
Jeffwanjeejeelee
authored andcommitted
[Core] Support load and unload LoRA in api server (vllm-project#6566)
Co-authored-by: Jee Jee Li <[email protected]> Signed-off-by: Alvant <[email protected]>
1 parent b8aaf26 commit ccc16cc

File tree

10 files changed

+336
-6
lines changed

10 files changed

+336
-6
lines changed

docs/requirements-docs.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,5 @@ pydantic >= 2.8
1111
torch
1212
py-cpuinfo
1313
transformers
14-
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args
1514
mistral_common >= 1.3.4
1615
openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args

docs/source/models/lora.rst

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,3 +107,55 @@ The following is an example request
107107
"max_tokens": 7,
108108
"temperature": 0
109109
}' | jq
110+
111+
112+
Dynamically serving LoRA Adapters
113+
---------------------------------
114+
115+
In addition to serving LoRA adapters at server startup, the vLLM server now supports dynamically loading and unloading
116+
LoRA adapters at runtime through dedicated API endpoints. This feature can be particularly useful when the flexibility
117+
to change models on-the-fly is needed.
118+
119+
Note: Enabling this feature in production environments is risky as user may participate model adapter management.
120+
121+
To enable dynamic LoRA loading and unloading, ensure that the environment variable `VLLM_ALLOW_RUNTIME_LORA_UPDATING`
122+
is set to `True`. When this option is enabled, the API server will log a warning to indicate that dynamic loading is active.
123+
124+
.. code-block:: bash
125+
126+
export VLLM_ALLOW_RUNTIME_LORA_UPDATING=True
127+
128+
129+
Loading a LoRA Adapter:
130+
131+
To dynamically load a LoRA adapter, send a POST request to the `/v1/load_lora_adapter` endpoint with the necessary
132+
details of the adapter to be loaded. The request payload should include the name and path to the LoRA adapter.
133+
134+
Example request to load a LoRA adapter:
135+
136+
.. code-block:: bash
137+
138+
curl -X POST http://localhost:8000/v1/load_lora_adapter \
139+
-H "Content-Type: application/json" \
140+
-d '{
141+
"lora_name": "sql_adapter",
142+
"lora_path": "/path/to/sql-lora-adapter"
143+
}'
144+
145+
Upon a successful request, the API will respond with a 200 OK status code. If an error occurs, such as if the adapter
146+
cannot be found or loaded, an appropriate error message will be returned.
147+
148+
Unloading a LoRA Adapter:
149+
150+
To unload a LoRA adapter that has been previously loaded, send a POST request to the `/v1/unload_lora_adapter` endpoint
151+
with the name or ID of the adapter to be unloaded.
152+
153+
Example request to unload a LoRA adapter:
154+
155+
.. code-block:: bash
156+
157+
curl -X POST http://localhost:8000/v1/unload_lora_adapter \
158+
-H "Content-Type: application/json" \
159+
-d '{
160+
"lora_name": "sql_adapter"
161+
}'

tests/entrypoints/llm/test_generate_multiple_loras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def zephyr_lora_files():
5050
@pytest.mark.skip_global_cleanup
5151
def test_multiple_lora_requests(llm: LLM, zephyr_lora_files):
5252
lora_request = [
53-
LoRARequest(LORA_NAME, idx + 1, zephyr_lora_files)
53+
LoRARequest(LORA_NAME + str(idx), idx + 1, zephyr_lora_files)
5454
for idx in range(len(PROMPTS))
5555
]
5656
# Multiple SamplingParams should be matched with each prompt
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
from http import HTTPStatus
2+
from unittest.mock import MagicMock
3+
4+
import pytest
5+
6+
from vllm.config import ModelConfig
7+
from vllm.engine.protocol import AsyncEngineClient
8+
from vllm.entrypoints.openai.protocol import (ErrorResponse,
9+
LoadLoraAdapterRequest,
10+
UnloadLoraAdapterRequest)
11+
from vllm.entrypoints.openai.serving_engine import OpenAIServing
12+
13+
MODEL_NAME = "meta-llama/Llama-2-7b"
14+
LORA_LOADING_SUCCESS_MESSAGE = (
15+
"Success: LoRA adapter '{lora_name}' added successfully.")
16+
LORA_UNLOADING_SUCCESS_MESSAGE = (
17+
"Success: LoRA adapter '{lora_name}' removed successfully.")
18+
19+
20+
async def _async_serving_engine_init():
21+
mock_engine_client = MagicMock(spec=AsyncEngineClient)
22+
mock_model_config = MagicMock(spec=ModelConfig)
23+
# Set the max_model_len attribute to avoid missing attribute
24+
mock_model_config.max_model_len = 2048
25+
26+
serving_engine = OpenAIServing(mock_engine_client,
27+
mock_model_config,
28+
served_model_names=[MODEL_NAME],
29+
lora_modules=None,
30+
prompt_adapters=None,
31+
request_logger=None)
32+
return serving_engine
33+
34+
35+
@pytest.mark.asyncio
36+
async def test_load_lora_adapter_success():
37+
serving_engine = await _async_serving_engine_init()
38+
request = LoadLoraAdapterRequest(lora_name="adapter",
39+
lora_path="/path/to/adapter2")
40+
response = await serving_engine.load_lora_adapter(request)
41+
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter')
42+
assert len(serving_engine.lora_requests) == 1
43+
assert serving_engine.lora_requests[0].lora_name == "adapter"
44+
45+
46+
@pytest.mark.asyncio
47+
async def test_load_lora_adapter_missing_fields():
48+
serving_engine = await _async_serving_engine_init()
49+
request = LoadLoraAdapterRequest(lora_name="", lora_path="")
50+
response = await serving_engine.load_lora_adapter(request)
51+
assert isinstance(response, ErrorResponse)
52+
assert response.type == "InvalidUserInput"
53+
assert response.code == HTTPStatus.BAD_REQUEST
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_load_lora_adapter_duplicate():
58+
serving_engine = await _async_serving_engine_init()
59+
request = LoadLoraAdapterRequest(lora_name="adapter1",
60+
lora_path="/path/to/adapter1")
61+
response = await serving_engine.load_lora_adapter(request)
62+
assert response == LORA_LOADING_SUCCESS_MESSAGE.format(
63+
lora_name='adapter1')
64+
assert len(serving_engine.lora_requests) == 1
65+
66+
request = LoadLoraAdapterRequest(lora_name="adapter1",
67+
lora_path="/path/to/adapter1")
68+
response = await serving_engine.load_lora_adapter(request)
69+
assert isinstance(response, ErrorResponse)
70+
assert response.type == "InvalidUserInput"
71+
assert response.code == HTTPStatus.BAD_REQUEST
72+
assert len(serving_engine.lora_requests) == 1
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_unload_lora_adapter_success():
77+
serving_engine = await _async_serving_engine_init()
78+
request = LoadLoraAdapterRequest(lora_name="adapter1",
79+
lora_path="/path/to/adapter1")
80+
response = await serving_engine.load_lora_adapter(request)
81+
assert len(serving_engine.lora_requests) == 1
82+
83+
request = UnloadLoraAdapterRequest(lora_name="adapter1")
84+
response = await serving_engine.unload_lora_adapter(request)
85+
assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format(
86+
lora_name='adapter1')
87+
assert len(serving_engine.lora_requests) == 0
88+
89+
90+
@pytest.mark.asyncio
91+
async def test_unload_lora_adapter_missing_fields():
92+
serving_engine = await _async_serving_engine_init()
93+
request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None)
94+
response = await serving_engine.unload_lora_adapter(request)
95+
assert isinstance(response, ErrorResponse)
96+
assert response.type == "InvalidUserInput"
97+
assert response.code == HTTPStatus.BAD_REQUEST
98+
99+
100+
@pytest.mark.asyncio
101+
async def test_unload_lora_adapter_not_found():
102+
serving_engine = await _async_serving_engine_init()
103+
request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter")
104+
response = await serving_engine.unload_lora_adapter(request)
105+
assert isinstance(response, ErrorResponse)
106+
assert response.type == "InvalidUserInput"
107+
assert response.code == HTTPStatus.BAD_REQUEST

vllm/entrypoints/openai/api_server.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,13 @@
3535
DetokenizeResponse,
3636
EmbeddingRequest,
3737
EmbeddingResponse, ErrorResponse,
38+
LoadLoraAdapterRequest,
3839
TokenizeRequest,
39-
TokenizeResponse)
40-
# yapf: enable
40+
TokenizeResponse,
41+
UnloadLoraAdapterRequest)
4142
from vllm.entrypoints.openai.rpc.client import AsyncEngineRPCClient
4243
from vllm.entrypoints.openai.rpc.server import run_rpc_server
44+
# yapf: enable
4345
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
4446
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
4547
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
@@ -343,6 +345,40 @@ async def stop_profile():
343345
return Response(status_code=200)
344346

345347

348+
if envs.VLLM_ALLOW_RUNTIME_LORA_UPDATING:
349+
logger.warning(
350+
"Lora dynamic loading & unloading is enabled in the API server. "
351+
"This should ONLY be used for local development!")
352+
353+
@router.post("/v1/load_lora_adapter")
354+
async def load_lora_adapter(request: LoadLoraAdapterRequest):
355+
response = await openai_serving_chat.load_lora_adapter(request)
356+
if isinstance(response, ErrorResponse):
357+
return JSONResponse(content=response.model_dump(),
358+
status_code=response.code)
359+
360+
response = await openai_serving_completion.load_lora_adapter(request)
361+
if isinstance(response, ErrorResponse):
362+
return JSONResponse(content=response.model_dump(),
363+
status_code=response.code)
364+
365+
return Response(status_code=200, content=response)
366+
367+
@router.post("/v1/unload_lora_adapter")
368+
async def unload_lora_adapter(request: UnloadLoraAdapterRequest):
369+
response = await openai_serving_chat.unload_lora_adapter(request)
370+
if isinstance(response, ErrorResponse):
371+
return JSONResponse(content=response.model_dump(),
372+
status_code=response.code)
373+
374+
response = await openai_serving_completion.unload_lora_adapter(request)
375+
if isinstance(response, ErrorResponse):
376+
return JSONResponse(content=response.model_dump(),
377+
status_code=response.code)
378+
379+
return Response(status_code=200, content=response)
380+
381+
346382
def build_app(args: Namespace) -> FastAPI:
347383
app = FastAPI(lifespan=lifespan)
348384
app.include_router(router)

vllm/entrypoints/openai/protocol.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -878,3 +878,13 @@ class DetokenizeRequest(OpenAIBaseModel):
878878

879879
class DetokenizeResponse(OpenAIBaseModel):
880880
prompt: str
881+
882+
883+
class LoadLoraAdapterRequest(BaseModel):
884+
lora_name: str
885+
lora_path: str
886+
887+
888+
class UnloadLoraAdapterRequest(BaseModel):
889+
lora_name: str
890+
lora_int_id: Optional[int] = Field(default=None)

vllm/entrypoints/openai/serving_engine.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@
1616
CompletionRequest,
1717
DetokenizeRequest,
1818
EmbeddingRequest, ErrorResponse,
19+
LoadLoraAdapterRequest,
1920
ModelCard, ModelList,
2021
ModelPermission,
2122
TokenizeChatRequest,
2223
TokenizeCompletionRequest,
23-
TokenizeRequest)
24+
TokenizeRequest,
25+
UnloadLoraAdapterRequest)
2426
# yapf: enable
2527
from vllm.inputs.parse import parse_and_batch_prompt
2628
from vllm.logger import init_logger
@@ -32,6 +34,7 @@
3234
from vllm.sampling_params import LogitsProcessor, SamplingParams
3335
from vllm.sequence import Logprob
3436
from vllm.transformers_utils.tokenizer import AnyTokenizer
37+
from vllm.utils import AtomicCounter
3538

3639
logger = init_logger(__name__)
3740

@@ -78,6 +81,7 @@ def __init__(
7881

7982
self.served_model_names = served_model_names
8083

84+
self.lora_id_counter = AtomicCounter(0)
8185
self.lora_requests = []
8286
if lora_modules is not None:
8387
self.lora_requests = [
@@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob,
403407
if logprob.decoded_token is not None:
404408
return logprob.decoded_token
405409
return tokenizer.decode(token_id)
410+
411+
async def _check_load_lora_adapter_request(
412+
self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]:
413+
# Check if both 'lora_name' and 'lora_path' are provided
414+
if not request.lora_name or not request.lora_path:
415+
return self.create_error_response(
416+
message="Both 'lora_name' and 'lora_path' must be provided.",
417+
err_type="InvalidUserInput",
418+
status_code=HTTPStatus.BAD_REQUEST)
419+
420+
# Check if the lora adapter with the given name already exists
421+
if any(lora_request.lora_name == request.lora_name
422+
for lora_request in self.lora_requests):
423+
return self.create_error_response(
424+
message=
425+
f"The lora adapter '{request.lora_name}' has already been"
426+
"loaded.",
427+
err_type="InvalidUserInput",
428+
status_code=HTTPStatus.BAD_REQUEST)
429+
430+
return None
431+
432+
async def _check_unload_lora_adapter_request(
433+
self,
434+
request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]:
435+
# Check if either 'lora_name' or 'lora_int_id' is provided
436+
if not request.lora_name and not request.lora_int_id:
437+
return self.create_error_response(
438+
message=
439+
"either 'lora_name' and 'lora_int_id' needs to be provided.",
440+
err_type="InvalidUserInput",
441+
status_code=HTTPStatus.BAD_REQUEST)
442+
443+
# Check if the lora adapter with the given name exists
444+
if not any(lora_request.lora_name == request.lora_name
445+
for lora_request in self.lora_requests):
446+
return self.create_error_response(
447+
message=
448+
f"The lora adapter '{request.lora_name}' cannot be found.",
449+
err_type="InvalidUserInput",
450+
status_code=HTTPStatus.BAD_REQUEST)
451+
452+
return None
453+
454+
async def load_lora_adapter(
455+
self,
456+
request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]:
457+
error_check_ret = await self._check_load_lora_adapter_request(request)
458+
if error_check_ret is not None:
459+
return error_check_ret
460+
461+
lora_name, lora_path = request.lora_name, request.lora_path
462+
unique_id = self.lora_id_counter.inc(1)
463+
self.lora_requests.append(
464+
LoRARequest(lora_name=lora_name,
465+
lora_int_id=unique_id,
466+
lora_path=lora_path))
467+
return f"Success: LoRA adapter '{lora_name}' added successfully."
468+
469+
async def unload_lora_adapter(
470+
self,
471+
request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]:
472+
error_check_ret = await self._check_unload_lora_adapter_request(request
473+
)
474+
if error_check_ret is not None:
475+
return error_check_ret
476+
477+
lora_name = request.lora_name
478+
self.lora_requests = [
479+
lora_request for lora_request in self.lora_requests
480+
if lora_request.lora_name != lora_name
481+
]
482+
return f"Success: LoRA adapter '{lora_name}' removed successfully."

vllm/envs.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
VLLM_ALLOW_ENGINE_USE_RAY: bool = False
6262
VLLM_PLUGINS: Optional[List[str]] = None
6363
VLLM_TORCH_PROFILER_DIR: Optional[str] = None
64+
VLLM_ALLOW_RUNTIME_LORA_UPDATING: bool = False
6465

6566

6667
def get_default_cache_root():
@@ -409,6 +410,12 @@ def get_default_config_root():
409410
# If set, vLLM will use Triton implementations of AWQ.
410411
"VLLM_USE_TRITON_AWQ":
411412
lambda: bool(int(os.getenv("VLLM_USE_TRITON_AWQ", "0"))),
413+
414+
# If set, allow loading or unloading lora adapters in runtime,
415+
"VLLM_ALLOW_RUNTIME_LORA_UPDATING":
416+
lambda:
417+
(os.environ.get("VLLM_ALLOW_RUNTIME_LORA_UPDATING", "0").strip().lower() in
418+
("1", "true")),
412419
}
413420

414421
# end-env-vars-definition

0 commit comments

Comments
 (0)