-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
[Core] Support load and unload LoRA in api server #6566
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
bfcdb4e
c90dbd8
0f9bd90
9b6bb28
2d77a0b
72aa3fb
bfcb68d
40d64d9
7ae737a
54272e3
729b620
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,107 @@ | ||
| from http import HTTPStatus | ||
| from unittest.mock import MagicMock | ||
|
|
||
| import pytest | ||
|
|
||
| from vllm.config import ModelConfig | ||
| from vllm.engine.protocol import AsyncEngineClient | ||
| from vllm.entrypoints.openai.protocol import (ErrorResponse, | ||
| LoadLoraAdapterRequest, | ||
| UnloadLoraAdapterRequest) | ||
| from vllm.entrypoints.openai.serving_engine import OpenAIServing | ||
|
|
||
| MODEL_NAME = "meta-llama/Llama-2-7b" | ||
| LORA_LOADING_SUCCESS_MESSAGE = ( | ||
| "Success: LoRA adapter '{lora_name}' added successfully.") | ||
| LORA_UNLOADING_SUCCESS_MESSAGE = ( | ||
| "Success: LoRA adapter '{lora_name}' removed successfully.") | ||
|
|
||
|
|
||
| async def _async_serving_engine_init(): | ||
| mock_engine_client = MagicMock(spec=AsyncEngineClient) | ||
| mock_model_config = MagicMock(spec=ModelConfig) | ||
| # Set the max_model_len attribute to avoid missing attribute | ||
| mock_model_config.max_model_len = 2048 | ||
|
|
||
| serving_engine = OpenAIServing(mock_engine_client, | ||
| mock_model_config, | ||
| served_model_names=[MODEL_NAME], | ||
| lora_modules=None, | ||
| prompt_adapters=None, | ||
| request_logger=None) | ||
| return serving_engine | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_load_lora_adapter_success(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = LoadLoraAdapterRequest(lora_name="adapter", | ||
| lora_path="/path/to/adapter2") | ||
| response = await serving_engine.load_lora_adapter(request) | ||
| assert response == LORA_LOADING_SUCCESS_MESSAGE.format(lora_name='adapter') | ||
| assert len(serving_engine.lora_requests) == 1 | ||
| assert serving_engine.lora_requests[0].lora_name == "adapter" | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_load_lora_adapter_missing_fields(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = LoadLoraAdapterRequest(lora_name="", lora_path="") | ||
| response = await serving_engine.load_lora_adapter(request) | ||
| assert isinstance(response, ErrorResponse) | ||
| assert response.type == "InvalidUserInput" | ||
| assert response.code == HTTPStatus.BAD_REQUEST | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_load_lora_adapter_duplicate(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = LoadLoraAdapterRequest(lora_name="adapter1", | ||
| lora_path="/path/to/adapter1") | ||
| response = await serving_engine.load_lora_adapter(request) | ||
| assert response == LORA_LOADING_SUCCESS_MESSAGE.format( | ||
| lora_name='adapter1') | ||
| assert len(serving_engine.lora_requests) == 1 | ||
|
|
||
| request = LoadLoraAdapterRequest(lora_name="adapter1", | ||
| lora_path="/path/to/adapter1") | ||
| response = await serving_engine.load_lora_adapter(request) | ||
| assert isinstance(response, ErrorResponse) | ||
| assert response.type == "InvalidUserInput" | ||
| assert response.code == HTTPStatus.BAD_REQUEST | ||
| assert len(serving_engine.lora_requests) == 1 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_unload_lora_adapter_success(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = LoadLoraAdapterRequest(lora_name="adapter1", | ||
| lora_path="/path/to/adapter1") | ||
| response = await serving_engine.load_lora_adapter(request) | ||
| assert len(serving_engine.lora_requests) == 1 | ||
|
|
||
| request = UnloadLoraAdapterRequest(lora_name="adapter1") | ||
| response = await serving_engine.unload_lora_adapter(request) | ||
| assert response == LORA_UNLOADING_SUCCESS_MESSAGE.format( | ||
| lora_name='adapter1') | ||
| assert len(serving_engine.lora_requests) == 0 | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_unload_lora_adapter_missing_fields(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = UnloadLoraAdapterRequest(lora_name="", lora_int_id=None) | ||
| response = await serving_engine.unload_lora_adapter(request) | ||
| assert isinstance(response, ErrorResponse) | ||
| assert response.type == "InvalidUserInput" | ||
| assert response.code == HTTPStatus.BAD_REQUEST | ||
|
|
||
|
|
||
| @pytest.mark.asyncio | ||
| async def test_unload_lora_adapter_not_found(): | ||
| serving_engine = await _async_serving_engine_init() | ||
| request = UnloadLoraAdapterRequest(lora_name="nonexistent_adapter") | ||
| response = await serving_engine.unload_lora_adapter(request) | ||
| assert isinstance(response, ErrorResponse) | ||
| assert response.type == "InvalidUserInput" | ||
| assert response.code == HTTPStatus.BAD_REQUEST |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,11 +16,13 @@ | |
| CompletionRequest, | ||
| DetokenizeRequest, | ||
| EmbeddingRequest, ErrorResponse, | ||
| LoadLoraAdapterRequest, | ||
| ModelCard, ModelList, | ||
| ModelPermission, | ||
| TokenizeChatRequest, | ||
| TokenizeCompletionRequest, | ||
| TokenizeRequest) | ||
| TokenizeRequest, | ||
| UnloadLoraAdapterRequest) | ||
| # yapf: enable | ||
| from vllm.inputs.parse import parse_and_batch_prompt | ||
| from vllm.logger import init_logger | ||
|
|
@@ -32,6 +34,7 @@ | |
| from vllm.sampling_params import LogitsProcessor, SamplingParams | ||
| from vllm.sequence import Logprob | ||
| from vllm.transformers_utils.tokenizer import AnyTokenizer | ||
| from vllm.utils import AtomicCounter | ||
|
|
||
| logger = init_logger(__name__) | ||
|
|
||
|
|
@@ -78,6 +81,7 @@ def __init__( | |
|
|
||
| self.served_model_names = served_model_names | ||
|
|
||
| self.lora_id_counter = AtomicCounter(0) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Jeffwan this is not required here. asyncio operations all happen in the same thread. Can change this to be a simple int field. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @njhill yeah, that makes sense. Let me file a follow up PR to improve it. |
||
| self.lora_requests = [] | ||
| if lora_modules is not None: | ||
| self.lora_requests = [ | ||
|
|
@@ -403,3 +407,76 @@ def _get_decoded_token(logprob: Logprob, | |
| if logprob.decoded_token is not None: | ||
| return logprob.decoded_token | ||
| return tokenizer.decode(token_id) | ||
|
|
||
| async def _check_load_lora_adapter_request( | ||
| self, request: LoadLoraAdapterRequest) -> Optional[ErrorResponse]: | ||
| # Check if both 'lora_name' and 'lora_path' are provided | ||
| if not request.lora_name or not request.lora_path: | ||
| return self.create_error_response( | ||
| message="Both 'lora_name' and 'lora_path' must be provided.", | ||
| err_type="InvalidUserInput", | ||
| status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
| # Check if the lora adapter with the given name already exists | ||
| if any(lora_request.lora_name == request.lora_name | ||
| for lora_request in self.lora_requests): | ||
| return self.create_error_response( | ||
| message= | ||
| f"The lora adapter '{request.lora_name}' has already been" | ||
| "loaded.", | ||
| err_type="InvalidUserInput", | ||
| status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
| return None | ||
|
|
||
| async def _check_unload_lora_adapter_request( | ||
| self, | ||
| request: UnloadLoraAdapterRequest) -> Optional[ErrorResponse]: | ||
| # Check if either 'lora_name' or 'lora_int_id' is provided | ||
| if not request.lora_name and not request.lora_int_id: | ||
| return self.create_error_response( | ||
| message= | ||
| "either 'lora_name' and 'lora_int_id' needs to be provided.", | ||
| err_type="InvalidUserInput", | ||
| status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
| # Check if the lora adapter with the given name exists | ||
| if not any(lora_request.lora_name == request.lora_name | ||
| for lora_request in self.lora_requests): | ||
| return self.create_error_response( | ||
| message= | ||
| f"The lora adapter '{request.lora_name}' cannot be found.", | ||
| err_type="InvalidUserInput", | ||
| status_code=HTTPStatus.BAD_REQUEST) | ||
|
|
||
| return None | ||
|
|
||
| async def load_lora_adapter( | ||
| self, | ||
| request: LoadLoraAdapterRequest) -> Union[ErrorResponse, str]: | ||
| error_check_ret = await self._check_load_lora_adapter_request(request) | ||
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| lora_name, lora_path = request.lora_name, request.lora_path | ||
| unique_id = self.lora_id_counter.inc(1) | ||
| self.lora_requests.append( | ||
| LoRARequest(lora_name=lora_name, | ||
| lora_int_id=unique_id, | ||
| lora_path=lora_path)) | ||
| return f"Success: LoRA adapter '{lora_name}' added successfully." | ||
|
|
||
| async def unload_lora_adapter( | ||
| self, | ||
| request: UnloadLoraAdapterRequest) -> Union[ErrorResponse, str]: | ||
| error_check_ret = await self._check_unload_lora_adapter_request(request | ||
| ) | ||
| if error_check_ret is not None: | ||
| return error_check_ret | ||
|
|
||
| lora_name = request.lora_name | ||
| self.lora_requests = [ | ||
| lora_request for lora_request in self.lora_requests | ||
| if lora_request.lora_name != lora_name | ||
| ] | ||
| return f"Success: LoRA adapter '{lora_name}' removed successfully." | ||
Uh oh!
There was an error while loading. Please reload this page.