Skip to content
2 changes: 2 additions & 0 deletions temporalio/contrib/openai_agents/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError

from . import workflow

__all__ = [
"AgentsWorkflowError",
"OpenAIAgentsPlugin",
"ModelActivityParameters",
"workflow",
Expand Down
34 changes: 24 additions & 10 deletions temporalio/contrib/openai_agents/_openai_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from agents import (
Agent,
AgentsException,
Handoff,
RunConfig,
RunContextWrapper,
Expand All @@ -21,6 +22,7 @@
from temporalio import workflow
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._temporal_model_stub import _TemporalModelStub
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError


class TemporalOpenAIRunner(AgentRunner):
Expand Down Expand Up @@ -136,16 +138,28 @@ async def on_invoke(
handoffs=new_handoffs,
)

return await self._runner.run(
starting_agent=convert_agent(starting_agent, None),
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
session=session,
)
try:
return await self._runner.run(
starting_agent=convert_agent(starting_agent, None),
input=input,
context=context,
max_turns=max_turns,
hooks=hooks,
run_config=run_config,
previous_response_id=previous_response_id,
session=session,
)
except AgentsException as e:
# In order for workflow failures to properly fail the workflow, we need to rewrap them in
# a Temporal error
if e.__cause__ and workflow.is_failure_exception(e.__cause__):
reraise = AgentsWorkflowError(
f"Workflow failure exception in Agents Framework: {e}"
)
reraise.__traceback__ = e.__traceback__
raise reraise from e.__cause__
else:
raise e

def run_sync(
self,
Expand Down
10 changes: 8 additions & 2 deletions temporalio/contrib/openai_agents/_temporal_openai_agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,19 @@

import temporalio.client
import temporalio.worker
from temporalio.client import ClientConfig, Plugin
from temporalio.client import ClientConfig
from temporalio.contrib.openai_agents._invoke_model_activity import ModelActivity
from temporalio.contrib.openai_agents._model_parameters import ModelActivityParameters
from temporalio.contrib.openai_agents._openai_runner import TemporalOpenAIRunner
from temporalio.contrib.openai_agents._openai_runner import (
TemporalOpenAIRunner,
)
from temporalio.contrib.openai_agents._temporal_trace_provider import (
TemporalTraceProvider,
)
from temporalio.contrib.openai_agents._trace_interceptor import (
OpenAIAgentsTracingInterceptor,
)
from temporalio.contrib.openai_agents.workflow import AgentsWorkflowError
from temporalio.contrib.pydantic import (
PydanticPayloadConverter,
ToJsonOptions,
Expand Down Expand Up @@ -284,6 +287,9 @@ def configure_worker(self, config: WorkerConfig) -> WorkerConfig:
config["activities"] = list(config.get("activities") or []) + [
ModelActivity(self._model_provider).invoke_model_activity
]
config["workflow_failure_exception_types"] = list(
config.get("workflow_failure_exception_types") or []
) + [AgentsWorkflowError]
return self.next_worker_plugin.configure_worker(config)

async def run_worker(self, worker: Worker) -> None:
Expand Down
9 changes: 9 additions & 0 deletions temporalio/contrib/openai_agents/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,12 @@ class ToolSerializationError(TemporalError):
To fix this error, ensure your tool returns string-convertible values or
modify the tool to return a string representation of the result.
"""


class AgentsWorkflowError(TemporalError):
"""Error that occurs when the agents SDK raises an error which should terminate the calling workflow or update.

.. warning::
This exception is experimental and may change in future versions.
Use with caution in production environments.
"""
42 changes: 21 additions & 21 deletions temporalio/worker/_workflow_instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,7 @@ def activate(
# We want some errors during activation, like those that can happen
# during payload conversion, to be able to fail the workflow not the
# task
if self._is_workflow_failure_exception(err):
if self.workflow_is_failure_exception(err):
try:
self._set_workflow_failure(err)
except Exception as inner_err:
Expand Down Expand Up @@ -629,7 +629,7 @@ async def run_update() -> None:
# Validation failures are always update failures. We reuse
# workflow failure logic to decide task failure vs update
# failure after validation.
if not past_validation or self._is_workflow_failure_exception(err):
if not past_validation or self.workflow_is_failure_exception(err):
if command is None:
command = self._add_command()
command.update_response.protocol_instance_id = (
Expand Down Expand Up @@ -1686,6 +1686,23 @@ def workflow_set_current_details(self, details: str):
self._assert_not_read_only("set current details")
self._current_details = details

def workflow_is_failure_exception(self, err: BaseException) -> bool:
# An exception is a failure instead of a task fail if it's already a
# failure error or if it is a timeout error or if it is an instance of
# any of the failure types in the worker or workflow-level setting
wf_failure_exception_types = self._defn.failure_exception_types
if self._dynamic_failure_exception_types is not None:
wf_failure_exception_types = self._dynamic_failure_exception_types
return (
isinstance(err, temporalio.exceptions.FailureError)
or isinstance(err, asyncio.TimeoutError)
or any(isinstance(err, typ) for typ in wf_failure_exception_types)
or any(
isinstance(err, typ)
for typ in self._worker_level_failure_exception_types
)
)

#### Calls from outbound impl ####
# These are in alphabetical order and all start with "_outbound_".

Expand Down Expand Up @@ -1939,7 +1956,7 @@ def _convert_payloads(
# Don't wrap payload conversion errors that would fail the workflow
raise
except Exception as err:
if self._is_workflow_failure_exception(err):
if self.workflow_is_failure_exception(err):
raise
raise RuntimeError("Failed decoding arguments") from err

Expand Down Expand Up @@ -1982,23 +1999,6 @@ def _instantiate_workflow_object(self) -> Any:

return workflow_instance

def _is_workflow_failure_exception(self, err: BaseException) -> bool:
# An exception is a failure instead of a task fail if it's already a
# failure error or if it is a timeout error or if it is an instance of
# any of the failure types in the worker or workflow-level setting
wf_failure_exception_types = self._defn.failure_exception_types
if self._dynamic_failure_exception_types is not None:
wf_failure_exception_types = self._dynamic_failure_exception_types
return (
isinstance(err, temporalio.exceptions.FailureError)
or isinstance(err, asyncio.TimeoutError)
or any(isinstance(err, typ) for typ in wf_failure_exception_types)
or any(
isinstance(err, typ)
for typ in self._worker_level_failure_exception_types
)
)

def _warn_if_unfinished_handlers(self) -> None:
def warnable(handler_executions: Iterable[HandlerExecution]):
return [
Expand Down Expand Up @@ -2192,7 +2192,7 @@ async def _run_top_level_workflow_function(self, coro: Awaitable[None]) -> None:
err
):
self._add_command().cancel_workflow_execution.SetInParent()
elif self._is_workflow_failure_exception(err):
elif self.workflow_is_failure_exception(err):
# All other failure errors fail the workflow
self._set_workflow_failure(err)
else:
Expand Down
12 changes: 12 additions & 0 deletions temporalio/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,9 @@ def workflow_get_current_details(self) -> str: ...
@abstractmethod
def workflow_set_current_details(self, details: str): ...

@abstractmethod
def workflow_is_failure_exception(self, err: BaseException) -> bool: ...


_current_update_info: contextvars.ContextVar[UpdateInfo] = contextvars.ContextVar(
"__temporal_current_update_info"
Expand Down Expand Up @@ -981,6 +984,15 @@ def memo() -> Mapping[str, Any]:
return _Runtime.current().workflow_memo()


def is_failure_exception(err: BaseException) -> bool:
"""Checks if the given exception is a workflow failure in the current workflow.

Returns:
True if the given exception is a workflow failure in the current workflow.
"""
return _Runtime.current().workflow_is_failure_exception(err)


@overload
def memo_value(key: str, default: Any = temporalio.common._arg_unset) -> Any: ...

Expand Down
77 changes: 57 additions & 20 deletions tests/contrib/openai_agents/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ async def run(self, question: str) -> str:
ActivityWeatherService.get_weather_method,
start_to_close_timeout=timedelta(seconds=10),
),
openai_agents.workflow.activity_as_tool(
get_weather_failure,
start_to_close_timeout=timedelta(seconds=10),
),
],
)
result = await Runner.run(
Expand Down Expand Up @@ -462,6 +466,53 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
)


@activity.defn
async def get_weather_failure(city: str) -> Weather:
"""
Get the weather for a given city.
"""
raise ApplicationError("No weather", non_retryable=True)


class TestWeatherFailureModel(StaticTestModel):
responses = [
ResponseBuilders.tool_call('{"city":"Tokyo"}', "get_weather_failure"),
]


async def test_tool_failure_workflow(client: Client):
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=30)
),
model_provider=TestModelProvider(TestWeatherFailureModel()),
)
]
client = Client(**new_config)

async with new_worker(
client,
ToolsWorkflow,
activities=[
get_weather_failure,
],
) as worker:
workflow_handle = await client.start_workflow(
ToolsWorkflow.run,
"What is the weather in Tokio?",
id=f"tools-failure-workflow-{uuid.uuid4()}",
task_queue=worker.task_queue,
execution_timeout=timedelta(seconds=2),
)
with pytest.raises(WorkflowFailureError) as e:
result = await workflow_handle.result()
cause = e.value.cause
assert isinstance(cause, ApplicationError)
assert "Workflow failure exception in Agents Framework" in cause.message


@pytest.mark.parametrize("use_local_model", [True, False])
async def test_nexus_tool_workflow(
client: Client, env: WorkflowEnvironment, use_local_model: bool
Expand Down Expand Up @@ -1909,20 +1960,14 @@ async def run(self, question: str) -> str:
return result.final_output


@pytest.mark.parametrize("use_local_model", [True, False])
async def test_code_interpreter_tool(client: Client, use_local_model):
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
pytest.skip("No openai API key")

async def test_code_interpreter_tool(client: Client):
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=60)
),
model_provider=TestModelProvider(CodeInterpreterModel())
if use_local_model
else None,
model_provider=TestModelProvider(CodeInterpreterModel()),
)
]
client = Client(**new_config)
Expand All @@ -1939,8 +1984,7 @@ async def test_code_interpreter_tool(client: Client, use_local_model):
execution_timeout=timedelta(seconds=60),
)
result = await workflow_handle.result()
if use_local_model:
assert result == "Over 9000"
assert result == "Over 9000"


class HostedMCPModel(StaticTestModel):
Expand Down Expand Up @@ -2011,20 +2055,14 @@ def approve(_: MCPToolApprovalRequest) -> MCPToolApprovalFunctionResult:
return result.final_output


@pytest.mark.parametrize("use_local_model", [True, False])
async def test_hosted_mcp_tool(client: Client, use_local_model):
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
pytest.skip("No openai API key")

async def test_hosted_mcp_tool(client: Client):
new_config = client.config()
new_config["plugins"] = [
openai_agents.OpenAIAgentsPlugin(
model_params=ModelActivityParameters(
start_to_close_timeout=timedelta(seconds=120)
),
model_provider=TestModelProvider(HostedMCPModel())
if use_local_model
else None,
model_provider=TestModelProvider(HostedMCPModel()),
)
]
client = Client(**new_config)
Expand All @@ -2041,8 +2079,7 @@ async def test_hosted_mcp_tool(client: Client, use_local_model):
execution_timeout=timedelta(seconds=120),
)
result = await workflow_handle.result()
if use_local_model:
assert result == "Some language"
assert result == "Some language"


class AssertDifferentModelProvider(ModelProvider):
Expand Down