Skip to content

Commit e73d6b5

Browse files
authored
Support Nexus tool calls in OpenAI Agents integration (#949)
1 parent 5d0e9b1 commit e73d6b5

File tree

2 files changed

+244
-9
lines changed

2 files changed

+244
-9
lines changed

temporalio/contrib/openai_agents/temporal_openai_agents.py

Lines changed: 90 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import json
44
from contextlib import contextmanager
55
from datetime import timedelta
6-
from typing import Any, AsyncIterator, Callable, Optional, Union, overload
6+
from typing import Any, AsyncIterator, Callable, Optional, Type, Union
77

8+
import nexusrpc
89
from agents import (
9-
Agent,
1010
AgentOutputSchemaBase,
1111
Handoff,
1212
Model,
@@ -19,20 +19,14 @@
1919
TResponseInputItem,
2020
set_trace_provider,
2121
)
22-
from agents.function_schema import DocstringStyle, function_schema
22+
from agents.function_schema import function_schema
2323
from agents.items import TResponseStreamEvent
2424
from agents.run import get_default_agent_runner, set_default_agent_runner
2525
from agents.tool import (
2626
FunctionTool,
27-
ToolErrorFunction,
28-
ToolFunction,
29-
ToolParams,
30-
default_tool_error_function,
31-
function_tool,
3227
)
3328
from agents.tracing import get_trace_provider
3429
from agents.tracing.provider import DefaultTraceProvider
35-
from agents.util._types import MaybeAwaitable
3630
from openai.types.responses import ResponsePromptParam
3731

3832
from temporalio import activity
@@ -266,3 +260,90 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
266260
on_invoke_tool=run_activity,
267261
strict_json_schema=True,
268262
)
263+
264+
@classmethod
265+
def nexus_operation_as_tool(
266+
cls,
267+
operation: nexusrpc.Operation[Any, Any],
268+
*,
269+
service: Type[Any],
270+
endpoint: str,
271+
schedule_to_close_timeout: Optional[timedelta] = None,
272+
) -> Tool:
273+
"""Convert a Nexus operation into an OpenAI agent tool.
274+
275+
.. warning::
276+
This API is experimental and may change in future versions.
277+
Use with caution in production environments.
278+
279+
This function takes a Nexus operation and converts it into an
280+
OpenAI agent tool that can be used by the agent to execute the operation
281+
during workflow execution. The tool will automatically handle the conversion
282+
of inputs and outputs between the agent and the operation.
283+
284+
Args:
285+
fn: A Nexus operation to convert into a tool.
286+
service: The Nexus service class that contains the operation.
287+
endpoint: The Nexus endpoint to use for the operation.
288+
289+
Returns:
290+
An OpenAI agent tool that wraps the provided operation.
291+
292+
Example:
293+
>>> @nexusrpc.service
294+
... class WeatherService:
295+
... get_weather_object_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]
296+
>>>
297+
>>> # Create tool with custom activity options
298+
>>> tool = nexus_operation_as_tool(
299+
... WeatherService.get_weather_object_nexus_operation,
300+
... service=WeatherService,
301+
... endpoint="weather-service",
302+
... )
303+
>>> # Use tool with an OpenAI agent
304+
"""
305+
306+
def operation_callable(input):
307+
raise NotImplementedError("This function definition is used as a type only")
308+
309+
operation_callable.__annotations__ = {
310+
"input": operation.input_type,
311+
"return": operation.output_type,
312+
}
313+
operation_callable.__name__ = operation.name
314+
315+
schema = function_schema(operation_callable)
316+
317+
async def run_operation(ctx: RunContextWrapper[Any], input: str) -> Any:
318+
try:
319+
json_data = json.loads(input)
320+
except Exception as e:
321+
raise ApplicationError(
322+
f"Invalid JSON input for tool {schema.name}: {input}"
323+
) from e
324+
325+
nexus_client = temporal_workflow.create_nexus_client(
326+
service=service, endpoint=endpoint
327+
)
328+
args, _ = schema.to_call_args(schema.params_pydantic_model(**json_data))
329+
assert len(args) == 1, "Nexus operations must have exactly one argument"
330+
[arg] = args
331+
result = await nexus_client.execute_operation(
332+
operation,
333+
arg,
334+
schedule_to_close_timeout=schedule_to_close_timeout,
335+
)
336+
try:
337+
return str(result)
338+
except Exception as e:
339+
raise ToolSerializationError(
340+
"You must return a string representation of the tool output, or something we can call str() on"
341+
) from e
342+
343+
return FunctionTool(
344+
name=schema.name,
345+
description=schema.description or "",
346+
params_json_schema=schema.params_json_schema,
347+
on_invoke_tool=run_operation,
348+
strict_json_schema=True,
349+
)

tests/contrib/openai_agents/test_openai.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from datetime import timedelta
55
from typing import Any, Optional, Union, no_type_check
66

7+
import nexusrpc
78
import pytest
89
from agents import (
910
Agent,
@@ -59,10 +60,12 @@
5960
)
6061
from temporalio.contrib.pydantic import pydantic_data_converter
6162
from temporalio.exceptions import CancelledError
63+
from temporalio.testing import WorkflowEnvironment
6264
from tests.contrib.openai_agents.research_agents.research_manager import (
6365
ResearchManager,
6466
)
6567
from tests.helpers import new_worker
68+
from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name
6669

6770
response_index: int = 0
6871

@@ -192,6 +195,22 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather
192195
return Weather(city=city, temperature_range="14-20C", conditions=ctx.context)
193196

194197

