@@ -8,12 +8,15 @@ using ..Reactant:
8
8
Reactant,
9
9
TracedRArray,
10
10
TracedRNumber,
11
+ ConcreteRNG,
11
12
TracedRNG,
12
13
AnyTracedRArray,
13
14
Reactant,
14
15
TracedUtils,
15
16
Ops,
16
- ConcreteRArray
17
+ ConcreteRArray,
18
+ ConcreteRNumber,
19
+ unwrapped_eltype
17
20
using Random: Random, AbstractRNG
18
21
19
22
@noinline function make_seed (rng:: AbstractRNG = Random. RandomDevice ())
@@ -25,44 +28,55 @@ using Random: Random, AbstractRNG
25
28
return seed
26
29
end
27
30
28
- function Random. seed! (rng:: TracedRNG , seed:: Number )
31
+ @noinline function Random. seed! (rng:: TracedRNG , seed:: Number )
29
32
if seed isa TracedRNumber
30
33
error (" Passing in `TracedRNumber` as a seed is not supported. Please pass in a \
31
34
`TracedRArray` of the appropriate size instead." )
32
35
end
33
36
34
37
seed = reinterpret (UInt64, Random. hash_seed (seed))
35
- seed = if Reactant. within_reactant_interpreter ()
36
- TracedUtils. promote_to (TracedRArray{UInt64,1 }, seed[1 : length (rng. seed)])
37
- else
38
- ConcreteRArray (seed[1 : length (rng. seed)])
39
- end
40
- return Random. seed! (rng, seed)
38
+ return Random. seed! (
39
+ rng, TracedUtils. promote_to (TracedRArray{UInt64,1 }, seed[1 : length (rng. seed)])
40
+ )
41
41
end
42
42
43
- function Random. seed! (rng:: TracedRNG , seed:: AbstractArray {<:Integer,1 } )
43
+ @noinline function Random. seed! (rng:: TracedRNG , seed:: AbstractVector {<:Integer} )
44
44
return Random. seed! (rng, UInt64 .(seed))
45
45
end
46
46
47
- function Random. seed! (rng:: TracedRNG , seed:: AbstractArray {UInt64,1 } )
47
+ @noinline function Random. seed! (rng:: TracedRNG , seed:: AbstractVector {UInt64} )
48
48
return Random. seed! (rng, TracedUtils. promote_to (TracedRArray{UInt64,1 }, seed))
49
49
end
50
50
51
- function Random. seed! (
52
- rng:: TracedRNG , seed:: Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
53
- )
51
+ @noinline function Random. seed! (rng:: TracedRNG , seed:: TracedRArray{UInt64,1} )
54
52
rng. seed = seed
55
53
return rng
56
54
end
57
55
58
- @noinline TracedRNG () = TracedRNG (ConcreteRArray (make_seed ()))
59
- @noinline TracedRNG (seed:: ConcreteRArray{UInt64,1} ) = TracedRNG (seed, " DEFAULT" )
56
+ @noinline function Random. seed! (rng:: ConcreteRNG , seed:: Number )
57
+ seed isa ConcreteRNumber && (seed = unwrapped_eltype (seed)(seed))
58
+ seed = reinterpret (UInt64, Random. hash_seed (seed))
59
+ return Random. seed! (rng, ConcreteRArray (seed))
60
+ end
61
+
62
+ @noinline function Random. seed! (rng:: ConcreteRNG , seed:: AbstractVector{<:Integer} )
63
+ return Random. seed! (rng, seed)
64
+ end
60
65
61
- @noinline function default_rng ()
62
- Reactant. within_reactant_interpreter () || return TracedRNG ()
63
- return TracedRNG (TracedUtils. promote_to (TracedRArray{UInt64,1 }, make_seed ()), " DEFAULT" )
66
+ @noinline function Random. seed! (rng:: ConcreteRNG , seed:: AbstractVector{UInt64} )
67
+ return Random. seed! (rng, ConcreteRArray (seed))
64
68
end
65
69
70
+ @noinline function Random. seed! (rng:: ConcreteRNG , seed:: ConcreteRArray{UInt64,1} )
71
+ rng. seed = seed
72
+ return rng
73
+ end
74
+
75
+ @noinline ConcreteRNG () = ConcreteRNG (ConcreteRArray (make_seed ()))
76
+ @noinline ConcreteRNG (seed:: ConcreteRArray{UInt64,1} ) = ConcreteRNG (seed, " DEFAULT" )
77
+
78
+ @noinline default_rng () = ConcreteRNG ()
79
+
66
80
@noinline rng_algorithm (rng:: TracedRNG ) = rng. algorithm
67
81
@noinline rng_algorithm (:: AbstractRNG ) = " DEFAULT"
68
82
0 commit comments