Skip to content
66 changes: 36 additions & 30 deletions tests/nexus/test_workflow_caller_cancellation_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
class TestContext:
__test__ = False
cancellation_type: workflow.NexusOperationCancellationType
caller_op_future_resolved: asyncio.Future[datetime] = field(
default_factory=asyncio.Future
)
caller_workflow_id: str
cancel_handler_released: asyncio.Future[datetime] = field(
default_factory=asyncio.Future
)
Expand Down Expand Up @@ -96,7 +94,15 @@ async def cancel(
# by the caller server. At that point, the caller server will write
# NexusOperationCancelRequestCompleted. For TRY_CANCEL we want to prove that the nexus
# op handle future can be resolved as cancelled before any of that.
await test_context.caller_op_future_resolved
caller_wf: WorkflowHandle[Any, CancellationResult] = (
nexus.client().get_workflow_handle_for(
CallerWorkflow.run,
workflow_id=test_context.caller_workflow_id,
)
)
await caller_wf.execute_update(
CallerWorkflow.wait_caller_op_future_resolved
)
test_context.cancel_handler_released.set_result(datetime.now(timezone.utc))
await super().cancel(ctx, token)

Expand All @@ -117,6 +123,7 @@ class Input:
@dataclass
class CancellationResult:
operation_token: str
caller_op_future_resolved: datetime


@workflow.defn(sandboxed=False)
Expand All @@ -129,6 +136,7 @@ def __init__(self, input: Input):
)
self.released = False
self.operation_token: Optional[str] = None
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()

@workflow.signal
def release(self):
Expand All @@ -140,6 +148,10 @@ async def get_operation_token(self) -> str:
assert self.operation_token
return self.operation_token

@workflow.update
async def wait_caller_op_future_resolved(self) -> None:
await self.caller_op_future_resolved

@workflow.run
async def run(self, input: Input) -> CancellationResult:
op_handle = await (
Expand Down Expand Up @@ -188,9 +200,7 @@ async def run(self, input: Input) -> CancellationResult:
try:
await op_handle
except exceptions.NexusOperationError:
test_context.caller_op_future_resolved.set_result(
datetime.now(timezone.utc)
)
self.caller_op_future_resolved.set_result(workflow.now())
assert op_handle.operation_token
if input.cancellation_type in [
workflow.NexusOperationCancellationType.TRY_CANCEL,
Expand All @@ -210,6 +220,7 @@ async def run(self, input: Input) -> CancellationResult:
await workflow.wait_condition(lambda: self.released)
return CancellationResult(
operation_token=op_handle.operation_token,
caller_op_future_resolved=self.caller_op_future_resolved.result(),
)
else:
pytest.fail("Expected NexusOperationError")
Expand All @@ -233,7 +244,10 @@ async def test_cancellation_type(

cancellation_type = workflow.NexusOperationCancellationType[cancellation_type_name]
global test_context
test_context = TestContext(cancellation_type=cancellation_type)
test_context = TestContext(
cancellation_type=cancellation_type,
caller_workflow_id="caller-wf-" + str(uuid.uuid4()),
)

client = env.client

Expand All @@ -253,7 +267,7 @@ async def test_cancellation_type(
endpoint=make_nexus_endpoint_name(worker.task_queue),
cancellation_type=cancellation_type,
),
id="caller-wf-" + str(uuid.uuid4()),
id=test_context.caller_workflow_id,
task_queue=worker.task_queue,
id_conflict_policy=WorkflowIDConflictPolicy.FAIL,
)
Expand Down Expand Up @@ -314,8 +328,12 @@ async def check_behavior_for_try_cancel(
) -> None:
"""
Check that a cancellation request is sent and the caller workflow nexus op future is unblocked
as cancelled before the cancel handler returns (i.e. before the
NexusOperationCancelRequestCompleted in the caller workflow history).
as cancelled before the caller server writes CANCEL_REQUESTED.

There is a race between (a) the caller server writing CANCEL_REQUEST_COMPLETED in response to
the cancel handler returning, and (b) the caller server writing CANCELED in response to the
handler workflow exiting as canceled. If (b) happens first then (a) may never happen, therefore
we do not make any assertions regarding CANCEL_REQUEST_COMPLETED.
"""
try:
await handler_wf.result()
Expand All @@ -324,31 +342,21 @@ async def check_behavior_for_try_cancel(
else:
pytest.fail("Expected WorkflowFailureError")
await caller_wf.signal(CallerWorkflow.release)
await caller_wf.result()
result = await caller_wf.result()

handler_status = (await handler_wf.describe()).status
assert handler_status == WorkflowExecutionStatus.CANCELED
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
await assert_event_subsequence(
[
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED),
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED),
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED),
]
)
op_cancel_requested_event = await get_event_time(
caller_wf,
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUESTED,
)
op_cancel_request_completed_event = await get_event_time(
caller_wf,
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED,
)
assert (
caller_op_future_resolved
< op_cancel_requested_event
< op_cancel_request_completed_event
)
assert result.caller_op_future_resolved < op_cancel_requested_event


