Skip to content

Commit 46838f8

Browse files
committed
Random: better handling of the "global seed" (using TLS)
We maintain a "global seed" for this feature of `@testset`: > Before the execution of the body of a @testset, there is an implicit call to Random.seed!(seed) where seed is the current seed of the global RNG. Moreover, after the execution of the body, the state of the global RNG is restored to what it was before the @testset. This is meant to ease reproducibility in case of failure, and to allow seamless re-arrangements of @testsets regardless of their side-effect on the global RNG state. But since we don't use `MersenneTwister` as the "global RNG" anymore, we need to maintain a separate "global seed" object. So far we literally used a global object `Random.GLOBAL_SEED` storing the original seed, but it's not robust when multi-tasking is involved: e.g. ``` seed!(0) x = rand() seed!(0) @sync begin @async @testset "A" begin seed!(1) # reset GLOBAL_SEED to V2 sleep(2) end # reset GLOBAL_SEED to its original value V1 sleep(0.5) @async @testset "B" begin # here seed!(2) above has already been called # so @testset B recorded value V2 as the "original" value of GLOBAL_SEED seed!(2) sleep(2) # here @testset A already finished end # reset GLOBAL_SEED to the wrong original value V2 end @testset "main task" begin # async tests didn't mutate this task's global seed @test x == rand() # fails! end ``` So we store here a "global seed" in `task_local_storage()`, which is set when `seed!()` is invoked without an explicit RNG, and defaults to `Random.GLOBAL_SEED`, which is set only once when `Random` is loaded. And instead of actually storing a seed, we store a copy of the RNG state. This is still not ideal, in that at the beginning of `@testset "A"` or `@testset "B"`, we can't do `@test x == rand()`, because these are in separate tasks, so the global seed defaults to `Random.GLOBAL_SEED`, and not to the global seed of the parent's task; there might be a nice way to handle that, but at least different tasks don't corrupt each-other's seeds.
1 parent ab992b9 commit 46838f8

File tree

3 files changed

+90
-37
lines changed

3 files changed

+90
-37
lines changed

stdlib/Random/src/RNGs.jl

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -387,23 +387,28 @@ const GLOBAL_RNG = default_rng()
387387
# the following feature of `@testset`:
388388
# > Before the execution of the body of a `@testset`, there is an implicit
389389
# > call to `Random.seed!(seed)` where `seed` is the current seed of the global RNG.
390-
# But the global RNG is now TaskLocalRNG() and doesn't store its seed; in order to not break `@testset`, we now
391-
# store the seed used in a call like `seed!(seed)` *without* an explicit RNG in `GLOBAL_SEED`; the wording of the
392-
# feature above was sufficiently unprecise (e.g. what exactly is the "global RNG"?) that this solution seems fine
393-
GLOBAL_SEED = 0
394-
# only the Test module is allowed to use this function!
395-
set_global_seed!(seed) = global GLOBAL_SEED = seed
396-
397-
# seed the "global" RNG
390+
# But the global RNG is now `TaskLocalRNG()` and doesn't store its seed; in order to not break `@testset`,
391+
# in a call like `seed!(seed)` *without* an explicit RNG, we now store the state of `TaskLocalRNG()` in
392+
# `task_local_storage()`
393+
394+
# GLOBAL_SEED is used as a fall-back when no tls seed is found
395+
# only `Random.__init__` is allowed to set it
396+
const GLOBAL_SEED = Xoshiro(0, 0, 0, 0, 0)
397+
398+
get_tls_seed() = get!(() -> copy(GLOBAL_SEED), task_local_storage(),
399+
:__RANDOM_GLOBAL_RNG_SEED_uBlmfA8ZS__)::Xoshiro
400+
401+
# seed the default RNG
398402
function seed!(seed=nothing)
399-
# the seed is not left as `nothing`, as storing `nothing` as the global seed wouldn't lead to reproducible streams
400-
seed = @something seed rand(RandomDevice(), UInt128)
401-
set_global_seed!(seed)
402403
seed!(default_rng(), seed)
404+
copy!(get_tls_seed(), default_rng())
405+
default_rng()
403406
end
404407

