Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 84 additions & 41 deletions dspy/streaming/streaming_listener.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from queue import Queue
from typing import TYPE_CHECKING, Any

import jiter
from litellm import ModelResponseStream

from dspy.adapters.chat_adapter import ChatAdapter
Expand Down Expand Up @@ -49,6 +50,11 @@ def __init__(
self.cache_hit = False
self.allow_reuse = allow_reuse

self.json_adapter_state = {
"field_accumulated_tokens": "",
"curly_bracket_diff": 0,
}

self.adapter_identifiers = {
"ChatAdapter": {
"start_identifier": f"[[ ## {self.signature_field_name} ## ]]",
Expand Down Expand Up @@ -91,6 +97,8 @@ def receive(self, chunk: ModelResponseStream):
self.cache_hit = False
self.field_start_queue = []
self.field_end_queue = Queue()
self.json_adapter_state["field_accumulated_tokens"] = ""
self.json_adapter_state["curly_bracket_diff"] = 0
self.stream_start = False
else:
return
Expand All @@ -112,7 +120,7 @@ def receive(self, chunk: ModelResponseStream):
is_last_chunk=self.stream_end,
)

if chunk_message and start_identifier in chunk_message:
if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter):
# If the cache is hit, the chunk_message could be the full response. When it happens we can
# directly end the stream listening. In some models like gemini, each stream chunk can be multiple
# tokens, so it's possible that response only has one chunk, we also fall back to this logic.
Expand Down Expand Up @@ -145,10 +153,16 @@ def receive(self, chunk: ModelResponseStream):
# Keep the part after the start_identifier from the concat_message, we need to write it to the buffer.
value_start_index = concat_message.find(start_identifier) + len(start_identifier)
chunk_message = concat_message[value_start_index:].lstrip()
if isinstance(settings.adapter, JSONAdapter) and chunk_message.startswith('"'):
# For JSONAdapter, we need to remove the leading ". We cannot do this with the start_identifier
# because there could be a few splitters between ':' and '"', e.g., '"name": "value"'.
chunk_message = chunk_message[1:]

if isinstance(settings.adapter, JSONAdapter):
# For JSONAdapter, we rely on partial json parsing to detect the end of the field we are listening
# to, so we need to maintain a few extra states to help us with that.
# 1. We add an extra "{" to the beginning of the field_accumulated_tokens, so we can detect the
# appearance of the next key.
# 2. We maintain a curly_bracket_diff to help us detect the balance of curly brackets, so we can
# detect when the streaming for the entire dspy.Predict is finished.
self.json_adapter_state["field_accumulated_tokens"] += "{" + start_identifier
self.json_adapter_state["curly_bracket_diff"] = 1

elif self._buffered_message_end_with_start_identifier(concat_message.strip(), start_identifier):
# If the buffered message ends with part of the start_identifier, we keep looking for the
Expand All @@ -161,28 +175,79 @@ def receive(self, chunk: ModelResponseStream):

if self.stream_start:
# The stream is started, we keep returning the token until we see the start of the next field.
token = None
self.field_end_queue.put(chunk_message)

token = None
if self.field_end_queue.qsize() > 10:
# We keep the last 10 tokens in the buffer to check if they form a valid identifier for end_identifier,
# i.e., "[[ ## {next_field_name} ## ]]" for ChatAdapter to identify the end of the current field.
# In most cases 10 tokens are enough to cover the end_identifier for all adapters.
token = self.field_end_queue.get()
concat_message = "".join(self.field_end_queue.queue).strip()
if re.search(end_identifier, concat_message):
# The next field is identified, we can end the stream and flush out all tokens in the buffer.

if isinstance(settings.adapter, JSONAdapter):
return self._json_adapter_handle_stream_chunk(token, chunk_message)
else:
return self._default_handle_stream_chunk(token, end_identifier)

def _json_adapter_handle_stream_chunk(self, token: str, chunk_message: str) -> str:
self.json_adapter_state["field_accumulated_tokens"] += chunk_message
self.json_adapter_state["curly_bracket_diff"] += chunk_message.count("{") - chunk_message.count("}")
if self.json_adapter_state["curly_bracket_diff"] == 0:
# We add an extra "{" to the beginning of the field_accumulated_tokens, so if we get a balance of curly
# brackets, that means the streaming for the entire dspy.Predict is finished.
self.stream_end = True
last_token = self.flush()
right_curly_bracket_index = last_token.rfind("}")
token = token + last_token[:right_curly_bracket_index] if token else last_token[:right_curly_bracket_index]

try:
parsed = jiter.from_json(
self.json_adapter_state["field_accumulated_tokens"].encode("utf-8"),
partial_mode="trailing-strings",
)
if len(parsed) > 1:
# If partial json parsing finds a second key, that means the streaming for the field we are listening to
# is finished.
self.stream_end = True
last_token = self.flush()
token = token + last_token if token else last_token
token = token.rstrip() # Remove the trailing \n\n

if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)
keys = list(parsed.keys())
next_field_name = None
for key in keys:
if key != self.signature_field_name:
next_field_name = key
break

last_token_index = last_token.find(next_field_name)
token = token + last_token[:last_token_index] if token else last_token[:last_token_index]
except ValueError:
pass

if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)

def _default_handle_stream_chunk(self, token: str, end_identifier: str) -> str:
concat_message = "".join(self.field_end_queue.queue).strip()

if re.search(end_identifier, concat_message):
# The next field is identified, we can end the stream and flush out all tokens in the buffer.
self.stream_end = True
last_token = self.flush()
token = token + last_token if token else last_token
token = token.rstrip() # Remove the trailing \n\n

if token:
return StreamResponse(
self.predict_name,
self.signature_field_name,
token,
is_last_chunk=self.stream_end,
)

def flush(self) -> str:
"""Flush all tokens in the field end queue.
Expand All @@ -194,12 +259,7 @@ def flush(self) -> str:
last_tokens = "".join(self.field_end_queue.queue)
self.field_end_queue = Queue()
if isinstance(settings.adapter, JSONAdapter):
match = re.search(r'",|"\s*}', last_tokens)
if match:
boundary_index = match.start()
else:
boundary_index = len(last_tokens)
return last_tokens[:boundary_index]
return last_tokens
elif isinstance(settings.adapter, XMLAdapter):
boundary_index = last_tokens.find(f"</{self.signature_field_name}>")
if boundary_index == -1:
Expand All @@ -222,7 +282,6 @@ def _output_type(self) -> type | None:
return None



def find_predictor_for_stream_listeners(program: "Module", stream_listeners: list[StreamListener]):
"""Find the predictor for each stream listener.

Expand All @@ -248,13 +307,6 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
f"Signature field {field_name} is not unique in the program, cannot automatically determine which "
"predictor to use for streaming. Please specify the predictor to listen to."
)

if not _is_streamable(field_info.annotation):
raise ValueError(
f"Stream listener can only be applied to string or subclass of `dspy.Type` that has `is_streamable() == True`, "
f"but your field {field_name} is of type {field_info.annotation}."
)

field_name_to_named_predictor[field_name] = (name, predictor)

predict_id_to_listener = defaultdict(list)
Expand All @@ -271,12 +323,3 @@ def find_predictor_for_stream_listeners(program: "Module", stream_listeners: lis
listener.predict_name, listener.predict = field_name_to_named_predictor[listener.signature_field_name]
predict_id_to_listener[id(listener.predict)].append(listener)
return predict_id_to_listener

def _is_streamable(field_type: type | None) -> bool:
if field_type is None:
return False
if field_type is str:
return True
if issubclass(field_type, Type):
return field_type.is_streamable()
return False
Loading