diff --git a/temporalio/contrib/openai_agents/__init__.py b/temporalio/contrib/openai_agents/__init__.py index 274f5b98b..998cf61eb 100644 --- a/temporalio/contrib/openai_agents/__init__.py +++ b/temporalio/contrib/openai_agents/__init__.py @@ -17,10 +17,12 @@ from temporalio.contrib.openai_agents._trace_interceptor import ( OpenAIAgentsTracingInterceptor, ) +from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError from . import workflow __all__ = [ + "AgentsWorkflowError", "OpenAIAgentsPlugin", "ModelActivityParameters", "workflow", diff --git a/temporalio/contrib/openai_agents/_openai_runner.py b/temporalio/contrib/openai_agents/_openai_runner.py index 396d74546..4f9dbbf65 100644 --- a/temporalio/contrib/openai_agents/_openai_runner.py +++ b/temporalio/contrib/openai_agents/_openai_runner.py @@ -5,6 +5,7 @@ from agents import ( Agent, + AgentsException, Handoff, RunConfig, RunContextWrapper, @@ -21,6 +22,7 @@ from temporalio import workflow from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub +from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError class TemporalOpenAIRunner(AgentRunner): @@ -136,16 +138,28 @@ async def on_invoke( handoffs=new_handoffs, ) - return await self._runner.run( - starting_agent=convert_agent(starting_agent, None), - input=input, - context=context, - max_turns=max_turns, - hooks=hooks, - run_config=run_config, - previous_response_id=previous_response_id, - session=session, - ) + try: + return await self._runner.run( + starting_agent=convert_agent(starting_agent, None), + input=input, + context=context, + max_turns=max_turns, + hooks=hooks, + run_config=run_config, + previous_response_id=previous_response_id, + session=session, + ) + except AgentsException as e: + # In order for workflow failures to properly fail the workflow, we need to rewrap them in + # a Temporal error + if e.__cause__ and workflow.is_failure_exception(e.__cause__): + reraise = AgentsWorkflowError( + f"Workflow failure exception in Agents Framework: {e}" + ) + reraise.__traceback__ = e.__traceback__ + raise reraise from e.__cause__ + else: + raise e def run_sync( self, diff --git a/temporalio/contrib/openai_agents/_temporal_openai_agents.py b/temporalio/contrib/openai_agents/_temporal_openai_agents.py index 73b9723d0..0c698aa98 100644 --- a/temporalio/contrib/openai_agents/_temporal_openai_agents.py +++ b/temporalio/contrib/openai_agents/_temporal_openai_agents.py @@ -24,16 +24,19 @@ import temporalio.client import temporalio.worker -from temporalio.client import ClientConfig, Plugin +from temporalio.client import ClientConfig from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters -from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner +from temporalio.contrib.openai_agents._openai_runner import ( + TemporalOpenAIRunner, +) from temporalio.contrib.openai_agents._temporal_trace_provider import ( TemporalTraceProvider, ) from temporalio.contrib.openai_agents._trace_interceptor import ( OpenAIAgentsTracingInterceptor, ) +from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError from temporalio.contrib.pydantic import ( PydanticPayloadConverter, ToJsonOptions, @@ -284,6 +287,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig: config["activities"] = list(config.get("activities") or []) + [ ModelActivity(self._model_provider).invoke_model_activity ] + config["workflow_failure_exception_types"] = list( + config.get("workflow_failure_exception_types") or [] + ) + [AgentsWorkflowError] return self.next_worker_plugin.configure_worker(config) async def run_worker(self, worker: Worker) -> None: diff --git a/temporalio/contrib/openai_agents/workflow.py b/temporalio/contrib/openai_agents/workflow.py index d9f27e679..2f69866ce 100644 --- a/temporalio/contrib/openai_agents/workflow.py +++ b/temporalio/contrib/openai_agents/workflow.py @@ -263,3 +263,12 @@ class ToolSerializationError(TemporalError): To fix this error, ensure your tool returns string-convertible values or modify the tool to return a string representation of the result. """ + + +class AgentsWorkflowError(TemporalError): + """Error that occurs when the agents SDK raises an error which should terminate the calling workflow or update. + + .. warning:: + This exception is experimental and may change in future versions. + Use with caution in production environments. + """ diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index c93155672..f0984cc84 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -414,7 +414,7 @@ def activate( # We want some errors during activation, like those that can happen # during payload conversion, to be able to fail the workflow not the # task - if self._is_workflow_failure_exception(err): + if self.workflow_is_failure_exception(err): try: self._set_workflow_failure(err) except Exception as inner_err: @@ -629,7 +629,7 @@ async def run_update() -> None: # Validation failures are always update failures. We reuse # workflow failure logic to decide task failure vs update # failure after validation. - if not past_validation or self._is_workflow_failure_exception(err): + if not past_validation or self.workflow_is_failure_exception(err): if command is None: command = self._add_command() command.update_response.protocol_instance_id = ( @@ -1686,6 +1686,23 @@ def workflow_set_current_details(self, details: str): self._assert_not_read_only("set current details") self._current_details = details + def workflow_is_failure_exception(self, err: BaseException) -> bool: + # An exception is a failure instead of a task fail if it's already a + # failure error or if it is a timeout error or if it is an instance of + # any of the failure types in the worker or workflow-level setting + wf_failure_exception_types = self._defn.failure_exception_types + if self._dynamic_failure_exception_types is not None: + wf_failure_exception_types = self._dynamic_failure_exception_types + return ( + isinstance(err, temporalio.exceptions.FailureError) + or isinstance(err, asyncio.TimeoutError) + or any(isinstance(err, typ) for typ in wf_failure_exception_types) + or any( + isinstance(err, typ) + for typ in self._worker_level_failure_exception_types + ) + ) + #### Calls from outbound impl #### # These are in alphabetical order and all start with "_outbound_". @@ -1939,7 +1956,7 @@ def _convert_payloads( # Don't wrap payload conversion errors that would fail the workflow raise except Exception as err: - if self._is_workflow_failure_exception(err): + if self.workflow_is_failure_exception(err): raise raise RuntimeError("Failed decoding arguments") from err @@ -1982,23 +1999,6 @@ def _instantiate_workflow_object(self) -> Any: return workflow_instance - def _is_workflow_failure_exception(self, err: BaseException) -> bool: - # An exception is a failure instead of a task fail if it's already a - # failure error or if it is a timeout error or if it is an instance of - # any of the failure types in the worker or workflow-level setting - wf_failure_exception_types = self._defn.failure_exception_types - if self._dynamic_failure_exception_types is not None: - wf_failure_exception_types = self._dynamic_failure_exception_types - return ( - isinstance(err, temporalio.exceptions.FailureError) - or isinstance(err, asyncio.TimeoutError) - or any(isinstance(err, typ) for typ in wf_failure_exception_types) - or any( - isinstance(err, typ) - for typ in self._worker_level_failure_exception_types - ) - ) - def _warn_if_unfinished_handlers(self) -> None: def warnable(handler_executions: Iterable[HandlerExecution]): return [ @@ -2192,7 +2192,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None: err ): self._add_command().cancel_workflow_execution.SetInParent() - elif self._is_workflow_failure_exception(err): + elif self.workflow_is_failure_exception(err): # All other failure errors fail the workflow self._set_workflow_failure(err) else: diff --git a/temporalio/workflow.py b/temporalio/workflow.py index 423d5289b..50118a2bb 100644 --- a/temporalio/workflow.py +++ b/temporalio/workflow.py @@ -897,6 +897,9 @@ def workflow_get_current_details(self) -> str: ... @abstractmethod def workflow_set_current_details(self, details: str): ... + @abstractmethod + def workflow_is_failure_exception(self, err: BaseException) -> bool: ... + _current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar( "__temporal_current_update_info" @@ -981,6 +984,15 @@ def memo() -> Mapping[str, Any]: return _Runtime.current().workflow_memo() +def is_failure_exception(err: BaseException) -> bool: + """Checks if the given exception is a workflow failure in the current workflow. + + Returns: + True if the given exception is a workflow failure in the current workflow. + """ + return _Runtime.current().workflow_is_failure_exception(err) + + @overload def memo_value(key: str, default: Any = temporalio.common._arg_unset) -> Any: ... diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 7c3df0897..11613296b 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -318,6 +318,10 @@ async def run(self, question: str) -> str: ActivityWeatherService.get_weather_method, start_to_close_timeout=timedelta(seconds=10), ), + openai_agents.workflow.activity_as_tool( + get_weather_failure, + start_to_close_timeout=timedelta(seconds=10), + ), ], ) result = await Runner.run( @@ -462,6 +466,53 @@ async def test_tool_workflow(client: Client, use_local_model: bool): ) +@activity.defn +async def get_weather_failure(city: str) -> Weather: + """ + Get the weather for a given city. + """ + raise ApplicationError("No weather", non_retryable=True) + + +class TestWeatherFailureModel(StaticTestModel): + responses = [ + ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"), + ] + + +async def test_tool_failure_workflow(client: Client): + new_config = client.config() + new_config["plugins"] = [ + openai_agents.OpenAIAgentsPlugin( + model_params=ModelActivityParameters( + start_to_close_timeout=timedelta(seconds=30) + ), + model_provider=TestModelProvider(TestWeatherFailureModel()), + ) + ] + client = Client(**new_config) + + async with new_worker( + client, + ToolsWorkflow, + activities=[ + get_weather_failure, + ], + ) as worker: + workflow_handle = await client.start_workflow( + ToolsWorkflow.run, + "What is the weather in Tokio?", + id=f"tools-failure-workflow-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=2), + ) + with pytest.raises(WorkflowFailureError) as e: + result = await workflow_handle.result() + cause = e.value.cause + assert isinstance(cause, ApplicationError) + assert "Workflow failure exception in Agents Framework" in cause.message + + @pytest.mark.parametrize("use_local_model", [True, False]) async def test_nexus_tool_workflow( client: Client, env: WorkflowEnvironment, use_local_model: bool @@ -1909,20 +1960,14 @@ async def run(self, question: str) -> str: return result.final_output -@pytest.mark.parametrize("use_local_model", [True, False]) -async def test_code_interpreter_tool(client: Client, use_local_model): - if not use_local_model and not os.environ.get("OPENAI_API_KEY"): - pytest.skip("No openai API key") - +async def test_code_interpreter_tool(client: Client): new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=60) ), - model_provider=TestModelProvider(CodeInterpreterModel()) - if use_local_model - else None, + model_provider=TestModelProvider(CodeInterpreterModel()), ) ] client = Client(**new_config) @@ -1939,8 +1984,7 @@ async def test_code_interpreter_tool(client: Client, use_local_model): execution_timeout=timedelta(seconds=60), ) result = await workflow_handle.result() - if use_local_model: - assert result == "Over 9000" + assert result == "Over 9000" class HostedMCPModel(StaticTestModel): @@ -2011,20 +2055,14 @@ def approve(_: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult: return result.final_output -@pytest.mark.parametrize("use_local_model", [True, False]) -async def test_hosted_mcp_tool(client: Client, use_local_model): - if not use_local_model and not os.environ.get("OPENAI_API_KEY"): - pytest.skip("No openai API key") - +async def test_hosted_mcp_tool(client: Client): new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( start_to_close_timeout=timedelta(seconds=120) ), - model_provider=TestModelProvider(HostedMCPModel()) - if use_local_model - else None, + model_provider=TestModelProvider(HostedMCPModel()), ) ] client = Client(**new_config) @@ -2041,8 +2079,7 @@ async def test_hosted_mcp_tool(client: Client, use_local_model): execution_timeout=timedelta(seconds=120), ) result = await workflow_handle.result() - if use_local_model: - assert result == "Some language" + assert result == "Some language" class AssertDifferentModelProvider(ModelProvider):