1414import threading
1515import warnings
1616from abc import ABC , abstractmethod
17+ from collections .abc import Iterator , Sequence
1718from contextlib import contextmanager
1819from dataclasses import dataclass , field
1920from datetime import datetime , timedelta , timezone
2021from typing import (
2122 Any ,
2223 Callable ,
23- Dict ,
24- Iterator ,
2524 NoReturn ,
2625 Optional ,
27- Sequence ,
28- Tuple ,
29- Type ,
3026 Union ,
3127)
3228
3329import google .protobuf .duration_pb2
3430import google .protobuf .timestamp_pb2
3531
3632import temporalio .activity
37- import temporalio .api .common .v1
38- import temporalio .bridge .client
39- import temporalio .bridge .proto
40- import temporalio .bridge .proto .activity_result
41- import temporalio .bridge .proto .activity_task
42- import temporalio .bridge .proto .common
4333import temporalio .bridge .runtime
4434import temporalio .bridge .worker
4535import temporalio .client
@@ -76,7 +66,7 @@ def __init__(
7666 self ._task_queue = task_queue
7767 self ._activity_executor = activity_executor
7868 self ._shared_state_manager = shared_state_manager
79- self ._running_activities : Dict [bytes , _RunningActivity ] = {}
69+ self ._running_activities : dict [bytes , _RunningActivity ] = {}
8070 self ._data_converter = data_converter
8171 self ._interceptors = interceptors
8272 self ._metric_meter = metric_meter
@@ -90,7 +80,7 @@ def __init__(
9080 self ._client = client
9181
9282 # Validate and build activity dict
93- self ._activities : Dict [str , temporalio .activity ._Definition ] = {}
83+ self ._activities : dict [str , temporalio .activity ._Definition ] = {}
9484 self ._dynamic_activity : Optional [temporalio .activity ._Definition ] = None
9585 for activity in activities :
9686 # Get definition
@@ -178,7 +168,7 @@ async def raise_from_exception_queue() -> NoReturn:
178168 self ._handle_cancel_activity_task (task .task_token , task .cancel )
179169 else :
180170 raise RuntimeError (f"Unrecognized activity task: { task } " )
181- except temporalio .bridge .worker .PollShutdownError :
171+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
182172 exception_task .cancel ()
183173 return
184174 except Exception as err :
@@ -195,12 +185,12 @@ async def drain_poll_queue(self) -> None:
195185 try :
196186 # Just take all tasks and say we can't handle them
197187 task = await self ._bridge_worker ().poll_activity_task ()
198- completion = temporalio .bridge .proto .ActivityTaskCompletion (
188+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
199189 task_token = task .task_token
200190 )
201191 completion .result .failed .failure .message = "Worker shutting down"
202192 await self ._bridge_worker ().complete_activity_task (completion )
203- except temporalio .bridge .worker .PollShutdownError :
193+ except temporalio .bridge .worker .PollShutdownError : # type: ignore[reportPrivateLocalImportUsage]
204194 return
205195
206196 # Only call this after run()/drain_poll_queue() have returned. This will not
@@ -214,7 +204,9 @@ async def wait_all_completed(self) -> None:
214204 await asyncio .gather (* running_tasks , return_exceptions = False )
215205
216206 def _handle_cancel_activity_task (
217- self , task_token : bytes , cancel : temporalio .bridge .proto .activity_task .Cancel
207+ self ,
208+ task_token : bytes ,
209+ cancel : temporalio .bridge .proto .activity_task .Cancel , # type: ignore[reportAttributeAccessIssue]
218210 ) -> None :
219211 """Request cancellation of a running activity task."""
220212 activity = self ._running_activities .get (task_token )
@@ -262,7 +254,9 @@ async def _heartbeat_async(
262254
263255 # Perform the heartbeat
264256 try :
265- heartbeat = temporalio .bridge .proto .ActivityHeartbeat (task_token = task_token )
257+ heartbeat = temporalio .bridge .proto .ActivityHeartbeat ( # type: ignore[reportAttributeAccessIssue]
258+ task_token = task_token
259+ )
266260 if details :
267261 # Convert to core payloads
268262 heartbeat .details .extend (await self ._data_converter .encode (details ))
@@ -284,7 +278,7 @@ async def _heartbeat_async(
284278 async def _handle_start_activity_task (
285279 self ,
286280 task_token : bytes ,
287- start : temporalio .bridge .proto .activity_task .Start ,
281+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
288282 running_activity : _RunningActivity ,
289283 ) -> None :
290284 """Handle a start activity task.
@@ -296,7 +290,7 @@ async def _handle_start_activity_task(
296290 # We choose to surround interceptor creation and activity invocation in
297291 # a try block so we can mark the workflow as failed on any error instead
298292 # of having error handling in the interceptor
299- completion = temporalio .bridge .proto .ActivityTaskCompletion (
293+ completion = temporalio .bridge .proto .ActivityTaskCompletion ( # type: ignore[reportAttributeAccessIssue]
300294 task_token = task_token
301295 )
302296 try :
@@ -413,7 +407,7 @@ async def _handle_start_activity_task(
413407
414408 async def _execute_activity (
415409 self ,
416- start : temporalio .bridge .proto .activity_task .Start ,
410+ start : temporalio .bridge .proto .activity_task .Start , # type: ignore[reportAttributeAccessIssue]
417411 running_activity : _RunningActivity ,
418412 task_token : bytes ,
419413 ) -> Any :
@@ -649,14 +643,14 @@ class _ThreadExceptionRaiser:
649643 def __init__ (self ) -> None :
650644 self ._lock = threading .Lock ()
651645 self ._thread_id : Optional [int ] = None
652- self ._pending_exception : Optional [Type [Exception ]] = None
646+ self ._pending_exception : Optional [type [Exception ]] = None
653647 self ._shield_depth = 0
654648
655649 def set_thread_id (self , thread_id : int ) -> None :
656650 with self ._lock :
657651 self ._thread_id = thread_id
658652
659- def raise_in_thread (self , exc_type : Type [Exception ]) -> None :
653+ def raise_in_thread (self , exc_type : type [Exception ]) -> None :
660654 with self ._lock :
661655 self ._pending_exception = exc_type
662656 self ._raise_in_thread_if_pending_unlocked ()
@@ -812,7 +806,7 @@ def _execute_sync_activity(
812806 cancelled_event : threading .Event ,
813807 worker_shutdown_event : threading .Event ,
814808 payload_converter_class_or_instance : Union [
815- Type [temporalio .converter .PayloadConverter ],
809+ type [temporalio .converter .PayloadConverter ],
816810 temporalio .converter .PayloadConverter ,
817811 ],
818812 runtime_metric_meter : Optional [temporalio .common .MetricMeter ],
@@ -824,13 +818,10 @@ def _execute_sync_activity(
824818 thread_id = threading .current_thread ().ident
825819 if thread_id is not None :
826820 cancel_thread_raiser .set_thread_id (thread_id )
827- heartbeat_fn : Callable [..., None ]
828821 if isinstance (heartbeat , SharedHeartbeatSender ):
829- # To make mypy happy
830- heartbeat_sender = heartbeat
831- heartbeat_fn = lambda * details : heartbeat_sender .send_heartbeat (
832- info .task_token , * details
833- )
822+
823+ def heartbeat_fn (* details : Any ) -> None :
824+ heartbeat .send_heartbeat (info .task_token , * details )
834825 else :
835826 heartbeat_fn = heartbeat
836827 temporalio .activity ._Context .set (
@@ -940,11 +931,11 @@ def __init__(
940931 self ._mgr = mgr
941932 self ._queue_poller_executor = queue_poller_executor
942933 # 1000 in-flight heartbeats should be plenty
943- self ._heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]] = mgr .Queue (
934+ self ._heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]] = mgr .Queue (
944935 1000
945936 )
946- self ._heartbeats : Dict [bytes , Callable [..., None ]] = {}
947- self ._heartbeat_completions : Dict [bytes , Callable ] = {}
937+ self ._heartbeats : dict [bytes , Callable [..., None ]] = {}
938+ self ._heartbeat_completions : dict [bytes , Callable ] = {}
948939
949940 def new_event (self ) -> threading .Event :
950941 return self ._mgr .Event ()
@@ -1002,7 +993,7 @@ def _heartbeat_processor(self) -> None:
1002993
1003994class _MultiprocessingSharedHeartbeatSender (SharedHeartbeatSender ):
1004995 def __init__ (
1005- self , heartbeat_queue : queue .Queue [Tuple [bytes , Sequence [Any ]]]
996+ self , heartbeat_queue : queue .Queue [tuple [bytes , Sequence [Any ]]]
1006997 ) -> None :
1007998 super ().__init__ ()
1008999 self ._heartbeat_queue = heartbeat_queue
0 commit comments