Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions django_tasks/backends/celery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .backend import CeleryBackend

__all__ = ["CeleryBackend"]
22 changes: 22 additions & 0 deletions django_tasks/backends/celery/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from celery import Celery

from django_tasks.task import DEFAULT_QUEUE_NAME

# Set the default Django settings module for the 'celery' program.
django_settings = os.environ.get("DJANGO_SETTINGS_MODULE")
if django_settings is None:
raise ValueError("DJANGO_SETTINGS_MODULE environment variable is not set")

app = Celery("django_tasks")

# Using a string here means the worker doesn't have to serialize
# the configuration object to child processes.
# - namespace='CELERY' means all celery-related configuration keys
# should have a `CELERY_` prefix.
app.config_from_object("django.conf:settings", namespace="CELERY")

app.conf.task_default_queue = DEFAULT_QUEUE_NAME

app.autodiscover_tasks()
121 changes: 121 additions & 0 deletions django_tasks/backends/celery/backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
from functools import partial
from typing import Any, Iterable, TypeVar

from celery import shared_task
from celery.app import default_app
from celery.local import Proxy as CeleryTaskProxy
from django.apps import apps
from django.core.checks import ERROR, CheckMessage
from django.db import transaction
from django.utils import timezone
from kombu.utils.uuid import uuid
from typing_extensions import ParamSpec

from django_tasks.backends.base import BaseTaskBackend
from django_tasks.task import MAX_PRIORITY, MIN_PRIORITY, ResultStatus, TaskResult
from django_tasks.task import Task as BaseTask

if not default_app:
from django_tasks.backends.celery.app import app as celery_app

celery_app.set_default()


T = TypeVar("T")
P = ParamSpec("P")


CELERY_MIN_PRIORITY = 0
CELERY_MAX_PRIORITY = 9


def _map_priority(value: int) -> int:
# linear scale value to the range 0 to 9
scaled_value = (value + abs(MIN_PRIORITY)) / (
(MAX_PRIORITY - MIN_PRIORITY) / (CELERY_MAX_PRIORITY - CELERY_MIN_PRIORITY)
)
mapped_value = int(scaled_value)

# ensure the mapped value is within the range 0 to 9
if mapped_value < CELERY_MIN_PRIORITY:
mapped_value = CELERY_MIN_PRIORITY
elif mapped_value > CELERY_MAX_PRIORITY:
mapped_value = CELERY_MAX_PRIORITY

return mapped_value


class Task(BaseTask[P, T]):
celery_task: CeleryTaskProxy = None
"""Celery proxy to the task in the current celery app task registry."""

def __post_init__(self) -> None:
celery_task = shared_task()(self.func)
self.celery_task = celery_task
return super().__post_init__()


class CeleryBackend(BaseTaskBackend):
task_class = Task
supports_defer = True

def enqueue(
self,
task: Task[P, T], # type: ignore[override]
args: P.args,
kwargs: P.kwargs,
) -> TaskResult[T]:
self.validate_task(task)

apply_async_kwargs: P.kwargs = {
"eta": task.run_after,
}
if task.queue_name:
apply_async_kwargs["queue"] = task.queue_name
if task.priority is not None:
# map priority to the range 0 to 9
priority = _map_priority(task.priority)
apply_async_kwargs["priority"] = priority

task_id = uuid()
apply_async_kwargs["task_id"] = task_id

if self._get_enqueue_on_commit_for_task(task):
transaction.on_commit(
partial(
task.celery_task.apply_async,
args,
kwargs=kwargs,
**apply_async_kwargs,
)
)
else:
task.celery_task.apply_async(args, kwargs=kwargs, **apply_async_kwargs)

# TODO: send task_enqueued signal
# TODO: link a task to trigger the task_finished signal?
# TODO: consider using DBTaskResult for results?

# TODO: a Celery result backend is required to get additional information
task_result = TaskResult[T](
task=task,
id=task_id,
status=ResultStatus.NEW,
enqueued_at=timezone.now(),
started_at=None,
finished_at=None,
args=args,
kwargs=kwargs,
backend=self.alias,
)
return task_result

def check(self, **kwargs: Any) -> Iterable[CheckMessage]:
backend_name = self.__class__.__name__

if not apps.is_installed("django_tasks.backends.celery"):
yield CheckMessage(
ERROR,
f"{backend_name} configured as django_tasks backend, but celery app not installed",
"Insert 'django_tasks.backends.celery' in INSTALLED_APPS",
)
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ dev = [
"coverage",
"django-stubs[compatible-mypy]",
"dj-database-url",
"celery",
]
mysql = [
"mysqlclient"
Expand Down
1 change: 1 addition & 0 deletions tests/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"django.contrib.sessions",
"django.contrib.staticfiles",
"django_tasks",
"django_tasks.backends.celery",
"django_tasks.backends.database",
"tests",
]
Expand Down
201 changes: 201 additions & 0 deletions tests/tests/test_celery_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
from datetime import timedelta
from unittest.mock import patch

from celery import Celery
from celery.result import AsyncResult
from django.db import transaction
from django.test import TestCase, TransactionTestCase, override_settings
from django.utils import timezone

from django_tasks import ResultStatus, default_task_backend, task, tasks
from django_tasks.backends.celery import CeleryBackend
from django_tasks.backends.celery.backend import _map_priority
from django_tasks.task import DEFAULT_PRIORITY, DEFAULT_QUEUE_NAME


def noop_task(*args: tuple, **kwargs: dict) -> None:
return None


