Skip to content
15 changes: 15 additions & 0 deletions docs/source/models/lora.rst
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,18 @@ The following is an example request
"max_tokens": 7,
"temperature": 0
}' | jq


Alternatively, the request can specify a LoRA adapter to load dynamically from the server's local disk storage:

.. code-block:: bash

curl http://localhost:8000/v1/completions \
-H "Content-Type: application/json" \
-d '{
"model": "sql-lora",
"prompt": "San Francisco is a",
"max_tokens": 7,
"temperature": 0,
"lora_request": {"lora_name":"sql-lora","lora_local_path":"/data/adapters/sql-lora"}
}' | jq
15 changes: 14 additions & 1 deletion vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.pooling_params import PoolingParams
from vllm.sampling_params import LogitsProcessor, SamplingParams
from vllm.utils import random_uuid
from vllm.lora.request import LoRARequest

# torch is mocked during docs generation,
# so we have to provide the values as literals
Expand Down Expand Up @@ -218,6 +219,8 @@ class ChatCompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
lora_request: Optional[dict] = Field(default_factory=dict)

guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
Expand All @@ -232,6 +235,11 @@ class ChatCompletionRequest(OpenAIBaseModel):

# doc: end-chat-completion-extra-params

def to_lora_params(self) -> Union[LoRARequest, None]:
if not self.lora_request:
return None
return LoRARequest(**self.lora_request)

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
Expand Down Expand Up @@ -403,6 +411,7 @@ class CompletionRequest(OpenAIBaseModel):
description=(
"If specified, the output will follow the context free grammar."),
)
lora_request: Optional[dict] = Field(default_factory=dict)
guided_decoding_backend: Optional[str] = Field(
default=None,
description=(
Expand All @@ -417,14 +426,18 @@ class CompletionRequest(OpenAIBaseModel):

# doc: end-completion-extra-params

def to_lora_params(self) -> Union[LoRARequest, None]:
if not self.lora_request:
return None
return LoRARequest(**self.lora_request)

def to_sampling_params(
self, tokenizer: PreTrainedTokenizer,
guided_decode_logits_processor: Optional[LogitsProcessor],
default_max_tokens: int) -> SamplingParams:
max_tokens = self.max_tokens
if max_tokens is None:
max_tokens = default_max_tokens

echo_without_generation = self.echo and self.max_tokens == 0

logits_processors = get_logits_processors(
Expand Down
10 changes: 10 additions & 0 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from http import HTTPStatus
from typing import Iterable, Iterator, List, Optional, Tuple, TypedDict, Union
import os

from pydantic import Field
from typing_extensions import Annotated
Expand Down Expand Up @@ -169,6 +170,8 @@ async def _check_model(
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
elif request.lora_request and os.path.exists(request.lora_request.get("lora_local_path")):
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
Expand All @@ -188,6 +191,13 @@ def _maybe_get_adapters(
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora, None
if request.lora_request and os.path.exists(request.lora_request.get("lora_local_path")):
new_lora = LoRARequest(
lora_name=request.model,
lora_local_path=request.lora_request.get("lora_local_path")
)
self.lora_requests.append(new_lora)
return new_lora, None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return None, prompt_adapter
Expand Down
20 changes: 18 additions & 2 deletions vllm/lora/request.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
import warnings
from dataclasses import dataclass, field
import warnings
from typing import Optional
import hashlib

from vllm.adapter_commons.request import AdapterRequest


def positive_hash_sha256(input_string):
"""
function to generate positive hash from input string, which is used to identify the model variant for lora
sha-256 is used to keep it consistent between python versions and the sheets addon
"""
return int(hashlib.sha256(input_string.encode('utf-8')).hexdigest(), 16) % (2 ** 63)


@dataclass
class LoRARequest(AdapterRequest):
"""
Expand All @@ -20,7 +29,7 @@ class LoRARequest(AdapterRequest):
"""

lora_name: str
lora_int_id: int
lora_int_id: Optional[int] = 0
lora_path: str = ""
lora_local_path: Optional[str] = field(default=None, repr=False)
long_lora_max_len: Optional[int] = None
Expand All @@ -37,6 +46,13 @@ def __post_init__(self):
if not self.lora_path:
self.lora_path = self.lora_local_path or ""

# if no int_id was given, use the name hash as id
if not self.lora_int_id:
self.lora_int_id = positive_hash_sha256(self.lora_name)
if self.lora_int_id < 1:
raise ValueError(
f"lora_int_id must be > 0, got {self.lora_int_id}")

# Ensure lora_path is not empty
assert self.lora_path, "lora_path cannot be empty"

Expand Down