Skip to content

Commit 9da1795

Browse files
committed
Support for method activities conversion into tools (#968)
* Support for method activities conversion into tools * Fix test
1 parent 0f19e49 commit 9da1795

File tree

2 files changed

+77
-8
lines changed

2 files changed

+77
-8
lines changed

temporalio/contrib/openai_agents/workflow.py

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,27 @@
11
"""Workflow-specific primitives for working with the OpenAI Agents SDK in a workflow context"""
22

3+
import functools
4+
import inspect
35
import json
46
from datetime import timedelta
5-
from typing import Any, Callable, Optional, Type
7+
from typing import Any, Callable, Optional, Type, Union, overload
68

79
import nexusrpc
810
from agents import (
11+
Agent,
912
RunContextWrapper,
1013
Tool,
1114
)
12-
from agents.function_schema import function_schema
15+
from agents.function_schema import DocstringStyle, function_schema
1316
from agents.tool import (
1417
FunctionTool,
18+
ToolErrorFunction,
19+
ToolFunction,
20+
ToolParams,
21+
default_tool_error_function,
22+
function_tool,
1523
)
24+
from agents.util._types import MaybeAwaitable
1625

1726
from temporalio import activity
1827
from temporalio import workflow as temporal_workflow
@@ -78,6 +87,25 @@ def activity_as_tool(
7887
"Bare function without tool and activity decorators is not supported",
7988
"invalid_tool",
8089
)
90+
if ret.name is None:
91+
raise ApplicationError(
92+
"Input activity must have a name to be made into a tool",
93+
"invalid_tool",
94+
)
95+
# If the provided callable has a first argument of `self`, partially apply it with the same metadata
96+
# The actual instance will be picked up by the activity execution, the partially applied function will never actually be executed
97+
params = list(inspect.signature(fn).parameters.keys())
98+
if len(params) > 0 and params[0] == "self":
99+
partial = functools.partial(fn, None)
100+
setattr(partial, "__name__", fn.__name__)
101+
partial.__annotations__ = getattr(fn, "__annotations__")
102+
setattr(
103+
partial,
104+
"__temporal_activity_definition",
105+
getattr(fn, "__temporal_activity_definition"),
106+
)
107+
partial.__doc__ = fn.__doc__
108+
fn = partial
81109
schema = function_schema(fn)
82110

83111
async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
@@ -94,9 +122,8 @@ async def run_activity(ctx: RunContextWrapper[Any], input: str) -> Any:
94122
# Add the context to the arguments if it takes that
95123
if schema.takes_context:
96124
args = [ctx] + args
97-
98125
result = await temporal_workflow.execute_activity(
99-
fn,
126+
ret.name, # type: ignore
100127
args=args,
101128
task_queue=task_queue,
102129
schedule_to_close_timeout=schedule_to_close_timeout,

tests/contrib/openai_agents/test_openai.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,17 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather
195195
return Weather(city=city, temperature_range="14-20C", conditions=ctx.context)
196196

197197

198+
class ActivityWeatherService:
199+
@activity.defn
200+
async def get_weather_method(self, city: str) -> Weather:
201+
"""
202+
Get the weather for a given city.
203+
"""
204+
return Weather(
205+
city=city, temperature_range="14-20C", conditions="Sunny with wind."
206+
)
207+
208+
198209
@nexusrpc.service
199210
class WeatherService:
200211
get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather]
@@ -269,6 +280,20 @@ class TestWeatherModel(StaticTestModel):
269280
usage=Usage(),
270281
response_id=None,
271282
),
283+
ModelResponse(
284+
output=[
285+
ResponseFunctionToolCall(
286+
arguments='{"city":"Tokyo"}',
287+
call_id="call",
288+
name="get_weather_method",
289+
type="function_call",
290+
id="id",
291+
status="completed",
292+
)
293+
],
294+
usage=Usage(),
295+
response_id=None,
296+
),
272297
ModelResponse(
273298
output=[
274299
ResponseOutputMessage(
@@ -333,7 +358,7 @@ class TestNexusWeatherModel(StaticTestModel):
333358
class ToolsWorkflow:
334359
@workflow.run
335360
async def run(self, question: str) -> str:
336-
agent = Agent(
361+
agent: Agent = Agent(
337362
name="Tools Workflow",
338363
instructions="You are a helpful agent.",
339364
tools=[
@@ -349,8 +374,12 @@ async def run(self, question: str) -> str:
349374
openai_agents.workflow.activity_as_tool(
350375
get_weather_context, start_to_close_timeout=timedelta(seconds=10)
351376
),
377+
openai_agents.workflow.activity_as_tool(
378+
ActivityWeatherService.get_weather_method,
379+
start_to_close_timeout=timedelta(seconds=10),
380+
),
352381
],
353-
) # type: Agent
382+
)
354383
result = await Runner.run(
355384
starting_agent=agent, input=question, context="Stormy"
356385
)
@@ -406,6 +435,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
406435
get_weather_object,
407436
get_weather_country,
408437
get_weather_context,
438+
ActivityWeatherService().get_weather_method,
409439
],
410440
interceptors=[OpenAIAgentsTracingInterceptor()],
411441
) as worker:
@@ -426,7 +456,7 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
426456
if e.HasField("activity_task_completed_event_attributes"):
427457
events.append(e)
428458

429-
assert len(events) == 9
459+
assert len(events) == 11
430460
assert (
431461
"function_call"
432462
in events[0]
@@ -476,11 +506,23 @@ async def test_tool_workflow(client: Client, use_local_model: bool):
476506
.data.decode()
477507
)
478508
assert (
479-
"Test weather result"
509+
"function_call"
480510
in events[8]
481511
.activity_task_completed_event_attributes.result.payloads[0]
482512
.data.decode()
483513
)
514+
assert (
515+
"Sunny with wind"
516+
in events[9]
517+
.activity_task_completed_event_attributes.result.payloads[0]
518+
.data.decode()
519+
)
520+
assert (
521+
"Test weather result"
522+
in events[10]
523+
.activity_task_completed_event_attributes.result.payloads[0]
524+
.data.decode()
525+
)
484526

485527

486528
@pytest.mark.parametrize("use_local_model", [True, False])

0 commit comments

Comments
 (0)