11import re
22import warnings
3+ from collections .abc import Callable
34from dataclasses import dataclass , field
45from datetime import datetime , timedelta
56from enum import Enum
67from functools import lru_cache
78from pathlib import Path
9+ from typing import Any
810from uuid import UUID
911
1012import boto3
@@ -108,19 +110,19 @@ def to_message(self) -> core_pb2.TaskLease:
108110
109111
110112@dataclass (order = True )
111- class ProgressBar :
113+ class ProgressIndicator :
112114 label : str | None
113115 total : int
114116 done : int
115117
116118 @classmethod
117- def from_message (cls , progress_bar : core_pb2 .ProgressBar ) -> "ProgressBar " :
118- """Convert a ProgressBar protobuf message to a ProgressBar object."""
119- return cls (label = progress_bar .label or None , total = progress_bar .total , done = progress_bar .done )
119+ def from_message (cls , progress_indicator : core_pb2 .Progress ) -> "ProgressIndicator " :
120+ """Convert a ProgressIndicator protobuf message to a ProgressIndicator object."""
121+ return cls (label = progress_indicator .label or None , total = progress_indicator .total , done = progress_indicator .done )
120122
121- def to_message (self ) -> core_pb2 .ProgressBar :
122- """Convert a ProgressBar object to a ProgressBar protobuf message."""
123- return core_pb2 .ProgressBar (label = self .label , total = self .total , done = self .done )
123+ def to_message (self ) -> core_pb2 .Progress :
124+ """Convert a ProgressIndicator object to a ProgressIndicator protobuf message."""
125+ return core_pb2 .Progress (label = self .label , total = self .total , done = self .done )
124126
125127
126128@dataclass (order = True )
@@ -195,7 +197,7 @@ class JobState(Enum):
195197_JOB_STATES = {state .value : state for state in JobState }
196198
197199
198- @dataclass (order = True )
200+ @dataclass (order = True , frozen = True )
199201class Job :
200202 id : UUID
201203 name : str
@@ -204,10 +206,12 @@ class Job:
204206 submitted_at : datetime
205207 started_at : datetime | None
206208 canceled : bool
207- progress_bars : list [ProgressBar ]
209+ progress : list [ProgressIndicator ]
208210
209211 @classmethod
210- def from_message (cls , job : core_pb2 .Job ) -> "Job" : # lets use typing.Self once we require python >= 3.11
212+ def from_message (
213+ cls , job : core_pb2 .Job , ** extra_kwargs : Any
214+ ) -> "Job" : # lets use typing.Self once we require python >= 3.11
211215 """Convert a Job protobuf message to a Job object."""
212216 return cls (
213217 id = uuid_message_to_uuid (job .id ),
@@ -217,7 +221,8 @@ def from_message(cls, job: core_pb2.Job) -> "Job": # lets use typing.Self once
217221 submitted_at = timestamp_to_datetime (job .submitted_at ),
218222 started_at = timestamp_to_datetime (job .started_at ) if job .HasField ("started_at" ) else None ,
219223 canceled = job .canceled ,
220- progress_bars = [ProgressBar .from_message (progress_bar ) for progress_bar in job .progress_bars ],
224+ progress = [ProgressIndicator .from_message (progress ) for progress in job .progress ],
225+ ** extra_kwargs ,
221226 )
222227
223228 def to_message (self ) -> core_pb2 .Job :
@@ -230,7 +235,7 @@ def to_message(self) -> core_pb2.Job:
230235 submitted_at = datetime_to_timestamp (self .submitted_at ),
231236 started_at = datetime_to_timestamp (self .started_at ) if self .started_at else None ,
232237 canceled = self .canceled ,
233- progress_bars = [ progress_bar .to_message () for progress_bar in self .progress_bars ],
238+ progress = [ progress .to_message () for progress in self .progress ],
234239 )
235240
236241
@@ -303,7 +308,7 @@ class ComputedTask:
303308 id : UUID
304309 display : str | None
305310 sub_tasks : list [TaskSubmission ]
306- progress_updates : list [ProgressBar ]
311+ progress_updates : list [ProgressIndicator ]
307312
308313 @classmethod
309314 def from_message (cls , computed_task : task_pb2 .ComputedTask ) -> "ComputedTask" :
@@ -312,7 +317,7 @@ def from_message(cls, computed_task: task_pb2.ComputedTask) -> "ComputedTask":
312317 id = uuid_message_to_uuid (computed_task .id ),
313318 display = computed_task .display ,
314319 sub_tasks = [TaskSubmission .from_message (sub_task ) for sub_task in computed_task .sub_tasks ],
315- progress_updates = [ProgressBar .from_message (progress ) for progress in computed_task .progress_updates ],
320+ progress_updates = [ProgressIndicator .from_message (progress ) for progress in computed_task .progress_updates ],
316321 )
317322
318323 def to_message (self ) -> task_pb2 .ComputedTask :
@@ -571,9 +576,13 @@ class QueryJobsResponse:
571576 next_page : Pagination
572577
573578 @classmethod
574- def from_message (cls , page : job_pb2 .QueryJobsResponse ) -> "QueryJobsResponse" :
579+ def from_message (
580+ cls ,
581+ page : job_pb2 .QueryJobsResponse ,
582+ job_factory : Callable [[core_pb2 .Job ], Job ] = Job .from_message ,
583+ ) -> "QueryJobsResponse" :
575584 return cls (
576- jobs = [Job . from_message (job ) for job in page .jobs ],
585+ jobs = [job_factory (job ) for job in page .jobs ],
577586 next_page = Pagination .from_message (page .next_page ),
578587 )
579588
0 commit comments