Skip to content

Commit 2f29d96

Browse files
committed
Add pause and reset to base class
1 parent 5faa4dd commit 2f29d96

File tree

1 file changed

+105
-2
lines changed

1 file changed

+105
-2
lines changed

temporalio/client.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3128,6 +3128,111 @@ async def report_cancellation(
31283128
),
31293129
)
31303130

3131+
async def pause(
3132+
self,
3133+
*,
3134+
reason: Optional[str] = None,
3135+
rpc_metadata: Mapping[str, Union[str, bytes]] = {},
3136+
rpc_timeout: Optional[timedelta] = None,
3137+
) -> None:
3138+
"""Pause the activity.
3139+
3140+
Args:
3141+
reason: Reason for pausing the activity.
3142+
rpc_metadata: Headers used on the RPC call. Keys here override
3143+
client-level RPC metadata keys.
3144+
rpc_timeout: Optional RPC deadline to set for the RPC call.
3145+
"""
3146+
if not isinstance(self._id_or_token, ActivityIDReference):
3147+
raise ValueError("Cannot pause activity with task token")
3148+
3149+
await self._client.workflow_service.pause_activity(
3150+
temporalio.api.workflowservice.v1.PauseActivityRequest(
3151+
namespace=self._client.namespace,
3152+
execution=temporalio.api.common.v1.WorkflowExecution(
3153+
workflow_id=self._id_or_token.workflow_id or "",
3154+
run_id=self._id_or_token.run_id or "",
3155+
),
3156+
identity=self._client.identity,
3157+
id=self._id_or_token.activity_id,
3158+
reason=reason or "",
3159+
),
3160+
retry=True,
3161+
metadata=rpc_metadata,
3162+
timeout=rpc_timeout,
3163+
)
3164+
3165+
async def unpause(
3166+
self,
3167+
*,
3168+
reset_attempts: bool = False,
3169+
rpc_metadata: Mapping[str, Union[str, bytes]] = {},
3170+
rpc_timeout: Optional[timedelta] = None,
3171+
) -> None:
3172+
"""Unpause the activity.
3173+
3174+
Args:
3175+
reset_attempts: Whether to reset the number of attempts.
3176+
rpc_metadata: Headers used on the RPC call. Keys here override
3177+
client-level RPC metadata keys.
3178+
rpc_timeout: Optional RPC deadline to set for the RPC call.
3179+
"""
3180+
if not isinstance(self._id_or_token, ActivityIDReference):
3181+
raise ValueError("Cannot unpause activity with task token")
3182+
3183+
await self._client.workflow_service.unpause_activity(
3184+
temporalio.api.workflowservice.v1.UnpauseActivityRequest(
3185+
namespace=self._client.namespace,
3186+
execution=temporalio.api.common.v1.WorkflowExecution(
3187+
workflow_id=self._id_or_token.workflow_id or "",
3188+
run_id=self._id_or_token.run_id or "",
3189+
),
3190+
identity=self._client.identity,
3191+
id=self._id_or_token.activity_id,
3192+
reset_attempts=reset_attempts,
3193+
),
3194+
retry=True,
3195+
metadata=rpc_metadata,
3196+
timeout=rpc_timeout,
3197+
)
3198+
3199+
async def reset(
3200+
self,
3201+
*,
3202+
reset_heartbeat: bool = False,
3203+
keep_paused: bool = False,
3204+
rpc_metadata: Mapping[str, Union[str, bytes]] = {},
3205+
rpc_timeout: Optional[timedelta] = None,
3206+
) -> None:
3207+
"""Reset the activity.
3208+
3209+
Args:
3210+
reset_heartbeat: Whether to reset heartbeat details.
3211+
keep_paused: If activity is paused, whether to keep it paused after reset.
3212+
rpc_metadata: Headers used on the RPC call. Keys here override
3213+
client-level RPC metadata keys.
3214+
rpc_timeout: Optional RPC deadline to set for the RPC call.
3215+
"""
3216+
if not isinstance(self._id_or_token, ActivityIDReference):
3217+
raise ValueError("Cannot reset activity with task token")
3218+
3219+
await self._client.workflow_service.reset_activity(
3220+
temporalio.api.workflowservice.v1.ResetActivityRequest(
3221+
namespace=self._client.namespace,
3222+
execution=temporalio.api.common.v1.WorkflowExecution(
3223+
workflow_id=self._id_or_token.workflow_id or "",
3224+
run_id=self._id_or_token.run_id or "",
3225+
),
3226+
identity=self._client.identity,
3227+
id=self._id_or_token.activity_id,
3228+
reset_heartbeat=reset_heartbeat,
3229+
keep_paused=keep_paused,
3230+
),
3231+
retry=True,
3232+
metadata=rpc_metadata,
3233+
timeout=rpc_timeout,
3234+
)
3235+
31313236

31323237
class WorkflowActivityHandle(_BaseActivityHandle):
31333238
"""Handle representing an activity started by a workflow."""
@@ -3248,8 +3353,6 @@ async def describe(
32483353
raise NotImplementedError
32493354

32503355
# TODO:
3251-
# pause
3252-
# reset
32533356
# update_options
32543357

32553358

0 commit comments

Comments
 (0)