11from __future__ import annotations
22
33import logging
4- from datetime import timedelta
54from typing import Optional
65
76from temporalio import workflow
8- from temporalio .common import Priority , RetryPolicy
97from temporalio .contrib .openai_agents ._model_parameters import ModelActivityParameters
10- from temporalio .workflow import ActivityCancellationType , VersioningIntent
118
129logger = logging .getLogger (__name__ )
1310
14- from typing import Any , AsyncIterator , Optional , Sequence , Union , cast
11+ from typing import Any , AsyncIterator , Sequence , Union , cast
1512
1613from agents import (
1714 AgentOutputSchema ,
@@ -57,7 +54,7 @@ def __init__(
5754 async def get_response (
5855 self ,
5956 system_instructions : Optional [str ],
60- input : Union [str , list [TResponseInputItem ]],
57+ input : Union [str , list [TResponseInputItem ], dict [ str , str ] ],
6158 model_settings : ModelSettings ,
6259 tools : list [Tool ],
6360 output_schema : Optional [AgentOutputSchemaBase ],
@@ -67,7 +64,9 @@ async def get_response(
6764 previous_response_id : Optional [str ],
6865 prompt : Optional [ResponsePromptParam ],
6966 ) -> ModelResponse :
70- def get_summary (input : Union [str , list [TResponseInputItem ]]) -> str :
67+ def get_summary (
68+ input : Union [str , list [TResponseInputItem ], dict [str , str ]],
69+ ) -> str :
7170 ### Activity summary shown in the UI
7271 try :
7372 max_size = 100
@@ -88,21 +87,18 @@ def get_summary(input: Union[str, list[TResponseInputItem]]) -> str:
8887 return ""
8988
9089 def make_tool_info (tool : Tool ) -> ToolInput :
91- if isinstance (tool , FileSearchTool ):
92- return cast (FileSearchTool , tool )
93- elif isinstance (tool , WebSearchTool ):
94- return cast (WebSearchTool , tool )
90+ if isinstance (tool , (FileSearchTool , WebSearchTool )):
91+ return tool
9592 elif isinstance (tool , ComputerTool ):
9693 raise NotImplementedError (
9794 "Computer search preview is not supported in Temporal model"
9895 )
9996 elif isinstance (tool , FunctionTool ):
100- t = cast (FunctionToolInput , tool )
10197 return FunctionToolInput (
102- name = t .name ,
103- description = t .description ,
104- params_json_schema = t .params_json_schema ,
105- strict_json_schema = t .strict_json_schema ,
98+ name = tool .name ,
99+ description = tool .description ,
100+ params_json_schema = tool .params_json_schema ,
101+ strict_json_schema = tool .strict_json_schema ,
106102 )
107103 else :
108104 raise ValueError (f"Unknown tool type: { tool .name } " )
@@ -141,7 +137,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
141137 activity_input = ActivityModelInput (
142138 model_name = self .model_name ,
143139 system_instructions = system_instructions ,
144- input = input ,
140+ input = cast ( Union [ str , list [ TResponseInputItem ]], input ) ,
145141 model_settings = model_settings ,
146142 tools = tool_infos ,
147143 output_schema = output_schema_input ,
@@ -169,7 +165,7 @@ def make_tool_info(tool: Tool) -> ToolInput:
169165 def stream_response (
170166 self ,
171167 system_instructions : Optional [str ],
172- input : Union [str , list ] [TResponseInputItem ], # type: ignore
168+ input : Union [str , list [TResponseInputItem ]],
173169 model_settings : ModelSettings ,
174170 tools : list [Tool ],
175171 output_schema : Optional [AgentOutputSchemaBase ],
0 commit comments