diff --git a/temporalio/activity.py b/temporalio/activity.py index f259a675d..e08081f19 100644 --- a/temporalio/activity.py +++ b/temporalio/activity.py @@ -154,6 +154,7 @@ class ActivityCancellationDetails: not_found: bool = False cancel_requested: bool = False paused: bool = False + reset: bool = False timed_out: bool = False worker_shutdown: bool = False @@ -167,6 +168,7 @@ def _from_proto( paused=proto.is_paused, timed_out=proto.is_timed_out, worker_shutdown=proto.is_worker_shutdown, + reset=proto.is_reset, ) diff --git a/temporalio/bridge/src/client.rs b/temporalio/bridge/src/client.rs index 1ae1967ce..2f4ab867e 100644 --- a/temporalio/bridge/src/client.rs +++ b/temporalio/bridge/src/client.rs @@ -262,6 +262,9 @@ impl ClientRef { "request_cancel_workflow_execution" => { rpc_call!(retry_client, call, request_cancel_workflow_execution) } + "reset_activity" => { + rpc_call!(retry_client, call, reset_activity) + } "reset_sticky_task_queue" => { rpc_call!(retry_client, call, reset_sticky_task_queue) } diff --git a/temporalio/client.py b/temporalio/client.py index 50fab46ba..4eb1dc868 100644 --- a/temporalio/client.py +++ b/temporalio/client.py @@ -6363,11 +6363,16 @@ async def heartbeat_async_activity( metadata=input.rpc_metadata, timeout=input.rpc_timeout, ) - if resp_by_id.cancel_requested or resp_by_id.activity_paused: + if ( + resp_by_id.cancel_requested + or resp_by_id.activity_paused + or resp_by_id.activity_reset + ): raise AsyncActivityCancelledError( details=ActivityCancellationDetails( cancel_requested=resp_by_id.cancel_requested, paused=resp_by_id.activity_paused, + reset=resp_by_id.activity_reset, ) ) @@ -6388,6 +6393,7 @@ async def heartbeat_async_activity( details=ActivityCancellationDetails( cancel_requested=resp.cancel_requested, paused=resp.activity_paused, + reset=resp.activity_reset, ) ) diff --git a/temporalio/worker/_activity.py b/temporalio/worker/_activity.py index c76c8f005..19e406e6f 100644 --- a/temporalio/worker/_activity.py +++ b/temporalio/worker/_activity.py @@ -334,6 +334,24 @@ async def _handle_start_activity_task( ), completion.result.failed.failure, ) + elif ( + isinstance( + err, + (asyncio.CancelledError, temporalio.exceptions.CancelledError), + ) + and running_activity.cancellation_details.details + and running_activity.cancellation_details.details.reset + ): + temporalio.activity.logger.warning( + "Completing as failure due to unhandled cancel error produced by activity reset", + ) + await self._data_converter.encode_failure( + temporalio.exceptions.ApplicationError( + type="ActivityReset", + message="Unhandled activity cancel error produced by activity reset", + ), + completion.result.failed.failure, + ) elif ( isinstance( err, diff --git a/temporalio/worker/_workflow_instance.py b/temporalio/worker/_workflow_instance.py index c93155672..795f1226d 100644 --- a/temporalio/worker/_workflow_instance.py +++ b/temporalio/worker/_workflow_instance.py @@ -2791,6 +2791,7 @@ def _apply_schedule_command( command.user_metadata.summary.CopyFrom( self._instance._payload_converter.to_payload(self._input.summary) ) + print("Activity summary: ", command.user_metadata.summary) if self._input.priority: command.schedule_activity.priority.CopyFrom( self._input.priority._to_proto() diff --git a/tests/worker/test_activity.py b/tests/worker/test_activity.py index 8d39069a7..7868fc281 100644 --- a/tests/worker/test_activity.py +++ b/tests/worker/test_activity.py @@ -9,12 +9,15 @@ import threading import time import uuid +from concurrent.futures import ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool from contextvars import ContextVar from dataclasses import dataclass from datetime import datetime, timedelta, timezone +from time import sleep from typing import Any, Callable, List, NoReturn, Optional, Sequence, Type +import temporalio.api.workflowservice.v1 from temporalio import activity, workflow from temporalio.client import ( AsyncActivityHandle, @@ -1486,3 +1489,105 @@ async def h(): client, worker, heartbeat, retry_max_attempts=2 ) assert result.result == "details: Some detail" + + +async def test_activity_reset_catch( + client: Client, worker: ExternalWorker, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time skipping server doesn't support activity reset") + + @activity.defn + async def wait_cancel() -> str: + req = temporalio.api.workflowservice.v1.ResetActivityRequest( + namespace=client.namespace, + execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=activity.info().workflow_id, + run_id=activity.info().workflow_run_id, + ), + id=activity.info().activity_id, + ) + await client.workflow_service.reset_activity(req) + try: + while True: + await asyncio.sleep(0.3) + activity.heartbeat() + except asyncio.CancelledError: + details = activity.cancellation_details() + assert details is not None + return "Got cancelled error, reset? " + str(details.reset) + + @activity.defn + def sync_wait_cancel() -> str: + req = temporalio.api.workflowservice.v1.ResetActivityRequest( + namespace=client.namespace, + execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=activity.info().workflow_id, + run_id=activity.info().workflow_run_id, + ), + id=activity.info().activity_id, + ) + asyncio.run(client.workflow_service.reset_activity(req)) + try: + while True: + sleep(0.3) + activity.heartbeat() + except temporalio.exceptions.CancelledError: + details = activity.cancellation_details() + assert details is not None + return "Got cancelled error, reset? " + str(details.reset) + except Exception as e: + return str(type(e)) + str(e) + + result = await _execute_workflow_with_activity( + client, + worker, + wait_cancel, + ) + assert result.result == "Got cancelled error, reset? True" + + config = WorkerConfig( + activity_executor=ThreadPoolExecutor(max_workers=1), + ) + result = await _execute_workflow_with_activity( + client, + worker, + sync_wait_cancel, + worker_config=config, + ) + assert result.result == "Got cancelled error, reset? True" + + +async def test_activity_reset_history( + client: Client, worker: ExternalWorker, env: WorkflowEnvironment +): + if env.supports_time_skipping: + pytest.skip("Time skipping server doesn't support activity reset") + + @activity.defn + async def wait_cancel() -> str: + req = temporalio.api.workflowservice.v1.ResetActivityRequest( + namespace=client.namespace, + execution=temporalio.api.common.v1.WorkflowExecution( + workflow_id=activity.info().workflow_id, + run_id=activity.info().workflow_run_id, + ), + id=activity.info().activity_id, + ) + await client.workflow_service.reset_activity(req) + while True: + await asyncio.sleep(0.3) + activity.heartbeat() + + with pytest.raises(WorkflowFailureError) as e: + result = await _execute_workflow_with_activity( + client, + worker, + wait_cancel, + ) + assert isinstance(e.value.cause, ActivityError) + assert isinstance(e.value.cause.cause, ApplicationError) + assert ( + e.value.cause.cause.message + == "Unhandled activity cancel error produced by activity reset" + ) diff --git a/tests/worker/test_workflow.py b/tests/worker/test_workflow.py index e97bf3e02..862f9a456 100644 --- a/tests/worker/test_workflow.py +++ b/tests/worker/test_workflow.py @@ -814,7 +814,10 @@ class SimpleActivityWorkflow: @workflow.run async def run(self, name: str) -> str: return await workflow.execute_activity( - say_hello, name, schedule_to_close_timeout=timedelta(seconds=5) + say_hello, + name, + schedule_to_close_timeout=timedelta(seconds=5), + summary="Do a thing", )