198+
@nexusrpc.service
199+
class WeatherService:
200+
get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]
201+
202+
203+
@nexusrpc.handler.service_handler(service=WeatherService)
204+
class WeatherServiceHandler:
205+
@nexusrpc.handler.sync_operation
206+
async def get_weather_nexus_operation(
207+
self, ctx: nexusrpc.handler.StartOperationContext, input: WeatherInput
208+
) -> Weather:
209+
return Weather(
210+
city=input.city, temperature_range="14-20C", conditions="Sunny with wind."
211+
)
212+
213+
195214
class TestWeatherModel(StaticTestModel):
196215
responses = [
197216
ModelResponse(
@@ -272,6 +291,44 @@ class TestWeatherModel(StaticTestModel):
272291
]
273292

274293

294+
class TestNexusWeatherModel(StaticTestModel):
295+
responses = [
296+
ModelResponse(
297+
output=[
298+
ResponseFunctionToolCall(
299+
arguments='{"input":{"city":"Tokyo"}}',
300+
call_id="call",
301+
name="get_weather_nexus_operation",
302+
type="function_call",
303+
id="id",
304+
status="completed",
305+
)
306+
],
307+
usage=Usage(),
308+
response_id=None,
309+
),
310+
ModelResponse(
311+
output=[
312+
ResponseOutputMessage(
313+
id="",
314+
content=[
315+
ResponseOutputText(
316+
text="Test nexus weather result",
317+
annotations=[],
318+
type="output_text",
319+
)
320+
],
321+
role="assistant",
322+
status="completed",
323+
type="message",
324+
)
325+
],
326+
usage=Usage(),
327+
response_id=None,
328+
),
329+
]
330+
331+
275332
@workflow.defn
276333
class ToolsWorkflow:
277334
@workflow.run
@@ -300,6 +357,28 @@ async def run(self, question: str) -> str:
300357
return result.final_output
301358

302359

360+
@workflow.defn
361+
class NexusToolsWorkflow:
362+
@workflow.run
363+
async def run(self, question: str) -> str:
364+
agent = Agent(
365+
name="Nexus Tools Workflow",
366+
instructions="You are a helpful agent.",
367+
tools=[
368+
openai_agents.workflow.nexus_operation_as_tool(
369+
WeatherService.get_weather_nexus_operation,
370+
service=WeatherService,
371+
endpoint=make_nexus_endpoint_name(workflow.info().task_queue),
372+
schedule_to_close_timeout=timedelta(seconds=10),
373+
),
374+
],
375+
) # type: Agent
376+
result = await Runner.run(
377+
starting_agent=agent, input=question, context="Stormy"
378+
)
379+
return result.final_output
380+
381+
303382
@pytest.mark.parametrize("use_local_model", [True, False])
304383
async def test_tool_workflow(client: Client, use_local_model: bool):
305384
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
@@ -404,6 +483,81 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
404483
)
405484

406485

486+
@pytest.mark.parametrize("use_local_model", [True, False])
487+
async def test_nexus_tool_workflow(
488+
client: Client, env: WorkflowEnvironment, use_local_model: bool
489+
):
490+
if not use_local_model and not os.environ.get("OPENAI_API_KEY"):
491+
pytest.skip("No openai API key")
492+
493+
if env.supports_time_skipping:
494+
pytest.skip("Nexus tests don't work with time-skipping server")
495+
496+
new_config = client.config()
497+
new_config["data_converter"] = pydantic_data_converter
498+
client = Client(**new_config)
499+
500+
model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30))
501+
with set_open_ai_agent_temporal_overrides(model_params):
502+
model_activity = ModelActivity(
503+
TestModelProvider(
504+
TestNexusWeatherModel( # type: ignore
505+
)
506+
)
507+
if use_local_model
508+
else None
509+
)
510+
async with new_worker(
511+
client,
512+
NexusToolsWorkflow,
513+
activities=[
514+
model_activity.invoke_model_activity,
515+
],
516+
nexus_service_handlers=[WeatherServiceHandler()],
517+
interceptors=[OpenAIAgentsTracingInterceptor()],
518+
) as worker:
519+
await create_nexus_endpoint(worker.task_queue, client)
520+
521+
workflow_handle = await client.start_workflow(
522+
NexusToolsWorkflow.run,
523+
"What is the weather in Tokio?",
524+
id=f"nexus-tools-workflow-{uuid.uuid4()}",
525+
task_queue=worker.task_queue,
526+
execution_timeout=timedelta(seconds=30),
527+
)
528+
result = await workflow_handle.result()
529+
530+
if use_local_model:
531+
assert result == "Test nexus weather result"
532+
533+
events = []
534+
async for e in workflow_handle.fetch_history_events():
535+
if e.HasField(
536+
"activity_task_completed_event_attributes"
537+
) or e.HasField("nexus_operation_completed_event_attributes"):
538+
events.append(e)
539+
540+
assert len(events) == 3
541+
assert (
542+
"function_call"
543+
in events[0]
544+
.activity_task_completed_event_attributes.result.payloads[0]
545+
.data.decode()
546+
)
547+
assert (
548+
"Sunny with wind"
549+
in events[
550+
1
551+
].nexus_operation_completed_event_attributes.result.data.decode()
552+
)
553+
assert (
554+
"Test nexus weather result"
555+
in events[2]
556+
.activity_task_completed_event_attributes.result.payloads[0]
557+
.data.decode()
558+
)
559+
560+
407561
@no_type_check
408562
class TestResearchModel(StaticTestModel):
409563
responses = [

0 commit comments

Comments
 (0)