def enqueue_on_commit_task(*args: tuple, **kwargs: dict) -> None:
pass


@override_settings(
TASKS={
"default": {
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
"QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"],
}
}
)
class CeleryBackendTestCase(TransactionTestCase):
def setUp(self) -> None:
# register task during setup so it is registered as a Celery task
self.task = task()(noop_task)
self.enqueue_on_commit_task = task(enqueue_on_commit=True)(
enqueue_on_commit_task
)

def test_using_correct_backend(self) -> None:
self.assertEqual(default_task_backend, tasks["default"])
self.assertIsInstance(tasks["default"], CeleryBackend)

def test_check(self) -> None:
errors = list(default_task_backend.check())

self.assertEqual(len(errors), 0)

@override_settings(INSTALLED_APPS=[])
def test_celery_backend_app_missing(self) -> None:
errors = list(default_task_backend.check())

self.assertEqual(len(errors), 1)
self.assertIn("django_tasks.backends.celery", errors[0].hint) # type:ignore[arg-type]

def test_enqueue_task(self) -> None:
task = self.task
assert task.celery_task # type: ignore[attr-defined]

# import here so that it is not set as default before registering the task
from django_tasks.backends.celery.app import app as celery_app

self.assertEqual(task.celery_task.app, celery_app) # type: ignore[attr-defined]
task_id = "123"
with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id):
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
mock_apply_async.return_value = AsyncResult(id=task_id)
result = default_task_backend.enqueue(task, (1,), {"two": 3})

self.assertEqual(result.id, task_id)
self.assertEqual(result.status, ResultStatus.NEW)
self.assertIsNone(result.started_at)
self.assertIsNone(result.finished_at)
with self.assertRaisesMessage(ValueError, "Task has not finished yet"):
result.return_value # noqa:B018
self.assertEqual(result.task, task)
self.assertEqual(result.args, (1,))
self.assertEqual(result.kwargs, {"two": 3})
expected_priority = _map_priority(DEFAULT_PRIORITY)
mock_apply_async.assert_called_once_with(
(1,),
kwargs={"two": 3},
task_id=task_id,
eta=None,
priority=expected_priority,
queue=DEFAULT_QUEUE_NAME,
)

def test_using_additional_params(self) -> None:
task_id = "123"
with patch("django_tasks.backends.celery.backend.uuid", return_value=task_id):
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
mock_apply_async.return_value = AsyncResult(id=task_id)
run_after = timezone.now() + timedelta(hours=10)
result = self.task.using(
run_after=run_after, priority=75, queue_name="queue-1"
).enqueue()

self.assertEqual(result.id, task_id)
self.assertEqual(result.status, ResultStatus.NEW)
mock_apply_async.assert_called_once_with(
[], kwargs={}, task_id=task_id, eta=run_after, priority=7, queue="queue-1"
)

def test_priority_mapping(self) -> None:
for priority, expected in [(-100, 0), (-50, 2), (0, 4), (75, 7), (100, 9)]:
task_id = "123"
with patch(
"django_tasks.backends.celery.backend.uuid", return_value=task_id
):
with patch("celery.app.task.Task.apply_async") as mock_apply_async:
mock_apply_async.return_value = AsyncResult(id=task_id)
self.task.using(priority=priority).enqueue()

mock_apply_async.assert_called_with(
[],
kwargs={},
task_id=task_id,
eta=None,
priority=expected,
queue=DEFAULT_QUEUE_NAME,
)

@override_settings(
TASKS={
"default": {
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
"ENQUEUE_ON_COMMIT": True,
}
}
)
def test_wait_until_transaction_commit(self) -> None:
self.assertTrue(default_task_backend.enqueue_on_commit)
self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task))

with patch("celery.app.task.Task.apply_async") as mock_apply_async:
mock_apply_async.return_value = AsyncResult(id="task_id")
with transaction.atomic():
self.task.enqueue()
assert not mock_apply_async.called

mock_apply_async.assert_called_once()

@override_settings(
TASKS={
"default": {
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
}
}
)
def test_wait_until_transaction_by_default(self) -> None:
self.assertTrue(default_task_backend.enqueue_on_commit)
self.assertTrue(default_task_backend._get_enqueue_on_commit_for_task(self.task))

@override_settings(
TASKS={
"default": {
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
"ENQUEUE_ON_COMMIT": False,
}
}
)
def test_task_specific_enqueue_on_commit(self) -> None:
self.assertFalse(default_task_backend.enqueue_on_commit)
self.assertTrue(self.enqueue_on_commit_task.enqueue_on_commit)
self.assertTrue(
default_task_backend._get_enqueue_on_commit_for_task(
self.enqueue_on_commit_task
)
)


@override_settings(
TASKS={
"default": {
"BACKEND": "django_tasks.backends.celery.CeleryBackend",
"QUEUES": [DEFAULT_QUEUE_NAME, "queue-1"],
}
}
)
class CeleryBackendCustomAppTestCase(TestCase):
def setUp(self) -> None:
self.celery_app = Celery("test_app")
self.task = task()(noop_task)

def tearDown(self) -> None:
# restore the default Celery app
from django_tasks.backends.celery.app import app as celery_app

celery_app.set_current()
return super().tearDown()

def test_enqueue_task(self) -> None:
task = self.task
assert task.celery_task # type: ignore[attr-defined]

from django_tasks.backends.celery.app import app as celery_app

self.assertNotEqual(celery_app, self.celery_app)
# it should use the custom Celery app
self.assertEqual(task.celery_task.app, self.celery_app) # type: ignore[attr-defined]