diff --git a/tests/nexus/test_workflow_caller_cancellation_types.py b/tests/nexus/test_workflow_caller_cancellation_types.py index 7cbbb95c2..2d44e6416 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types.py +++ b/tests/nexus/test_workflow_caller_cancellation_types.py @@ -28,9 +28,7 @@ class TestContext: __test__ = False cancellation_type: workflow.NexusOperationCancellationType - caller_op_future_resolved: asyncio.Future[datetime] = field( - default_factory=asyncio.Future - ) + caller_workflow_id: str cancel_handler_released: asyncio.Future[datetime] = field( default_factory=asyncio.Future ) @@ -96,7 +94,15 @@ async def cancel( # by the caller server. At that point, the caller server will write # NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus # op handle future can be resolved as cancelled before any of that. - await test_context.caller_op_future_resolved + caller_wf: WorkflowHandle[Any, CancellationResult] = ( + nexus.client().get_workflow_handle_for( + CallerWorkflow.run, + workflow_id=test_context.caller_workflow_id, + ) + ) + await caller_wf.execute_update( + CallerWorkflow.wait_caller_op_future_resolved + ) test_context.cancel_handler_released.set_result(datetime.now(timezone.utc)) await super().cancel(ctx, token) @@ -117,6 +123,7 @@ class Input: @dataclass class CancellationResult: operation_token: str + caller_op_future_resolved: datetime @workflow.defn(sandboxed=False) @@ -129,6 +136,7 @@ def __init__(self, input: Input): ) self.released = False self.operation_token: Optional[str] = None + self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future() @workflow.signal def release(self): @@ -140,6 +148,10 @@ async def get_operation_token(self) -> str: assert self.operation_token return self.operation_token + @workflow.update + async def wait_caller_op_future_resolved(self) -> None: + await self.caller_op_future_resolved + @workflow.run async def run(self, input: Input) -> CancellationResult: op_handle = await ( @@ -188,9 +200,7 @@ async def run(self, input: Input) -> CancellationResult: try: await op_handle except exceptions.NexusOperationError: - test_context.caller_op_future_resolved.set_result( - datetime.now(timezone.utc) - ) + self.caller_op_future_resolved.set_result(workflow.now()) assert op_handle.operation_token if input.cancellation_type in [ workflow.NexusOperationCancellationType.TRY_CANCEL, @@ -210,6 +220,7 @@ async def run(self, input: Input) -> CancellationResult: await workflow.wait_condition(lambda: self.released) return CancellationResult( operation_token=op_handle.operation_token, + caller_op_future_resolved=self.caller_op_future_resolved.result(), ) else: pytest.fail("Expected NexusOperationError") @@ -233,7 +244,10 @@ async def test_cancellation_type( cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name] global test_context - test_context = TestContext(cancellation_type=cancellation_type) + test_context = TestContext( + cancellation_type=cancellation_type, + caller_workflow_id="caller-wf-" + str(uuid.uuid4()), + ) client = env.client @@ -253,7 +267,7 @@ async def test_cancellation_type( endpoint=make_nexus_endpoint_name(worker.task_queue), cancellation_type=cancellation_type, ), - id="caller-wf-" + str(uuid.uuid4()), + id=test_context.caller_workflow_id, task_queue=worker.task_queue, id_conflict_policy=WorkflowIDConflictPolicy.FAIL, ) @@ -314,8 +328,12 @@ async def check_behavior_for_try_cancel( ) -> None: """ Check that a cancellation request is sent and the caller workflow nexus op future is unblocked - as cancelled before the cancel handler returns (i.e. before the - NexusOperationCancelRequestCompleted in the caller workflow history). + as cancelled before the caller server writes CANCEL_REQUESTED. + + There is a race between (a) the caller server writing CANCEL_REQUEST_COMPLETED in response to + the cancel handler returning, and (b) the caller server writing CANCELED in response to the + handler workflow exiting as canceled. If (b) happens first then (a) may never happen, therefore + we do not make any assertions regarding CANCEL_REQUEST_COMPLETED. """ try: await handler_wf.result() @@ -324,15 +342,13 @@ async def check_behavior_for_try_cancel( else: pytest.fail("Expected WorkflowFailureError") await caller_wf.signal(CallerWorkflow.release) - await caller_wf.result() + result = await caller_wf.result() handler_status = (await handler_wf.describe()).status assert handler_status == WorkflowExecutionStatus.CANCELED - caller_op_future_resolved = test_context.caller_op_future_resolved.result() await assert_event_subsequence( [ (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED), - (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED), (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), ] ) @@ -340,15 +356,7 @@ async def check_behavior_for_try_cancel( caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED, ) - op_cancel_request_completed_event = await get_event_time( - caller_wf, - EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, - ) - assert ( - caller_op_future_resolved - < op_cancel_requested_event - < op_cancel_request_completed_event - ) + assert result.caller_op_future_resolved < op_cancel_requested_event async def check_behavior_for_wait_cancellation_requested( @@ -369,7 +377,7 @@ async def check_behavior_for_wait_cancellation_requested( pytest.fail("Expected WorkflowFailureError") await caller_wf.signal(CallerWorkflow.release) - await caller_wf.result() + result = await caller_wf.result() handler_status = (await handler_wf.describe()).status assert handler_status == WorkflowExecutionStatus.CANCELED @@ -380,7 +388,6 @@ async def check_behavior_for_wait_cancellation_requested( (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED), ] ) - caller_op_future_resolved = test_context.caller_op_future_resolved.result() op_cancel_request_completed = await get_event_time( caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED, @@ -389,7 +396,7 @@ async def check_behavior_for_wait_cancellation_requested( handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, ) - assert op_cancel_request_completed < caller_op_future_resolved < op_canceled + assert op_cancel_request_completed < result.caller_op_future_resolved < op_canceled async def check_behavior_for_wait_cancellation_completed( @@ -411,7 +418,7 @@ async def check_behavior_for_wait_cancellation_completed( assert handler_status == WorkflowExecutionStatus.CANCELED await caller_wf.signal(CallerWorkflow.release) - await caller_wf.result() + result = await caller_wf.result() await assert_event_subsequence( [ @@ -426,12 +433,11 @@ async def check_behavior_for_wait_cancellation_completed( (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), ] ) - caller_op_future_resolved = test_context.caller_op_future_resolved.result() - handler_wf_canceled_event_time = await get_event_time( + handler_wf_canceled_event = await get_event_time( handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED, ) - assert caller_op_future_resolved > handler_wf_canceled_event_time + assert handler_wf_canceled_event < result.caller_op_future_resolved async def has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType): diff --git a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py index 9585a8445..33b167245 100644 --- a/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py +++ b/tests/nexus/test_workflow_caller_cancellation_types_when_cancel_handler_fails.py @@ -35,9 +35,6 @@ class TestContext: __test__ = False cancellation_type: workflow.NexusOperationCancellationType - caller_op_future_resolved: asyncio.Future[datetime] = field( - default_factory=asyncio.Future - ) cancel_handler_released: asyncio.Future[datetime] = field( default_factory=asyncio.Future ) @@ -129,6 +126,7 @@ class Input: @dataclass class CancellationResult: operation_token: str + caller_op_future_resolved: datetime error_type: Optional[str] = None error_cause_type: Optional[str] = None @@ -143,6 +141,7 @@ def __init__(self, input: Input): ) self.released = False self.operation_token: Optional[str] = None + self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future() @workflow.signal def release(self): @@ -184,13 +183,14 @@ async def run(self, input: Input) -> CancellationResult: error_type = err.__class__.__name__ error_cause_type = err.__cause__.__class__.__name__ - test_context.caller_op_future_resolved.set_result(datetime.now(timezone.utc)) + self.caller_op_future_resolved.set_result(workflow.now()) assert op_handle.operation_token await workflow.wait_condition(lambda: self.released) return CancellationResult( operation_token=op_handle.operation_token, error_type=error_type, error_cause_type=error_cause_type, + caller_op_future_resolved=self.caller_op_future_resolved.result(), ) @@ -300,7 +300,6 @@ async def check_behavior_for_try_cancel( assert result.error_type == "NexusOperationError" assert result.error_cause_type == "CancelledError" - caller_op_future_resolved = test_context.caller_op_future_resolved.result() await assert_event_subsequence( [ (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED), @@ -317,7 +316,7 @@ async def check_behavior_for_try_cancel( EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED, ) assert ( - caller_op_future_resolved + result.caller_op_future_resolved < op_cancel_requested_event < op_cancel_request_failed_event ) @@ -341,7 +340,6 @@ async def check_behavior_for_wait_cancellation_requested( (caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED), ] ) - caller_op_future_resolved = test_context.caller_op_future_resolved.result() op_cancel_request_failed = await get_event_time( caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED, @@ -350,7 +348,11 @@ async def check_behavior_for_wait_cancellation_requested( handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, ) - assert op_cancel_request_failed < caller_op_future_resolved < handler_wf_completed + assert ( + op_cancel_request_failed + < result.caller_op_future_resolved + < handler_wf_completed + ) async def check_behavior_for_wait_cancellation_completed( @@ -373,9 +375,8 @@ async def check_behavior_for_wait_cancellation_completed( (caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED), ] ) - caller_op_future_resolved = test_context.caller_op_future_resolved.result() handler_wf_completed = await get_event_time( handler_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED, ) - assert handler_wf_completed < caller_op_future_resolved + assert handler_wf_completed < result.caller_op_future_resolved