From 6f91e65acc4feca6ecd2d42cd68e352369a6524e Mon Sep 17 00:00:00 2001 From: Jake Howard Date: Fri, 30 Aug 2024 14:53:02 +0100 Subject: [PATCH] Prevent writing to task results The code which actually does it gets slower, but it prevents a number of potential foot-guns --- django_tasks/backends/database/backend.py | 2 +- django_tasks/backends/database/models.py | 4 ++-- django_tasks/backends/immediate.py | 20 ++++++++++++-------- django_tasks/task.py | 6 +++--- tests/tests/test_dummy_backend.py | 4 ++-- 5 files changed, 20 insertions(+), 16 deletions(-) diff --git a/django_tasks/backends/database/backend.py b/django_tasks/backends/database/backend.py index b6756b8..e4761a7 100644 --- a/django_tasks/backends/database/backend.py +++ b/django_tasks/backends/database/backend.py @@ -21,7 +21,7 @@ P = ParamSpec("P") -@dataclass +@dataclass(frozen=True) class TaskResult(BaseTaskResult[T]): db_result: "DBTaskResult" diff --git a/django_tasks/backends/database/models.py b/django_tasks/backends/database/models.py index 42cfb85..32b6e2a 100644 --- a/django_tasks/backends/database/models.py +++ b/django_tasks/backends/database/models.py @@ -148,8 +148,8 @@ def task_result(self) -> "TaskResult[T]": backend=self.backend_name, ) - result._return_value = self.return_value - result._exception_data = self.exception_data + object.__setattr__(result, "_return_value", self.return_value) + object.__setattr__(result, "_exception_data", self.exception_data) return result diff --git a/django_tasks/backends/immediate.py b/django_tasks/backends/immediate.py index b94aa45..7402584 100644 --- a/django_tasks/backends/immediate.py +++ b/django_tasks/backends/immediate.py @@ -34,15 +34,19 @@ def _execute_task(self, task_result: TaskResult) -> None: async_to_sync(task.func) if iscoroutinefunction(task.func) else task.func ) - task_result.started_at = timezone.now() + object.__setattr__(task_result, "started_at", timezone.now()) try: - task_result._return_value = json_normalize( - calling_task_func(*task_result.args, **task_result.kwargs) + object.__setattr__( + task_result, + "_return_value", + json_normalize( + calling_task_func(*task_result.args, **task_result.kwargs) + ), ) except BaseException as e: - task_result.finished_at = timezone.now() + object.__setattr__(task_result, "finished_at", timezone.now()) try: - task_result._exception_data = exception_to_dict(e) + object.__setattr__(task_result, "_exception_data", exception_to_dict(e)) except Exception: logger.exception("Task id=%s unable to save exception", task_result.id) @@ -53,14 +57,14 @@ def _execute_task(self, task_result: TaskResult) -> None: task.module_path, ResultStatus.FAILED, ) - task_result.status = ResultStatus.FAILED + object.__setattr__(task_result, "status", ResultStatus.FAILED) # If the user tried to terminate, let them if isinstance(e, KeyboardInterrupt): raise else: - task_result.finished_at = timezone.now() - task_result.status = ResultStatus.COMPLETE + object.__setattr__(task_result, "finished_at", timezone.now()) + object.__setattr__(task_result, "status", ResultStatus.COMPLETE) def enqueue( self, task: Task[P, T], args: P.args, kwargs: P.kwargs diff --git a/django_tasks/task.py b/django_tasks/task.py index fb42adb..e2f4b52 100644 --- a/django_tasks/task.py +++ b/django_tasks/task.py @@ -218,7 +218,7 @@ def wrapper(f: Callable[P, T]) -> Task[P, T]: return wrapper -@dataclass +@dataclass(frozen=True) class TaskResult(Generic[T]): task: Task """The task for which this is a result""" @@ -292,7 +292,7 @@ def refresh(self) -> None: refreshed_task = self.task.get_backend().get_result(self.id) for attr in TASK_REFRESH_ATTRS: - setattr(self, attr, getattr(refreshed_task, attr)) + object.__setattr__(self, attr, getattr(refreshed_task, attr)) async def arefresh(self) -> None: """ @@ -301,4 +301,4 @@ async def arefresh(self) -> None: refreshed_task = await self.task.get_backend().aget_result(self.id) for attr in TASK_REFRESH_ATTRS: - setattr(self, attr, getattr(refreshed_task, attr)) + object.__setattr__(self, attr, getattr(refreshed_task, attr)) diff --git a/tests/tests/test_dummy_backend.py b/tests/tests/test_dummy_backend.py index 3fbefd2..d5b23dd 100644 --- a/tests/tests/test_dummy_backend.py +++ b/tests/tests/test_dummy_backend.py @@ -74,7 +74,7 @@ def test_refresh_result(self) -> None: ) enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined] - enqueued_result.status = ResultStatus.COMPLETE + object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE) self.assertEqual(result.status, ResultStatus.NEW) result.refresh() @@ -86,7 +86,7 @@ async def test_refresh_result_async(self) -> None: ) enqueued_result = default_task_backend.results[0] # type:ignore[attr-defined] - enqueued_result.status = ResultStatus.COMPLETE + object.__setattr__(enqueued_result, "status", ResultStatus.COMPLETE) self.assertEqual(result.status, ResultStatus.NEW) await result.arefresh()