diff --git a/base/task.jl b/base/task.jl index ffe8e5665b041..e407cbd62bbd6 100644 --- a/base/task.jl +++ b/base/task.jl @@ -253,7 +253,7 @@ istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed) Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1) function Threads.threadpool(t::Task) tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), t) - return tpid == 0 ? :default : :interactive + return Threads._tpid_to_sym(tpid) end task_result(t::Task) = t.result @@ -786,7 +786,7 @@ function enq_work(t::Task) if Threads.threadpoolsize(tp) == 1 # There's only one thread in the task's assigned thread pool; # use its work queue. - tid = (tp === :default) ? 1 : Threads.threadpoolsize(:default)+1 + tid = (tp === :interactive) ? 1 : Threads.threadpoolsize(:interactive)+1 ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid-1) push!(workqueue_for(tid), t) else diff --git a/base/threadingconstructs.jl b/base/threadingconstructs.jl index e7257759b15a9..f6e7ea4480305 100644 --- a/base/threadingconstructs.jl +++ b/base/threadingconstructs.jl @@ -39,6 +39,14 @@ function _nthreads_in_pool(tpid::Int8) return Int(unsafe_load(p, tpid + 1)) end +function _tpid_to_sym(tpid::Int8) + return tpid == 0 ? :interactive : :default +end + +function _sym_to_tpid(tp::Symbol) + return tp === :interactive ? Int8(0) : Int8(1) +end + """ Threads.threadpool(tid = threadid()) -> Symbol @@ -46,7 +54,7 @@ Returns the specified thread's threadpool; either `:default` or `:interactive`. """ function threadpool(tid = threadid()) tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) - return tpid == 0 ? :default : :interactive + return _tpid_to_sym(tpid) end """ @@ -67,24 +75,39 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the [`Distributed`](@ref man-distributed) standard library. """ function threadpoolsize(pool::Symbol = :default) - if pool === :default - tpid = Int8(0) - elseif pool === :interactive - tpid = Int8(1) + if pool === :default || pool === :interactive + tpid = _sym_to_tpid(pool) else error("invalid threadpool specified") end return _nthreads_in_pool(tpid) end +""" + threadpooltids(pool::Symbol) + +Returns a vector of IDs of threads in the given pool. +""" +function threadpooltids(pool::Symbol) + ni = _nthreads_in_pool(Int8(0)) + if pool === :interactive + return collect(1:ni) + elseif pool === :default + return collect(ni+1:ni+_nthreads_in_pool(Int8(1))) + else + error("invalid threadpool specified") + end +end + function threading_run(fun, static) ccall(:jl_enter_threaded_region, Cvoid, ()) n = threadpoolsize() + tid_offset = threadpoolsize(:interactive) tasks = Vector{Task}(undef, n) for i = 1:n t = Task(() -> fun(i)) # pass in tid t.sticky = static - static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, i-1) + static && ccall(:jl_set_task_tid, Cint, (Any, Cint), t, tid_offset + i-1) tasks[i] = t schedule(t) end @@ -287,6 +310,15 @@ macro threads(args...) return _threadsfor(ex.args[1], ex.args[2], sched) end +function _spawn_set_thrpool(t::Task, tp::Symbol) + tpid = _sym_to_tpid(tp) + if _nthreads_in_pool(tpid) == 0 + tpid = _sym_to_tpid(:default) + end + ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), t, tpid) + nothing +end + """ Threads.@spawn [:default|:interactive] expr @@ -315,7 +347,7 @@ the variable's value in the current task. A threadpool may be specified as of Julia 1.9. """ macro spawn(args...) - tpid = Int8(0) + tp = :default na = length(args) if na == 2 ttype, ex = args @@ -325,9 +357,9 @@ macro spawn(args...) # TODO: allow unquoted symbols ttype = nothing end - if ttype === :interactive - tpid = Int8(1) - elseif ttype !== :default + if ttype === :interactive || ttype === :default + tp = ttype + else throw(ArgumentError("unsupported threadpool in @spawn: $ttype")) end elseif na == 1 @@ -344,11 +376,7 @@ macro spawn(args...) let $(letargs...) local task = Task($thunk) task.sticky = false - local tpid_actual = $tpid - if _nthreads_in_pool(tpid_actual) == 0 - tpid_actual = Int8(0) - end - ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, tpid_actual) + _spawn_set_thrpool(task, $(QuoteNode(tp))) if $(Expr(:islocal, var)) put!($var, task) end diff --git a/src/threading.c b/src/threading.c index db9df0bad0dde..f909f41ac5c64 100644 --- a/src/threading.c +++ b/src/threading.c @@ -600,17 +600,16 @@ void jl_init_threading(void) // specified on the command line (and so are in `jl_options`) or by the // environment variable. Set the globals `jl_n_threadpools`, `jl_n_threads` // and `jl_n_threads_per_pool`. - jl_n_threadpools = 1; + jl_n_threadpools = 2; int16_t nthreads = JULIA_NUM_THREADS; int16_t nthreadsi = 0; char *endptr, *endptri; if (jl_options.nthreads != 0) { // --threads specified - jl_n_threadpools = jl_options.nthreadpools; nthreads = jl_options.nthreads_per_pool[0]; if (nthreads < 0) nthreads = jl_effective_threads(); - if (jl_n_threadpools == 2) + if (jl_options.nthreadpools == 2) nthreadsi = jl_options.nthreads_per_pool[1]; } else if ((cp = getenv(NUM_THREADS_NAME))) { // ENV[NUM_THREADS_NAME] specified @@ -635,15 +634,13 @@ void jl_init_threading(void) if (errno != 0 || endptri == cp || nthreadsi < 0) nthreadsi = 0; } - if (nthreadsi > 0) - jl_n_threadpools++; } } jl_all_tls_states_size = nthreads + nthreadsi; jl_n_threads_per_pool = (int*)malloc_s(2 * sizeof(int)); - jl_n_threads_per_pool[0] = nthreads; - jl_n_threads_per_pool[1] = nthreadsi; + jl_n_threads_per_pool[0] = nthreadsi; + jl_n_threads_per_pool[1] = nthreads; jl_atomic_store_release(&jl_all_tls_states, (jl_ptls_t*)calloc(jl_all_tls_states_size, sizeof(jl_ptls_t))); jl_atomic_store_release(&jl_n_threads, jl_all_tls_states_size); diff --git a/test/threadpool_use.jl b/test/threadpool_use.jl index 64227c8a8110b..e5ea5f95cf4ff 100644 --- a/test/threadpool_use.jl +++ b/test/threadpool_use.jl @@ -4,8 +4,10 @@ using Test using Base.Threads @test nthreadpools() == 2 -@test threadpool() === :default -@test threadpool(2) === :interactive +@test threadpool() === :interactive +@test threadpool(2) === :default @test fetch(Threads.@spawn Threads.threadpool()) === :default @test fetch(Threads.@spawn :default Threads.threadpool()) === :default @test fetch(Threads.@spawn :interactive Threads.threadpool()) === :interactive +@test Threads.threadpooltids(:interactive) == [1] +@test Threads.threadpooltids(:default) == [2]