From 1ebe253237fdea80d037a11f30f2993dac53e464 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Mon, 21 Jul 2025 14:55:32 -0700 Subject: [PATCH 1/5] Add heartbeat test and fix bug --- .../openai_agents/_heartbeat_decorator.py | 5 +- tests/contrib/openai_agents/test_openai.py | 81 ++++++++++++++++++- 2 files changed, 84 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/openai_agents/_heartbeat_decorator.py b/temporalio/contrib/openai_agents/_heartbeat_decorator.py index bce015ed8..fb645d7c2 100644 --- a/temporalio/contrib/openai_agents/_heartbeat_decorator.py +++ b/temporalio/contrib/openai_agents/_heartbeat_decorator.py @@ -24,7 +24,10 @@ async def wrapper(*args, **kwargs): if heartbeat_task: heartbeat_task.cancel() # Wait for heartbeat cancellation to complete - await heartbeat_task + try: + await heartbeat_task + except asyncio.CancelledError: + pass return cast(F, wrapper) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 57dc5c252..20ed91570 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1,8 +1,9 @@ +import asyncio import os import uuid from dataclasses import dataclass from datetime import timedelta -from typing import Any, Optional, Union, no_type_check +from typing import Any, AsyncIterator, Optional, Union, no_type_check import nexusrpc import pytest @@ -14,6 +15,7 @@ InputGuardrailTripwireTriggered, ItemHelpers, MessageOutputItem, + Model, ModelResponse, ModelSettings, ModelTracing, @@ -35,6 +37,7 @@ HandoffOutputItem, ToolCallItem, ToolCallOutputItem, + TResponseStreamEvent, ) from openai import AsyncOpenAI, BaseModel from openai.types.responses import ( @@ -1778,3 +1781,79 @@ async def test_workflow_method_tools(client: Client): execution_timeout=timedelta(seconds=10), ) await workflow_handle.result() + + +class WaitModel(Model): + async def get_response( + self, + system_instructions: Union[str, None], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Union[AgentOutputSchemaBase, None], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Union[str, None], + prompt: Union[ResponsePromptParam, None] = None, + ) -> ModelResponse: + activity.logger.info("Waiting") + await asyncio.sleep(5.0) + activity.logger.info("Returning") + return ModelResponse( + output=[ + ResponseOutputMessage( + id="", + content=[ + ResponseOutputText( + text="test", annotations=[], type="output_text" + ) + ], + role="assistant", + status="completed", + type="message", + ) + ], + usage=Usage(), + response_id=None, + ) + + def stream_response( + self, + system_instructions: Optional[str], + input: Union[str, list[TResponseInputItem]], + model_settings: ModelSettings, + tools: list[Tool], + output_schema: Optional[AgentOutputSchemaBase], + handoffs: list[Handoff], + tracing: ModelTracing, + *, + previous_response_id: Optional[str], + prompt: Optional[ResponsePromptParam], + ) -> AsyncIterator[TResponseStreamEvent]: + raise NotImplementedError() + + +async def test_heartbeat(client: Client): + new_config = client.config() + new_config["data_converter"] = pydantic_data_converter + client = Client(**new_config) + + with set_open_ai_agent_temporal_overrides( + model_params=ModelActivityParameters(heartbeat_timeout=timedelta(seconds=2)) + ): + model_activity = ModelActivity(TestModelProvider(WaitModel())) + async with new_worker( + client, + HelloWorldAgent, + activities=[model_activity.invoke_model_activity], + interceptors=[OpenAIAgentsTracingInterceptor()], + ) as worker: + workflow_handle = await client.start_workflow( + HelloWorldAgent.run, + "Tell me about recursion in programming.", + id=f"workflow-tool-{uuid.uuid4()}", + task_queue=worker.task_queue, + execution_timeout=timedelta(seconds=10), + ) + await workflow_handle.result() From 8b94f6df119ce5511d9a35f0301a2b33aa10d86a Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 22 Jul 2025 15:43:35 -0700 Subject: [PATCH 2/5] Reduce intervals for speed --- tests/contrib/openai_agents/test_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 20ed91570..e7ef43a60 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1798,7 +1798,7 @@ async def get_response( prompt: Union[ResponsePromptParam, None] = None, ) -> ModelResponse: activity.logger.info("Waiting") - await asyncio.sleep(5.0) + await asyncio.sleep(.5) activity.logger.info("Returning") return ModelResponse( output=[ @@ -1840,7 +1840,7 @@ async def test_heartbeat(client: Client): client = Client(**new_config) with set_open_ai_agent_temporal_overrides( - model_params=ModelActivityParameters(heartbeat_timeout=timedelta(seconds=2)) + model_params=ModelActivityParameters(heartbeat_timeout=timedelta(seconds=.2)) ): model_activity = ModelActivity(TestModelProvider(WaitModel())) async with new_worker( @@ -1854,6 +1854,6 @@ async def test_heartbeat(client: Client): "Tell me about recursion in programming.", id=f"workflow-tool-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=10), + execution_timeout=timedelta(seconds=1), ) await workflow_handle.result() From 0a9f03ad8af5b18933d0eeaa5e12230037e8e6ae Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Tue, 22 Jul 2025 15:47:39 -0700 Subject: [PATCH 3/5] Linting --- tests/contrib/openai_agents/test_openai.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 29cc38943..dd9609678 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1859,7 +1859,7 @@ async def get_response( prompt: Union[ResponsePromptParam, None] = None, ) -> ModelResponse: activity.logger.info("Waiting") - await asyncio.sleep(.5) + await asyncio.sleep(0.5) activity.logger.info("Returning") return ModelResponse( output=[ @@ -1900,7 +1900,7 @@ async def test_heartbeat(client: Client): new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( - heartbeat_timeout=timedelta(seconds=.2), + heartbeat_timeout=timedelta(seconds=0.2), ), model_provider=TestModelProvider(WaitModel()), ) From b054542958535873c0028ba6a2d60dbbfb3fe9df Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 23 Jul 2025 10:27:20 -0700 Subject: [PATCH 4/5] Skip new test on time skipping server --- tests/contrib/openai_agents/test_openai.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index ab64606f3..97a1bec72 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1931,7 +1931,10 @@ def stream_response( raise NotImplementedError() -async def test_heartbeat(client: Client): +async def test_heartbeat(client: Client, env: WorkflowEnvironment): + if env.supports_time_skipping: + pytest.skip("Relies on real timing, skip.") + new_config = client.config() new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( From 226d2fb76cb1d5ad07287dd8d36304609aa7a1d4 Mon Sep 17 00:00:00 2001 From: Tim Conley Date: Wed, 23 Jul 2025 10:28:14 -0700 Subject: [PATCH 5/5] Update timings --- tests/contrib/openai_agents/test_openai.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/contrib/openai_agents/test_openai.py b/tests/contrib/openai_agents/test_openai.py index 97a1bec72..bb7ed38a3 100644 --- a/tests/contrib/openai_agents/test_openai.py +++ b/tests/contrib/openai_agents/test_openai.py @@ -1895,7 +1895,7 @@ async def get_response( prompt: Union[ResponsePromptParam, None] = None, ) -> ModelResponse: activity.logger.info("Waiting") - await asyncio.sleep(0.5) + await asyncio.sleep(1.0) activity.logger.info("Returning") return ModelResponse( output=[ @@ -1939,7 +1939,7 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment): new_config["plugins"] = [ openai_agents.OpenAIAgentsPlugin( model_params=ModelActivityParameters( - heartbeat_timeout=timedelta(seconds=0.2), + heartbeat_timeout=timedelta(seconds=0.5), ), model_provider=TestModelProvider(WaitModel()), ) @@ -1955,6 +1955,6 @@ async def test_heartbeat(client: Client, env: WorkflowEnvironment): "Tell me about recursion in programming.", id=f"workflow-tool-{uuid.uuid4()}", task_queue=worker.task_queue, - execution_timeout=timedelta(seconds=1), + execution_timeout=timedelta(seconds=5.0), ) await workflow_handle.result()