Skip to content

Commit 7223ddf

Browse files
committed
Implement optimal uniform random number generator using the method proposed in swiftlang/swift#39143 based on OpenSSL's implementation of it in https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
1 parent d84d3ad commit 7223ddf

File tree

2 files changed

+90
-1
lines changed

2 files changed

+90
-1
lines changed

base/partr.jl

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,95 @@ const heap_d = UInt32(8)
1919
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
2020
const heaps_lock = [SpinLock(), SpinLock()]
2121

22+
"""
23+
cong(max::UInt32)
24+
25+
Return a random UInt32 in the range `1:max` except if max is 0, in that case return 0.
26+
"""
27+
cong(max::UInt32) = iszero(max) ? UInt32(0) : jl_rand_ptls(max) + UInt32(1) #TODO: make sure users don't use 0 and remove this check
28+
29+
30+
"""
31+
jl_rand_ptls(max::UInt32)
32+
33+
Return a random UInt32 in the range `0:max-1` using the thread-local RNG
34+
state. Max must be greater than 0.
35+
"""
36+
function jl_rand_ptls(max::UInt32)
37+
ptls = Base.unsafe_convert(Ptr{UInt64}, Core.getptls())
38+
rngseed = Base.unsafe_load(ptls, 2)
39+
val, seed = rand_uniform_max_int32(max, rngseed)
40+
Base.unsafe_store!(ptls, seed, 2)
41+
return val % UInt32
42+
end
43+
44+
# This implementation is based on OpenSSLs implementation of rand_uniform
45+
# https://github.com/openssl/openssl/blob/1d2cbd9b5a126189d5e9bc78a3bdb9709427d02b/crypto/rand/rand_uniform.c#L13-L99
46+
# Comments are vendored from their implemantation as well.
47+
# For the original developer check the PR to swift https://github.com/apple/swift/pull/39143.
48+
49+
# Essentially it boils down to incrementally generating a fixed point
50+
# number on the interval [0, 1) and multiplying this number by the upper
51+
# range limit. Once it is certain what the fractional part contributes to
52+
# the integral part of the product, the algorithm has produced a definitive
53+
# result.
54+
"""
55+
rand_uniform_max_int32(max::UInt32, seed::UInt64)
56+
57+
Return a random UInt32 in the range `0:max-1` using the given seed.
58+
Max must be greater than 0.
59+
"""
60+
function rand_uniform_max_int32(max::UInt32, seed::UInt64)
61+
if max == UInt32(1)
62+
return UInt32(0), seed
63+
end
2264

23-
cong(max::UInt32) = iszero(max) ? UInt32(0) : ccall(:jl_rand_ptls, UInt32, (UInt32,), max) + UInt32(1)
65+
# We are generating a fixed point number on the interval [0, 1).
66+
# Multiplying this by the range gives us a number on [0, upper).
67+
# The high word of the multiplication result represents the integral
68+
# part we want. The lower word is the fractional part. We can early exit if
69+
# if the fractional part is small enough that no carry from the next lower
70+
# word can cause an overflow and carry into the integer part. This
71+
# happens when the fractional part is bounded by 2^32 - upper which
72+
# can be simplified to just -upper (as an unsigned integer).
73+
seed = UInt64(69069) * seed + UInt64(362437)
74+
prod = (UInt64(max)) * (seed % UInt32) # 64 bit product
75+
i = unsafe_trunc(UInt32, prod >> 32) # integral part
76+
f = unsafe_trunc(UInt32, (prod & 0xffffffff)) # fractional part
77+
if (f <= (UInt32(1) + ~max)) # likely
78+
return unsafe_trunc(UInt32, i), seed
79+
end
2480

81+
# We're in the position where the carry from the next word *might* cause
82+
# a carry to the integral part. The process here is to generate the next
83+
# word, multiply it by the range and add that to the current word. If
84+
# it overflows, the carry propagates to the integer part (return i+1).
85+
# If it can no longer overflow regardless of further lower order bits,
86+
# we are done (return i). If there is still a chance of overflow, we
87+
# repeat the process with the next lower word.
88+
#
89+
# Each *bit* of randomness has a probability of one half of terminating
90+
# this process, so each each word beyond the first has a probability
91+
# of 2^-32 of not terminating the process. That is, we're extremely
92+
# likely to stop very rapidly.
93+
for _ in 1:10
94+
seed = UInt64(69069) * seed + UInt64(362437)
95+
prod = (UInt64(max)) * (seed % UInt32)
96+
f2 = unsafe_trunc(UInt32,prod >> 32) # extra fractional part
97+
f *= f2 % UInt32
98+
if f < f2
99+
return i + UInt32(1), seed
100+
end
101+
if (f != 0xffffffff) #unlikely
102+
return i, seed
103+
end
104+
f = prod & 0xffffffff % UInt32
105+
end
106+
# If we get here, we've consumed 32 * max_followup_iterations + 32 bits
107+
# with no firm decision, this gives a bias with probability < 2^-(32*n),
108+
# which is likely acceptable.
109+
return i, seed
110+
end
25111

26112
function multiq_sift_up(heap::taskheap, idx::Int32)
27113
while idx > Int32(1)

src/julia_internal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1306,6 +1306,9 @@ JL_DLLEXPORT size_t jl_maxrss(void);
13061306
// congruential random number generator
13071307
// for a small amount of thread-local randomness
13081308

1309+
//TODO: utilize https://github.com/openssl/openssl/blob/master/crypto/rand/rand_uniform.c#L13-L99
1310+
// for better performance, it does however require making users expect a 32bit random number.
1311+
13091312
STATIC_INLINE uint64_t cong(uint64_t max, uint64_t *seed) JL_NOTSAFEPOINT
13101313
{
13111314
if (max < 2)

0 commit comments

Comments
 (0)