diff --git a/temporalio/contrib/openai_agents/_temporal_model_stub.py b/temporalio/contrib/openai_agents/_temporal_model_stub.py index b053c259b..e4b7664ca 100644 --- a/temporalio/contrib/openai_agents/_temporal_model_stub.py +++ b/temporalio/contrib/openai_agents/_temporal_model_stub.py @@ -8,7 +8,7 @@ logger = logging.getLogger(__name__) -from typing import Any, AsyncIterator, Sequence, Union, cast +from typing import Any, AsyncIterator, Union, cast from agents import ( AgentOutputSchema, @@ -54,7 +54,7 @@ def __init__( async def get_response( self, system_instructions: Optional[str], - input: Union[str, list[TResponseInputItem], dict[str, str]], + input: Union[str, list[TResponseInputItem]], model_settings: ModelSettings, tools: list[Tool], output_schema: Optional[AgentOutputSchemaBase], @@ -64,28 +64,6 @@ async def get_response( previous_response_id: Optional[str], prompt: Optional[ResponsePromptParam], ) -> ModelResponse: - def get_summary( - input: Union[str, list[TResponseInputItem], dict[str, str]], - ) -> str: - ### Activity summary shown in the UI - try: - max_size = 100 - if isinstance(input, str): - return input[:max_size] - elif isinstance(input, list): - seq_input = cast(Sequence[Any], input) - last_item = seq_input[-1] - if isinstance(last_item, dict): - return last_item.get("content", "")[:max_size] - elif hasattr(last_item, "content"): - return str(getattr(last_item, "content"))[:max_size] - return str(last_item)[:max_size] - elif isinstance(input, dict): - return input.get("content", "")[:max_size] - except Exception as e: - logger.error(f"Error getting summary: {e}") - return "" - def make_tool_info(tool: Tool) -> ToolInput: if isinstance(tool, (FileSearchTool, WebSearchTool)): return tool @@ -150,7 +128,7 @@ def make_tool_info(tool: Tool) -> ToolInput: return await workflow.execute_activity_method( ModelActivity.invoke_model_activity, activity_input, - summary=self.model_params.summary_override or get_summary(input), + summary=self.model_params.summary_override or _extract_summary(input), task_queue=self.model_params.task_queue, schedule_to_close_timeout=self.model_params.schedule_to_close_timeout, schedule_to_start_timeout=self.model_params.schedule_to_start_timeout, @@ -176,3 +154,34 @@ def stream_response( prompt: ResponsePromptParam | None, ) -> AsyncIterator[TResponseStreamEvent]: raise NotImplementedError("Temporal model doesn't support streams yet") + + +def _extract_summary(input: Union[str, list[TResponseInputItem]]) -> str: + ### Activity summary shown in the UI + try: + max_size = 100 + if isinstance(input, str): + return input[:max_size] + elif isinstance(input, list): + # Find all message inputs, which are reasonably summarizable + messages: list[TResponseInputItem] = [ + item for item in input if item.get("type", "message") == "message" + ] + if not messages: + return "" + + content: Any = messages[-1].get("content", "") + + # In the case of multiple contents, take the last one + if isinstance(content, list): + if not content: + return "" + content = content[-1] + + # Take the text field from the content if present + if isinstance(content, dict) and content.get("text") is not None: + content = content.get("text") + return str(content)[:max_size] + except Exception as e: + logger.error(f"Error getting summary: {e}") + return "" diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index 35d7c0311..d9f27e679 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -134,7 +134,7 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any: cancellation_type=cancellation_type, activity_id=activity_id, versioning_intent=versioning_intent, - summary=summary, + summary=summary or schema.description, priority=priority, ) try: diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 31cb5eda9..7147d3ca6 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -44,12 +44,16 @@ ) from openai import APIStatusError, AsyncOpenAI, BaseModel from openai.types.responses import ( + EasyInputMessageParam, ResponseFunctionToolCall, + ResponseFunctionToolCallParam, ResponseFunctionWebSearch, + ResponseInputTextParam, ResponseOutputMessage, ResponseOutputText, ) from openai.types.responses.response_function_web_search import ActionSearch +from openai.types.responses.response_input_item_param import Message from openai.types.responses.response_prompt_param import ResponsePromptParam from pydantic import ConfigDict, Field, TypeAdapter @@ -63,6 +67,7 @@ TestModel, TestModelProvider, ) +from temporalio.contrib.openai_agents._temporal_model_stub import _extract_summary from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.exceptions import ApplicationError, CancelledError from temporalio.testing import WorkflowEnvironment @@ -680,7 +685,8 @@ async def test_research_workflow(client: Client, use_local_model: bool): new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( - start_to_close_timeout=timedelta(seconds=30) + start_to_close_timeout=timedelta(seconds=120), + schedule_to_close_timeout=timedelta(seconds=120), ), model_provider=TestModelProvider(TestResearchModel()) if use_local_model @@ -1687,7 +1693,7 @@ class WorkflowToolModel(StaticTestModel): id="", content=[ ResponseOutputText( - text="", + text="Workflow tool was used", annotations=[], type="output_text", ) @@ -1938,3 +1944,37 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment): execution_timeout=timedelta(seconds=5.0), ) await workflow_handle.result() + + +def test_summary_extraction(): + input: list[TResponseInputItem] = [ + EasyInputMessageParam( + content="First message", + role="user", + ) + ] + + assert _extract_summary(input) == "First message" + + input.append( + Message( + content=[ + ResponseInputTextParam( + text="Second message", + type="input_text", + ) + ], + role="user", + ) + ) + assert _extract_summary(input) == "Second message" + + input.append( + ResponseFunctionToolCallParam( + arguments="", + call_id="", + name="", + type="function_call", + ) + ) + assert _extract_summary(input) == "Second message"