diff --git a/src/MultiThreadedCaches.jl b/src/MultiThreadedCaches.jl index c3587b5..7979e2a 100644 --- a/src/MultiThreadedCaches.jl +++ b/src/MultiThreadedCaches.jl @@ -45,6 +45,7 @@ julia> get!(cache, 5) do """ struct MultiThreadedCache{K,V} thread_caches::Vector{Dict{K,V}} + thread_locks::Vector{ReentrantLock} base_cache::Dict{K,V} # Guarded by: base_cache_lock base_cache_lock::ReentrantLock @@ -58,11 +59,12 @@ struct MultiThreadedCache{K,V} function MultiThreadedCache{K,V}(base_cache::Dict) where {K,V} thread_caches = Dict{K,V}[] + thread_locks = ReentrantLock[] base_cache_lock = ReentrantLock() base_cache_futures = Dict{K,Channel{V}}() - return new(thread_caches, base_cache, base_cache_lock, base_cache_futures) + return new(thread_caches, thread_locks, base_cache, base_cache_lock, base_cache_futures) end end @@ -81,6 +83,7 @@ function init_cache!(cache::MultiThreadedCache{K,V}) where {K,V} # requested, so that the object will be allocated on the thread that will consume it. # (This follows the guidance from Julia Base.) resize!(cache.thread_caches, Threads.nthreads()) + resize!(cache.thread_locks, Threads.nthreads()) return cache end @@ -91,9 +94,8 @@ end # Based upon the thread-safe Global RNG implementation in the Random stdlib: # https://github.com/JuliaLang/julia/blob/e4fcdf5b04fd9751ce48b0afc700330475b42443/stdlib/Random/src/RNGs.jl#L369-L385 # Get or lazily construct the per-thread cache when first requested. -function _thread_cache(mtcache::MultiThreadedCache) +function _thread_cache(mtcache::MultiThreadedCache, tid) length(mtcache.thread_caches) >= Threads.nthreads() || _thread_cache_length_assert() - tid = Threads.threadid() if @inbounds isassigned(mtcache.thread_caches, tid) @inbounds cache = mtcache.thread_caches[tid] else @@ -105,73 +107,124 @@ function _thread_cache(mtcache::MultiThreadedCache) return cache end @noinline _thread_cache_length_assert() = @assert false "** Must call `init_cache!(cache)` in your Module's __init__()! - length(cache.thread_caches) < Threads.nthreads() " +function _thread_lock(cache::MultiThreadedCache, tid) + length(cache.thread_locks) >= Threads.nthreads() || _thread_cache_length_assert() + if @inbounds isassigned(cache.thread_locks, tid) + @inbounds lock = cache.thread_locks[tid] + else + lock = eltype(cache.thread_locks)() + @inbounds cache.thread_locks[tid] = lock + end + return lock +end +@noinline _thread_cache_length_assert() = @assert false "** Must call `init_cache!(cache)` in your Module's __init__()! - length(cache.thread_caches) < Threads.nthreads() " const CACHE_MISS = :__MultiThreadedCaches_key_not_found__ function Base.get!(func::Base.Callable, cache::MultiThreadedCache{K,V}, key) where {K,V} # If the thread-local cache has the value, we can return immediately. + # We store tcache in a local variable, so that even if the Task migrates Threads, we are + # still operating on the same initial cache object. + tid = Threads.threadid() + tcache = _thread_cache(cache, tid) + tlock = _thread_lock(cache, tid) + # We have to lock during access to the thread-local dict, because it's possible that the + # Task may migrate to another thread by the end, and we really might be mutating the + # dict in parallel. But most of the time this lock should have 0 contention, since it's + # only held during get() and set!(). + Base.@lock tlock begin + thread_local_cached_value_or_miss = get(tcache, key, CACHE_MISS) + if thread_local_cached_value_or_miss !== CACHE_MISS + return thread_local_cached_value_or_miss::V + end + end # If not, we need to check the base cache. - return get!(_thread_cache(cache), key) do - # Even though we're using Thread-local caches, we still need to lock during - # construction to prevent multiple tasks redundantly constructing the same object, - # and potential thread safety violations due to Tasks migrating threads. - # NOTE that we only grab the lock if the key doesn't exist, so the mutex contention - # is not on the critical path for most accessses. :) - is_first_task = false - local future # used only if the base_cache doesn't have the key - # We lock the mutex, but for only a short, *constant time* duration, to grab the - # future for this key, or to create the future if it doesn't exist. - @lock cache.base_cache_lock begin - value_or_miss = get(cache.base_cache, key, CACHE_MISS) - if value_or_miss !== CACHE_MISS - return value_or_miss::V - end - future = get!(cache.base_cache_futures, key) do - is_first_task = true - Channel{V}(1) + # When we're done, call this function to set the result + @inline function _store_result!(v::V; test_haskey::Bool) + # Set the value into thread-local cache for the supplied key. + # Note that we must perform two separate get() and setindex!() calls, for + # concurrency-safety, in case the dict has been mutated by another task in between. + # TODO: For 100% concurrency-safety, we maybe want to lock around the get() above + # and the setindex!() here.. it's probably fine without it, but needs considering. + # Currently this is relying on get() and setindex!() having no yields. + Base.@lock tlock begin + if test_haskey + if !haskey(tcache, key) + setindex!(tcache, key, v) + end + else + setindex!(tcache, key, v) end end - if is_first_task - v = try - func() - catch e - # In the case of an exception, we abort the current computation of this - # key/value pair, and throw the exception onto the future, so that all - # pending tasks will see the exeption as well. - # - # NOTE: we could also cache the exception and throw it from now on, but this - # would make interactive development difficult, since once you fix the - # error, you'd have to clear out your cache. So instead, we just rethrow the - # exception and don't cache anything, so that you can fix the exception and - # continue on. (This means that throwing exceptions remains expensive.) - - # close(::Channel, ::Exception) requires an Exception object, so if the user - # threw a non-Exception, we convert it to one, here. - e isa Exception || (e = ErrorException("Non-exception object thrown during get!(): $e")) - close(future, e) - # As below, the future isn't needed after this returns (see below). - delete!(cache.base_cache_futures, key) - rethrow(e) - end - # Finally, lock again for a *constant time* to insert the computed value into - # the shared cache, so that we can free the Channel and future gets can read - # from the shared base_cache. + end + + # Even though we're using Thread-local caches, we still need to lock during + # construction to prevent multiple tasks redundantly constructing the same object, + # and potential thread safety violations due to Tasks migrating threads. + # NOTE that we only grab the lock if the key doesn't exist, so the mutex contention + # is not on the critical path for most accessses. :) + is_first_task = false + local future # used only if the base_cache doesn't have the key + # We lock the mutex, but for only a short, *constant time* duration, to grab the + # future for this key, or to create the future if it doesn't exist. + @lock cache.base_cache_lock begin + value_or_miss = get(cache.base_cache, key, CACHE_MISS) + if value_or_miss !== CACHE_MISS + return value_or_miss::V + end + future = get!(cache.base_cache_futures, key) do + is_first_task = true + Channel{V}(1) + end + end + if is_first_task + v = try + func() + catch e + # In the case of an exception, we abort the current computation of this + # key/value pair, and throw the exception onto the future, so that all + # pending tasks will see the exeption as well. + # + # NOTE: we could also cache the exception and throw it from now on, but this + # would make interactive development difficult, since once you fix the + # error, you'd have to clear out your cache. So instead, we just rethrow the + # exception and don't cache anything, so that you can fix the exception and + # continue on. (This means that throwing exceptions remains expensive.) + + # close(::Channel, ::Exception) requires an Exception object, so if the user + # threw a non-Exception, we convert it to one, here. + e isa Exception || (e = ErrorException("Non-exception object thrown during get!(): $e")) + close(future, e) + # As below, the future isn't needed after this returns (see below). @lock cache.base_cache_lock begin - cache.base_cache[key] = v - # We no longer need the Future, since all future requests will see the key - # in the base_cache. (Other Tasks may still hold a reference, but it will - # be GC'd once they have all completed.) delete!(cache.base_cache_futures, key) end - # Return v to any other Tasks that were blocking on this key. - put!(future, v) - return v - else - # Block on the future until the first task that asked for this key finishes - # computing a value for it. - return fetch(future) + rethrow(e) end + # Finally, lock again for a *constant time* to insert the computed value into + # the shared cache, so that we can free the Channel and future gets can read + # from the shared base_cache. + @lock cache.base_cache_lock begin + cache.base_cache[key] = v + # We no longer need the Future, since all future requests will see the key + # in the base_cache. (Other Tasks may still hold a reference, but it will + # be GC'd once they have all completed.) + delete!(cache.base_cache_futures, key) + end + # Store the result in this thread-local dictionary. + _store_result!(v, test_haskey=false) + # Return v to any other Tasks that were blocking on this key. + put!(future, v) + return v + else + # Block on the future until the first task that asked for this key finishes + # computing a value for it. + v = fetch(future) + # Store the result in our original thread-local cache (if another Task hasn't) set + # it already. + _store_result!(v, test_haskey=true) + return v end end