From 1119e3755614eb8b906cddfe31da9461a03e3236 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Fri, 28 Jun 2024 16:06:12 -0300 Subject: [PATCH 1/5] Add initial support for a celery backend --- django_tasks/backends/celery/__init__.py | 3 ++ django_tasks/backends/celery/app.py | 21 ++++++++++++ django_tasks/backends/celery/backend.py | 43 ++++++++++++++++++++++++ pyproject.toml | 1 + 4 files changed, 68 insertions(+) create mode 100644 django_tasks/backends/celery/__init__.py create mode 100644 django_tasks/backends/celery/app.py create mode 100644 django_tasks/backends/celery/backend.py diff --git a/django_tasks/backends/celery/__init__.py b/django_tasks/backends/celery/__init__.py new file mode 100644 index 0000000..9fbaf65 --- /dev/null +++ b/django_tasks/backends/celery/__init__.py @@ -0,0 +1,3 @@ +from .backend import CeleryBackend + +__all__ = ["CeleryBackend"] diff --git a/django_tasks/backends/celery/app.py b/django_tasks/backends/celery/app.py new file mode 100644 index 0000000..864796a --- /dev/null +++ b/django_tasks/backends/celery/app.py @@ -0,0 +1,21 @@ +import os + +from celery import Celery + +from django_tasks.task import DEFAULT_QUEUE_NAME + + +# Set the default Django settings module for the 'celery' program. +django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') +if django_settings is None: + raise ValueError('DJANGO_SETTINGS_MODULE environment variable is not set') + +app = Celery('django_tasks') + +# Using a string here means the worker doesn't have to serialize +# the configuration object to child processes. +# - namespace='CELERY' means all celery-related configuration keys +# should have a `CELERY_` prefix. +app.config_from_object('django.conf:settings', namespace='CELERY') + +app.conf.task_default_queue = DEFAULT_QUEUE_NAME diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py new file mode 100644 index 0000000..075de82 --- /dev/null +++ b/django_tasks/backends/celery/backend.py @@ -0,0 +1,43 @@ +from typing import TypeVar + +from typing_extensions import ParamSpec + +from celery import shared_task +from celery.local import Proxy as CeleryTaskProxy +from django_tasks.backends.base import BaseTaskBackend +from django_tasks.task import Task, TaskResult + + +T = TypeVar("T") +P = ParamSpec("P") + + +class CeleryTask(Task): + + celery_task: CeleryTaskProxy + """Celery proxy to the task in the current celery app task registry.""" + + def __post_init__(self) -> None: + # TODO: allow passing extra celery specific parameters? + celery_task = shared_task()(self.func) + self.celery_task = celery_task + return super().__post_init__() + + +class CeleryBackend(BaseTaskBackend): + task_class = CeleryTask + supports_defer = True + + def enqueue( + self, task: Task[P, T], args: P.args, kwargs: P.kwargs + ) -> TaskResult[T]: + self.validate_task(task) + + apply_async_kwargs = { + "eta": task.run_after, + } + if task.queue_name: + apply_async_kwargs["queue"] = task.queue_name + if task.priority: + apply_async_kwargs["priority"] = task.priority + task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs) diff --git a/pyproject.toml b/pyproject.toml index e445b5a..1ac1202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "coverage", "django-stubs[compatible-mypy]", "dj-database-url", + "celery", ] mysql = [ "mysqlclient" From 05c2bd2cabd75b4b8028878510b7488d49076227 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 1 Jul 2024 21:48:10 -0300 Subject: [PATCH 2/5] Fix lint issues, minor celery app update --- django_tasks/backends/celery/app.py | 3 ++- django_tasks/backends/celery/backend.py | 9 +++++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/django_tasks/backends/celery/app.py b/django_tasks/backends/celery/app.py index 864796a..c219e81 100644 --- a/django_tasks/backends/celery/app.py +++ b/django_tasks/backends/celery/app.py @@ -4,7 +4,6 @@ from django_tasks.task import DEFAULT_QUEUE_NAME - # Set the default Django settings module for the 'celery' program. django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') if django_settings is None: @@ -19,3 +18,5 @@ app.config_from_object('django.conf:settings', namespace='CELERY') app.conf.task_default_queue = DEFAULT_QUEUE_NAME + +app.autodiscover_tasks() diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py index 075de82..8b99a62 100644 --- a/django_tasks/backends/celery/backend.py +++ b/django_tasks/backends/celery/backend.py @@ -1,12 +1,17 @@ from typing import TypeVar -from typing_extensions import ParamSpec - from celery import shared_task +from celery.app import default_app from celery.local import Proxy as CeleryTaskProxy +from typing_extensions import ParamSpec + from django_tasks.backends.base import BaseTaskBackend from django_tasks.task import Task, TaskResult +if not default_app: + from django_tasks.backends.celery.app import app as celery_app + celery_app.set_default() + T = TypeVar("T") P = ParamSpec("P") From eab648270ec952de09d9fdbed56e974c7f847e03 Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 8 Jul 2024 17:12:36 -0300 Subject: [PATCH 3/5] Lint issues, return task result --- django_tasks/backends/celery/app.py | 8 ++--- django_tasks/backends/celery/backend.py | 40 +++++++++++++++++++------ 2 files changed, 35 insertions(+), 13 deletions(-) diff --git a/django_tasks/backends/celery/app.py b/django_tasks/backends/celery/app.py index c219e81..e069c60 100644 --- a/django_tasks/backends/celery/app.py +++ b/django_tasks/backends/celery/app.py @@ -5,17 +5,17 @@ from django_tasks.task import DEFAULT_QUEUE_NAME # Set the default Django settings module for the 'celery' program. -django_settings = os.environ.get('DJANGO_SETTINGS_MODULE') +django_settings = os.environ.get("DJANGO_SETTINGS_MODULE") if django_settings is None: - raise ValueError('DJANGO_SETTINGS_MODULE environment variable is not set') + raise ValueError("DJANGO_SETTINGS_MODULE environment variable is not set") -app = Celery('django_tasks') +app = Celery("django_tasks") # Using a string here means the worker doesn't have to serialize # the configuration object to child processes. # - namespace='CELERY' means all celery-related configuration keys # should have a `CELERY_` prefix. -app.config_from_object('django.conf:settings', namespace='CELERY') +app.config_from_object("django.conf:settings", namespace="CELERY") app.conf.task_default_queue = DEFAULT_QUEUE_NAME diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py index 8b99a62..c99e519 100644 --- a/django_tasks/backends/celery/backend.py +++ b/django_tasks/backends/celery/backend.py @@ -1,15 +1,20 @@ +from dataclasses import dataclass from typing import TypeVar from celery import shared_task from celery.app import default_app from celery.local import Proxy as CeleryTaskProxy +from django.utils import timezone from typing_extensions import ParamSpec from django_tasks.backends.base import BaseTaskBackend -from django_tasks.task import Task, TaskResult +from django_tasks.task import ResultStatus, TaskResult +from django_tasks.task import Task as BaseTask +from django_tasks.utils import json_normalize if not default_app: from django_tasks.backends.celery.app import app as celery_app + celery_app.set_default() @@ -17,32 +22,49 @@ P = ParamSpec("P") -class CeleryTask(Task): - - celery_task: CeleryTaskProxy +@dataclass +class Task(BaseTask[P, T]): + celery_task: CeleryTaskProxy = None """Celery proxy to the task in the current celery app task registry.""" def __post_init__(self) -> None: - # TODO: allow passing extra celery specific parameters? celery_task = shared_task()(self.func) self.celery_task = celery_task return super().__post_init__() class CeleryBackend(BaseTaskBackend): - task_class = CeleryTask + task_class = Task supports_defer = True def enqueue( - self, task: Task[P, T], args: P.args, kwargs: P.kwargs + self, + task: Task[P, T], # type: ignore[override] + args: P.args, + kwargs: P.kwargs, ) -> TaskResult[T]: self.validate_task(task) - apply_async_kwargs = { + apply_async_kwargs: P.kwargs = { "eta": task.run_after, } if task.queue_name: apply_async_kwargs["queue"] = task.queue_name if task.priority: apply_async_kwargs["priority"] = task.priority - task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs) + + # TODO: a Celery result backend is required to get additional information + async_result = task.celery_task.apply_async( + args, kwargs=kwargs, **apply_async_kwargs + ) + task_result = TaskResult[T]( + task=task, + id=async_result.id, + status=ResultStatus.NEW, + enqueued_at=timezone.now(), + finished_at=None, + args=json_normalize(args), + kwargs=json_normalize(kwargs), + backend=self.alias, + ) + return task_result From d43fcd395303fffe3e9a64ed514a4674f6e0fa4f Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 22 Jul 2024 16:47:20 -0300 Subject: [PATCH 4/5] Update priority, add tests --- django_tasks/backends/celery/backend.py | 43 +++++++- tests/settings.py | 1 + tests/tests/test_celery_backend.py | 131 ++++++++++++++++++++++++ 3 files changed, 171 insertions(+), 4 deletions(-) create mode 100644 tests/tests/test_celery_backend.py diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py index c99e519..5272e39 100644 --- a/django_tasks/backends/celery/backend.py +++ b/django_tasks/backends/celery/backend.py @@ -1,14 +1,16 @@ from dataclasses import dataclass -from typing import TypeVar +from typing import Any, Iterable, TypeVar from celery import shared_task from celery.app import default_app from celery.local import Proxy as CeleryTaskProxy +from django.apps import apps +from django.core.checks import ERROR, CheckMessage from django.utils import timezone from typing_extensions import ParamSpec from django_tasks.backends.base import BaseTaskBackend -from django_tasks.task import ResultStatus, TaskResult +from django_tasks.task import MAX_PRIORITY, MIN_PRIORITY, ResultStatus, TaskResult from django_tasks.task import Task as BaseTask from django_tasks.utils import json_normalize @@ -22,6 +24,26 @@ P = ParamSpec("P") +CELERY_MIN_PRIORITY = 0 +CELERY_MAX_PRIORITY = 9 + + +def _map_priority(value: int) -> int: + # linear scale value to the range 0 to 9 + scaled_value = (value + abs(MIN_PRIORITY)) / ( + (MAX_PRIORITY - MIN_PRIORITY) / (CELERY_MAX_PRIORITY - CELERY_MIN_PRIORITY) + ) + mapped_value = int(scaled_value) + + # ensure the mapped value is within the range 0 to 9 + if mapped_value < CELERY_MIN_PRIORITY: + mapped_value = CELERY_MIN_PRIORITY + elif mapped_value > CELERY_MAX_PRIORITY: + mapped_value = CELERY_MAX_PRIORITY + + return mapped_value + + @dataclass class Task(BaseTask[P, T]): celery_task: CeleryTaskProxy = None @@ -50,8 +72,10 @@ def enqueue( } if task.queue_name: apply_async_kwargs["queue"] = task.queue_name - if task.priority: - apply_async_kwargs["priority"] = task.priority + if task.priority is not None: + # map priority to the range 0 to 9 + priority = _map_priority(task.priority) + apply_async_kwargs["priority"] = priority # TODO: a Celery result backend is required to get additional information async_result = task.celery_task.apply_async( @@ -62,9 +86,20 @@ def enqueue( id=async_result.id, status=ResultStatus.NEW, enqueued_at=timezone.now(), + started_at=None, finished_at=None, args=json_normalize(args), kwargs=json_normalize(kwargs), backend=self.alias, ) return task_result + + def check(self, **kwargs: Any) -> Iterable[CheckMessage]: + backend_name = self.__class__.__name__ + + if not apps.is_installed("django_tasks.backends.celery"): + yield CheckMessage( + ERROR, + f"{backend_name} configured as django_tasks backend, but celery app not installed", + "Insert 'django_tasks.backends.celery' in INSTALLED_APPS", + ) diff --git a/tests/settings.py b/tests/settings.py index 2a1167b..56e133b 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -18,6 +18,7 @@ "django.contrib.sessions", "django.contrib.staticfiles", "django_tasks", + "django_tasks.backends.celery", "django_tasks.backends.database", "tests", ] diff --git a/tests/tests/test_celery_backend.py b/tests/tests/test_celery_backend.py new file mode 100644 index 0000000..769fbfb --- /dev/null +++ b/tests/tests/test_celery_backend.py @@ -0,0 +1,131 @@ +from datetime import timedelta +from unittest.mock import patch + +from celery import Celery +from celery.result import AsyncResult +from django.test import TestCase, override_settings +from django.utils import timezone + +from django_tasks import ResultStatus, default_task_backend, task, tasks +from django_tasks.backends.celery import CeleryBackend +from django_tasks.backends.celery.backend import _map_priority +from django_tasks.task import DEFAULT_PRIORITY, DEFAULT_QUEUE_NAME + + +def noop_task(*args: tuple, **kwargs: dict) -> None: + return None + + +@override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"], + } + } +) +class CeleryBackendTestCase(TestCase): + def setUp(self) -> None: + # register task during setup so it is registered as a Celery task + self.task = task()(noop_task) + + def test_using_correct_backend(self) -> None: + self.assertEqual(default_task_backend, tasks["default"]) + self.assertIsInstance(tasks["default"], CeleryBackend) + + def test_check(self) -> None: + errors = list(default_task_backend.check()) + + self.assertEqual(len(errors), 0) + + @override_settings(INSTALLED_APPS=[]) + def test_celery_backend_app_missing(self) -> None: + errors = list(default_task_backend.check()) + + self.assertEqual(len(errors), 1) + self.assertIn("django_tasks.backends.celery", errors[0].hint) + + def test_enqueue_task(self) -> None: + task = self.task + assert task.celery_task # type: ignore[attr-defined] + + # import here so that it is not set as default before registering the task + from django_tasks.backends.celery.app import app as celery_app + + self.assertEqual(task.celery_task.app, celery_app) # type: ignore[attr-defined] + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id="123") + result = default_task_backend.enqueue(task, (1,), {"two": 3}) + + self.assertEqual(result.id, "123") + self.assertEqual(result.status, ResultStatus.NEW) + self.assertIsNone(result.started_at) + self.assertIsNone(result.finished_at) + with self.assertRaisesMessage(ValueError, "Task has not finished yet"): + result.result # noqa:B018 + self.assertEqual(result.task, task) + self.assertEqual(result.args, [1]) + self.assertEqual(result.kwargs, {"two": 3}) + expected_priority = _map_priority(DEFAULT_PRIORITY) + mock_apply_async.assert_called_once_with( + (1,), + kwargs={"two": 3}, + eta=None, + priority=expected_priority, + queue=DEFAULT_QUEUE_NAME, + ) + + def test_using_additional_params(self) -> None: + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id="123") + run_after = timezone.now() + timedelta(hours=10) + result = self.task.using( + run_after=run_after, priority=75, queue_name="queue-1" + ).enqueue() + + self.assertEqual(result.id, "123") + self.assertEqual(result.status, ResultStatus.NEW) + mock_apply_async.assert_called_once_with( + (), kwargs={}, eta=run_after, priority=7, queue="queue-1" + ) + + def test_priority_mapping(self) -> None: + for priority, expected in [(-100, 0), (-50, 2), (0, 4), (75, 7), (100, 9)]: + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id="123") + self.task.using(priority=priority).enqueue() + + mock_apply_async.assert_called_with( + (), kwargs={}, eta=None, priority=expected, queue=DEFAULT_QUEUE_NAME + ) + + +@override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"], + } + } +) +class CeleryBackendCustomAppTestCase(TestCase): + def setUp(self) -> None: + self.celery_app = Celery("test_app") + self.task = task()(noop_task) + + def tearDown(self) -> None: + # restore the default Celery app + from django_tasks.backends.celery.app import app as celery_app + + celery_app.set_current() + return super().tearDown() + + def test_enqueue_task(self) -> None: + task = self.task + assert task.celery_task # type: ignore[attr-defined] + + from django_tasks.backends.celery.app import app as celery_app + + self.assertNotEqual(celery_app, self.celery_app) + # it should use the custom Celery app + self.assertEqual(task.celery_task.app, self.celery_app) # type: ignore[attr-defined] From f5ab8019537f682c654e17a51af74bb9a25c8bce Mon Sep 17 00:00:00 2001 From: Matias Bordese Date: Mon, 2 Dec 2024 22:18:21 -0300 Subject: [PATCH 5/5] Support enqueue_on_commit option --- django_tasks/backends/celery/backend.py | 34 +++++-- tests/tests/test_celery_backend.py | 112 +++++++++++++++++++----- 2 files changed, 116 insertions(+), 30 deletions(-) diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py index 5272e39..52d9061 100644 --- a/django_tasks/backends/celery/backend.py +++ b/django_tasks/backends/celery/backend.py @@ -1,4 +1,4 @@ -from dataclasses import dataclass +from functools import partial from typing import Any, Iterable, TypeVar from celery import shared_task @@ -6,13 +6,14 @@ from celery.local import Proxy as CeleryTaskProxy from django.apps import apps from django.core.checks import ERROR, CheckMessage +from django.db import transaction from django.utils import timezone +from kombu.utils.uuid import uuid from typing_extensions import ParamSpec from django_tasks.backends.base import BaseTaskBackend from django_tasks.task import MAX_PRIORITY, MIN_PRIORITY, ResultStatus, TaskResult from django_tasks.task import Task as BaseTask -from django_tasks.utils import json_normalize if not default_app: from django_tasks.backends.celery.app import app as celery_app @@ -44,7 +45,6 @@ def _map_priority(value: int) -> int: return mapped_value -@dataclass class Task(BaseTask[P, T]): celery_task: CeleryTaskProxy = None """Celery proxy to the task in the current celery app task registry.""" @@ -77,19 +77,35 @@ def enqueue( priority = _map_priority(task.priority) apply_async_kwargs["priority"] = priority + task_id = uuid() + apply_async_kwargs["task_id"] = task_id + + if self._get_enqueue_on_commit_for_task(task): + transaction.on_commit( + partial( + task.celery_task.apply_async, + args, + kwargs=kwargs, + **apply_async_kwargs, + ) + ) + else: + task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs) + + # TODO: send task_enqueued signal + # TODO: link a task to trigger the task_finished signal? + # TODO: consider using DBTaskResult for results? + # TODO: a Celery result backend is required to get additional information - async_result = task.celery_task.apply_async( - args, kwargs=kwargs, **apply_async_kwargs - ) task_result = TaskResult[T]( task=task, - id=async_result.id, + id=task_id, status=ResultStatus.NEW, enqueued_at=timezone.now(), started_at=None, finished_at=None, - args=json_normalize(args), - kwargs=json_normalize(kwargs), + args=args, + kwargs=kwargs, backend=self.alias, ) return task_result diff --git a/tests/tests/test_celery_backend.py b/tests/tests/test_celery_backend.py index 769fbfb..cb93faa 100644 --- a/tests/tests/test_celery_backend.py +++ b/tests/tests/test_celery_backend.py @@ -3,7 +3,8 @@ from celery import Celery from celery.result import AsyncResult -from django.test import TestCase, override_settings +from django.db import transaction +from django.test import TestCase, TransactionTestCase, override_settings from django.utils import timezone from django_tasks import ResultStatus, default_task_backend, task, tasks @@ -16,6 +17,10 @@ def noop_task(*args: tuple, **kwargs: dict) -> None: return None +def enqueue_on_commit_task(*args: tuple, **kwargs: dict) -> None: + pass + + @override_settings( TASKS={ "default": { @@ -24,10 +29,13 @@ def noop_task(*args: tuple, **kwargs: dict) -> None: } } ) -class CeleryBackendTestCase(TestCase): +class CeleryBackendTestCase(TransactionTestCase): def setUp(self) -> None: # register task during setup so it is registered as a Celery task self.task = task()(noop_task) + self.enqueue_on_commit_task = task(enqueue_on_commit=True)( + enqueue_on_commit_task + ) def test_using_correct_backend(self) -> None: self.assertEqual(default_task_backend, tasks["default"]) @@ -43,7 +51,7 @@ def test_celery_backend_app_missing(self) -> None: errors = list(default_task_backend.check()) self.assertEqual(len(errors), 1) - self.assertIn("django_tasks.backends.celery", errors[0].hint) + self.assertIn("django_tasks.backends.celery", errors[0].hint) # type:ignore[arg-type] def test_enqueue_task(self) -> None: task = self.task @@ -53,52 +61,114 @@ def test_enqueue_task(self) -> None: from django_tasks.backends.celery.app import app as celery_app self.assertEqual(task.celery_task.app, celery_app) # type: ignore[attr-defined] - with patch("celery.app.task.Task.apply_async") as mock_apply_async: - mock_apply_async.return_value = AsyncResult(id="123") - result = default_task_backend.enqueue(task, (1,), {"two": 3}) + task_id = "123" + with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + result = default_task_backend.enqueue(task, (1,), {"two": 3}) - self.assertEqual(result.id, "123") + self.assertEqual(result.id, task_id) self.assertEqual(result.status, ResultStatus.NEW) self.assertIsNone(result.started_at) self.assertIsNone(result.finished_at) with self.assertRaisesMessage(ValueError, "Task has not finished yet"): - result.result # noqa:B018 + result.return_value # noqa:B018 self.assertEqual(result.task, task) - self.assertEqual(result.args, [1]) + self.assertEqual(result.args, (1,)) self.assertEqual(result.kwargs, {"two": 3}) expected_priority = _map_priority(DEFAULT_PRIORITY) mock_apply_async.assert_called_once_with( (1,), kwargs={"two": 3}, + task_id=task_id, eta=None, priority=expected_priority, queue=DEFAULT_QUEUE_NAME, ) def test_using_additional_params(self) -> None: - with patch("celery.app.task.Task.apply_async") as mock_apply_async: - mock_apply_async.return_value = AsyncResult(id="123") - run_after = timezone.now() + timedelta(hours=10) - result = self.task.using( - run_after=run_after, priority=75, queue_name="queue-1" - ).enqueue() + task_id = "123" + with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + run_after = timezone.now() + timedelta(hours=10) + result = self.task.using( + run_after=run_after, priority=75, queue_name="queue-1" + ).enqueue() - self.assertEqual(result.id, "123") + self.assertEqual(result.id, task_id) self.assertEqual(result.status, ResultStatus.NEW) mock_apply_async.assert_called_once_with( - (), kwargs={}, eta=run_after, priority=7, queue="queue-1" + [], kwargs={}, task_id=task_id, eta=run_after, priority=7, queue="queue-1" ) def test_priority_mapping(self) -> None: for priority, expected in [(-100, 0), (-50, 2), (0, 4), (75, 7), (100, 9)]: - with patch("celery.app.task.Task.apply_async") as mock_apply_async: - mock_apply_async.return_value = AsyncResult(id="123") - self.task.using(priority=priority).enqueue() + task_id = "123" + with patch( + "django_tasks.backends.celery.backend.uuid", return_value=task_id + ): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + self.task.using(priority=priority).enqueue() mock_apply_async.assert_called_with( - (), kwargs={}, eta=None, priority=expected, queue=DEFAULT_QUEUE_NAME + [], + kwargs={}, + task_id=task_id, + eta=None, + priority=expected, + queue=DEFAULT_QUEUE_NAME, ) + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "ENQUEUE_ON_COMMIT": True, + } + } + ) + def test_wait_until_transaction_commit(self) -> None: + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task)) + + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id="task_id") + with transaction.atomic(): + self.task.enqueue() + assert not mock_apply_async.called + + mock_apply_async.assert_called_once() + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + } + } + ) + def test_wait_until_transaction_by_default(self) -> None: + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task)) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_task_specific_enqueue_on_commit(self) -> None: + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertTrue(self.enqueue_on_commit_task.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task( + self.enqueue_on_commit_task + ) + ) + @override_settings( TASKS={