Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 40 additions & 5 deletions pydantic_ai_slim/pydantic_ai/_parts_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@

from collections.abc import Hashable
from dataclasses import dataclass, field, replace
from typing import Any, Union
from typing import Any, Literal, Union, overload

from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG
from pydantic_ai.exceptions import UnexpectedModelBehavior
from pydantic_ai.messages import (
ModelResponsePart,
Expand Down Expand Up @@ -66,12 +67,30 @@ def get_parts(self) -> list[ModelResponsePart]:
"""
return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)]

@overload
def handle_text_delta(
self,
*,
vendor_part_id: Hashable | None,
vendor_part_id: VendorId | None,
content: str,
) -> ModelResponseStreamEvent:
) -> ModelResponseStreamEvent: ...

@overload
def handle_text_delta(
self,
*,
vendor_part_id: VendorId,
content: str,
extract_think_tags: Literal[True],
) -> ModelResponseStreamEvent | None: ...

def handle_text_delta(
self,
*,
vendor_part_id: VendorId | None,
content: str,
extract_think_tags: bool = False,
) -> ModelResponseStreamEvent | None:
"""Handle incoming text content, creating or updating a TextPart in the manager as appropriate.

When `vendor_part_id` is None, the latest part is updated if it exists and is a TextPart;
Expand All @@ -83,6 +102,7 @@ def handle_text_delta(
of text. If None, a new part will be created unless the latest part is already
a TextPart.
content: The text content to append to the appropriate TextPart.
extract_think_tags: Whether to extract `<think>` tags from the text content and handle them as thinking parts.

Returns:
A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated.
Expand All @@ -104,9 +124,24 @@ def handle_text_delta(
part_index = self._vendor_id_to_part_index.get(vendor_part_id)
if part_index is not None:
existing_part = self._parts[part_index]
if not isinstance(existing_part, TextPart):

if extract_think_tags and isinstance(existing_part, ThinkingPart):
# We may be building a thinking part instead of a text part if we had previously seen a `<think>` tag
if content == END_THINK_TAG:
# When we see `</think>`, we're done with the thinking part and the next text delta will need a new part
self._vendor_id_to_part_index.pop(vendor_part_id)
return None
else:
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content=content)
elif isinstance(existing_part, TextPart):
existing_text_part_and_index = existing_part, part_index
else:
raise UnexpectedModelBehavior(f'Cannot apply a text delta to {existing_part=}')
existing_text_part_and_index = existing_part, part_index

if extract_think_tags and content == START_THINK_TAG:
# When we see a `<think>` tag (which is a single token), we'll build a new thinking part instead
self._vendor_id_to_part_index.pop(vendor_part_id, None)
return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='')

if existing_text_part_and_index is None:
# There is no existing text part that should be updated, so create a new one
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content is not None:
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
)
if maybe_event is not None: # pragma: no branch
yield maybe_event

# Handle the tool calls
for dtc in choice.delta.tool_calls or []:
Expand Down Expand Up @@ -444,7 +448,7 @@ def _map_usage(completion: chat.ChatCompletionChunk | chat.ChatCompletion) -> us
if isinstance(completion, chat.ChatCompletion):
response_usage = completion.usage
elif completion.x_groq is not None:
response_usage = completion.x_groq.usage # pragma: no cover
response_usage = completion.x_groq.usage

if response_usage is None:
return usage.Usage()
Expand Down
8 changes: 6 additions & 2 deletions pydantic_ai_slim/pydantic_ai/models/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:

# Handle the text part of the response
content = choice.delta.content
if content is not None:
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
if content:
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
)
if maybe_event is not None: # pragma: no branch
yield maybe_event

for dtc in choice.delta.tool_calls or []:
maybe_event = self._parts_manager.handle_tool_call_delta(
Expand Down
6 changes: 5 additions & 1 deletion pydantic_ai_slim/pydantic_ai/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]:
# Handle the text part of the response
content = choice.delta.content
if content:
yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content)
maybe_event = self._parts_manager.handle_text_delta(
vendor_part_id='content', content=content, extract_think_tags=True
)
if maybe_event is not None: # pragma: no branch
yield maybe_event

# Handle reasoning part of the response, present in DeepSeek models
if reasoning_content := getattr(choice.delta, 'reasoning_content', None):
Expand Down
2 changes: 1 addition & 1 deletion pydantic_ai_slim/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ anthropic = ["anthropic>=0.52.0"]
groq = ["groq>=0.19.0"]
mistral = ["mistralai>=1.9.2"]
bedrock = ["boto3>=1.37.24"]
huggingface = ["huggingface-hub[inference]>=0.33.2"]
huggingface = ["huggingface-hub[inference]>=0.33.5"]
# Tools
duckduckgo = ["ddgs>=9.0.0"]
tavily = ["tavily-python>=0.5.0"]
Expand Down
Loading