async def check_behavior_for_wait_cancellation_requested(
Expand All @@ -369,7 +377,7 @@ async def check_behavior_for_wait_cancellation_requested(
pytest.fail("Expected WorkflowFailureError")

await caller_wf.signal(CallerWorkflow.release)
await caller_wf.result()
result = await caller_wf.result()

handler_status = (await handler_wf.describe()).status
assert handler_status == WorkflowExecutionStatus.CANCELED
Expand All @@ -380,7 +388,6 @@ async def check_behavior_for_wait_cancellation_requested(
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_CANCELED),
]
)
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
op_cancel_request_completed = await get_event_time(
caller_wf,
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_COMPLETED,
Expand All @@ -389,7 +396,7 @@ async def check_behavior_for_wait_cancellation_requested(
handler_wf,
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
)
assert op_cancel_request_completed < caller_op_future_resolved < op_canceled
assert op_cancel_request_completed < result.caller_op_future_resolved < op_canceled


async def check_behavior_for_wait_cancellation_completed(
Expand All @@ -411,7 +418,7 @@ async def check_behavior_for_wait_cancellation_completed(
assert handler_status == WorkflowExecutionStatus.CANCELED

await caller_wf.signal(CallerWorkflow.release)
await caller_wf.result()
result = await caller_wf.result()

await assert_event_subsequence(
[
Expand All @@ -426,12 +433,11 @@ async def check_behavior_for_wait_cancellation_completed(
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
]
)
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
handler_wf_canceled_event_time = await get_event_time(
handler_wf_canceled_event = await get_event_time(
handler_wf,
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_CANCELED,
)
assert caller_op_future_resolved > handler_wf_canceled_event_time
assert handler_wf_canceled_event < result.caller_op_future_resolved


async def has_event(wf_handle: WorkflowHandle, event_type: EventType.ValueType):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@
class TestContext:
__test__ = False
cancellation_type: workflow.NexusOperationCancellationType
caller_op_future_resolved: asyncio.Future[datetime] = field(
default_factory=asyncio.Future
)
cancel_handler_released: asyncio.Future[datetime] = field(
default_factory=asyncio.Future
)
Expand Down Expand Up @@ -129,6 +126,7 @@ class Input:
@dataclass
class CancellationResult:
operation_token: str
caller_op_future_resolved: datetime
error_type: Optional[str] = None
error_cause_type: Optional[str] = None

Expand All @@ -143,6 +141,7 @@ def __init__(self, input: Input):
)
self.released = False
self.operation_token: Optional[str] = None
self.caller_op_future_resolved: asyncio.Future[datetime] = asyncio.Future()

@workflow.signal
def release(self):
Expand Down Expand Up @@ -184,13 +183,14 @@ async def run(self, input: Input) -> CancellationResult:
error_type = err.__class__.__name__
error_cause_type = err.__cause__.__class__.__name__

test_context.caller_op_future_resolved.set_result(datetime.now(timezone.utc))
self.caller_op_future_resolved.set_result(workflow.now())
assert op_handle.operation_token
await workflow.wait_condition(lambda: self.released)
return CancellationResult(
operation_token=op_handle.operation_token,
error_type=error_type,
error_cause_type=error_cause_type,
caller_op_future_resolved=self.caller_op_future_resolved.result(),
)


Expand Down Expand Up @@ -300,7 +300,6 @@ async def check_behavior_for_try_cancel(
assert result.error_type == "NexusOperationError"
assert result.error_cause_type == "CancelledError"

caller_op_future_resolved = test_context.caller_op_future_resolved.result()
await assert_event_subsequence(
[
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_STARTED),
Expand All @@ -317,7 +316,7 @@ async def check_behavior_for_try_cancel(
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
)
assert (
caller_op_future_resolved
result.caller_op_future_resolved
< op_cancel_requested_event
< op_cancel_request_failed_event
)
Expand All @@ -341,7 +340,6 @@ async def check_behavior_for_wait_cancellation_requested(
(caller_wf, EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED),
]
)
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
op_cancel_request_failed = await get_event_time(
caller_wf,
EventType.EVENT_TYPE_NEXUS_OPERATION_CANCEL_REQUEST_FAILED,
Expand All @@ -350,7 +348,11 @@ async def check_behavior_for_wait_cancellation_requested(
handler_wf,
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
)
assert op_cancel_request_failed < caller_op_future_resolved < handler_wf_completed
assert (
op_cancel_request_failed
< result.caller_op_future_resolved
< handler_wf_completed
)


async def check_behavior_for_wait_cancellation_completed(
Expand All @@ -373,9 +375,8 @@ async def check_behavior_for_wait_cancellation_completed(
(caller_wf, EventType.EVENT_TYPE_NEXUS_OPERATION_COMPLETED),
]
)
caller_op_future_resolved = test_context.caller_op_future_resolved.result()
handler_wf_completed = await get_event_time(
handler_wf,
EventType.EVENT_TYPE_WORKFLOW_EXECUTION_COMPLETED,
)
assert handler_wf_completed < caller_op_future_resolved
assert handler_wf_completed < result.caller_op_future_resolved
Loading