33
44from celery import Celery
55from celery .result import AsyncResult
6- from django .test import TestCase , override_settings
6+ from django .db import transaction
7+ from django .test import TestCase , TransactionTestCase , override_settings
78from django .utils import timezone
89
910from django_tasks import ResultStatus , default_task_backend , task , tasks
@@ -16,6 +17,10 @@ def noop_task(*args: tuple, **kwargs: dict) -> None:
1617 return None
1718
1819
20+ def enqueue_on_commit_task (* args : tuple , ** kwargs : dict ) -> None :
21+ pass
22+
23+
1924@override_settings (
2025 TASKS = {
2126 "default" : {
@@ -24,10 +29,13 @@ def noop_task(*args: tuple, **kwargs: dict) -> None:
2429 }
2530 }
2631)
27- class CeleryBackendTestCase (TestCase ):
32+ class CeleryBackendTestCase (TransactionTestCase ):
2833 def setUp (self ) -> None :
2934 # register task during setup so it is registered as a Celery task
3035 self .task = task ()(noop_task )
36+ self .enqueue_on_commit_task = task (enqueue_on_commit = True )(
37+ enqueue_on_commit_task
38+ )
3139
3240 def test_using_correct_backend (self ) -> None :
3341 self .assertEqual (default_task_backend , tasks ["default" ])
@@ -43,7 +51,7 @@ def test_celery_backend_app_missing(self) -> None:
4351 errors = list (default_task_backend .check ())
4452
4553 self .assertEqual (len (errors ), 1 )
46- self .assertIn ("django_tasks.backends.celery" , errors [0 ].hint )
54+ self .assertIn ("django_tasks.backends.celery" , errors [0 ].hint ) # type:ignore[arg-type]
4755
4856 def test_enqueue_task (self ) -> None :
4957 task = self .task
@@ -53,52 +61,114 @@ def test_enqueue_task(self) -> None:
5361 from django_tasks .backends .celery .app import app as celery_app
5462
5563 self .assertEqual (task .celery_task .app , celery_app ) # type: ignore[attr-defined]
56- with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
57- mock_apply_async .return_value = AsyncResult (id = "123" )
58- result = default_task_backend .enqueue (task , (1 ,), {"two" : 3 })
64+ task_id = "123"
65+ with patch ("django_tasks.backends.celery.backend.uuid" , return_value = task_id ):
66+ with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
67+ mock_apply_async .return_value = AsyncResult (id = task_id )
68+ result = default_task_backend .enqueue (task , (1 ,), {"two" : 3 })
5969
60- self .assertEqual (result .id , "123" )
70+ self .assertEqual (result .id , task_id )
6171 self .assertEqual (result .status , ResultStatus .NEW )
6272 self .assertIsNone (result .started_at )
6373 self .assertIsNone (result .finished_at )
6474 with self .assertRaisesMessage (ValueError , "Task has not finished yet" ):
65- result .result # noqa:B018
75+ result .return_value # noqa:B018
6676 self .assertEqual (result .task , task )
67- self .assertEqual (result .args , [ 1 ] )
77+ self .assertEqual (result .args , ( 1 ,) )
6878 self .assertEqual (result .kwargs , {"two" : 3 })
6979 expected_priority = _map_priority (DEFAULT_PRIORITY )
7080 mock_apply_async .assert_called_once_with (
7181 (1 ,),
7282 kwargs = {"two" : 3 },
83+ task_id = task_id ,
7384 eta = None ,
7485 priority = expected_priority ,
7586 queue = DEFAULT_QUEUE_NAME ,
7687 )
7788
7889 def test_using_additional_params (self ) -> None :
79- with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
80- mock_apply_async .return_value = AsyncResult (id = "123" )
81- run_after = timezone .now () + timedelta (hours = 10 )
82- result = self .task .using (
83- run_after = run_after , priority = 75 , queue_name = "queue-1"
84- ).enqueue ()
90+ task_id = "123"
91+ with patch ("django_tasks.backends.celery.backend.uuid" , return_value = task_id ):
92+ with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
93+ mock_apply_async .return_value = AsyncResult (id = task_id )
94+ run_after = timezone .now () + timedelta (hours = 10 )
95+ result = self .task .using (
96+ run_after = run_after , priority = 75 , queue_name = "queue-1"
97+ ).enqueue ()
8598
86- self .assertEqual (result .id , "123" )
99+ self .assertEqual (result .id , task_id )
87100 self .assertEqual (result .status , ResultStatus .NEW )
88101 mock_apply_async .assert_called_once_with (
89- () , kwargs = {}, eta = run_after , priority = 7 , queue = "queue-1"
102+ [] , kwargs = {}, task_id = task_id , eta = run_after , priority = 7 , queue = "queue-1"
90103 )
91104
92105 def test_priority_mapping (self ) -> None :
93106 for priority , expected in [(- 100 , 0 ), (- 50 , 2 ), (0 , 4 ), (75 , 7 ), (100 , 9 )]:
94- with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
95- mock_apply_async .return_value = AsyncResult (id = "123" )
96- self .task .using (priority = priority ).enqueue ()
107+ task_id = "123"
108+ with patch (
109+ "django_tasks.backends.celery.backend.uuid" , return_value = task_id
110+ ):
111+ with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
112+ mock_apply_async .return_value = AsyncResult (id = task_id )
113+ self .task .using (priority = priority ).enqueue ()
97114
98115 mock_apply_async .assert_called_with (
99- (), kwargs = {}, eta = None , priority = expected , queue = DEFAULT_QUEUE_NAME
116+ [],
117+ kwargs = {},
118+ task_id = task_id ,
119+ eta = None ,
120+ priority = expected ,
121+ queue = DEFAULT_QUEUE_NAME ,
100122 )
101123
124+ @override_settings (
125+ TASKS = {
126+ "default" : {
127+ "BACKEND" : "django_tasks.backends.celery.CeleryBackend" ,
128+ "ENQUEUE_ON_COMMIT" : True ,
129+ }
130+ }
131+ )
132+ def test_wait_until_transaction_commit (self ) -> None :
133+ self .assertTrue (default_task_backend .enqueue_on_commit )
134+ self .assertTrue (default_task_backend ._get_enqueue_on_commit_for_task (self .task ))
135+
136+ with patch ("celery.app.task.Task.apply_async" ) as mock_apply_async :
137+ mock_apply_async .return_value = AsyncResult (id = "task_id" )
138+ with transaction .atomic ():
139+ self .task .enqueue ()
140+ assert not mock_apply_async .called
141+
142+ mock_apply_async .assert_called_once ()
143+
144+ @override_settings (
145+ TASKS = {
146+ "default" : {
147+ "BACKEND" : "django_tasks.backends.celery.CeleryBackend" ,
148+ }
149+ }
150+ )
151+ def test_wait_until_transaction_by_default (self ) -> None :
152+ self .assertTrue (default_task_backend .enqueue_on_commit )
153+ self .assertTrue (default_task_backend ._get_enqueue_on_commit_for_task (self .task ))
154+
155+ @override_settings (
156+ TASKS = {
157+ "default" : {
158+ "BACKEND" : "django_tasks.backends.celery.CeleryBackend" ,
159+ "ENQUEUE_ON_COMMIT" : False ,
160+ }
161+ }
162+ )
163+ def test_task_specific_enqueue_on_commit (self ) -> None :
164+ self .assertFalse (default_task_backend .enqueue_on_commit )
165+ self .assertTrue (self .enqueue_on_commit_task .enqueue_on_commit )
166+ self .assertTrue (
167+ default_task_backend ._get_enqueue_on_commit_for_task (
168+ self .enqueue_on_commit_task
169+ )
170+ )
171+
102172
103173@override_settings (
104174 TASKS = {
0 commit comments