Skip to content

Commit 496a3dd

Browse files
committed
MersenneTwister: hash seeds like for Xoshiro
This addresses a part of #37165: > It's common that sequential seeds for RNGs are not as independent as one might like. This clears out this problem for `MersenneTwister`, and makes it easy to add the same feature to other RNGs via a new `hash_seed` function, which replaces `make_seed`. This is an alternative to #37766.
1 parent c476d84 commit 496a3dd

File tree

5 files changed

+67
-42
lines changed

5 files changed

+67
-42
lines changed

stdlib/Random/src/DSFMT.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ function dsfmt_init_gen_rand(s::DSFMT_state, seed::UInt32)
6565
s.val, seed)
6666
end
6767

68-
function dsfmt_init_by_array(s::DSFMT_state, seed::Vector{UInt32})
68+
function dsfmt_init_by_array(s::DSFMT_state, seed::StridedVector{UInt32})
69+
strides(seed) == (1,) || throw(ArgumentError("seed must have its stride equal to 1"))
6970
ccall((:dsfmt_init_by_array,:libdSFMT),
7071
Cvoid,
7172
(Ptr{Cvoid}, Ptr{UInt32}, Int32),

stdlib/Random/src/RNGs.jl

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The entropy is obtained from the operating system.
1212
"""
1313
struct RandomDevice <: AbstractRNG; end
1414
RandomDevice(seed::Nothing) = RandomDevice()
15-
seed!(rng::RandomDevice) = rng
15+
seed!(rng::RandomDevice, ::Nothing) = rng
1616

1717
rand(rd::RandomDevice, sp::SamplerBoolBitInteger) = Libc.getrandom!(Ref{sp[]}())[]
1818
rand(rd::RandomDevice, ::SamplerType{Bool}) = rand(rd, UInt8) % Bool
@@ -44,7 +44,7 @@ const MT_CACHE_I = 501 << 4 # number of bytes in the UInt128 cache
4444
@assert dsfmt_get_min_array_size() <= MT_CACHE_F
4545

4646
mutable struct MersenneTwister <: AbstractRNG
47-
seed::Vector{UInt32}
47+
seed::Any
4848
state::DSFMT_state
4949
vals::Vector{Float64}
5050
ints::Vector{UInt128}
@@ -70,7 +70,7 @@ mutable struct MersenneTwister <: AbstractRNG
7070
end
7171
end
7272

73-
MersenneTwister(seed::Vector{UInt32}, state::DSFMT_state) =
73+
MersenneTwister(seed, state::DSFMT_state) =
7474
MersenneTwister(seed, state,
7575
Vector{Float64}(undef, MT_CACHE_F),
7676
Vector{UInt128}(undef, MT_CACHE_I >> 4),
@@ -115,7 +115,7 @@ MersenneTwister(seed=nothing) =
115115

116116

117117
function copy!(dst::MersenneTwister, src::MersenneTwister)
118-
copyto!(resize!(dst.seed, length(src.seed)), src.seed)
118+
dst.seed = src.seed
119119
copy!(dst.state, src.state)
120120
copyto!(dst.vals, src.vals)
121121
copyto!(dst.ints, src.ints)
@@ -129,7 +129,7 @@ function copy!(dst::MersenneTwister, src::MersenneTwister)
129129
end
130130

131131
copy(src::MersenneTwister) =
132-
MersenneTwister(copy(src.seed), copy(src.state), copy(src.vals), copy(src.ints),
132+
MersenneTwister(src.seed, copy(src.state), copy(src.vals), copy(src.ints),
133133
src.idxF, src.idxI, src.adv, src.adv_jump, src.adv_vals, src.adv_ints)
134134

135135

@@ -144,12 +144,10 @@ hash(r::MersenneTwister, h::UInt) =
144144

145145
function show(io::IO, rng::MersenneTwister)
146146
# seed
147-
seed = from_seed(rng.seed)
148-
seed_str = seed <= typemax(Int) ? string(seed) : "0x" * string(seed, base=16) # DWIM
149147
if rng.adv_jump == 0 && rng.adv == 0
150-
return print(io, MersenneTwister, "(", seed_str, ")")
148+
return print(io, MersenneTwister, "(", repr(rng.seed), ")")
151149
end
152-
print(io, MersenneTwister, "(", seed_str, ", (")
150+
print(io, MersenneTwister, "(", repr(rng.seed), ", (")
153151
# state
154152
adv = Integer[rng.adv_jump, rng.adv]
155153
if rng.adv_vals != -1 || rng.adv_ints != -1
@@ -277,48 +275,72 @@ end
277275

278276
### seeding
279277

280-
#### make_seed()
278+
#### random_seed() & hash_seed()
281279

282-
# make_seed produces values of type Vector{UInt32}, suitable for MersenneTwister seeding
283-
function make_seed()
280+
# random_seed tries to produce a random seed of type UInt128 from system entropy
281+
function random_seed()
284282
try
285-
return rand(RandomDevice(), UInt32, 4)
283+
# as MersenneTwister prints its seed when `show`ed, 128 bits is a good compromise for
284+
# almost surely always getting distinct seeds, while having them printed reasonably tersely
285+
return rand(RandomDevice(), UInt128)
286286
catch ex
287287
ex isa IOError || rethrow()
288288
@warn "Entropy pool not available to seed RNG; using ad-hoc entropy sources."
289-
return make_seed(Libc.rand())
289+
return Libc.rand()
290290
end
291291
end
292292

293-
function make_seed(n::Integer)
294-
n < 0 && throw(DomainError(n, "`n` must be non-negative."))
295-
seed = UInt32[]
293+
function hash_seed(seed::Integer)
294+
seed < 0 && throw(DomainError(seed, "`seed` must be non-negative."))
295+
ctx = SHA.SHA2_256_CTX()
296296
while true
297-
push!(seed, n & 0xffffffff)
298-
n >>= 32
299-
if n == 0
300-
return seed
297+
SHA.update!(ctx, reinterpret(NTuple{4, UInt8}, (seed % UInt32) & 0xffffffff))
298+
seed >>= 32
299+
if seed == 0
300+
break
301301
end
302302
end
303+
SHA.digest!(ctx)
303304
end
304305

305-
# inverse of make_seed(::Integer)
306-
from_seed(a::Vector{UInt32})::BigInt = sum(a[i] * big(2)^(32*(i-1)) for i in 1:length(a))
306+
function hash_seed(seed::Union{DenseArray{UInt32}, DenseArray{UInt64}})
307+
ctx = SHA.SHA2_256_CTX()
308+
SHA.update!(ctx, reinterpret(UInt8, seed))
309+
SHA.digest!(ctx)
310+
end
311+
312+
313+
"""
314+
hash_seed(seed) -> AbstractVector{UInt8}
315+
316+
Return a cryptographic hash of `seed` of size 256 bits.
317+
`seed` can currently be of type `Union{Integer, DenseArray{UInt32}, DenseArray{UInt64}}`,
318+
but modules can extend this function for types they own.
307319
320+
This is an internal function subject to change.
321+
"""
322+
seed
308323

309324
#### seed!()
310325

311-
function seed!(r::MersenneTwister, seed::Vector{UInt32})
312-
copyto!(resize!(r.seed, length(seed)), seed)
313-
dsfmt_init_by_array(r.state, r.seed)
326+
function initstate!(r::MersenneTwister, data::StridedVector, seed)
327+
# we deepcopy `seed` because the caller might mutate it, and it's useful
328+
# to keep it constant inside `MersenneTwister`; but multiple instances
329+
# can share the same seed without any problem (e.g. in `copy`)
330+
r.seed = deepcopy(seed)
331+
dsfmt_init_by_array(r.state, reinterpret(UInt32, data))
314332
reset_caches!(r)
315333
r.adv = 0
316334
r.adv_jump = 0
317335
return r
318336
end
319337

320-
seed!(r::MersenneTwister) = seed!(r, make_seed())
321-
seed!(r::MersenneTwister, n::Integer) = seed!(r, make_seed(n))
338+
# when a seed is not provided, we generate one via `RandomDevice()` in `random_seed()` rather
339+
# than calling directly `initstate!` with `rand(RandomDevice(), UInt32, whatever)` because the
340+
# seed is printed in `show(::MersenneTwister)`, so we need one; the cost of `hash_seed` is a
341+
# small overhead compared to `initstate!`, so this simple solution is fine
342+
seed!(r::MersenneTwister, ::Nothing) = seed!(r, random_seed())
343+
seed!(r::MersenneTwister, seed) = initstate!(r, hash_seed(seed), seed)
322344

323345

324346
### Global RNG
@@ -701,7 +723,7 @@ end
701723
function _randjump(r::MersenneTwister, jumppoly::DSFMT.GF2X)
702724
adv = r.adv
703725
adv_jump = r.adv_jump
704-
s = MersenneTwister(copy(r.seed), DSFMT.dsfmt_jump(r.state, jumppoly))
726+
s = MersenneTwister(r.seed, DSFMT.dsfmt_jump(r.state, jumppoly))
705727
reset_caches!(s)
706728
s.adv = adv
707729
s.adv_jump = adv_jump

stdlib/Random/src/Random.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,11 @@ julia> rand(Xoshiro(), Bool) # not reproducible either
433433
true
434434
```
435435
"""
436-
seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)
436+
seed!(rng::AbstractRNG) = seed!(rng, nothing)
437+
#=
438+
We have this generic definition instead of the alternative option
439+
`seed!(rng::AbstractRNG, ::Nothing) = seed!(rng)`
440+
because it would lead too easily to ambiguities, e.g. when we define `seed!(::Xoshiro, seed)`.
441+
=#
437442

