diff --git a/redis_cache/backends/base.py b/redis_cache/backends/base.py index d473a635..6adb3a2d 100644 --- a/redis_cache/backends/base.py +++ b/redis_cache/backends/base.py @@ -1,4 +1,6 @@ from functools import wraps +import time +import uuid from django.core.cache.backends.base import ( BaseCache, DEFAULT_TIMEOUT, InvalidCacheBackendError, @@ -435,34 +437,42 @@ def get_or_set( lock_key = "__lock__" + key fresh_key = "__fresh__" + key - is_fresh = self._get(client, fresh_key) value = self._get(client, key) - - if is_fresh: + is_fresh = self._get(client, fresh_key) + if value is not None and is_fresh: return value - timeout = self.get_timeout(timeout) - lock = self.lock(lock_key, timeout=lock_timeout) - - acquired = lock.acquire(blocking=False) + fresh_timeout = self.get_timeout(timeout) + key_timeout = None if stale_cache_timeout is None else fresh_timeout + stale_cache_timeout - if acquired: - try: - value = func() - except Exception: - raise + token = uuid.uuid1().hex + lock = self.lock(lock_key, timeout=lock_timeout) + acquired = lock.acquire(blocking=False, token=token) + + while True: + if acquired: + try: + value = func() + except Exception: + raise + else: + pipeline = client.pipeline() + pipeline.set(key, self.prep_value(value), key_timeout) + pipeline.set(fresh_key, 1, fresh_timeout) + pipeline.execute() + return value + finally: + lock.release() + elif value is None: + time.sleep(lock.sleep) + value = self._get(client, key) + if value is None: + # If there is no value present yet, try to acquire the + # lock again (maybe the other thread died for some reason + # and we should try to compute the value instead). + acquired = lock.acquire(blocking=False, token=token) else: - key_timeout = ( - None if stale_cache_timeout is None else timeout + stale_cache_timeout - ) - pipeline = client.pipeline() - pipeline.set(key, self.prep_value(value), key_timeout) - pipeline.set(fresh_key, 1, timeout) - pipeline.execute() - finally: - lock.release() - - return value + return value def _reinsert_keys(self, client): keys = list(client.scan_iter(match='*')) diff --git a/tests/testapp/tests/base_tests.py b/tests/testapp/tests/base_tests.py index 719ec779..f9ba0304 100644 --- a/tests/testapp/tests/base_tests.py +++ b/tests/testapp/tests/base_tests.py @@ -530,45 +530,101 @@ def thread_worker(thread_id, return_value, timeout, lock_timeout, stale_cache_ti lock_timeout, stale_cache_timeout ) - results[thread_id] = value - - thread_0 = threading.Thread(target=thread_worker, args=(0, 'a', 1, None, 1)) - thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) - thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) - thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1)) - thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1)) + results[thread_id] = (value, expensive_function.num_calls) # First thread should complete and return its value + thread_0 = threading.Thread(target=thread_worker, args=(0, 'a', 1, None, 1)) thread_0.start() # t = 0, valid from t = .5 - 1.5, stale from t = 1.5 - 2.5 - - # Second thread will start while the first thread is still working and return None. + # Second thread will start while the first thread is still working, wait for the first thread + # to complete, and return the value computed by the first thread. time.sleep(.25) # t = .25 + thread_1 = threading.Thread(target=thread_worker, args=(1, 'b', 1, None, 1)) thread_1.start() # Third thread will start after the first value is computed, but before it expires. # its value. time.sleep(.5) # t = .75 + thread_2 = threading.Thread(target=thread_worker, args=(2, 'c', 1, None, 1)) thread_2.start() # Fourth thread will start after the first value has expired and will re-compute its value. # valid from t = 2.25 - 3.25, stale from t = 3.75 - 4.75. time.sleep(1) # t = 1.75 + thread_3 = threading.Thread(target=thread_worker, args=(3, 'd', 1, None, 1)) thread_3.start() # Fifth thread will start after the fourth thread has started to compute its value, but # before the first thread's stale cache has expired. time.sleep(.25) # t = 2 + thread_4 = threading.Thread(target=thread_worker, args=(4, 'e', 1, None, 1)) thread_4.start() + # Sixth thread will start after the fourth thread has finished to compute its value. + time.sleep(.5) # t = 2.5 + thread_5 = threading.Thread(target=thread_worker, args=(5, 'f', 1, None, 1)) + thread_5.start() thread_0.join() thread_1.join() thread_2.join() thread_3.join() thread_4.join() + thread_5.join() + + self.assertEqual(results, { + 0: ('a', 1), + 1: ('a', 1), + 2: ('a', 1), + 3: ('d', 2), + 4: ('a', 1), + 5: ('d', 2), + }) + + def test_get_or_set_take_over_after_lock_holder_fails(self): + results = {} + + def get_or_set(func): + return self.cache.get_or_set('key', func, timeout=1, lock_timeout=None, stale_cache_timeout=1) + + def worker_0(): + def func(): + worker_0.entered.set() + time.sleep(.5) + worker_0.completed.set() + raise ValueError() + worker_1.started.wait() + try: + results[0] = get_or_set(func) + except ValueError: + results[0] = ValueError + worker_0.entered = threading.Event() + worker_0.completed = threading.Event() + + def worker_1(): + def func(): + worker_1.entered.set() + time.sleep(.5) + return 'b' + worker_1.started.set() + worker_0.entered.wait() + results[1] = get_or_set(func) + worker_1.started = threading.Event() + worker_1.entered = threading.Event() + + def worker_2(): + worker_1.entered.wait() + results[2] = get_or_set(lambda: 'c') + + threads = [ + threading.Thread(target=worker_0), + threading.Thread(target=worker_1), + threading.Thread(target=worker_2), + ] + for thread in threads: + thread.start() + for thread in threads: + thread.join() self.assertEqual(results, { - 0: 'a', - 1: None, - 2: 'a', - 3: 'd', - 4: 'a' + 0: ValueError, + 1: 'b', + 2: 'b', }) def assertMaxConnection(self, cache, max_num):