Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 110 additions & 57 deletions src/MultiThreadedCaches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ julia> get!(cache, 5) do
"""
struct MultiThreadedCache{K,V}
thread_caches::Vector{Dict{K,V}}
thread_locks::Vector{ReentrantLock}

base_cache::Dict{K,V} # Guarded by: base_cache_lock
base_cache_lock::ReentrantLock
Expand All @@ -58,11 +59,12 @@ struct MultiThreadedCache{K,V}

function MultiThreadedCache{K,V}(base_cache::Dict) where {K,V}
thread_caches = Dict{K,V}[]
thread_locks = ReentrantLock[]

base_cache_lock = ReentrantLock()
base_cache_futures = Dict{K,Channel{V}}()

return new(thread_caches, base_cache, base_cache_lock, base_cache_futures)
return new(thread_caches, thread_locks, base_cache, base_cache_lock, base_cache_futures)
end
end

Expand All @@ -81,6 +83,7 @@ function init_cache!(cache::MultiThreadedCache{K,V}) where {K,V}
# requested, so that the object will be allocated on the thread that will consume it.
# (This follows the guidance from Julia Base.)
resize!(cache.thread_caches, Threads.nthreads())
resize!(cache.thread_locks, Threads.nthreads())
return cache
end

Expand All @@ -91,9 +94,8 @@ end
# Based upon the thread-safe Global RNG implementation in the Random stdlib:
# https://github.com/JuliaLang/julia/blob/e4fcdf5b04fd9751ce48b0afc700330475b42443/stdlib/Random/src/RNGs.jl#L369-L385
# Get or lazily construct the per-thread cache when first requested.
function _thread_cache(mtcache::MultiThreadedCache)
function _thread_cache(mtcache::MultiThreadedCache, tid)
length(mtcache.thread_caches) >= Threads.nthreads() || _thread_cache_length_assert()
tid = Threads.threadid()
if @inbounds isassigned(mtcache.thread_caches, tid)
@inbounds cache = mtcache.thread_caches[tid]
else
Expand All @@ -105,73 +107,124 @@ function _thread_cache(mtcache::MultiThreadedCache)
return cache
end
@noinline _thread_cache_length_assert() = @assert false "** Must call `init_cache!(cache)` in your Module's __init__()! - length(cache.thread_caches) < Threads.nthreads() "
function _thread_lock(cache::MultiThreadedCache, tid)
length(cache.thread_locks) >= Threads.nthreads() || _thread_cache_length_assert()
if @inbounds isassigned(cache.thread_locks, tid)
@inbounds lock = cache.thread_locks[tid]
else
lock = eltype(cache.thread_locks)()
@inbounds cache.thread_locks[tid] = lock
end
return lock
end
@noinline _thread_cache_length_assert() = @assert false "** Must call `init_cache!(cache)` in your Module's __init__()! - length(cache.thread_caches) < Threads.nthreads() "


const CACHE_MISS = :__MultiThreadedCaches_key_not_found__

function Base.get!(func::Base.Callable, cache::MultiThreadedCache{K,V}, key) where {K,V}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh the indentation change here down is annoying. I'd recommend hiding whitespace in the github UI:

Screen Shot 2022-02-15 at 6 03 11 PM

# If the thread-local cache has the value, we can return immediately.
# We store tcache in a local variable, so that even if the Task migrates Threads, we are
# still operating on the same initial cache object.
tid = Threads.threadid()
tcache = _thread_cache(cache, tid)
tlock = _thread_lock(cache, tid)
# We have to lock during access to the thread-local dict, because it's possible that the
# Task may migrate to another thread by the end, and we really might be mutating the
# dict in parallel. But most of the time this lock should have 0 contention, since it's
# only held during get() and set!().
Base.@lock tlock begin
thread_local_cached_value_or_miss = get(tcache, key, CACHE_MISS)
if thread_local_cached_value_or_miss !== CACHE_MISS
return thread_local_cached_value_or_miss::V
end
end
# If not, we need to check the base cache.
return get!(_thread_cache(cache), key) do
# Even though we're using Thread-local caches, we still need to lock during
# construction to prevent multiple tasks redundantly constructing the same object,
# and potential thread safety violations due to Tasks migrating threads.
# NOTE that we only grab the lock if the key doesn't exist, so the mutex contention
# is not on the critical path for most accessses. :)
is_first_task = false
local future # used only if the base_cache doesn't have the key
# We lock the mutex, but for only a short, *constant time* duration, to grab the
# future for this key, or to create the future if it doesn't exist.
@lock cache.base_cache_lock begin
value_or_miss = get(cache.base_cache, key, CACHE_MISS)
if value_or_miss !== CACHE_MISS
return value_or_miss::V
end
future = get!(cache.base_cache_futures, key) do
is_first_task = true
Channel{V}(1)
# When we're done, call this function to set the result
@inline function _store_result!(v::V; test_haskey::Bool)
# Set the value into thread-local cache for the supplied key.
# Note that we must perform two separate get() and setindex!() calls, for
# concurrency-safety, in case the dict has been mutated by another task in between.
# TODO: For 100% concurrency-safety, we maybe want to lock around the get() above
# and the setindex!() here.. it's probably fine without it, but needs considering.
# Currently this is relying on get() and setindex!() having no yields.
Base.@lock tlock begin
if test_haskey
if !haskey(tcache, key)
setindex!(tcache, key, v)
end
else
setindex!(tcache, key, v)
end
end
if is_first_task
v = try
func()
catch e
# In the case of an exception, we abort the current computation of this
# key/value pair, and throw the exception onto the future, so that all
# pending tasks will see the exeption as well.
#
# NOTE: we could also cache the exception and throw it from now on, but this
# would make interactive development difficult, since once you fix the
# error, you'd have to clear out your cache. So instead, we just rethrow the
# exception and don't cache anything, so that you can fix the exception and
# continue on. (This means that throwing exceptions remains expensive.)

# close(::Channel, ::Exception) requires an Exception object, so if the user
# threw a non-Exception, we convert it to one, here.
e isa Exception || (e = ErrorException("Non-exception object thrown during get!(): $e"))
close(future, e)
# As below, the future isn't needed after this returns (see below).
delete!(cache.base_cache_futures, key)
rethrow(e)
end
# Finally, lock again for a *constant time* to insert the computed value into
# the shared cache, so that we can free the Channel and future gets can read
# from the shared base_cache.
end

# Even though we're using Thread-local caches, we still need to lock during
# construction to prevent multiple tasks redundantly constructing the same object,
# and potential thread safety violations due to Tasks migrating threads.
# NOTE that we only grab the lock if the key doesn't exist, so the mutex contention
# is not on the critical path for most accessses. :)
is_first_task = false
local future # used only if the base_cache doesn't have the key
# We lock the mutex, but for only a short, *constant time* duration, to grab the
# future for this key, or to create the future if it doesn't exist.
@lock cache.base_cache_lock begin
value_or_miss = get(cache.base_cache, key, CACHE_MISS)
if value_or_miss !== CACHE_MISS
return value_or_miss::V
end
future = get!(cache.base_cache_futures, key) do
is_first_task = true
Channel{V}(1)
end
end
if is_first_task
v = try
func()
catch e
# In the case of an exception, we abort the current computation of this
# key/value pair, and throw the exception onto the future, so that all
# pending tasks will see the exeption as well.
#
# NOTE: we could also cache the exception and throw it from now on, but this
# would make interactive development difficult, since once you fix the
# error, you'd have to clear out your cache. So instead, we just rethrow the
# exception and don't cache anything, so that you can fix the exception and
# continue on. (This means that throwing exceptions remains expensive.)

# close(::Channel, ::Exception) requires an Exception object, so if the user
# threw a non-Exception, we convert it to one, here.
e isa Exception || (e = ErrorException("Non-exception object thrown during get!(): $e"))
close(future, e)
# As below, the future isn't needed after this returns (see below).
@lock cache.base_cache_lock begin
cache.base_cache[key] = v
# We no longer need the Future, since all future requests will see the key
# in the base_cache. (Other Tasks may still hold a reference, but it will
# be GC'd once they have all completed.)
delete!(cache.base_cache_futures, key)
end
# Return v to any other Tasks that were blocking on this key.
put!(future, v)
return v
else
# Block on the future until the first task that asked for this key finishes
# computing a value for it.
return fetch(future)
rethrow(e)
end
# Finally, lock again for a *constant time* to insert the computed value into
# the shared cache, so that we can free the Channel and future gets can read
# from the shared base_cache.
@lock cache.base_cache_lock begin
cache.base_cache[key] = v
# We no longer need the Future, since all future requests will see the key
# in the base_cache. (Other Tasks may still hold a reference, but it will
# be GC'd once they have all completed.)
delete!(cache.base_cache_futures, key)
end
# Store the result in this thread-local dictionary.
_store_result!(v, test_haskey=false)
# Return v to any other Tasks that were blocking on this key.
put!(future, v)
return v
else
# Block on the future until the first task that asked for this key finishes
# computing a value for it.
v = fetch(future)
# Store the result in our original thread-local cache (if another Task hasn't) set
# it already.
_store_result!(v, test_haskey=true)
return v
end
end

Expand Down