438443
end # module

stdlib/Random/src/Xoshiro.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -116,14 +116,18 @@ rng_native_52(::TaskLocalRNG) = UInt64
116116
## Shared implementation between Xoshiro and TaskLocalRNG
117117

118118
# this variant of setstate! initializes the internal splitmix state, a.k.a. `s4`
119-
@inline initstate!(x::Union{TaskLocalRNG, Xoshiro}, (s0, s1, s2, s3)::NTuple{4, UInt64}) =
119+
@inline function initstate!(x::Union{TaskLocalRNG, Xoshiro}, state)
120+
length(state) == 4 && eltype(state) == UInt64 ||
121+
throw(ArgumentError("initstate! expects a list of 4 `UInt64` values"))
122+
s0, s1, s2, s3 = state
120123
setstate!(x, (s0, s1, s2, s3, 1s0 + 3s1 + 5s2 + 7s3))
124+
end
121125

122126
copy(rng::Union{TaskLocalRNG, Xoshiro}) = Xoshiro(getstate(rng)...)
123127
copy!(dst::Union{TaskLocalRNG, Xoshiro}, src::Union{TaskLocalRNG, Xoshiro}) = setstate!(dst, getstate(src))
124128
==(x::Union{TaskLocalRNG, Xoshiro}, y::Union{TaskLocalRNG, Xoshiro}) = getstate(x) == getstate(y)
125129

