Skip to content

Commit b775a39

Browse files
authored
Support Responses Streaming (vllm-project#21)
1 parent 076cfce commit b775a39

File tree

6 files changed

+635
-36
lines changed

6 files changed

+635
-36
lines changed

responses_api.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""
4+
<<<<<<< HEAD
45
vllm serve /data/woosuk/os-mini-weights/pytorch-rc-20b --enforce-eager
6+
=======
7+
vllm serve /data/woosuk/os-mini-weights/pytorch-rc-20b \
8+
--tokenizer /data/xmo/os-mini/models/hf-converted --enforce-eager
9+
>>>>>>> 4a52d33d8 (streaming support)
510
"""
611
import argparse
712
import json
@@ -230,16 +235,48 @@ def test_stateful_multi_turn():
230235

231236

232237
def test_streaming():
233-
response = client.responses.create(
234-
model=MODEL,
235-
input="What is 13 * 24? Explain your answer.",
236-
stream=True,
237-
)
238+
promts = [
239+
"tell me a story about a cat in 20 words",
240+
"What is 13 * 24? Use python to calculate the result.",
241+
"When did Jensen found NVIDIA? Search it and answer the year only."
242+
]
243+
for prompt in promts:
244+
print(f"\n{prompt}\n")
245+
response = client.responses.create(
246+
model=MODEL,
247+
input=prompt,
248+
reasoning={"effort": "low"},
249+
tools=[{
250+
"type": "web_search_preview"
251+
}, {
252+
"type": "code_interpreter",
253+
"container": {
254+
"type": "auto"
255+
}
256+
}],
257+
stream=True,
258+
)
259+
260+
events = []
261+
current_event_mode = None
262+
263+
for event in response:
264+
if current_event_mode != event.type:
265+
current_event_mode = event.type
266+
print(f"\n[{event.type}] ", end="", flush=True)
267+
268+
if "text.delta" in event.type:
269+
print(event.delta, end="", flush=True)
270+
elif "reasoning_text.delta" in event.type:
271+
print(f"{event.delta}", end="", flush=True)
272+
elif "response.code_interpreter_call_code.done" in event.type:
273+
print(f"Code: {event.code}", end="", flush=True)
274+
elif ("response.output_item.added" in event.type
275+
and event.item.type == "web_search_call"):
276+
print(f"Web search: {event.item.action}", end="", flush=True)
277+
events.append(event)
238278

239-
for event in response:
240-
if "text.delta" in event.type:
241-
print(event.delta, end="", flush=True)
242-
print()
279+
print("\n--------------------------------\n")
243280

244281

245282
def test_web_search():
@@ -600,8 +637,8 @@ def test_function_calling_full_history():
600637
test_stateful_multi_turn()
601638

602639
# 3. Streaming tests:
603-
# print("===test_streaming:")
604-
# test_streaming()
640+
print("===test_streaming:")
641+
test_streaming()
605642

606643
# 4. Tool tests:
607644
print("===test_web_search:")

vllm/entrypoints/context.py

Lines changed: 68 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,11 @@
33
from abc import ABC, abstractmethod
44
from typing import TYPE_CHECKING, Optional
55

6-
from vllm.entrypoints.harmony_utils import (parse_output_into_messages,
7-
render_for_completion)
6+
from openai_harmony import Role, StreamState
7+
8+
from vllm.entrypoints.harmony_utils import (
9+
get_encoding, get_streamable_parser_for_assistant,
10+
parse_output_into_messages, render_for_completion)
811
from vllm.outputs import RequestOutput
912

1013
if TYPE_CHECKING:
@@ -51,7 +54,7 @@ def __init__(
5154
browser_tool,
5255
python_tool,
5356
):
54-
self.messages = messages
57+
self._messages = messages
5558
self.browser_tool = browser_tool
5659
self.python_tool = python_tool
5760

@@ -60,16 +63,20 @@ def __init__(
6063
self.num_prompt_tokens = 0
6164
self.num_cached_tokens = 0
6265
self.num_output_tokens = 0
66+
self.num_reasoning_tokens = 0
6367

6468
def append_output(self, output) -> None:
65-
# TODO: Support streaming.
6669
if isinstance(output, RequestOutput):
6770
output_token_ids = output.outputs[0].token_ids
6871
output_msgs = parse_output_into_messages(output_token_ids)
6972
else:
7073
# Tool output.
7174
output_msgs = output
72-
self.messages.extend(output_msgs)
75+
self._messages.extend(output_msgs)
76+
77+
@property
78+
def messages(self) -> list:
79+
return self._messages
7380

7481
def get_tool_call(self) -> Optional["Tool"]:
7582
last_msg = self.messages[-1]
@@ -83,3 +90,59 @@ def get_tool_call(self) -> Optional["Tool"]:
8390

8491
def render_for_completion(self) -> list[int]:
8592
return render_for_completion(self.messages)
93+
94+
95+
class StreamingHarmonyContext(HarmonyContext):
96+
97+
def __init__(self, *args, **kwargs):
98+
super().__init__(*args, **kwargs)
99+
self.last_output = None
100+
101+
self.parser = get_streamable_parser_for_assistant()
102+
self.encoding = get_encoding()
103+
self.last_tok = None
104+
105+
@property
106+
def messages(self) -> list:
107+
return self.parser.messages
108+
109+
def append_output(self, output) -> None:
110+
if isinstance(output, RequestOutput):
111+
tok = output.outputs[0].token_ids[0]
112+
self.parser.process(tok)
113+
self.last_tok = tok
114+
else:
115+
# Handle the case of tool output in direct message format
116+
assert len(output) == 1, "Tool output should be a single message"
117+
msg = output[0]
118+
# Sometimes the recipient is not set for tool messages,
119+
# so we set it to "assistant"
120+
if msg.author.role == Role.TOOL and msg.recipient is None:
121+
msg.recipient = "assistant"
122+
toks = self.encoding.render(msg)
123+
for tok in toks:
124+
self.parser.process(tok)
125+
self.last_tok = toks[-1]
126+
127+
def is_expecting_start(self) -> bool:
128+
return self.parser.state == StreamState.EXPECT_START
129+
130+
def is_assistant_action_turn(self) -> bool:
131+
return self.last_tok in self.encoding.stop_tokens_for_assistant_actions(
132+
)
133+
134+
def render_for_completion(self) -> list[int]:
135+
# now this list of tokens as next turn's starting tokens
136+
# `<|start|>assistant``,
137+
# we need to process them in parser.
138+
rendered_tokens = super().render_for_completion()
139+
140+
last_n = -1
141+
to_process = []
142+
while rendered_tokens[last_n] != self.last_tok:
143+
to_process.append(rendered_tokens[last_n])
144+
last_n -= 1
145+
for tok in reversed(to_process):
146+
self.parser.process(tok)
147+
148+
return rendered_tokens

vllm/entrypoints/harmony_utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -209,15 +209,16 @@ def parse_output_message(message: Message):
209209
raise ValueError("Invalid number of contents in browser message")
210210
content = message.content[0]
211211
browser_call = json.loads(content.text)
212+
# TODO: translate to url properly!
212213
if recipient == "browser.search":
213-
action = ActionSearch(query=browser_call["query"], type="search")
214+
action = ActionSearch(
215+
query=f"cursor:{browser_call.get('query', '')}", type="search")
214216
elif recipient == "browser.open":
215-
url = "" # FIXME: browser_call["url"]
216-
action = ActionOpenPage(url=url, type="open_page")
217+
action = ActionOpenPage(
218+
url=f"cursor:{browser_call.get('url', '')}", type="open_page")
217219
elif recipient == "browser.find":
218-
url = "" # FIXME: browser_call["url"]
219220
action = ActionFind(pattern=browser_call["pattern"],
220-
url=url,
221+
url=f"cursor:{browser_call.get('url', '')}",
221222
type="find")
222223
else:
223224
raise ValueError(f"Unknown browser action: {recipient}")

vllm/entrypoints/openai/protocol.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
Annotation as OpenAIAnnotation)
1919
# yapf: enable
2020
from openai.types.responses import (ResponseFunctionToolCall,
21+
ResponseFunctionToolCallOutputItem,
2122
ResponseInputItemParam, ResponseOutputItem,
2223
ResponsePrompt, ResponseStatus,
2324
ResponseTextConfig)
@@ -1731,6 +1732,60 @@ class ResponseReasoningItem(OpenAIBaseModel):
17311732
status: Optional[Literal["in_progress", "completed", "incomplete"]]
17321733

17331734

1735+
class InputTokensDetails(OpenAIBaseModel):
1736+
cached_tokens: int
1737+
1738+
1739+
class OutputTokensDetails(OpenAIBaseModel):
1740+
reasoning_tokens: int
1741+
1742+
1743+
class ResponseUsage(OpenAIBaseModel):
1744+
input_tokens: int
1745+
input_tokens_details: InputTokensDetails
1746+
output_tokens: int
1747+
output_tokens_details: OutputTokensDetails
1748+
total_tokens: int
1749+
1750+
1751+
class ResponseReasoningTextDeltaEvent(OpenAIBaseModel):
1752+
type: Literal[
1753+
"response.reasoning_text.delta"] = "response.reasoning_text.delta"
1754+
item_id: str = "item_1234"
1755+
output_index: int
1756+
content_index: int
1757+
delta: str
1758+
sequence_number: int = -1
1759+
1760+
1761+
class ResponseReasoningTextDoneEvent(OpenAIBaseModel):
1762+
type: Literal[
1763+
"response.reasoning_text.done"] = "response.reasoning_text.done"
1764+
item_id: str = "item_1234"
1765+
output_index: int
1766+
content_index: int
1767+
text: str
1768+
sequence_number: int = -1
1769+
1770+
1771+
class ResponseContentPartDoneEvent(OpenAIBaseModel):
1772+
type: Literal["response.content_part.done"] = "response.content_part.done"
1773+
item_id: str = "item_1234"
1774+
output_index: int
1775+
content_index: int
1776+
part: Union[ResponseOutputItem, ResponseReasoningItem]
1777+
sequence_number: int = -1
1778+
1779+
1780+
class ResponseOutputItemDoneEvent(OpenAIBaseModel):
1781+
type: Literal["response.output_item.done"] = "response.output_item.done"
1782+
item_id: str = "item_1234"
1783+
output_index: int
1784+
item: Union[ResponseOutputItem, ResponseReasoningItem,
1785+
ResponseFunctionToolCallOutputItem]
1786+
sequence_number: int = -1
1787+
1788+
17341789
class ResponsesResponse(OpenAIBaseModel):
17351790
id: str = Field(default_factory=lambda: f"resp_{random_uuid()}")
17361791
created_at: int = Field(default_factory=lambda: int(time.time()))
@@ -1757,7 +1812,7 @@ class ResponsesResponse(OpenAIBaseModel):
17571812
text: Optional[ResponseTextConfig] = None
17581813
top_logprobs: int
17591814
truncation: Literal["auto", "disabled"]
1760-
usage: Optional[UsageInfo] = None
1815+
usage: Optional[ResponseUsage] = None
17611816
user: Optional[str] = None
17621817

17631818
@classmethod

vllm/entrypoints/openai/serving_engine.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,9 @@ async def _generate_with_builtin_tools(
990990
tool_output = await tool.get_result(context)
991991
context.append_output(tool_output)
992992

993+
# TODO: uncomment this and enable tool output streaming
994+
# yield context
995+
993996
# Create inputs for the next turn.
994997
# Render the next prompt token ids.
995998
prompt_token_ids = context.render_for_completion()

0 commit comments

Comments
 (0)