Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ async def detokenize(request: DetokenizeRequest):

@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
models = await openai_serving_completion.show_available_models()
return JSONResponse(content=models.model_dump())


Expand Down Expand Up @@ -236,7 +236,8 @@ async def authentication(request: Request, call_next):
args.lora_modules,
args.chat_template)
openai_serving_completion = OpenAIServingCompletion(
engine, model_config, served_model_names, args.lora_modules)
engine, model_config, served_model_names, args.lora_modules,
args.prompt_adapters)
openai_serving_embedding = OpenAIServingEmbedding(engine, model_config,
served_model_names)
app.root_path = args.root_path
Expand Down
21 changes: 20 additions & 1 deletion vllm/entrypoints/openai/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import ssl

from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str
from vllm.entrypoints.openai.serving_engine import LoRAModulePath
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
PromptAdapterPath)
from vllm.utils import FlexibleArgumentParser


Expand All @@ -23,6 +24,16 @@ def __call__(self, parser, namespace, values, option_string=None):
setattr(namespace, self.dest, lora_list)


class PromptAdapterParserAction(argparse.Action):

def __call__(self, parser, namespace, values, option_string=None):
adapter_list = []
for item in values:
name, path = item.split('=')
adapter_list.append(PromptAdapterPath(name, path))
setattr(namespace, self.dest, adapter_list)


def make_arg_parser():
parser = FlexibleArgumentParser(
description="vLLM OpenAI-Compatible RESTful API server.")
Expand Down Expand Up @@ -65,6 +76,14 @@ def make_arg_parser():
action=LoRAParserAction,
help="LoRA module configurations in the format name=path. "
"Multiple modules can be specified.")
parser.add_argument(
"--prompt-adapters",
type=nullable_str,
default=None,
nargs='+',
action=PromptAdapterParserAction,
help="Prompt adapter configurations in the format name=path. "
"Multiple adapters can be specified.")
parser.add_argument("--chat-template",
type=nullable_str,
default=None,
Expand Down
11 changes: 8 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
OpenAIServing,
PromptAdapterPath)
from vllm.logger import init_logger
from vllm.model_executor.guided_decoding import (
get_guided_decoding_logits_processor)
Expand Down Expand Up @@ -67,11 +68,13 @@ class OpenAIServingCompletion(OpenAIServing):

def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]]):
super().__init__(engine=engine,
model_config=model_config,
served_model_names=served_model_names,
lora_modules=lora_modules)
lora_modules=lora_modules,
prompt_adapters=prompt_adapters)

async def create_completion(self, request: CompletionRequest,
raw_request: Request):
Expand Down Expand Up @@ -102,6 +105,7 @@ async def create_completion(self, request: CompletionRequest,
try:
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
prompt_adapter_request = self._maybe_get_prompt_adapter(request)
decoding_config = await self.engine.get_decoding_config()
guided_decoding_backend = request.guided_decoding_backend \
or decoding_config.guided_decoding_backend
Expand Down Expand Up @@ -147,6 +151,7 @@ async def create_completion(self, request: CompletionRequest,
sampling_params,
f"{request_id}-{i}",
lora_request=lora_request,
prompt_adapter_request=prompt_adapter_request,
trace_headers=trace_headers,
)

Expand Down
59 changes: 55 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,19 @@
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import Logprob
from vllm.transformers_utils.tokenizer import get_tokenizer

logger = init_logger(__name__)


@dataclass
class PromptAdapterPath:
name: str
local_path: str


@dataclass
class LoRAModulePath:
name: str
Expand All @@ -30,9 +37,14 @@ class LoRAModulePath:

class OpenAIServing:

def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]]):
def __init__(
self,
engine: AsyncLLMEngine,
model_config: ModelConfig,
served_model_names: List[str],
lora_modules: Optional[List[LoRAModulePath]],
prompt_adapters: Optional[List[PromptAdapterPath]] = None,
):
super().__init__()

self.engine = engine
Expand All @@ -59,6 +71,19 @@ def __init__(self, engine: AsyncLLMEngine, model_config: ModelConfig,
) for i, lora in enumerate(lora_modules, start=1)
]

self.prompt_adapter_requests = []
if prompt_adapters is not None:
for i, prompt_adapter in enumerate(prompt_adapters, start=1):
with open(f"./{prompt_adapter.local_path}/adapter_config.json") as f:
adapter_config = json.load(f)
num_virtual_tokens = adapter_config["num_virtual_tokens"]
self.prompt_adapter_requests.append(
PromptAdapterRequest(
prompt_adapter_name=prompt_adapter.name,
prompt_adapter_id=i,
prompt_adapter_local_path=prompt_adapter.local_path,
prompt_adapter_num_virtual_tokens=num_virtual_tokens))

async def show_available_models(self) -> ModelList:
"""Show available models. Right now we only have one model."""
model_cards = [
Expand All @@ -74,6 +99,13 @@ async def show_available_models(self) -> ModelList:
permission=[ModelPermission()])
for lora in self.lora_requests
]
prompt_adapter_cards = [
ModelCard(id=prompt_adapter.prompt_adapter_name,
root=self.served_model_names[0],
permission=[ModelPermission()])
for prompt_adapter in self.prompt_adapter_requests
]
model_cards.extend(prompt_adapter_cards)
model_cards.extend(lora_cards)
return ModelList(data=model_cards)

Expand Down Expand Up @@ -108,6 +140,11 @@ async def _check_model(
return None
if request.model in [lora.lora_name for lora in self.lora_requests]:
return None
if request.model in [
prompt_adapter.prompt_adapter_name
for prompt_adapter in self.prompt_adapter_requests
]:
return None
return self.create_error_response(
message=f"The model `{request.model}` does not exist.",
err_type="NotFoundError",
Expand All @@ -122,8 +159,22 @@ def _maybe_get_lora(
for lora in self.lora_requests:
if request.model == lora.lora_name:
return lora
return None
# if _check_model has been called earlier, this will be unreachable
#raise ValueError(f"The model `{request.model}` does not exist.")

def _maybe_get_prompt_adapter(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
) -> Optional[PromptAdapterRequest]:
if request.model in self.served_model_names:
return None
for prompt_adapter in self.prompt_adapter_requests:
if request.model == prompt_adapter.prompt_adapter_name:
return prompt_adapter
return None
# if _check_model has been called earlier, this will be unreachable
raise ValueError(f"The model `{request.model}` does not exist.")
#raise ValueError(f"The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
self,
Expand Down