126-
function seed!(rng::Union{TaskLocalRNG, Xoshiro})
130+
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, ::Nothing)
127131
# as we get good randomness from RandomDevice, we can skip hashing
128132
rd = RandomDevice()
129133
s0 = rand(rd, UInt64)
@@ -133,14 +137,9 @@ function seed!(rng::Union{TaskLocalRNG, Xoshiro})
133137
initstate!(rng, (s0, s1, s2, s3))
134138
end
135139

136-
function seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Union{Vector{UInt32}, Vector{UInt64}})
137-
c = SHA.SHA2_256_CTX()
138-
SHA.update!(c, reinterpret(UInt8, seed))
139-
s0, s1, s2, s3 = reinterpret(UInt64, SHA.digest!(c))
140-
initstate!(rng, (s0, s1, s2, s3))
141-
end
140+
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed) =
141+
initstate!(rng, reinterpret(UInt64, hash_seed(seed)))
142142

143-
seed!(rng::Union{TaskLocalRNG, Xoshiro}, seed::Integer) = seed!(rng, make_seed(seed))
144143

145144
@inline function rand(x::Union{TaskLocalRNG, Xoshiro}, ::SamplerType{UInt64})
146145
s0, s1, s2, s3 = getstate(x)

stdlib/Random/test/runtests.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -633,9 +633,7 @@ end
633633
let seed = rand(UInt32, 10)
634634
r = MersenneTwister(seed)
635635
@test r.seed == seed && r.seed !== seed
636-
# RNGs do not share their seed in randjump
637636
let r2 = Future.randjump(r, big(10)^20)
638-
@test r.seed !== r2.seed
639637
Random.seed!(r2)
640638
@test seed == r.seed != r2.seed
641639
end

0 commit comments

Comments
 (0)