Skip to content
59 changes: 34 additions & 25 deletions temporalio/contrib/openai_agents/_temporal_model_stub.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

logger = logging.getLogger(__name__)

from typing import Any, AsyncIterator, Sequence, Union, cast
from typing import Any, AsyncIterator, Union, cast

from agents import (
AgentOutputSchema,
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
async def get_response(
self,
system_instructions: Optional[str],
input: Union[str, list[TResponseInputItem], dict[str, str]],
input: Union[str, list[TResponseInputItem]],
model_settings: ModelSettings,
tools: list[Tool],
output_schema: Optional[AgentOutputSchemaBase],
Expand All @@ -64,28 +64,6 @@ async def get_response(
previous_response_id: Optional[str],
prompt: Optional[ResponsePromptParam],
) -> ModelResponse:
def get_summary(
input: Union[str, list[TResponseInputItem], dict[str, str]],
) -> str:
### Activity summary shown in the UI
try:
max_size = 100
if isinstance(input, str):
return input[:max_size]
elif isinstance(input, list):
seq_input = cast(Sequence[Any], input)
last_item = seq_input[-1]
if isinstance(last_item, dict):
return last_item.get("content", "")[:max_size]
elif hasattr(last_item, "content"):
return str(getattr(last_item, "content"))[:max_size]
return str(last_item)[:max_size]
elif isinstance(input, dict):
return input.get("content", "")[:max_size]
except Exception as e:
logger.error(f"Error getting summary: {e}")
return ""

def make_tool_info(tool: Tool) -> ToolInput:
if isinstance(tool, (FileSearchTool, WebSearchTool)):
return tool
Expand Down Expand Up @@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
return await workflow.execute_activity_method(
ModelActivity.invoke_model_activity,
activity_input,
summary=self.model_params.summary_override or get_summary(input),
summary=self.model_params.summary_override or _extract_summary(input),
task_queue=self.model_params.task_queue,
schedule_to_close_timeout=self.model_params.schedule_to_close_timeout,
schedule_to_start_timeout=self.model_params.schedule_to_start_timeout,
Expand All @@ -176,3 +154,34 @@ def stream_response(
prompt: ResponsePromptParam | None,
) -> AsyncIterator[TResponseStreamEvent]:
raise NotImplementedError("Temporal model doesn't support streams yet")


def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str:
### Activity summary shown in the UI
try:
max_size = 100
if isinstance(input, str):
return input[:max_size]
elif isinstance(input, list):
# Find all message inputs, which are reasonably summarizable
messages: list[TResponseInputItem] = [
item for item in input if item.get("type", "message") == "message"
]
if not messages:
return ""

content: Any = messages[-1].get("content", "")

# In the case of multiple contents, take the last one
if isinstance(content, list):
if not content:
return ""
content = content[-1]

# Take the text field from the content if present
if isinstance(content, dict) and content.get("text") is not None:
content = content.get("text")
return str(content)[:max_size]
except Exception as e:
logger.error(f"Error getting summary: {e}")
return ""
2 changes: 1 addition & 1 deletion temporalio/contrib/openai_agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
cancellation_type=cancellation_type,
activity_id=activity_id,
versioning_intent=versioning_intent,
summary=summary,
summary=summary or schema.description,
priority=priority,
)
try:
Expand Down
44 changes: 42 additions & 2 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,16 @@
)
from openai import APIStatusError, AsyncOpenAI, BaseModel
from openai.types.responses import (
EasyInputMessageParam,
ResponseFunctionToolCall,
ResponseFunctionToolCallParam,
ResponseFunctionWebSearch,
ResponseInputTextParam,
ResponseOutputMessage,
ResponseOutputText,
)
from openai.types.responses.response_function_web_search import ActionSearch
from openai.types.responses.response_input_item_param import Message
from openai.types.responses.response_prompt_param import ResponsePromptParam
from pydantic import ConfigDict, Field, TypeAdapter

Expand All @@ -63,6 +67,7 @@
TestModel,
TestModelProvider,
)
from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary
from temporalio.contrib.pydantic import pydantic_data_converter
from temporalio.exceptions import ApplicationError, CancelledError
from temporalio.testing import WorkflowEnvironment
Expand Down Expand Up @@ -680,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool):
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=30)
start_to_close_timeout=timedelta(seconds=120),
schedule_to_close_timeout=timedelta(seconds=120),
),
model_provider=TestModelProvider(TestResearchModel())
if use_local_model
Expand Down Expand Up @@ -1687,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel):
id="",
content=[
ResponseOutputText(
text="",
text="Workflow tool was used",
annotations=[],
type="output_text",
)
Expand Down Expand Up @@ -1938,3 +1944,37 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment):
execution_timeout=timedelta(seconds=5.0),
)
await workflow_handle.result()


def test_summary_extraction():
input: list[TResponseInputItem] = [
EasyInputMessageParam(
content="First message",
role="user",
)
]

assert _extract_summary(input) == "First message"

input.append(
Message(
content=[
ResponseInputTextParam(
text="Second message",
type="input_text",
)
],
role="user",
)
)
assert _extract_summary(input) == "Second message"

input.append(
ResponseFunctionToolCallParam(
arguments="",
call_id="",
name="",
type="function_call",
)
)
assert _extract_summary(input) == "Second message"