From a5159bef18c223242f0e917382ab57216420bd2c Mon Sep 17 00:00:00 2001 From: Florian Zimmermann Date: Thu, 28 Nov 2019 17:10:25 +0100 Subject: [PATCH 1/2] attempt to solve the concurrency issues with get_or_set this should fix issue #149 when there is no stale value to return, one thread tries to compute func()'s result while all other threads wait for it to complete. --- redis_cache/backends/base.py | 56 ++++++++++++--------- tests/testapp/tests/base_tests.py | 84 +++++++++++++++++++++++++------ 2 files changed, 103 insertions(+), 37 deletions(-) diff --git a/redis_cache/backends/base.py b/redis_cache/backends/base.py index d473a635..6adb3a2d 100644 --- a/redis_cache/backends/base.py +++ b/redis_cache/backends/base.py @@ -1,4 +1,6 @@ from functools import wraps +import time +import uuid from django.core.cache.backends.base import ( BaseCache, DEFAULT_TIMEOUT, InvalidCacheBackendError, @@ -435,34 +437,42 @@ def get_or_set( lock_key = "__lock__" + key fresh_key = "__fresh__" + key - is_fresh = self._get(client, fresh_key) value = self._get(client, key) - - if is_fresh: + is_fresh = self._get(client, fresh_key) + if value is not None and is_fresh: return value - timeout = self.get_timeout(timeout) - lock = self.lock(lock_key, timeout=lock_timeout) - - acquired = lock.acquire(blocking=False) + fresh_timeout = self.get_timeout(timeout) + key_timeout = None if stale_cache_timeout is None else fresh_timeout + stale_cache_timeout - if acquired: - try: - value = func() - except Exception: - raise + token = uuid.uuid1().hex + lock = self.lock(lock_key, timeout=lock_timeout) + acquired = lock.acquire(blocking=False, token=token) + + while True: + if acquired: + try: + value = func() + except Exception: + raise + else: + pipeline = client.pipeline() + pipeline.set(key, self.prep_value(value), key_timeout) + pipeline.set(fresh_key, 1, fresh_timeout) + pipeline.execute() + return value + finally: + lock.release() + elif value is None: + time.sleep(lock.sleep) + value = self._get(client, key) + if value is None: + # If there is no value present yet, try to acquire the + # lock again (maybe the other thread died for some reason + # and we should try to compute the value instead). + acquired = lock.acquire(blocking=False, token=token) else: - key_timeout = ( - None if stale_cache_timeout is None else timeout + stale_cache_timeout - ) - pipeline = client.pipeline() - pipeline.set(key, self.prep_value(value), key_timeout) - pipeline.set(fresh_key, 1, timeout) - pipeline.execute() - finally: - lock.release() - - return value + return value def _reinsert_keys(self, client): keys = list(client.scan_iter(match='*')) diff --git a/tests/testapp/tests/base_tests.py b/tests/testapp/tests/base_tests.py index 719ec779..09b635fa 100644 --- a/tests/testapp/tests/base_tests.py +++ b/tests/testapp/tests/base_tests.py @@ -530,45 +530,101 @@ def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_ti lock_timeout, stale_cache_timeout ) - results[thread_id] = value - - thread_0 = threading.Thread(target=thread_worker, args=(0, 'a', 1, None, 1)) - thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) - thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) - thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1)) - thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1)) + results[thread_id] = (value, expensive_function.num_calls) # First thread should complete and return its value + thread_0 = threading.Thread(target=thread_worker, args=(0, 'a', 1, None, 1)) thread_0.start() # t = 0, valid from t = .5 - 1.5, stale from t = 1.5 - 2.5 - - # Second thread will start while the first thread is still working and return None. + # Second thread will start while the first thread is still working, wait for the first thread + # to complete, and return the value computed by the first thread. time.sleep(.25) # t = .25 + thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) thread_1.start() # Third thread will start after the first value is computed, but before it expires. # its value. time.sleep(.5) # t = .75 + thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) thread_2.start() # Fourth thread will start after the first value has expired and will re-compute its value. # valid from t = 2.25 - 3.25, stale from t = 3.75 - 4.75. time.sleep(1) # t = 1.75 + thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1)) thread_3.start() # Fifth thread will start after the fourth thread has started to compute its value, but # before the first thread's stale cache has expired. time.sleep(.25) # t = 2 + thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1)) thread_4.start() + # Sixth thread will start after the fourth thread has finished to compute its value. + time.sleep(.5) # t = 2.5 + thread_5 = threading.Thread(target=thread_worker, args=(5, 'f', 1, None, 1)) + thread_5.start() thread_0.join() thread_1.join() thread_2.join() thread_3.join() thread_4.join() + thread_5.join() + + self.assertEqual(results, { + 0: ('a', 1), + 1: ('a', 1), + 2: ('a', 1), + 3: ('d', 2), + 4: ('a', 1), + 5: ('d', 2), + }) + + def test_get_or_set_takeover_after_lock_holder_fails(self): + + def expensive_function(result): + time.sleep(.5) + expensive_function.num_calls += 1 + if result is ValueError: + raise result() + return result + + expensive_function.num_calls = 0 + self.assertEqual(expensive_function.num_calls, 0) + results = {} + + def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_timeout): + try: + value = self.cache.get_or_set( + 'key', + lambda: expensive_function(return_value), + timeout, + lock_timeout, + stale_cache_timeout + ) + except ValueError: + results[thread_id] = (ValueError, expensive_function.num_calls) + else: + results[thread_id] = (value, expensive_function.num_calls) + + # First thread should fail + thread_0 = threading.Thread(target=thread_worker, args=(0, ValueError, 1, None, 1)) + thread_0.start() # t = 0 + # Second thread will start while the first thread is still working, wait for the first thread to die, + # then take over computation, and return its value. + time.sleep(.25) # t = .25 + thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) + thread_1.start() + # Third thread will start after the first thread has died but while the second thread is still working, + # wait for the second thread to complete its work, and return that value. + time.sleep(.5) # t = .75 + thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) + thread_2.start() + + thread_0.join() + thread_1.join() + thread_2.join() self.assertEqual(results, { - 0: 'a', - 1: None, - 2: 'a', - 3: 'd', - 4: 'a' + 0: (ValueError, 1), + 1: ('b', 2), + 2: ('b', 2), }) def assertMaxConnection(self, cache, max_num): From 98902a2ffdd6ec971cfec315d32c5ae638ef7a7c Mon Sep 17 00:00:00 2001 From: Florian Zimmermann Date: Thu, 28 Nov 2019 18:57:58 +0100 Subject: [PATCH 2/2] try to force a consistent order for the worker threads in the take over test case I *hope* that the build failure at https://travis-ci.org/sebleier/django-redis-cache/jobs/618267821 resulted from thread_2 being too slow so that the key was already expired. Otherwise I can't explain how *both* thread_1 and thread_2 got the values from their value func(). (I would at least have expected that thread_2 came before thread_1 and so they would both get 'c' instead of the expected 'b'...) Again, I really hope that this was just a bug in the test, not in the actual implementation... --- tests/testapp/tests/base_tests.py | 86 +++++++++++++++---------------- 1 file changed, 43 insertions(+), 43 deletions(-) diff --git a/tests/testapp/tests/base_tests.py b/tests/testapp/tests/base_tests.py index 09b635fa..f9ba0304 100644 --- a/tests/testapp/tests/base_tests.py +++ b/tests/testapp/tests/base_tests.py @@ -576,55 +576,55 @@ def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_ti 5: ('d', 2), }) - def test_get_or_set_takeover_after_lock_holder_fails(self): - - def expensive_function(result): - time.sleep(.5) - expensive_function.num_calls += 1 - if result is ValueError: - raise result() - return result - - expensive_function.num_calls = 0 - self.assertEqual(expensive_function.num_calls, 0) + def test_get_or_set_take_over_after_lock_holder_fails(self): results = {} - def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_timeout): + def get_or_set(func): + return self.cache.get_or_set('key', func, timeout=1, lock_timeout=None, stale_cache_timeout=1) + + def worker_0(): + def func(): + worker_0.entered.set() + time.sleep(.5) + worker_0.completed.set() + raise ValueError() + worker_1.started.wait() try: - value = self.cache.get_or_set( - 'key', - lambda: expensive_function(return_value), - timeout, - lock_timeout, - stale_cache_timeout - ) + results[0] = get_or_set(func) except ValueError: - results[thread_id] = (ValueError, expensive_function.num_calls) - else: - results[thread_id] = (value, expensive_function.num_calls) - - # First thread should fail - thread_0 = threading.Thread(target=thread_worker, args=(0, ValueError, 1, None, 1)) - thread_0.start() # t = 0 - # Second thread will start while the first thread is still working, wait for the first thread to die, - # then take over computation, and return its value. - time.sleep(.25) # t = .25 - thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) - thread_1.start() - # Third thread will start after the first thread has died but while the second thread is still working, - # wait for the second thread to complete its work, and return that value. - time.sleep(.5) # t = .75 - thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) - thread_2.start() - - thread_0.join() - thread_1.join() - thread_2.join() + results[0] = ValueError + worker_0.entered = threading.Event() + worker_0.completed = threading.Event() + + def worker_1(): + def func(): + worker_1.entered.set() + time.sleep(.5) + return 'b' + worker_1.started.set() + worker_0.entered.wait() + results[1] = get_or_set(func) + worker_1.started = threading.Event() + worker_1.entered = threading.Event() + + def worker_2(): + worker_1.entered.wait() + results[2] = get_or_set(lambda: 'c') + + threads = [ + threading.Thread(target=worker_0), + threading.Thread(target=worker_1), + threading.Thread(target=worker_2), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() self.assertEqual(results, { - 0: (ValueError, 1), - 1: ('b', 2), - 2: ('b', 2), + 0: ValueError, + 1: 'b', + 2: 'b', }) def assertMaxConnection(self, cache, max_num):