405408
function __init__()
406-
seed!()
409+
# do not call no-arg `seed!()` to not update `task_local_storage()` unnecessarily at startup
410+
seed!(default_rng())
411+
copy!(GLOBAL_SEED, TaskLocalRNG())
407412
ccall(:jl_gc_init_finalizer_rng_state, Cvoid, ())
408413
end
409414

stdlib/Test/src/Test.jl

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1589,11 +1589,11 @@ function testset_beginend_call(args, tests, source)
15891589
# we reproduce the logic of guardseed, but this function
15901590
# cannot be used as it changes slightly the semantic of @testset,
15911591
# by wrapping the body in a function
1592-
local oldrng = copy(default_rng())
1593-
local oldseed = Random.GLOBAL_SEED
1592+
local default_rng_orig = copy(default_rng())
1593+
local tls_seed_orig = copy(Random.get_tls_seed())
15941594
try
1595-
# default RNG is re-seeded with its own seed to ease reproduce a failed test
1596-
Random.seed!(Random.GLOBAL_SEED)
1595+
# default RNG is reset to its state from last `seed!()` to ease reproduce a failed test
1596+
copy!(Random.default_rng(), tls_seed_orig)
15971597
let
15981598
$(esc(tests))
15991599
end
@@ -1608,8 +1608,8 @@ function testset_beginend_call(args, tests, source)
16081608
record(ts, Error(:nontest_error, Expr(:tuple), err, Base.current_exceptions(), $(QuoteNode(source))))
16091609
end
16101610
finally
1611-
copy!(default_rng(), oldrng)
1612-
Random.set_global_seed!(oldseed)
1611+
copy!(default_rng(), default_rng_orig)
1612+
copy!(Random.get_tls_seed(), tls_seed_orig)
16131613
pop_testset()
16141614
ret = finish(ts)
16151615
end
@@ -1674,10 +1674,7 @@ function testset_forloop(args, testloop, source)
16741674
finish_errored = true
16751675
push!(arr, finish(ts))
16761676
finish_errored = false
1677-
1678-
# it's 1000 times faster to copy from tmprng rather than calling Random.seed!
1679-
copy!(default_rng(), tmprng)
1680-
1677+
copy!(default_rng(), tls_seed_orig)
16811678
end
16821679
ts = if ($testsettype === $DefaultTestSet) && $(isa(source, LineNumberNode))
16831680
$(testsettype)($desc; source=$(QuoteNode(source.file)), $options...)
@@ -1703,10 +1700,9 @@ function testset_forloop(args, testloop, source)
17031700
local first_iteration = true
17041701
local ts
17051702
local finish_errored = false
1706-
local oldrng = copy(default_rng())
1707-
local oldseed = Random.GLOBAL_SEED
1708-
Random.seed!(Random.GLOBAL_SEED)
1709-
local tmprng = copy(default_rng())
1703+
local default_rng_orig = copy(default_rng())
1704+
local tls_seed_orig = copy(Random.get_tls_seed())
1705+
copy!(Random.default_rng(), tls_seed_orig)
17101706
try
17111707
let
17121708
$(Expr(:for, Expr(:block, [esc(v) for v in loopvars]...), blk))
@@ -1717,8 +1713,8 @@ function testset_forloop(args, testloop, source)
17171713
pop_testset()
17181714
push!(arr, finish(ts))
17191715
end
1720-
copy!(default_rng(), oldrng)
1721-
Random.set_global_seed!(oldseed)
1716+
copy!(default_rng(), default_rng_orig)
1717+
copy!(Random.get_tls_seed(), tls_seed_orig)
17221718
end
17231719
arr
17241720
end

stdlib/Test/test/runtests.jl

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,6 +1032,7 @@ end
10321032
# i.e. it behaves as if it was wrapped in a `guardseed(GLOBAL_SEED)` block
10331033
seed = rand(UInt128)
10341034
Random.seed!(seed)
1035+
seeded_state = copy(Random.default_rng())
10351036
a = rand()
10361037
@testset begin
10371038
# global RNG must re-seeded at the beginning of @testset
@@ -1043,31 +1044,82 @@ end
10431044
# the @testset's above must have no consequence for rand() below
10441045
b = rand()
10451046
Random.seed!(seed)
1047+
@test Random.default_rng() == seeded_state
10461048
@test a == rand()
10471049
@test b == rand()
10481050

