|
4 | 4 | from datetime import timedelta |
5 | 5 | from typing import Any, Optional, Union, no_type_check |
6 | 6 |
|
| 7 | +import nexusrpc |
7 | 8 | import pytest |
8 | 9 | from agents import ( |
9 | 10 | Agent, |
|
59 | 60 | ) |
60 | 61 | from temporalio.contrib.pydantic import pydantic_data_converter |
61 | 62 | from temporalio.exceptions import CancelledError |
| 63 | +from temporalio.testing import WorkflowEnvironment |
62 | 64 | from tests.contrib.openai_agents.research_agents.research_manager import ( |
63 | 65 | ResearchManager, |
64 | 66 | ) |
65 | 67 | from tests.helpers import new_worker |
| 68 | +from tests.helpers.nexus import create_nexus_endpoint, make_nexus_endpoint_name |
66 | 69 |
|
67 | 70 | response_index: int = 0 |
68 | 71 |
|
@@ -192,6 +195,22 @@ async def get_weather_context(ctx: RunContextWrapper[str], city: str) -> Weather |
192 | 195 | return Weather(city=city, temperature_range="14-20C", conditions=ctx.context) |
193 | 196 |
|
194 | 197 |
|
| 198 | +@nexusrpc.service |
| 199 | +class WeatherService: |
| 200 | + get_weather_nexus_operation: nexusrpc.Operation[WeatherInput, Weather] |
| 201 | + |
| 202 | + |
| 203 | +@nexusrpc.handler.service_handler(service=WeatherService) |
| 204 | +class WeatherServiceHandler: |
| 205 | + @nexusrpc.handler.sync_operation |
| 206 | + async def get_weather_nexus_operation( |
| 207 | + self, ctx: nexusrpc.handler.StartOperationContext, input: WeatherInput |
| 208 | + ) -> Weather: |
| 209 | + return Weather( |
| 210 | + city=input.city, temperature_range="14-20C", conditions="Sunny with wind." |
| 211 | + ) |
| 212 | + |
| 213 | + |
195 | 214 | class TestWeatherModel(StaticTestModel): |
196 | 215 | responses = [ |
197 | 216 | ModelResponse( |
@@ -272,6 +291,44 @@ class TestWeatherModel(StaticTestModel): |
272 | 291 | ] |
273 | 292 |
|
274 | 293 |
|
| 294 | +class TestNexusWeatherModel(StaticTestModel): |
| 295 | + responses = [ |
| 296 | + ModelResponse( |
| 297 | + output=[ |
| 298 | + ResponseFunctionToolCall( |
| 299 | + arguments='{"input":{"city":"Tokyo"}}', |
| 300 | + call_id="call", |
| 301 | + name="get_weather_nexus_operation", |
| 302 | + type="function_call", |
| 303 | + id="id", |
| 304 | + status="completed", |
| 305 | + ) |
| 306 | + ], |
| 307 | + usage=Usage(), |
| 308 | + response_id=None, |
| 309 | + ), |
| 310 | + ModelResponse( |
| 311 | + output=[ |
| 312 | + ResponseOutputMessage( |
| 313 | + id="", |
| 314 | + content=[ |
| 315 | + ResponseOutputText( |
| 316 | + text="Test nexus weather result", |
| 317 | + annotations=[], |
| 318 | + type="output_text", |
| 319 | + ) |
| 320 | + ], |
| 321 | + role="assistant", |
| 322 | + status="completed", |
| 323 | + type="message", |
| 324 | + ) |
| 325 | + ], |
| 326 | + usage=Usage(), |
| 327 | + response_id=None, |
| 328 | + ), |
| 329 | + ] |
| 330 | + |
| 331 | + |
275 | 332 | @workflow.defn |
276 | 333 | class ToolsWorkflow: |
277 | 334 | @workflow.run |
@@ -300,6 +357,28 @@ async def run(self, question: str) -> str: |
300 | 357 | return result.final_output |
301 | 358 |
|
302 | 359 |
|
| 360 | +@workflow.defn |
| 361 | +class NexusToolsWorkflow: |
| 362 | + @workflow.run |
| 363 | + async def run(self, question: str) -> str: |
| 364 | + agent = Agent( |
| 365 | + name="Nexus Tools Workflow", |
| 366 | + instructions="You are a helpful agent.", |
| 367 | + tools=[ |
| 368 | + openai_agents.workflow.nexus_operation_as_tool( |
| 369 | + WeatherService.get_weather_nexus_operation, |
| 370 | + service=WeatherService, |
| 371 | + endpoint=make_nexus_endpoint_name(workflow.info().task_queue), |
| 372 | + schedule_to_close_timeout=timedelta(seconds=10), |
| 373 | + ), |
| 374 | + ], |
| 375 | + ) # type: Agent |
| 376 | + result = await Runner.run( |
| 377 | + starting_agent=agent, input=question, context="Stormy" |
| 378 | + ) |
| 379 | + return result.final_output |
| 380 | + |
| 381 | + |
303 | 382 | @pytest.mark.parametrize("use_local_model", [True, False]) |
304 | 383 | async def test_tool_workflow(client: Client, use_local_model: bool): |
305 | 384 | if not use_local_model and not os.environ.get("OPENAI_API_KEY"): |
@@ -404,6 +483,81 @@ async def test_tool_workflow(client: Client, use_local_model: bool): |
404 | 483 | ) |
405 | 484 |
|
406 | 485 |
|
| 486 | +@pytest.mark.parametrize("use_local_model", [True, False]) |
| 487 | +async def test_nexus_tool_workflow( |
| 488 | + client: Client, env: WorkflowEnvironment, use_local_model: bool |
| 489 | +): |
| 490 | + if not use_local_model and not os.environ.get("OPENAI_API_KEY"): |
| 491 | + pytest.skip("No openai API key") |
| 492 | + |
| 493 | + if env.supports_time_skipping: |
| 494 | + pytest.skip("Nexus tests don't work with time-skipping server") |
| 495 | + |
| 496 | + new_config = client.config() |
| 497 | + new_config["data_converter"] = pydantic_data_converter |
| 498 | + client = Client(**new_config) |
| 499 | + |
| 500 | + model_params = ModelActivityParameters(start_to_close_timeout=timedelta(seconds=30)) |
| 501 | + with set_open_ai_agent_temporal_overrides(model_params): |
| 502 | + model_activity = ModelActivity( |
| 503 | + TestModelProvider( |
| 504 | + TestNexusWeatherModel( # type: ignore |
| 505 | + ) |
| 506 | + ) |
| 507 | + if use_local_model |
| 508 | + else None |
| 509 | + ) |
| 510 | + async with new_worker( |
| 511 | + client, |
| 512 | + NexusToolsWorkflow, |
| 513 | + activities=[ |
| 514 | + model_activity.invoke_model_activity, |
| 515 | + ], |
| 516 | + nexus_service_handlers=[WeatherServiceHandler()], |
| 517 | + interceptors=[OpenAIAgentsTracingInterceptor()], |
| 518 | + ) as worker: |
| 519 | + await create_nexus_endpoint(worker.task_queue, client) |
| 520 | + |
| 521 | + workflow_handle = await client.start_workflow( |
| 522 | + NexusToolsWorkflow.run, |
| 523 | + "What is the weather in Tokio?", |
| 524 | + id=f"nexus-tools-workflow-{uuid.uuid4()}", |
| 525 | + task_queue=worker.task_queue, |
| 526 | + execution_timeout=timedelta(seconds=30), |
| 527 | + ) |
| 528 | + result = await workflow_handle.result() |
| 529 | + |
| 530 | + if use_local_model: |
| 531 | + assert result == "Test nexus weather result" |
| 532 | + |
| 533 | + events = [] |
| 534 | + async for e in workflow_handle.fetch_history_events(): |
| 535 | + if e.HasField( |
| 536 | + "activity_task_completed_event_attributes" |
| 537 | + ) or e.HasField("nexus_operation_completed_event_attributes"): |
| 538 | + events.append(e) |
| 539 | + |
| 540 | + assert len(events) == 3 |
| 541 | + assert ( |
| 542 | + "function_call" |
| 543 | + in events[0] |
| 544 | + .activity_task_completed_event_attributes.result.payloads[0] |
| 545 | + .data.decode() |
| 546 | + ) |
| 547 | + assert ( |
| 548 | + "Sunny with wind" |
| 549 | + in events[ |
| 550 | + 1 |
| 551 | + ].nexus_operation_completed_event_attributes.result.data.decode() |
| 552 | + ) |
| 553 | + assert ( |
| 554 | + "Test nexus weather result" |
| 555 | + in events[2] |
| 556 | + .activity_task_completed_event_attributes.result.payloads[0] |
| 557 | + .data.decode() |
| 558 | + ) |
| 559 | + |
| 560 | + |
407 | 561 | @no_type_check |
408 | 562 | class TestResearchModel(StaticTestModel): |
409 | 563 | responses = [ |
|
0 commit comments