From 692e081042e89354bbcc63ec665d98b808294f95 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Thu, 27 Jun 2024 17:37:18 +0000 Subject: [PATCH 1/6] [Frontend] Support for chat completions input in the tokenize endpoint --- tests/entrypoints/test_openai_server.py | 52 ++++++++++++++++++- vllm/entrypoints/openai/api_server.py | 4 +- vllm/entrypoints/openai/protocol.py | 8 +-- vllm/entrypoints/openai/serving_chat.py | 51 +++++++++++++++++- vllm/entrypoints/openai/serving_completion.py | 31 +---------- 5 files changed, 107 insertions(+), 39 deletions(-) diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 14f59ea66e6a..fa8c13f7b0ff 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -619,7 +619,8 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, "model_name", [MODEL_NAME], ) -async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): +async def test_tokenize_completions(client: openai.AsyncOpenAI, + model_name: str): base_url = str(client.base_url)[:-3] tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") @@ -627,13 +628,14 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): prompt = "This is a test prompt." tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - response = requests.post(base_url + "/tokenize", + response = requests.post(base_url + "tokenize", json={ "add_special_tokens": add_special, "model": model_name, "prompt": prompt }) response.raise_for_status() + assert response.json() == { "tokens": tokens, "count": len(tokens), @@ -641,6 +643,51 @@ async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): } +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3] + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": "user", + "content": "Hi there!" + }, { + "role": "assistant", + "content": "Nice to meet you!" + }, { + "role": "user", + "content": "Can I ask a question?" + }] + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + conversation=conversation, + tokenize=False) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "tokenize", + json={ + "add_generation_prompt": + add_generation, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name", @@ -659,6 +706,7 @@ async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): "tokens": tokens }) response.raise_for_status() + assert response.json() == {"prompt": prompt} diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a708176c254e..e87c3189d878 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -94,7 +94,7 @@ async def health() -> Response: @app.post("/tokenize") async def tokenize(request: TokenizeRequest): - generator = await openai_serving_completion.create_tokenize(request) + generator = await openai_serving_chat.create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -105,7 +105,7 @@ async def tokenize(request: TokenizeRequest): @app.post("/detokenize") async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_completion.create_detokenize(request) + generator = await openai_serving_chat.create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 7fb1af158531..56c7af91382a 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -702,15 +702,17 @@ class BatchRequestOutput(OpenAIBaseModel): class TokenizeRequest(OpenAIBaseModel): + add_generation_prompt: bool = Field(default=True) + add_special_tokens: bool = Field(default=False) + prompt: Optional[str] = Field(default=None) + messages: Optional[List[ChatCompletionMessageParam]] = Field(default=None) model: str - prompt: str - add_special_tokens: bool = Field(default=True) class TokenizeResponse(OpenAIBaseModel): - tokens: List[int] count: int max_model_len: int + tokens: List[int] class DetokenizeRequest(OpenAIBaseModel): diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 744e1d94511b..c900c3fc4774 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -18,8 +18,9 @@ ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, - FunctionCall, ToolCall, UsageInfo) + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, DetokenizeRequest, + DetokenizeResponse, ErrorResponse, FunctionCall, TokenizeRequest, + TokenizeResponse, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.inputs import PromptInputs @@ -611,3 +612,49 @@ def _create_chat_logprobs( step_top_logprobs, num_output_top_logprobs))) return ChatCompletionLogProbs(content=logprobs_content) + + async def create_tokenize(self, + request: TokenizeRequest) -> TokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + if not (request.prompt or request.messages): + return self.create_error_response( + "Either `prompt` or `messages` should be provided.") + + if (request.prompt and request.messages): + return self.create_error_response( + "Only one of `prompt` or `messages` should be provided.") + + if request.messages: + conversation: List[ConversationMessage] = [] + + for message in request.messages: + conversation.extend( + self._parse_chat_message_content(message).messages) + + request.prompt = self.tokenizer.apply_chat_template( + add_generation_prompt=request.add_generation_prompt, + conversation=conversation, + tokenize=False) + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, + prompt=request.prompt, + add_special_tokens=request.add_special_tokens) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, request: DetokenizeRequest) -> DetokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, prompt_ids=request.tokens) + + return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 8741893c9271..4ad24acc3132 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -16,10 +16,7 @@ CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, - DetokenizeRequest, - DetokenizeResponse, - TokenizeRequest, - TokenizeResponse, UsageInfo) + UsageInfo) # yapf: enable from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) @@ -446,29 +443,3 @@ def _create_completion_logprobs( tokens=out_tokens, top_logprobs=out_top_logprobs, ) - - async def create_tokenize(self, - request: TokenizeRequest) -> TokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, - prompt=request.prompt, - add_special_tokens=request.add_special_tokens) - - return TokenizeResponse(tokens=input_ids, - count=len(input_ids), - max_model_len=self.max_model_len) - - async def create_detokenize( - self, request: DetokenizeRequest) -> DetokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, prompt_ids=request.tokens) - - return DetokenizeResponse(prompt=input_text) From 179d0f869a006cb092c8efda743611835583a6c0 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 16 Jul 2024 06:31:15 +0000 Subject: [PATCH 2/6] Refactor --- tests/entrypoints/openai/test_chat.py | 48 ---- tests/entrypoints/openai/test_completion.py | 51 ----- tests/entrypoints/openai/test_tokenization.py | 104 +++++++++ vllm/entrypoints/openai/api_server.py | 9 +- vllm/entrypoints/openai/chat_utils.py | 146 ++++++++++++ vllm/entrypoints/openai/serving_chat.py | 214 ++---------------- .../openai/serving_tokenization.py | 74 ++++++ 7 files changed, 345 insertions(+), 301 deletions(-) create mode 100644 tests/entrypoints/openai/test_tokenization.py create mode 100644 vllm/entrypoints/openai/chat_utils.py create mode 100644 vllm/entrypoints/openai/serving_tokenization.py diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 9bb4a9da5399..d370c63c0c7b 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -6,14 +6,11 @@ import jsonschema import openai # use the official client for correctness check import pytest -import requests import torch # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError -from vllm.transformers_utils.tokenizer import get_tokenizer - from ...utils import RemoteOpenAIServer # any model with a chat template should work here @@ -824,48 +821,3 @@ async def test_long_seed(client: openai.AsyncOpenAI): assert ("greater_than_equal" in exc_info.value.message or "less_than_equal" in exc_info.value.message) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): - base_url = str(client.base_url)[:-3] - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") - - for add_generation in [False, True]: - for add_special in [False, True]: - conversation = [{ - "role": "user", - "content": "Hi there!" - }, { - "role": "assistant", - "content": "Nice to meet you!" - }, { - "role": "user", - "content": "Can I ask a question?" - }] - - prompt = tokenizer.apply_chat_template( - add_generation_prompt=add_generation, - conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - - response = requests.post(base_url + "tokenize", - json={ - "add_generation_prompt": - add_generation, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) - response.raise_for_status() - - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index e0ad210677de..35af0b02747e 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -6,7 +6,6 @@ import jsonschema import openai # use the official client for correctness check import pytest -import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -636,53 +635,3 @@ async def test_guided_decoding_type_error(client: openai.AsyncOpenAI, prompt="Give an example string that fits this regex", extra_body=dict(guided_regex=sample_regex, guided_json=sample_json_schema)) - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_tokenize(client: openai.AsyncOpenAI, model_name: str): - base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") - - for add_special in [False, True]: - prompt = "This is a test prompt." - tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - - response = requests.post(base_url + "tokenize", - json={ - "add_special_tokens": add_special, - "model": model_name, - "prompt": prompt - }) - response.raise_for_status() - - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model_name", - [MODEL_NAME], -) -async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): - base_url = str(client.base_url)[:-3] - tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") - - prompt = "This is a test prompt." - tokens = tokenizer.encode(prompt, add_special_tokens=False) - - response = requests.post(base_url + "detokenize", - json={ - "model": model_name, - "tokens": tokens - }) - response.raise_for_status() - - assert response.json() == {"prompt": prompt} diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py new file mode 100644 index 000000000000..3ee0b3aaa956 --- /dev/null +++ b/tests/entrypoints/openai/test_tokenization.py @@ -0,0 +1,104 @@ +import openai # use the official client for correctness check +import pytest +import requests + +from vllm.transformers_utils.tokenizer import get_tokenizer + +# any model with a chat template should work here +MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize_completions(client: openai.AsyncOpenAI, + model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + + for add_special in [False, True]: + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_special_tokens": add_special, + "model": model_name, + "prompt": prompt + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + + for add_generation in [False, True]: + for add_special in [False, True]: + conversation = [{ + "role": "user", + "content": "Hi there!" + }, { + "role": "assistant", + "content": "Nice to meet you!" + }, { + "role": "user", + "content": "Can I ask a question?" + }] + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + conversation=conversation, + tokenize=False) + tokens = tokenizer.encode(prompt, add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_generation_prompt": + add_generation, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [MODEL_NAME], +) +async def test_detokenize(client: openai.AsyncOpenAI, model_name: str): + base_url = str(client.base_url)[:-3].strip("/") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") + + prompt = "This is a test prompt." + tokens = tokenizer.encode(prompt, add_special_tokens=False) + + response = requests.post(base_url + "/detokenize", + json={ + "model": model_name, + "tokens": tokens + }) + response.raise_for_status() + + assert response.json() == {"prompt": prompt} diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index c3df080ea04b..a56dcf648da5 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,6 +33,7 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding +from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser @@ -46,6 +47,7 @@ openai_serving_chat: OpenAIServingChat openai_serving_completion: OpenAIServingCompletion openai_serving_embedding: OpenAIServingEmbedding +openai_serving_tokenization: OpenAIServingTokenization logger = init_logger('vllm.entrypoints.openai.api_server') @@ -86,7 +88,7 @@ async def health() -> Response: @router.post("/tokenize") async def tokenize(request: TokenizeRequest): - generator = await openai_serving_chat.create_tokenize(request) + generator = await openai_serving_tokenization.create_tokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -97,7 +99,7 @@ async def tokenize(request: TokenizeRequest): @router.post("/detokenize") async def detokenize(request: DetokenizeRequest): - generator = await openai_serving_chat.create_detokenize(request) + generator = await openai_serving_tokenization.create_detokenize(request) if isinstance(generator, ErrorResponse): return JSONResponse(content=generator.model_dump(), status_code=generator.code) @@ -241,6 +243,7 @@ def run_server(args, llm_engine=None): global openai_serving_chat global openai_serving_completion global openai_serving_embedding + global openai_serving_tokenization openai_serving_chat = OpenAIServingChat(engine, model_config, served_model_names, @@ -252,6 +255,8 @@ def run_server(args, llm_engine=None): args.prompt_adapters) openai_serving_embedding = OpenAIServingEmbedding(engine, model_config, served_model_names) + openai_serving_tokenization = OpenAIServingTokenization( + engine, model_config, served_model_names, args.chat_template) app.root_path = args.root_path logger.info("Available routes are:") diff --git a/vllm/entrypoints/openai/chat_utils.py b/vllm/entrypoints/openai/chat_utils.py new file mode 100644 index 000000000000..32678844495a --- /dev/null +++ b/vllm/entrypoints/openai/chat_utils.py @@ -0,0 +1,146 @@ +import codecs +from dataclasses import dataclass, field +from typing import (Awaitable, Iterable, List, Optional, TypedDict, cast, + final) + +from openai.types.chat import (ChatCompletionContentPartImageParam, + ChatCompletionContentPartTextParam) + +from vllm.config import ModelConfig +from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, + ChatCompletionMessageParam) +from vllm.logger import init_logger +from vllm.multimodal import MultiModalDataDict +from vllm.multimodal.utils import async_get_and_parse_image + +logger = init_logger(__name__) + + +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + +@dataclass(frozen=True) +class ChatMessageParseResult: + messages: List[ConversationMessage] + mm_futures: List[Awaitable[MultiModalDataDict]] = field( + default_factory=list) + + +def load_chat_template(tokenizer, chat_template: Optional[str]): + if chat_template is not None: + try: + with open(chat_template, "r") as f: + tokenizer.chat_template = f.read() + except OSError as e: + JINJA_CHARS = "{}\n" + if not any(c in chat_template for c in JINJA_CHARS): + msg = (f"The supplied chat template ({chat_template}) " + f"looks like a file path, but it failed to be " + f"opened. Reason: {e}") + raise ValueError(msg) from e + + # If opening a file fails, set chat template to be args to + # ensure we decode so our escape are interpreted correctly + tokenizer.chat_template = codecs.decode(chat_template, + "unicode_escape") + + logger.info("Using supplied chat template:\n%s", + tokenizer.chat_template) + elif tokenizer.chat_template is not None: + logger.info("Using default chat template:\n%s", + tokenizer.chat_template) + else: + logger.warning("No chat template provided. Chat API will not work.") + + +def image_token_str(model_config: ModelConfig, tokenizer) -> Optional[str]: + # TODO: Let user specify how to insert image tokens into prompt + # (similar to chat template) + model_type = model_config.hf_config.model_type + if model_type == "phi3_v": + # Workaround since this token is not defined in the tokenizer + return "<|image_1|>" + if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", "paligemma"): + # These models do not use image tokens in the prompt + return None + if model_type.startswith("llava"): + return tokenizer.decode(model_config.hf_config.image_token_index) + + else: + raise TypeError("Unknown model type: {model_type}") + + +# TODO: Let user specify how to insert image tokens into prompt +# (similar to chat template) +def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: + """Combine image and text prompts for vision language model""" + + # NOTE: For now we assume all model architectures use the same + # image + text prompt format. This may change in the future. + return f"{image_token_str}\n{text_prompt}" + + +def _parse_chat_message_content_parts( + role: str, + parts: Iterable[ChatCompletionContentPartParam], + image_token_str: Optional[str], +) -> ChatMessageParseResult: + texts: List[str] = [] + mm_futures: List[Awaitable[MultiModalDataDict]] = [] + + for part in parts: + part_type = part["type"] + if part_type == "text": + text = cast(ChatCompletionContentPartTextParam, part)["text"] + texts.append(text) + elif part_type == "image_url": + if len(mm_futures) > 0: + raise NotImplementedError( + "Multiple 'image_url' input is currently not supported.") + + image_url = cast(ChatCompletionContentPartImageParam, + part)["image_url"] + + if image_url.get("detail", "auto") != "auto": + logger.warning( + "'image_url.detail' is currently not supported and " + "will be ignored.") + + image_future = async_get_and_parse_image(image_url["url"]) + mm_futures.append(image_future) + else: + raise NotImplementedError(f"Unknown part type: {part_type}") + + text_prompt = "\n".join(texts) + + if mm_futures and image_token_str is not None: + if image_token_str in text_prompt: + logger.warning("Detected image token string in the text prompt. " + "Skipping prompt formatting.") + else: + text_prompt = _get_full_image_text_prompt( + image_token_str=image_token_str, + text_prompt=text_prompt, + ) + + messages = [ConversationMessage(role=role, content=text_prompt)] + + return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) + + +def parse_chat_message_content( + message: ChatCompletionMessageParam, + image_token_str: Optional[str] = None) -> ChatMessageParseResult: + role = message["role"] + content = message.get("content") + + if content is None: + return ChatMessageParseResult(messages=[], mm_futures=[]) + if isinstance(content, str): + messages = [ConversationMessage(role=role, content=content)] + return ChatMessageParseResult(messages=messages, mm_futures=[]) + + return _parse_chat_message_content_parts(role, content, image_token_str) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 07d10f00fa7e..b108b3bf12f5 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,27 +1,23 @@ -import codecs import time -from dataclasses import dataclass, field -from functools import cached_property -from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, - List, Optional) +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, + Optional, Union) from typing import Sequence as GenericSequence -from typing import TypedDict, Union, cast, final from fastapi import Request -from openai.types.chat import (ChatCompletionContentPartImageParam, - ChatCompletionContentPartTextParam) from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.chat_utils import (load_chat_template, + parse_chat_message_content, + image_token_str, + ConversationMessage) from vllm.entrypoints.openai.protocol import ( - ChatCompletionContentPartParam, ChatCompletionLogProb, - ChatCompletionLogProbs, ChatCompletionLogProbsContent, - ChatCompletionMessageParam, ChatCompletionNamedToolChoiceParam, + ChatCompletionLogProb, ChatCompletionLogProbs, + ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, - ChatCompletionStreamResponse, ChatMessage, DeltaMessage, DetokenizeRequest, - DetokenizeResponse, ErrorResponse, FunctionCall, TokenizeRequest, - TokenizeResponse, ToolCall, UsageInfo) + ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, + FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing) from vllm.inputs import PromptInputs @@ -29,7 +25,6 @@ from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.multimodal import MultiModalDataDict -from vllm.multimodal.utils import async_get_and_parse_image from vllm.outputs import RequestOutput from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, @@ -39,19 +34,6 @@ logger = init_logger(__name__) -@final # So that it should be compatible with Dict[str, str] -class ConversationMessage(TypedDict): - role: str - content: str - - -@dataclass(frozen=True) -class ChatMessageParseResult: - messages: List[ConversationMessage] - mm_futures: List[Awaitable[MultiModalDataDict]] = field( - default_factory=list) - - class OpenAIServingChat(OpenAIServing): def __init__(self, @@ -67,131 +49,8 @@ def __init__(self, lora_modules=lora_modules) self.response_role = response_role - self._load_chat_template(chat_template) - - def _load_chat_template(self, chat_template: Optional[str]): - tokenizer = self.tokenizer - - if chat_template is not None: - try: - with open(chat_template, "r") as f: - tokenizer.chat_template = f.read() - except OSError as e: - JINJA_CHARS = "{}\n" - if not any(c in chat_template for c in JINJA_CHARS): - msg = (f"The supplied chat template ({chat_template}) " - f"looks like a file path, but it failed to be " - f"opened. Reason: {e}") - raise ValueError(msg) from e - - # If opening a file fails, set chat template to be args to - # ensure we decode so our escape are interpreted correctly - tokenizer.chat_template = codecs.decode( - chat_template, "unicode_escape") - - logger.info("Using supplied chat template:\n%s", - tokenizer.chat_template) - elif tokenizer.chat_template is not None: - logger.info("Using default chat template:\n%s", - tokenizer.chat_template) - else: - logger.warning( - "No chat template provided. Chat API will not work.") - - @cached_property - def image_token_str(self) -> Optional[str]: - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - model_type = self.model_config.hf_config.model_type - if model_type == "phi3_v": - # Workaround since this token is not defined in the tokenizer - return "<|image_1|>" - if model_type in ("blip-2", "chatglm", "fuyu", "minicpmv", - "paligemma"): - # These models do not use image tokens in the prompt - return None - if model_type.startswith("llava"): - return self.tokenizer.decode( - self.model_config.hf_config.image_token_index) - - else: - raise TypeError("Unknown model type: {model_type}") - - # TODO: Let user specify how to insert image tokens into prompt - # (similar to chat template) - def _get_full_image_text_prompt(self, image_token_str: str, - text_prompt: str) -> str: - """Combine image and text prompts for vision language model""" - - # NOTE: For now we assume all model architectures use the same - # image + text prompt format. This may change in the future. - return f"{image_token_str}\n{text_prompt}" - - def _parse_chat_message_content_parts( - self, - role: str, - parts: Iterable[ChatCompletionContentPartParam], - ) -> ChatMessageParseResult: - texts: List[str] = [] - mm_futures: List[Awaitable[MultiModalDataDict]] = [] - - for part in parts: - part_type = part["type"] - if part_type == "text": - text = cast(ChatCompletionContentPartTextParam, part)["text"] - texts.append(text) - elif part_type == "image_url": - if len(mm_futures) > 0: - raise NotImplementedError( - "Multiple 'image_url' input is currently not supported." - ) - - image_url = cast(ChatCompletionContentPartImageParam, - part)["image_url"] - - if image_url.get("detail", "auto") != "auto": - logger.warning( - "'image_url.detail' is currently not supported and " - "will be ignored.") - - image_future = async_get_and_parse_image(image_url["url"]) - mm_futures.append(image_future) - else: - raise NotImplementedError(f"Unknown part type: {part_type}") - - text_prompt = "\n".join(texts) - - if mm_futures: - image_token_str = self.image_token_str - if image_token_str is not None: - if image_token_str in text_prompt: - logger.warning( - "Detected image token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = self._get_full_image_text_prompt( - image_token_str=image_token_str, - text_prompt=text_prompt, - ) - - messages = [ConversationMessage(role=role, content=text_prompt)] - - return ChatMessageParseResult(messages=messages, mm_futures=mm_futures) - - def _parse_chat_message_content( - self, - message: ChatCompletionMessageParam, - ) -> ChatMessageParseResult: - role = message["role"] - content = message.get("content") - - if content is None: - return ChatMessageParseResult(messages=[], mm_futures=[]) - if isinstance(content, str): - messages = [ConversationMessage(role=role, content=content)] - return ChatMessageParseResult(messages=messages, mm_futures=[]) - - return self._parse_chat_message_content_parts(role, content) + self.image_token_str = image_token_str(model_config, self.tokenizer) + load_chat_template(self.tokenizer, chat_template) async def create_chat_completion( self, @@ -217,7 +76,8 @@ async def create_chat_completion( mm_futures: List[Awaitable[MultiModalDataDict]] = [] for msg in request.messages: - chat_parsed_result = self._parse_chat_message_content(msg) + chat_parsed_result = parse_chat_message_content( + msg, self.image_token_str) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) @@ -624,49 +484,3 @@ def _create_chat_logprobs( step_top_logprobs, num_output_top_logprobs))) return ChatCompletionLogProbs(content=logprobs_content) - - async def create_tokenize(self, - request: TokenizeRequest) -> TokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - if not (request.prompt or request.messages): - return self.create_error_response( - "Either `prompt` or `messages` should be provided.") - - if (request.prompt and request.messages): - return self.create_error_response( - "Only one of `prompt` or `messages` should be provided.") - - if request.messages: - conversation: List[ConversationMessage] = [] - - for message in request.messages: - conversation.extend( - self._parse_chat_message_content(message).messages) - - request.prompt = self.tokenizer.apply_chat_template( - add_generation_prompt=request.add_generation_prompt, - conversation=conversation, - tokenize=False) - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, - prompt=request.prompt, - add_special_tokens=request.add_special_tokens) - - return TokenizeResponse(tokens=input_ids, - count=len(input_ids), - max_model_len=self.max_model_len) - - async def create_detokenize( - self, request: DetokenizeRequest) -> DetokenizeResponse: - error_check_ret = await self._check_model(request) - if error_check_ret is not None: - return error_check_ret - - (input_ids, input_text) = self._validate_prompt_and_tokenize( - request, prompt_ids=request.tokens) - - return DetokenizeResponse(prompt=input_text) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py new file mode 100644 index 000000000000..00f5c9df1a4d --- /dev/null +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -0,0 +1,74 @@ +from typing import (List, Optional) + +from vllm.config import ModelConfig +from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.protocol import (DetokenizeRequest, + DetokenizeResponse, + TokenizeRequest, + TokenizeResponse) + +from vllm.entrypoints.openai.serving_engine import OpenAIServing +from vllm.entrypoints.openai.chat_utils import (load_chat_template, + parse_chat_message_content, + ConversationMessage) + + +class OpenAIServingTokenization(OpenAIServing): + + def __init__(self, + engine: AsyncLLMEngine, + model_config: ModelConfig, + served_model_names: List[str], + chat_template: Optional[str] = None): + super().__init__(engine=engine, + model_config=model_config, + served_model_names=served_model_names, + lora_modules=None) + + load_chat_template(self.tokenizer, chat_template) + + async def create_tokenize(self, + request: TokenizeRequest) -> TokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + if not (request.prompt or request.messages): + return self.create_error_response( + "Either `prompt` or `messages` should be provided.") + + if (request.prompt and request.messages): + return self.create_error_response( + "Only one of `prompt` or `messages` should be provided.") + + if request.messages: + conversation: List[ConversationMessage] = [] + + for message in request.messages: + conversation.extend( + parse_chat_message_content(message).messages) + + request.prompt = self.tokenizer.apply_chat_template( + add_generation_prompt=request.add_generation_prompt, + conversation=conversation, + tokenize=False) + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, + prompt=request.prompt, + add_special_tokens=request.add_special_tokens) + + return TokenizeResponse(tokens=input_ids, + count=len(input_ids), + max_model_len=self.max_model_len) + + async def create_detokenize( + self, request: DetokenizeRequest) -> DetokenizeResponse: + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + (input_ids, input_text) = self._validate_prompt_and_tokenize( + request, prompt_ids=request.tokens) + + return DetokenizeResponse(prompt=input_text) From 8caa1e4daa852e64cc7536447179a84cbf60d65a Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 16 Jul 2024 06:33:38 +0000 Subject: [PATCH 3/6] fix ci --- vllm/entrypoints/openai/api_server.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a56dcf648da5..a35dcbbd6545 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -33,7 +33,8 @@ from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding -from vllm.entrypoints.openai.serving_tokenization import OpenAIServingTokenization +from vllm.entrypoints.openai.serving_tokenization import ( + OpenAIServingTokenization) from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser From 63b43865673ec1bca748e8f7319f7c64143c3ae4 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 16 Jul 2024 06:35:16 +0000 Subject: [PATCH 4/6] format --- vllm/entrypoints/openai/chat_utils.py | 3 +-- vllm/entrypoints/openai/serving_chat.py | 9 +++++---- vllm/entrypoints/openai/serving_tokenization.py | 9 ++++----- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/vllm/entrypoints/openai/chat_utils.py b/vllm/entrypoints/openai/chat_utils.py index 32678844495a..96ee2b4f20a9 100644 --- a/vllm/entrypoints/openai/chat_utils.py +++ b/vllm/entrypoints/openai/chat_utils.py @@ -1,7 +1,6 @@ import codecs from dataclasses import dataclass, field -from typing import (Awaitable, Iterable, List, Optional, TypedDict, cast, - final) +from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index b108b3bf12f5..f58fc93416e3 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,16 +1,17 @@ import time from typing import (AsyncGenerator, AsyncIterator, Awaitable, Dict, List, - Optional, Union) + Optional) from typing import Sequence as GenericSequence +from typing import Union from fastapi import Request from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.chat_utils import (load_chat_template, - parse_chat_message_content, +from vllm.entrypoints.openai.chat_utils import (ConversationMessage, image_token_str, - ConversationMessage) + load_chat_template, + parse_chat_message_content) from vllm.entrypoints.openai.protocol import ( ChatCompletionLogProb, ChatCompletionLogProbs, ChatCompletionLogProbsContent, ChatCompletionNamedToolChoiceParam, diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 00f5c9df1a4d..2b9131de5265 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -1,16 +1,15 @@ -from typing import (List, Optional) +from typing import List, Optional from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.chat_utils import (ConversationMessage, + load_chat_template, + parse_chat_message_content) from vllm.entrypoints.openai.protocol import (DetokenizeRequest, DetokenizeResponse, TokenizeRequest, TokenizeResponse) - from vllm.entrypoints.openai.serving_engine import OpenAIServing -from vllm.entrypoints.openai.chat_utils import (load_chat_template, - parse_chat_message_content, - ConversationMessage) class OpenAIServingTokenization(OpenAIServing): From ca374f890987552c12504cc1d895f2910727935e Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 16 Jul 2024 07:44:55 +0000 Subject: [PATCH 5/6] fixes --- tests/async_engine/test_chat_template.py | 14 ++---- tests/entrypoints/openai/test_tokenization.py | 2 +- vllm/entrypoints/openai/chat_utils.py | 49 ++++++++++++------- vllm/entrypoints/openai/serving_chat.py | 7 +-- .../openai/serving_tokenization.py | 4 +- 5 files changed, 40 insertions(+), 36 deletions(-) diff --git a/tests/async_engine/test_chat_template.py b/tests/async_engine/test_chat_template.py index 55b730812ea9..536a7c96a1e9 100644 --- a/tests/async_engine/test_chat_template.py +++ b/tests/async_engine/test_chat_template.py @@ -4,8 +4,8 @@ import pytest +from vllm.entrypoints.openai.chat_utils import load_chat_template from vllm.entrypoints.openai.protocol import ChatCompletionRequest -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat from vllm.transformers_utils.tokenizer import get_tokenizer chatml_jinja_path = pathlib.Path(os.path.dirname(os.path.abspath( @@ -64,8 +64,7 @@ def test_load_chat_template(): # Testing chatml template tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=chatml_jinja_path) + load_chat_template(mock_serving_chat, chat_template=chatml_jinja_path) template_content = tokenizer.chat_template @@ -84,8 +83,7 @@ def test_no_load_chat_template_filelike(): mock_serving_chat = MockServingChat(tokenizer) with pytest.raises(ValueError, match="looks like a file path"): - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) def test_no_load_chat_template_literallike(): @@ -94,8 +92,7 @@ def test_no_load_chat_template_literallike(): tokenizer = MockTokenizer() mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) template_content = tokenizer.chat_template assert template_content == template @@ -109,8 +106,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) mock_serving_chat = MockServingChat(tokenizer) - OpenAIServingChat._load_chat_template(mock_serving_chat, - chat_template=template) + load_chat_template(mock_serving_chat, chat_template=template) # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 3ee0b3aaa956..ca3153c4bf35 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -44,7 +44,7 @@ async def test_tokenize_completions(client: openai.AsyncOpenAI, ) async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str): base_url = str(client.base_url)[:-3].strip("/") - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast") + tokenizer = get_tokenizer(tokenizer_name=model_name, tokenizer_mode="fast") for add_generation in [False, True]: for add_special in [False, True]: diff --git a/vllm/entrypoints/openai/chat_utils.py b/vllm/entrypoints/openai/chat_utils.py index 96ee2b4f20a9..27115391d5b2 100644 --- a/vllm/entrypoints/openai/chat_utils.py +++ b/vllm/entrypoints/openai/chat_utils.py @@ -1,13 +1,14 @@ import codecs from dataclasses import dataclass, field +from functools import lru_cache from typing import Awaitable, Iterable, List, Optional, TypedDict, cast, final from openai.types.chat import (ChatCompletionContentPartImageParam, ChatCompletionContentPartTextParam) -from vllm.config import ModelConfig from vllm.entrypoints.openai.protocol import (ChatCompletionContentPartParam, ChatCompletionMessageParam) +from vllm.entrypoints.openai.serving_engine import OpenAIServing from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict from vllm.multimodal.utils import async_get_and_parse_image @@ -28,7 +29,9 @@ class ChatMessageParseResult: default_factory=list) -def load_chat_template(tokenizer, chat_template: Optional[str]): +def load_chat_template(engine: OpenAIServing, chat_template: Optional[str]): + tokenizer = engine.tokenizer + if chat_template is not None: try: with open(chat_template, "r") as f: @@ -55,10 +58,11 @@ def load_chat_template(tokenizer, chat_template: Optional[str]): logger.warning("No chat template provided. Chat API will not work.") -def image_token_str(model_config: ModelConfig, tokenizer) -> Optional[str]: +@lru_cache(maxsize=None) +def _image_token_str(engine: OpenAIServing) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) - model_type = model_config.hf_config.model_type + model_type = engine.model_config.hf_config.model_type if model_type == "phi3_v": # Workaround since this token is not defined in the tokenizer return "<|image_1|>" @@ -66,7 +70,8 @@ def image_token_str(model_config: ModelConfig, tokenizer) -> Optional[str]: # These models do not use image tokens in the prompt return None if model_type.startswith("llava"): - return tokenizer.decode(model_config.hf_config.image_token_index) + return engine.tokenizer.decode( + engine.model_config.hf_config.image_token_index) else: raise TypeError("Unknown model type: {model_type}") @@ -74,7 +79,8 @@ def image_token_str(model_config: ModelConfig, tokenizer) -> Optional[str]: # TODO: Let user specify how to insert image tokens into prompt # (similar to chat template) -def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: +def _get_full_image_text_prompt(engine: OpenAIServing, image_token_str: str, + text_prompt: str) -> str: """Combine image and text prompts for vision language model""" # NOTE: For now we assume all model architectures use the same @@ -83,9 +89,9 @@ def _get_full_image_text_prompt(image_token_str: str, text_prompt: str) -> str: def _parse_chat_message_content_parts( + engine: OpenAIServing, role: str, parts: Iterable[ChatCompletionContentPartParam], - image_token_str: Optional[str], ) -> ChatMessageParseResult: texts: List[str] = [] mm_futures: List[Awaitable[MultiModalDataDict]] = [] @@ -115,15 +121,19 @@ def _parse_chat_message_content_parts( text_prompt = "\n".join(texts) - if mm_futures and image_token_str is not None: - if image_token_str in text_prompt: - logger.warning("Detected image token string in the text prompt. " - "Skipping prompt formatting.") - else: - text_prompt = _get_full_image_text_prompt( - image_token_str=image_token_str, - text_prompt=text_prompt, - ) + if mm_futures: + image_token_str = _image_token_str(engine) + if image_token_str is not None: + if image_token_str in text_prompt: + logger.warning( + "Detected image token string in the text prompt. " + "Skipping prompt formatting.") + else: + text_prompt = _get_full_image_text_prompt( + engine, + image_token_str=image_token_str, + text_prompt=text_prompt, + ) messages = [ConversationMessage(role=role, content=text_prompt)] @@ -131,8 +141,9 @@ def _parse_chat_message_content_parts( def parse_chat_message_content( - message: ChatCompletionMessageParam, - image_token_str: Optional[str] = None) -> ChatMessageParseResult: + engine: OpenAIServing, + message: ChatCompletionMessageParam, +) -> ChatMessageParseResult: role = message["role"] content = message.get("content") @@ -142,4 +153,4 @@ def parse_chat_message_content( messages = [ConversationMessage(role=role, content=content)] return ChatMessageParseResult(messages=messages, mm_futures=[]) - return _parse_chat_message_content_parts(role, content, image_token_str) + return _parse_chat_message_content_parts(engine, role, content) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f58fc93416e3..dbd4521073da 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -9,7 +9,6 @@ from vllm.config import ModelConfig from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.chat_utils import (ConversationMessage, - image_token_str, load_chat_template, parse_chat_message_content) from vllm.entrypoints.openai.protocol import ( @@ -50,8 +49,7 @@ def __init__(self, lora_modules=lora_modules) self.response_role = response_role - self.image_token_str = image_token_str(model_config, self.tokenizer) - load_chat_template(self.tokenizer, chat_template) + load_chat_template(self, chat_template) async def create_chat_completion( self, @@ -77,8 +75,7 @@ async def create_chat_completion( mm_futures: List[Awaitable[MultiModalDataDict]] = [] for msg in request.messages: - chat_parsed_result = parse_chat_message_content( - msg, self.image_token_str) + chat_parsed_result = parse_chat_message_content(self, msg) conversation.extend(chat_parsed_result.messages) mm_futures.extend(chat_parsed_result.mm_futures) diff --git a/vllm/entrypoints/openai/serving_tokenization.py b/vllm/entrypoints/openai/serving_tokenization.py index 2b9131de5265..f441e940c5e5 100644 --- a/vllm/entrypoints/openai/serving_tokenization.py +++ b/vllm/entrypoints/openai/serving_tokenization.py @@ -24,7 +24,7 @@ def __init__(self, served_model_names=served_model_names, lora_modules=None) - load_chat_template(self.tokenizer, chat_template) + load_chat_template(self, chat_template) async def create_tokenize(self, request: TokenizeRequest) -> TokenizeResponse: @@ -45,7 +45,7 @@ async def create_tokenize(self, for message in request.messages: conversation.extend( - parse_chat_message_content(message).messages) + parse_chat_message_content(self, message).messages) request.prompt = self.tokenizer.apply_chat_template( add_generation_prompt=request.add_generation_prompt, From ee910ccd4b0656a72bcf9c3e3dc208370e60fa05 Mon Sep 17 00:00:00 2001 From: sasha0552 Date: Tue, 16 Jul 2024 08:29:44 +0000 Subject: [PATCH 6/6] fix tests --- tests/entrypoints/openai/test_tokenization.py | 24 +++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index ca3153c4bf35..d33fd222ee15 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -4,10 +4,34 @@ from vllm.transformers_utils.tokenizer import get_tokenizer +from ...utils import RemoteOpenAIServer + # any model with a chat template should work here MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta" +@pytest.fixture(scope="module") +def server(): + with RemoteOpenAIServer([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "bfloat16", + "--max-model-len", + "8192", + "--enforce-eager", + "--max-num-seqs", + "128", + ]) as remote_server: + yield remote_server + + +@pytest.fixture(scope="module") +def client(server): + return server.get_async_client() + + @pytest.mark.asyncio @pytest.mark.parametrize( "model_name",