10491051
# Even when seed!() is called within a testset A, subsequent testsets
10501052
# should start with the same "global RNG state" as what A started with,
10511053
# such that the test `refvalue == rand(Int)` below succeeds.
1052-
# Currently, this means that Random.GLOBAL_SEED has to be restored,
1054+
# Currently, this means that `Random.get_tls_seed()` has to be restored,
10531055
# in addition to the state of Random.default_rng().
1054-
GLOBAL_SEED_orig = Random.GLOBAL_SEED
1056+
tls_seed_orig = copy(Random.get_tls_seed())
10551057
local refvalue
1056-
@testset "GLOBAL_SEED is also preserved (setup)" begin
1057-
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
1058+
@testset "TLS seed is also preserved (setup)" begin
1059+
@test tls_seed_orig == Random.get_tls_seed()
10581060
refvalue = rand(Int)
10591061
Random.seed!()
1060-
@test GLOBAL_SEED_orig != Random.GLOBAL_SEED
1062+
@test tls_seed_orig != Random.get_tls_seed()
10611063
end
1062-
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
1063-
@testset "GLOBAL_SEED is also preserved (forloop)" for _=1:3
1064+
@test tls_seed_orig == Random.get_tls_seed()
1065+
@testset "TLS seed is also preserved (forloop)" for _=1:3
10641066
@test refvalue == rand(Int)
10651067
Random.seed!()
10661068
end
1067-
@test GLOBAL_SEED_orig == Random.GLOBAL_SEED
1068-
@testset "GLOBAL_SEED is also preserved (beginend)" begin
1069+
@test tls_seed_orig == Random.get_tls_seed()
1070+
@testset "TLS seed is also preserved (beginend)" begin
10691071
@test refvalue == rand(Int)
10701072
end
1073+
1074+
# @testset below is not compatible with e.g. v1.9, but it still fails there (at "main task")
1075+
# when deleting lines using get_tls_seed() or GLOBAL_SEED
1076+
@testset "TLS seed and concurrency" begin
1077+
# Even with multi-tasking, the TLS seed must stay consistent: the default_rng() state
1078+
# is reset to the "global seed" at the beginning, and the "global seed" is reset to what
1079+
# it was at the end of the testset; make sure that distinct tasks don't see the mutation
1080+
# of this "global seed" (iow, it's task-local)
1081+
seed = rand(UInt128)
1082+
Random.seed!(seed)
1083+
seeded_state = copy(Random.default_rng())
1084+
a = rand()
1085+
1086+
ch = Channel{Nothing}()
1087+
@sync begin
1088+
@async begin
1089+
@testset "task 1" begin
1090+
# tick 1
1091+
# this task didn't call seed! explicitly (yet), so its TaskLocalRNG() should have been
1092+
# reset to `Random.GLOBAL_SEED` at the beginning of `@testset`
1093+
@test Random.GLOBAL_SEED == Random.default_rng()
1094+
seed!()
1095+
put!(ch, nothing) # tick 1 -> tick 2
1096+
take!(ch) # tick 3
1097+
end
1098+
put!(ch, nothing) # tick 3 -> tick 4
1099+
end
1100+
@async begin
1101+
take!(ch) # tick 2
1102+
# @testset below will record the current TLS "seed" and reset default_rng() to
1103+
# this value;
1104+
# it must not be affected by the fact that "task 1" called `seed!()` first
1105+
@test Random.get_tls_seed() == Random.GLOBAL_SEED
1106+
1107+
@testset "task 2" begin
1108+
@test Random.GLOBAL_SEED == Random.default_rng()
1109+
seed!()
1110+
put!(ch, nothing) # tick 2 -> tick 3
1111+
take!(ch) # tick 4
1112+
end
1113+
# when `@testset` of task 2 finishes, which is after `@testset` from task 1,
1114+
# it resets `get_tls_seed()` to what it was before starting:
1115+
@test Random.get_tls_seed() == Random.GLOBAL_SEED
1116+
end
1117+
end
1118+
@testset "main task" begin
1119+
@test Random.default_rng() == seeded_state
1120+
@test a == rand()
1121+
end
1122+
end
10711123
end
10721124

10731125
@testset "InterruptExceptions #21043" begin

0 commit comments

Comments
 (0)