Skip to content

Commit d26ffb5

Browse files
dandavisontconley1428
authored andcommitted
Activity worker: refactoring part 2 (#899)
1 parent 6618aa1 commit d26ffb5

File tree

1 file changed

+31
-26
lines changed

1 file changed

+31
-26
lines changed

temporalio/worker/_activity.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -138,24 +138,21 @@ def __init__(
138138
self._dynamic_activity = defn
139139

140140
async def run(self) -> None:
141-
# Create a task that fails when we get a failure on the queue
142-
async def raise_from_queue() -> NoReturn:
141+
"""Continually poll for activity tasks and dispatch to handlers."""
142+
143+
async def raise_from_exception_queue() -> NoReturn:
143144
raise await self._fail_worker_exception_queue.get()
144145

145-
exception_task = asyncio.create_task(raise_from_queue())
146+
exception_task = asyncio.create_task(raise_from_exception_queue())
146147

147-
# Continually poll for activity work
148148
while True:
149149
try:
150-
# Poll for a task
151150
poll_task = asyncio.create_task(
152151
self._bridge_worker().poll_activity_task()
153152
)
154153
await asyncio.wait(
155154
[poll_task, exception_task], return_when=asyncio.FIRST_COMPLETED
156-
) # type: ignore
157-
# If exception for failing the worker happened, raise it.
158-
# Otherwise, the poll succeeded.
155+
)
159156
if exception_task.done():
160157
poll_task.cancel()
161158
await exception_task
@@ -169,11 +166,14 @@ async def raise_from_queue() -> NoReturn:
169166
# size of 1000 should be plenty for the heartbeat queue.
170167
activity = _RunningActivity(pending_heartbeats=asyncio.Queue(1000))
171168
activity.task = asyncio.create_task(
172-
self._run_activity(task.task_token, task.start, activity)
169+
self._handle_start_activity_task(
170+
task.task_token, task.start, activity
171+
)
173172
)
174173
self._running_activities[task.task_token] = activity
175174
elif task.HasField("cancel"):
176-
self._cancel(task.task_token, task.cancel)
175+
# TODO(nexus-prerelease): does the task get removed from running_activities?
176+
self._handle_cancel_activity_task(task.task_token, task.cancel)
177177
else:
178178
raise RuntimeError(f"Unrecognized activity task: {task}")
179179
except temporalio.bridge.worker.PollShutdownError:
@@ -210,9 +210,10 @@ async def wait_all_completed(self) -> None:
210210
if running_tasks:
211211
await asyncio.gather(*running_tasks, return_exceptions=False)
212212

213-
def _cancel(
213+
def _handle_cancel_activity_task(
214214
self, task_token: bytes, cancel: temporalio.bridge.proto.activity_task.Cancel
215215
) -> None:
216+
"""Request cancellation of a running activity task."""
216217
activity = self._running_activities.get(task_token)
217218
if not activity:
218219
warnings.warn(f"Cannot find activity to cancel for token {task_token!r}")
@@ -277,12 +278,17 @@ async def _heartbeat_async(
277278
)
278279
activity.cancel(cancelled_due_to_heartbeat_error=err)
279280

280-
async def _run_activity(
281+
async def _handle_start_activity_task(
281282
self,
282283
task_token: bytes,
283284
start: temporalio.bridge.proto.activity_task.Start,
284285
running_activity: _RunningActivity,
285286
) -> None:
287+
"""Handle a start activity task.
288+
289+
Attempt to execute the user activity function and invoke the data converter on
290+
the result. Handle errors and send the task completion.
291+
"""
286292
logger.debug("Running activity %s (token %s)", start.activity_type, task_token)
287293
# We choose to surround interceptor creation and activity invocation in
288294
# a try block so we can mark the workflow as failed on any error instead
@@ -291,7 +297,9 @@ async def _run_activity(
291297
task_token=task_token
292298
)
293299
try:
294-
await self._execute_activity(start, running_activity, completion)
300+
result = await self._execute_activity(start, running_activity, task_token)
301+
[payload] = await self._data_converter.encode([result])
302+
completion.result.completed.result.CopyFrom(payload)
295303
except BaseException as err:
296304
try:
297305
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -320,7 +328,7 @@ async def _run_activity(
320328
and running_activity.cancellation_details.details.paused
321329
):
322330
temporalio.activity.logger.warning(
323-
f"Completing as failure due to unhandled cancel error produced by activity pause",
331+
"Completing as failure due to unhandled cancel error produced by activity pause",
324332
)
325333
await self._data_converter.encode_failure(
326334
temporalio.exceptions.ApplicationError(
@@ -404,8 +412,12 @@ async def _execute_activity(
404412
self,
405413
start: temporalio.bridge.proto.activity_task.Start,
406414
running_activity: _RunningActivity,
407-
completion: temporalio.bridge.proto.ActivityTaskCompletion,
408-
):
415+
task_token: bytes,
416+
) -> Any:
417+
"""Invoke the user's activity function.
418+
419+
Exceptions are handled by a caller of this function.
420+
"""
409421
# Find activity or fail
410422
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
411423
if not activity_def:
@@ -525,7 +537,7 @@ async def _execute_activity(
525537
else None,
526538
started_time=_proto_to_datetime(start.started_time),
527539
task_queue=self._task_queue,
528-
task_token=completion.task_token,
540+
task_token=task_token,
529541
workflow_id=start.workflow_execution.workflow_id,
530542
workflow_namespace=start.workflow_namespace,
531543
workflow_run_id=start.workflow_execution.run_id,
@@ -564,16 +576,9 @@ async def _execute_activity(
564576
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
565577
for interceptor in reversed(list(self._interceptors)):
566578
impl = interceptor.intercept_activity(impl)
567-
# Init
579+
568580
impl.init(_ActivityOutboundImpl(self, running_activity.info))
569-
# Exec
570-
result = await impl.execute_activity(input)
571-
# Convert result even if none. Since Python essentially only
572-
# supports single result types (even if they are tuples), we will do
573-
# the same.
574-
completion.result.completed.result.CopyFrom(
575-
(await self._data_converter.encode([result]))[0]
576-
)
581+
return await impl.execute_activity(input)
577582

578583
def assert_activity_valid(self, activity) -> None:
579584
if self._dynamic_activity:

0 commit comments

Comments
 (0)