Skip to content

Commit 1aa3165

Browse files
committed
Refactor activity worker
1 parent 8449a35 commit 1aa3165

File tree

1 file changed

+176
-177
lines changed

1 file changed

+176
-177
lines changed

temporalio/worker/_activity.py

Lines changed: 176 additions & 177 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ async def drain_poll_queue(self) -> None:
201201

202202
# Only call this after run()/drain_poll_queue() have returned. This will not
203203
# raise an exception.
204+
# TODO(dan): check accuracy of this comment; I would say it *does* raise an exception.
204205
async def wait_all_completed(self) -> None:
205206
running_tasks = [v.task for v in self._running_activities.values() if v.task]
206207
if running_tasks:
@@ -281,183 +282,7 @@ async def _run_activity(
281282
task_token=task_token
282283
)
283284
try:
284-
# Find activity or fail
285-
activity_def = self._activities.get(
286-
start.activity_type, self._dynamic_activity
287-
)
288-
if not activity_def:
289-
activity_names = ", ".join(sorted(self._activities.keys()))
290-
raise temporalio.exceptions.ApplicationError(
291-
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
292-
f"is not registered on this worker, available activities: {activity_names}",
293-
type="NotFoundError",
294-
)
295-
296-
# Create the worker shutdown event if not created
297-
if not self._worker_shutdown_event:
298-
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
299-
thread_event=threading.Event(), async_event=asyncio.Event()
300-
)
301-
302-
# Setup events
303-
sync_non_threaded = False
304-
if not activity_def.is_async:
305-
running_activity.sync = True
306-
# If we're in a thread-pool executor we can use threading events
307-
# otherwise we must use manager events
308-
if isinstance(
309-
self._activity_executor, concurrent.futures.ThreadPoolExecutor
310-
):
311-
running_activity.cancelled_event = (
312-
temporalio.activity._CompositeEvent(
313-
thread_event=threading.Event(),
314-
# No async event
315-
async_event=None,
316-
)
317-
)
318-
if not activity_def.no_thread_cancel_exception:
319-
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
320-
else:
321-
sync_non_threaded = True
322-
manager = self._shared_state_manager
323-
# Pre-checked on worker init
324-
assert manager
325-
running_activity.cancelled_event = (
326-
temporalio.activity._CompositeEvent(
327-
thread_event=manager.new_event(),
328-
# No async event
329-
async_event=None,
330-
)
331-
)
332-
# We also must set the worker shutdown thread event to a
333-
# manager event if this is the first sync event. We don't
334-
# want to create if there never is a sync event.
335-
if not self._seen_sync_activity:
336-
self._worker_shutdown_event.thread_event = manager.new_event()
337-
# Say we've seen a sync activity
338-
self._seen_sync_activity = True
339-
else:
340-
# We have to set the async form of events
341-
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
342-
thread_event=threading.Event(),
343-
async_event=asyncio.Event(),
344-
)
345-
346-
# Convert arguments. We use raw value for dynamic. Otherwise, we
347-
# only use arg type hints if they match the input count.
348-
arg_types = activity_def.arg_types
349-
if not activity_def.name:
350-
# Dynamic is just the raw value for each input value
351-
arg_types = [temporalio.common.RawValue] * len(start.input)
352-
elif arg_types is not None and len(arg_types) != len(start.input):
353-
arg_types = None
354-
try:
355-
args = (
356-
[]
357-
if not start.input
358-
else await self._data_converter.decode(
359-
start.input, type_hints=arg_types
360-
)
361-
)
362-
except Exception as err:
363-
raise temporalio.exceptions.ApplicationError(
364-
"Failed decoding arguments"
365-
) from err
366-
# Put the args inside a list if dynamic
367-
if not activity_def.name:
368-
args = [args]
369-
370-
# Convert heartbeat details
371-
# TODO(cretz): Allow some way to configure heartbeat type hinting?
372-
try:
373-
heartbeat_details = (
374-
[]
375-
if not start.heartbeat_details
376-
else await self._data_converter.decode(start.heartbeat_details)
377-
)
378-
except Exception as err:
379-
raise temporalio.exceptions.ApplicationError(
380-
"Failed decoding heartbeat details", non_retryable=True
381-
) from err
382-
383-
# Build info
384-
info = temporalio.activity.Info(
385-
activity_id=start.activity_id,
386-
activity_type=start.activity_type,
387-
attempt=start.attempt,
388-
current_attempt_scheduled_time=_proto_to_datetime(
389-
start.current_attempt_scheduled_time
390-
),
391-
heartbeat_details=heartbeat_details,
392-
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
393-
if start.HasField("heartbeat_timeout")
394-
else None,
395-
is_local=start.is_local,
396-
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
397-
start.schedule_to_close_timeout
398-
)
399-
if start.HasField("schedule_to_close_timeout")
400-
else None,
401-
scheduled_time=_proto_to_datetime(start.scheduled_time),
402-
start_to_close_timeout=_proto_to_non_zero_timedelta(
403-
start.start_to_close_timeout
404-
)
405-
if start.HasField("start_to_close_timeout")
406-
else None,
407-
started_time=_proto_to_datetime(start.started_time),
408-
task_queue=self._task_queue,
409-
task_token=task_token,
410-
workflow_id=start.workflow_execution.workflow_id,
411-
workflow_namespace=start.workflow_namespace,
412-
workflow_run_id=start.workflow_execution.run_id,
413-
workflow_type=start.workflow_type,
414-
priority=temporalio.common.Priority._from_proto(start.priority),
415-
)
416-
running_activity.info = info
417-
input = ExecuteActivityInput(
418-
fn=activity_def.fn,
419-
args=args,
420-
executor=None if not running_activity.sync else self._activity_executor,
421-
headers=start.header_fields,
422-
)
423-
424-
# Set the context early so the logging adapter works and
425-
# interceptors have it
426-
temporalio.activity._Context.set(
427-
temporalio.activity._Context(
428-
info=lambda: info,
429-
heartbeat=None,
430-
cancelled_event=running_activity.cancelled_event,
431-
worker_shutdown_event=self._worker_shutdown_event,
432-
shield_thread_cancel_exception=None
433-
if not running_activity.cancel_thread_raiser
434-
else running_activity.cancel_thread_raiser.shielded,
435-
payload_converter_class_or_instance=self._data_converter.payload_converter,
436-
runtime_metric_meter=None
437-
if sync_non_threaded
438-
else self._metric_meter,
439-
)
440-
)
441-
temporalio.activity.logger.debug("Starting activity")
442-
443-
# Build the interceptors chaining in reverse. We build a context right
444-
# now even though the info() can't be intercepted and heartbeat() will
445-
# fail. The interceptors may want to use the info() during init.
446-
impl: ActivityInboundInterceptor = _ActivityInboundImpl(
447-
self, running_activity
448-
)
449-
for interceptor in reversed(list(self._interceptors)):
450-
impl = interceptor.intercept_activity(impl)
451-
# Init
452-
impl.init(_ActivityOutboundImpl(self, running_activity.info))
453-
# Exec
454-
result = await impl.execute_activity(input)
455-
# Convert result even if none. Since Python essentially only
456-
# supports single result types (even if they are tuples), we will do
457-
# the same.
458-
completion.result.completed.result.CopyFrom(
459-
(await self._data_converter.encode([result]))[0]
460-
)
285+
await self._execute_activity(start, running_activity, completion)
461286
except BaseException as err:
462287
try:
463288
if isinstance(err, temporalio.activity._CompleteAsyncError):
@@ -545,6 +370,180 @@ async def _run_activity(
545370
except Exception:
546371
temporalio.activity.logger.exception("Failed completing activity task")
547372

373+
async def _execute_activity(
374+
self,
375+
start: temporalio.bridge.proto.activity_task.Start,
376+
running_activity: _RunningActivity,
377+
completion: temporalio.bridge.proto.ActivityTaskCompletion,
378+
):
379+
# Find activity or fail
380+
activity_def = self._activities.get(start.activity_type, self._dynamic_activity)
381+
if not activity_def:
382+
activity_names = ", ".join(sorted(self._activities.keys()))
383+
raise temporalio.exceptions.ApplicationError(
384+
f"Activity function {start.activity_type} for workflow {start.workflow_execution.workflow_id} "
385+
f"is not registered on this worker, available activities: {activity_names}",
386+
type="NotFoundError",
387+
)
388+
389+
# Create the worker shutdown event if not created
390+
if not self._worker_shutdown_event:
391+
self._worker_shutdown_event = temporalio.activity._CompositeEvent(
392+
thread_event=threading.Event(), async_event=asyncio.Event()
393+
)
394+
395+
# Setup events
396+
sync_non_threaded = False
397+
if not activity_def.is_async:
398+
running_activity.sync = True
399+
# If we're in a thread-pool executor we can use threading events
400+
# otherwise we must use manager events
401+
if isinstance(
402+
self._activity_executor, concurrent.futures.ThreadPoolExecutor
403+
):
404+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
405+
thread_event=threading.Event(),
406+
# No async event
407+
async_event=None,
408+
)
409+
if not activity_def.no_thread_cancel_exception:
410+
running_activity.cancel_thread_raiser = _ThreadExceptionRaiser()
411+
else:
412+
sync_non_threaded = True
413+
manager = self._shared_state_manager
414+
# Pre-checked on worker init
415+
assert manager
416+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
417+
thread_event=manager.new_event(),
418+
# No async event
419+
async_event=None,
420+
)
421+
# We also must set the worker shutdown thread event to a
422+
# manager event if this is the first sync event. We don't
423+
# want to create if there never is a sync event.
424+
if not self._seen_sync_activity:
425+
self._worker_shutdown_event.thread_event = manager.new_event()
426+
# Say we've seen a sync activity
427+
self._seen_sync_activity = True
428+
else:
429+
# We have to set the async form of events
430+
running_activity.cancelled_event = temporalio.activity._CompositeEvent(
431+
thread_event=threading.Event(),
432+
async_event=asyncio.Event(),
433+
)
434+
435+
# Convert arguments. We use raw value for dynamic. Otherwise, we
436+
# only use arg type hints if they match the input count.
437+
arg_types = activity_def.arg_types
438+
if not activity_def.name:
439+
# Dynamic is just the raw value for each input value
440+
arg_types = [temporalio.common.RawValue] * len(start.input)
441+
elif arg_types is not None and len(arg_types) != len(start.input):
442+
arg_types = None
443+
try:
444+
args = (
445+
[]
446+
if not start.input
447+
else await self._data_converter.decode(
448+
start.input, type_hints=arg_types
449+
)
450+
)
451+
except Exception as err:
452+
raise temporalio.exceptions.ApplicationError(
453+
"Failed decoding arguments"
454+
) from err
455+
# Put the args inside a list if dynamic
456+
if not activity_def.name:
457+
args = [args]
458+
459+
# Convert heartbeat details
460+
# TODO(cretz): Allow some way to configure heartbeat type hinting?
461+
try:
462+
heartbeat_details = (
463+
[]
464+
if not start.heartbeat_details
465+
else await self._data_converter.decode(start.heartbeat_details)
466+
)
467+
except Exception as err:
468+
raise temporalio.exceptions.ApplicationError(
469+
"Failed decoding heartbeat details", non_retryable=True
470+
) from err
471+
472+
# Build info
473+
info = temporalio.activity.Info(
474+
activity_id=start.activity_id,
475+
activity_type=start.activity_type,
476+
attempt=start.attempt,
477+
current_attempt_scheduled_time=_proto_to_datetime(
478+
start.current_attempt_scheduled_time
479+
),
480+
heartbeat_details=heartbeat_details,
481+
heartbeat_timeout=_proto_to_non_zero_timedelta(start.heartbeat_timeout)
482+
if start.HasField("heartbeat_timeout")
483+
else None,
484+
is_local=start.is_local,
485+
schedule_to_close_timeout=_proto_to_non_zero_timedelta(
486+
start.schedule_to_close_timeout
487+
)
488+
if start.HasField("schedule_to_close_timeout")
489+
else None,
490+
scheduled_time=_proto_to_datetime(start.scheduled_time),
491+
start_to_close_timeout=_proto_to_non_zero_timedelta(
492+
start.start_to_close_timeout
493+
)
494+
if start.HasField("start_to_close_timeout")
495+
else None,
496+
started_time=_proto_to_datetime(start.started_time),
497+
task_queue=self._task_queue,
498+
task_token=completion.task_token,
499+
workflow_id=start.workflow_execution.workflow_id,
500+
workflow_namespace=start.workflow_namespace,
501+
workflow_run_id=start.workflow_execution.run_id,
502+
workflow_type=start.workflow_type,
503+
priority=temporalio.common.Priority._from_proto(start.priority),
504+
)
505+
running_activity.info = info
506+
input = ExecuteActivityInput(
507+
fn=activity_def.fn,
508+
args=args,
509+
executor=None if not running_activity.sync else self._activity_executor,
510+
headers=start.header_fields,
511+
)
512+
513+
# Set the context early so the logging adapter works and
514+
# interceptors have it
515+
temporalio.activity._Context.set(
516+
temporalio.activity._Context(
517+
info=lambda: info,
518+
heartbeat=None,
519+
cancelled_event=running_activity.cancelled_event,
520+
worker_shutdown_event=self._worker_shutdown_event,
521+
shield_thread_cancel_exception=None
522+
if not running_activity.cancel_thread_raiser
523+
else running_activity.cancel_thread_raiser.shielded,
524+
payload_converter_class_or_instance=self._data_converter.payload_converter,
525+
runtime_metric_meter=None if sync_non_threaded else self._metric_meter,
526+
)
527+
)
528+
temporalio.activity.logger.debug("Starting activity")
529+
530+
# Build the interceptors chaining in reverse. We build a context right
531+
# now even though the info() can't be intercepted and heartbeat() will
532+
# fail. The interceptors may want to use the info() during init.
533+
impl: ActivityInboundInterceptor = _ActivityInboundImpl(self, running_activity)
534+
for interceptor in reversed(list(self._interceptors)):
535+
impl = interceptor.intercept_activity(impl)
536+
# Init
537+
impl.init(_ActivityOutboundImpl(self, running_activity.info))
538+
# Exec
539+
result = await impl.execute_activity(input)
540+
# Convert result even if none. Since Python essentially only
541+
# supports single result types (even if they are tuples), we will do
542+
# the same.
543+
completion.result.completed.result.CopyFrom(
544+
(await self._data_converter.encode([result]))[0]
545+
)
546+
548547
def assert_activity_valid(self, activity) -> None:
549548
if self._dynamic_activity:
550549
return

0 commit comments

Comments
 (0)