Skip to content

Commit 2398646

Browse files
fgreinacherLiuXiaoxuanPKUDarkLight1337
committed
feat: support complex message content for chat completions endpoint
Co-authored-by: Lily Liu <[email protected]> Co-authored-by: Cyrus Leung <[email protected]>
1 parent df29793 commit 2398646

File tree

2 files changed

+46
-21
lines changed

2 files changed

+46
-21
lines changed

tests/entrypoints/test_openai_server.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -786,6 +786,25 @@ async def test_extra_fields(server, client: openai.AsyncOpenAI):
786786
assert "extra_forbidden" in exc_info.value.message
787787

788788

789+
async def test_complex_message_content(server, client: openai.AsyncOpenAI):
790+
resp = await client.chat.completions.create(
791+
model=MODEL_NAME,
792+
messages=[{
793+
"role":
794+
"user",
795+
"content": [{
796+
"type":
797+
"text",
798+
"text":
799+
"what is 1+1? please provide the result without any other text."
800+
}]
801+
}],
802+
temperature=0,
803+
seed=0)
804+
content = resp.choices[0].message.content
805+
assert content == "2"
806+
807+
789808
async def test_guided_grammar(server, client: openai.AsyncOpenAI):
790809
simple_sql_grammar = """
791810
start: select_statement

vllm/entrypoints/openai/serving_chat.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,16 @@ def _parse_chat_message_content(
5555
if isinstance(content, str):
5656
return [ConversationMessage(role=role, content=content)], []
5757

58-
# To be implemented: https://github.com/vllm-project/vllm/pull/3467
59-
# To be implemented: https://github.com/vllm-project/vllm/pull/4200
60-
raise NotImplementedError("Complex input not supported yet")
58+
texts: List[str] = []
59+
for _, part in enumerate(content):
60+
if part["type"] == "text":
61+
text = part["text"]
62+
63+
texts.append(text)
64+
else:
65+
raise NotImplementedError(f"Unknown part type: {part['type']}")
66+
67+
return [ConversationMessage(role=role, content="\n".join(texts))], []
6168

6269
async def create_chat_completion(
6370
self, request: ChatCompletionRequest, raw_request: Request
@@ -122,11 +129,12 @@ async def create_chat_completion(
122129
# Streaming response
123130
if request.stream:
124131
return self.chat_completion_stream_generator(
125-
request, result_generator, request_id)
132+
request, result_generator, request_id, conversation)
126133
else:
127134
try:
128135
return await self.chat_completion_full_generator(
129-
request, raw_request, result_generator, request_id)
136+
request, raw_request, result_generator, request_id,
137+
conversation)
130138
except ValueError as e:
131139
# TODO: Use a vllm-specific Validation Error
132140
return self.create_error_response(str(e))
@@ -139,8 +147,9 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str:
139147

140148
async def chat_completion_stream_generator(
141149
self, request: ChatCompletionRequest,
142-
result_generator: AsyncIterator[RequestOutput],
143-
request_id: str) -> AsyncGenerator[str, None]:
150+
result_generator: AsyncIterator[RequestOutput], request_id: str,
151+
conversation: List[ConversationMessage]
152+
) -> AsyncGenerator[str, None]:
144153
model_name = self.served_model_names[0]
145154
created_time = int(time.time())
146155
chunk_object_type = "chat.completion.chunk"
@@ -179,12 +188,10 @@ async def chat_completion_stream_generator(
179188
# last message
180189
if request.echo:
181190
last_msg_content = ""
182-
if request.messages and isinstance(
183-
request.messages,
184-
list) and request.messages[-1].get(
185-
"content") and request.messages[-1].get(
186-
"role") == role:
187-
last_msg_content = request.messages[-1]["content"]
191+
if conversation and conversation[-1].get(
192+
"content") and conversation[-1].get(
193+
"role") == role:
194+
last_msg_content = conversation[-1]["content"]
188195

189196
if last_msg_content:
190197
for i in range(request.n):
@@ -279,9 +286,10 @@ async def chat_completion_stream_generator(
279286
yield "data: [DONE]\n\n"
280287

281288
async def chat_completion_full_generator(
282-
self, request: ChatCompletionRequest, raw_request: Request,
283-
result_generator: AsyncIterator[RequestOutput],
284-
request_id: str) -> Union[ErrorResponse, ChatCompletionResponse]:
289+
self, request: ChatCompletionRequest, raw_request: Request,
290+
result_generator: AsyncIterator[RequestOutput], request_id: str,
291+
conversation: List[ConversationMessage]
292+
) -> Union[ErrorResponse, ChatCompletionResponse]:
285293

286294
model_name = self.served_model_names[0]
287295
created_time = int(time.time())
@@ -322,11 +330,9 @@ async def chat_completion_full_generator(
322330

323331
if request.echo:
324332
last_msg_content = ""
325-
if request.messages and isinstance(
326-
request.messages, list) and request.messages[-1].get(
327-
"content") and request.messages[-1].get(
328-
"role") == role:
329-
last_msg_content = request.messages[-1]["content"]
333+
if conversation and conversation[-1].get(
334+
"content") and conversation[-1].get("role") == role:
335+
last_msg_content = conversation[-1]["content"]
330336

331337
for choice in choices:
332338
full_message = last_msg_content + choice.message.content

0 commit comments

Comments
 (0)