diff --git a/django_tasks/backends/celery/__init__.py b/django_tasks/backends/celery/__init__.py new file mode 100644 index 0000000..9fbaf65 --- /dev/null +++ b/django_tasks/backends/celery/__init__.py @@ -0,0 +1,3 @@ +from .backend import CeleryBackend + +__all__ = ["CeleryBackend"] diff --git a/django_tasks/backends/celery/app.py b/django_tasks/backends/celery/app.py new file mode 100644 index 0000000..e069c60 --- /dev/null +++ b/django_tasks/backends/celery/app.py @@ -0,0 +1,22 @@ +import os + +from celery import Celery + +from django_tasks.task import DEFAULT_QUEUE_NAME + +# Set the default Django settings module for the 'celery' program. +django_settings = os.environ.get("DJANGO_SETTINGS_MODULE") +if django_settings is None: + raise ValueError("DJANGO_SETTINGS_MODULE environment variable is not set") + +app = Celery("django_tasks") + +# Using a string here means the worker doesn't have to serialize +# the configuration object to child processes. +# - namespace='CELERY' means all celery-related configuration keys +# should have a `CELERY_` prefix. +app.config_from_object("django.conf:settings", namespace="CELERY") + +app.conf.task_default_queue = DEFAULT_QUEUE_NAME + +app.autodiscover_tasks() diff --git a/django_tasks/backends/celery/backend.py b/django_tasks/backends/celery/backend.py new file mode 100644 index 0000000..52d9061 --- /dev/null +++ b/django_tasks/backends/celery/backend.py @@ -0,0 +1,121 @@ +from functools import partial +from typing import Any, Iterable, TypeVar + +from celery import shared_task +from celery.app import default_app +from celery.local import Proxy as CeleryTaskProxy +from django.apps import apps +from django.core.checks import ERROR, CheckMessage +from django.db import transaction +from django.utils import timezone +from kombu.utils.uuid import uuid +from typing_extensions import ParamSpec + +from django_tasks.backends.base import BaseTaskBackend +from django_tasks.task import MAX_PRIORITY, MIN_PRIORITY, ResultStatus, TaskResult +from django_tasks.task import Task as BaseTask + +if not default_app: + from django_tasks.backends.celery.app import app as celery_app + + celery_app.set_default() + + +T = TypeVar("T") +P = ParamSpec("P") + + +CELERY_MIN_PRIORITY = 0 +CELERY_MAX_PRIORITY = 9 + + +def _map_priority(value: int) -> int: + # linear scale value to the range 0 to 9 + scaled_value = (value + abs(MIN_PRIORITY)) / ( + (MAX_PRIORITY - MIN_PRIORITY) / (CELERY_MAX_PRIORITY - CELERY_MIN_PRIORITY) + ) + mapped_value = int(scaled_value) + + # ensure the mapped value is within the range 0 to 9 + if mapped_value < CELERY_MIN_PRIORITY: + mapped_value = CELERY_MIN_PRIORITY + elif mapped_value > CELERY_MAX_PRIORITY: + mapped_value = CELERY_MAX_PRIORITY + + return mapped_value + + +class Task(BaseTask[P, T]): + celery_task: CeleryTaskProxy = None + """Celery proxy to the task in the current celery app task registry.""" + + def __post_init__(self) -> None: + celery_task = shared_task()(self.func) + self.celery_task = celery_task + return super().__post_init__() + + +class CeleryBackend(BaseTaskBackend): + task_class = Task + supports_defer = True + + def enqueue( + self, + task: Task[P, T], # type: ignore[override] + args: P.args, + kwargs: P.kwargs, + ) -> TaskResult[T]: + self.validate_task(task) + + apply_async_kwargs: P.kwargs = { + "eta": task.run_after, + } + if task.queue_name: + apply_async_kwargs["queue"] = task.queue_name + if task.priority is not None: + # map priority to the range 0 to 9 + priority = _map_priority(task.priority) + apply_async_kwargs["priority"] = priority + + task_id = uuid() + apply_async_kwargs["task_id"] = task_id + + if self._get_enqueue_on_commit_for_task(task): + transaction.on_commit( + partial( + task.celery_task.apply_async, + args, + kwargs=kwargs, + **apply_async_kwargs, + ) + ) + else: + task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs) + + # TODO: send task_enqueued signal + # TODO: link a task to trigger the task_finished signal? + # TODO: consider using DBTaskResult for results? + + # TODO: a Celery result backend is required to get additional information + task_result = TaskResult[T]( + task=task, + id=task_id, + status=ResultStatus.NEW, + enqueued_at=timezone.now(), + started_at=None, + finished_at=None, + args=args, + kwargs=kwargs, + backend=self.alias, + ) + return task_result + + def check(self, **kwargs: Any) -> Iterable[CheckMessage]: + backend_name = self.__class__.__name__ + + if not apps.is_installed("django_tasks.backends.celery"): + yield CheckMessage( + ERROR, + f"{backend_name} configured as django_tasks backend, but celery app not installed", + "Insert 'django_tasks.backends.celery' in INSTALLED_APPS", + ) diff --git a/pyproject.toml b/pyproject.toml index e445b5a..1ac1202 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,7 @@ dev = [ "coverage", "django-stubs[compatible-mypy]", "dj-database-url", + "celery", ] mysql = [ "mysqlclient" diff --git a/tests/settings.py b/tests/settings.py index 2a1167b..56e133b 100644 --- a/tests/settings.py +++ b/tests/settings.py @@ -18,6 +18,7 @@ "django.contrib.sessions", "django.contrib.staticfiles", "django_tasks", + "django_tasks.backends.celery", "django_tasks.backends.database", "tests", ] diff --git a/tests/tests/test_celery_backend.py b/tests/tests/test_celery_backend.py new file mode 100644 index 0000000..cb93faa --- /dev/null +++ b/tests/tests/test_celery_backend.py @@ -0,0 +1,201 @@ +from datetime import timedelta +from unittest.mock import patch + +from celery import Celery +from celery.result import AsyncResult +from django.db import transaction +from django.test import TestCase, TransactionTestCase, override_settings +from django.utils import timezone + +from django_tasks import ResultStatus, default_task_backend, task, tasks +from django_tasks.backends.celery import CeleryBackend +from django_tasks.backends.celery.backend import _map_priority +from django_tasks.task import DEFAULT_PRIORITY, DEFAULT_QUEUE_NAME + + +def noop_task(*args: tuple, **kwargs: dict) -> None: + return None + + +def enqueue_on_commit_task(*args: tuple, **kwargs: dict) -> None: + pass + + +@override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"], + } + } +) +class CeleryBackendTestCase(TransactionTestCase): + def setUp(self) -> None: + # register task during setup so it is registered as a Celery task + self.task = task()(noop_task) + self.enqueue_on_commit_task = task(enqueue_on_commit=True)( + enqueue_on_commit_task + ) + + def test_using_correct_backend(self) -> None: + self.assertEqual(default_task_backend, tasks["default"]) + self.assertIsInstance(tasks["default"], CeleryBackend) + + def test_check(self) -> None: + errors = list(default_task_backend.check()) + + self.assertEqual(len(errors), 0) + + @override_settings(INSTALLED_APPS=[]) + def test_celery_backend_app_missing(self) -> None: + errors = list(default_task_backend.check()) + + self.assertEqual(len(errors), 1) + self.assertIn("django_tasks.backends.celery", errors[0].hint) # type:ignore[arg-type] + + def test_enqueue_task(self) -> None: + task = self.task + assert task.celery_task # type: ignore[attr-defined] + + # import here so that it is not set as default before registering the task + from django_tasks.backends.celery.app import app as celery_app + + self.assertEqual(task.celery_task.app, celery_app) # type: ignore[attr-defined] + task_id = "123" + with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + result = default_task_backend.enqueue(task, (1,), {"two": 3}) + + self.assertEqual(result.id, task_id) + self.assertEqual(result.status, ResultStatus.NEW) + self.assertIsNone(result.started_at) + self.assertIsNone(result.finished_at) + with self.assertRaisesMessage(ValueError, "Task has not finished yet"): + result.return_value # noqa:B018 + self.assertEqual(result.task, task) + self.assertEqual(result.args, (1,)) + self.assertEqual(result.kwargs, {"two": 3}) + expected_priority = _map_priority(DEFAULT_PRIORITY) + mock_apply_async.assert_called_once_with( + (1,), + kwargs={"two": 3}, + task_id=task_id, + eta=None, + priority=expected_priority, + queue=DEFAULT_QUEUE_NAME, + ) + + def test_using_additional_params(self) -> None: + task_id = "123" + with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + run_after = timezone.now() + timedelta(hours=10) + result = self.task.using( + run_after=run_after, priority=75, queue_name="queue-1" + ).enqueue() + + self.assertEqual(result.id, task_id) + self.assertEqual(result.status, ResultStatus.NEW) + mock_apply_async.assert_called_once_with( + [], kwargs={}, task_id=task_id, eta=run_after, priority=7, queue="queue-1" + ) + + def test_priority_mapping(self) -> None: + for priority, expected in [(-100, 0), (-50, 2), (0, 4), (75, 7), (100, 9)]: + task_id = "123" + with patch( + "django_tasks.backends.celery.backend.uuid", return_value=task_id + ): + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id=task_id) + self.task.using(priority=priority).enqueue() + + mock_apply_async.assert_called_with( + [], + kwargs={}, + task_id=task_id, + eta=None, + priority=expected, + queue=DEFAULT_QUEUE_NAME, + ) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "ENQUEUE_ON_COMMIT": True, + } + } + ) + def test_wait_until_transaction_commit(self) -> None: + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task)) + + with patch("celery.app.task.Task.apply_async") as mock_apply_async: + mock_apply_async.return_value = AsyncResult(id="task_id") + with transaction.atomic(): + self.task.enqueue() + assert not mock_apply_async.called + + mock_apply_async.assert_called_once() + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + } + } + ) + def test_wait_until_transaction_by_default(self) -> None: + self.assertTrue(default_task_backend.enqueue_on_commit) + self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task)) + + @override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "ENQUEUE_ON_COMMIT": False, + } + } + ) + def test_task_specific_enqueue_on_commit(self) -> None: + self.assertFalse(default_task_backend.enqueue_on_commit) + self.assertTrue(self.enqueue_on_commit_task.enqueue_on_commit) + self.assertTrue( + default_task_backend._get_enqueue_on_commit_for_task( + self.enqueue_on_commit_task + ) + ) + + +@override_settings( + TASKS={ + "default": { + "BACKEND": "django_tasks.backends.celery.CeleryBackend", + "QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"], + } + } +) +class CeleryBackendCustomAppTestCase(TestCase): + def setUp(self) -> None: + self.celery_app = Celery("test_app") + self.task = task()(noop_task) + + def tearDown(self) -> None: + # restore the default Celery app + from django_tasks.backends.celery.app import app as celery_app + + celery_app.set_current() + return super().tearDown() + + def test_enqueue_task(self) -> None: + task = self.task + assert task.celery_task # type: ignore[attr-defined] + + from django_tasks.backends.celery.app import app as celery_app + + self.assertNotEqual(celery_app, self.celery_app) + # it should use the custom Celery app + self.assertEqual(task.celery_task.app, self.celery_app) # type: ignore[attr-defined]