33from abc import ABC , abstractmethod
44from typing import TYPE_CHECKING , Optional
55
6- from vllm .entrypoints .harmony_utils import (parse_output_into_messages ,
7- render_for_completion )
6+ from openai_harmony import Role , StreamState
7+
8+ from vllm .entrypoints .harmony_utils import (
9+ get_encoding , get_streamable_parser_for_assistant ,
10+ parse_output_into_messages , render_for_completion )
811from vllm .outputs import RequestOutput
912
1013if TYPE_CHECKING :
@@ -51,7 +54,7 @@ def __init__(
5154 browser_tool ,
5255 python_tool ,
5356 ):
54- self .messages = messages
57+ self ._messages = messages
5558 self .browser_tool = browser_tool
5659 self .python_tool = python_tool
5760
@@ -60,16 +63,20 @@ def __init__(
6063 self .num_prompt_tokens = 0
6164 self .num_cached_tokens = 0
6265 self .num_output_tokens = 0
66+ self .num_reasoning_tokens = 0
6367
6468 def append_output (self , output ) -> None :
65- # TODO: Support streaming.
6669 if isinstance (output , RequestOutput ):
6770 output_token_ids = output .outputs [0 ].token_ids
6871 output_msgs = parse_output_into_messages (output_token_ids )
6972 else :
7073 # Tool output.
7174 output_msgs = output
72- self .messages .extend (output_msgs )
75+ self ._messages .extend (output_msgs )
76+
77+ @property
78+ def messages (self ) -> list :
79+ return self ._messages
7380
7481 def get_tool_call (self ) -> Optional ["Tool" ]:
7582 last_msg = self .messages [- 1 ]
@@ -83,3 +90,59 @@ def get_tool_call(self) -> Optional["Tool"]:
8390
8491 def render_for_completion (self ) -> list [int ]:
8592 return render_for_completion (self .messages )
93+
94+
95+ class StreamingHarmonyContext (HarmonyContext ):
96+
97+ def __init__ (self , * args , ** kwargs ):
98+ super ().__init__ (* args , ** kwargs )
99+ self .last_output = None
100+
101+ self .parser = get_streamable_parser_for_assistant ()
102+ self .encoding = get_encoding ()
103+ self .last_tok = None
104+
105+ @property
106+ def messages (self ) -> list :
107+ return self .parser .messages
108+
109+ def append_output (self , output ) -> None :
110+ if isinstance (output , RequestOutput ):
111+ tok = output .outputs [0 ].token_ids [0 ]
112+ self .parser .process (tok )
113+ self .last_tok = tok
114+ else :
115+ # Handle the case of tool output in direct message format
116+ assert len (output ) == 1 , "Tool output should be a single message"
117+ msg = output [0 ]
118+ # Sometimes the recipient is not set for tool messages,
119+ # so we set it to "assistant"
120+ if msg .author .role == Role .TOOL and msg .recipient is None :
121+ msg .recipient = "assistant"
122+ toks = self .encoding .render (msg )
123+ for tok in toks :
124+ self .parser .process (tok )
125+ self .last_tok = toks [- 1 ]
126+
127+ def is_expecting_start (self ) -> bool :
128+ return self .parser .state == StreamState .EXPECT_START
129+
130+ def is_assistant_action_turn (self ) -> bool :
131+ return self .last_tok in self .encoding .stop_tokens_for_assistant_actions (
132+ )
133+
134+ def render_for_completion (self ) -> list [int ]:
135+ # now this list of tokens as next turn's starting tokens
136+ # `<|start|>assistant``,
137+ # we need to process them in parser.
138+ rendered_tokens = super ().render_for_completion ()
139+
140+ last_n = - 1
141+ to_process = []
142+ while rendered_tokens [last_n ] != self .last_tok :
143+ to_process .append (rendered_tokens [last_n ])
144+ last_n -= 1
145+ for tok in reversed (to_process ):
146+ self .parser .process (tok )
147+
148+ return rendered_tokens
0 commit comments