@@ -39,14 +39,22 @@ function _nthreads_in_pool(tpid::Int8)
3939 return Int (unsafe_load (p, tpid + 1 ))
4040end
4141
42+ function _tpid_to_sym (tpid:: Int8 )
43+ return tpid == 0 ? :interactive : :default
44+ end
45+
46+ function _sym_to_tpid (tp:: Symbol )
47+ return tp === :interactive ? Int8 (0 ) : Int8 (1 )
48+ end
49+
4250"""
4351 Threads.threadpool(tid = threadid()) -> Symbol
4452
4553Returns the specified thread's threadpool; either `:default` or `:interactive`.
4654"""
4755function threadpool (tid = threadid ())
4856 tpid = ccall (:jl_threadpoolid , Int8, (Int16,), tid- 1 )
49- return tpid == 0 ? :default : :interactive
57+ return _tpid_to_sym ( tpid)
5058end
5159
5260"""
@@ -67,24 +75,39 @@ See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
6775[`Distributed`](@ref man-distributed) standard library.
6876"""
6977function threadpoolsize (pool:: Symbol = :default )
70- if pool === :default
71- tpid = Int8 (0 )
72- elseif pool === :interactive
73- tpid = Int8 (1 )
78+ if pool === :default || pool === :interactive
79+ tpid = _sym_to_tpid (pool)
7480 else
7581 error (" invalid threadpool specified" )
7682 end
7783 return _nthreads_in_pool (tpid)
7884end
7985
86+ """
87+ threadpooltids(pool::Symbol)
88+
89+ Returns a vector of IDs of threads in the given pool.
90+ """
91+ function threadpooltids (pool:: Symbol )
92+ ni = _nthreads_in_pool (Int8 (0 ))
93+ if pool === :interactive
94+ return collect (1 : ni)
95+ elseif pool === :default
96+ return collect (ni+ 1 : ni+ _nthreads_in_pool (Int8 (1 )))
97+ else
98+ error (" invalid threadpool specified" )
99+ end
100+ end
101+
80102function threading_run (fun, static)
81103 ccall (:jl_enter_threaded_region , Cvoid, ())
82104 n = threadpoolsize ()
105+ tid_offset = threadpoolsize (:interactive )
83106 tasks = Vector {Task} (undef, n)
84107 for i = 1 : n
85108 t = Task (() -> fun (i)) # pass in tid
86109 t. sticky = static
87- static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, i- 1 )
110+ static && ccall (:jl_set_task_tid , Cint, (Any, Cint), t, tid_offset + i- 1 )
88111 tasks[i] = t
89112 schedule (t)
90113 end
@@ -287,6 +310,15 @@ macro threads(args...)
287310 return _threadsfor (ex. args[1 ], ex. args[2 ], sched)
288311end
289312
313+ function _spawn_set_thrpool (t:: Task , tp:: Symbol )
314+ tpid = _sym_to_tpid (tp)
315+ if _nthreads_in_pool (tpid) == 0
316+ tpid = _sym_to_tpid (:default )
317+ end
318+ ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), t, tpid)
319+ nothing
320+ end
321+
290322"""
291323 Threads.@spawn [:default|:interactive] expr
292324
@@ -315,7 +347,7 @@ the variable's value in the current task.
315347 A threadpool may be specified as of Julia 1.9.
316348"""
317349macro spawn (args... )
318- tpid = Int8 ( 0 )
350+ tp = :default
319351 na = length (args)
320352 if na == 2
321353 ttype, ex = args
@@ -325,9 +357,9 @@ macro spawn(args...)
325357 # TODO : allow unquoted symbols
326358 ttype = nothing
327359 end
328- if ttype === :interactive
329- tpid = Int8 ( 1 )
330- elseif ttype != = :default
360+ if ttype === :interactive || ttype === :default
361+ tp = ttype
362+ else
331363 throw (ArgumentError (" unsupported threadpool in @spawn: $ttype " ))
332364 end
333365 elseif na == 1
@@ -344,11 +376,7 @@ macro spawn(args...)
344376 let $ (letargs... )
345377 local task = Task ($ thunk)
346378 task. sticky = false
347- local tpid_actual = $ tpid
348- if _nthreads_in_pool (tpid_actual) == 0
349- tpid_actual = Int8 (0 )
350- end
351- ccall (:jl_set_task_threadpoolid , Cint, (Any, Int8), task, tpid_actual)
379+ _spawn_set_thrpool (task, $ (QuoteNode (tp)))
352380 if $ (Expr (:islocal , var))
353381 put! ($ var, task)
354382 end
0 commit comments