Skip to content

Commit c86467a

Browse files
authored
refactor: rework TracedRNG to be similar to other types (#448)
1 parent 0343a39 commit c86467a

File tree

5 files changed

+66
-29
lines changed

5 files changed

+66
-29
lines changed

docs/src/api/internal.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ These functions are not part of the public API and are subject to change at any
88

99
```@docs
1010
Reactant.REDUB_ARGUMENTS_NAME
11-
Reactant.within_reactant_interpreter
1211
```

src/Overlay.jl

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
# correctly. Once that (https://github.com/timholy/Revise.jl/issues/646) is resolved
44
# we should move all the reactant_overrides to relevant files.
55

6-
# Helper Function to determine if we are inside the ReactantInterpreter
7-
"""
8-
within_reactant_interpreter()
9-
10-
Returns `true` if we are currently inside the ReactantInterpreter.
11-
"""
12-
@noinline within_reactant_interpreter() = false
13-
@reactant_overlay @noinline within_reactant_interpreter() = true
14-
156
# Compiling within a compile should return simply the original function
167
@reactant_overlay function Compiler.compile(
178
f, args; client=nothing, optimize=true, sync=false
@@ -37,6 +28,12 @@ end
3728
return call_with_reactant(TracedRandom.default_rng)
3829
end
3930

31+
@reactant_overlay @noinline function TracedRandom.default_rng()
32+
return TracedRNG(
33+
TracedUtils.promote_to(TracedRArray{UInt64,1}, TracedRandom.make_seed()), "DEFAULT"
34+
)
35+
end
36+
4037
## Only problematic edge case here is the direct `<randfun!>(rng, A::AbstractArray)` call
4138
## We can't directly overlay that call without breaking the semantics of inplace update
4239
for randfun in (:rand, :randn, :randexp)

src/Reactant.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,13 @@ include("TracedRArray.jl")
175175

176176
include("ConcreteRArray.jl")
177177

178+
mutable struct ConcreteRNG <: Random.AbstractRNG
179+
seed::ConcreteRArray{UInt64,1}
180+
const algorithm::String
181+
end
182+
178183
mutable struct TracedRNG <: Random.AbstractRNG
179-
seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
184+
seed::TracedRArray{UInt64,1}
180185
const algorithm::String
181186
end
182187

src/Tracing.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,16 @@ function traced_type(
227227
end
228228
end
229229

230+
function traced_type(::Type{<:ConcreteRNG}, seen, ::Val{mode}, track_numbers) where {mode}
231+
if mode == ConcreteToTraced
232+
return TracedRNG
233+
elseif mode == TracedToConcrete
234+
return ConcreteRNG
235+
else
236+
throw("Unsupported mode: $mode")
237+
end
238+
end
239+
230240
function traced_type(
231241
::Type{T}, seen::ST, ::Val{mode}, track_numbers
232242
) where {ST,T<:TracedType,mode}
@@ -246,6 +256,18 @@ function traced_type(
246256
end
247257
end
248258

259+
function traced_type(::Type{T}, seen, ::Val{mode}, track_numbers) where {T<:TracedRNG,mode}
260+
if mode == ConcreteToTraced
261+
throw("TracedRNG cannot be traced")
262+
elseif mode == TracedToConcrete
263+
return ConcreteRNG
264+
elseif mode == TracedTrack || mode == TracedSetPath
265+
return T
266+
else
267+
throw("Unsupported mode: $mode")
268+
end
269+
end
270+
249271
function traced_type(::Type{T}, seen, mode, track_numbers) where {T<:XLAArray}
250272
throw("XLA $T array cannot be traced")
251273
end

src/stdlibs/Random.jl

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,15 @@ using ..Reactant:
88
Reactant,
99
TracedRArray,
1010
TracedRNumber,
11+
ConcreteRNG,
1112
TracedRNG,
1213
AnyTracedRArray,
1314
Reactant,
1415
TracedUtils,
1516
Ops,
16-
ConcreteRArray
17+
ConcreteRArray,
18+
ConcreteRNumber,
19+
unwrapped_eltype
1720
using Random: Random, AbstractRNG
1821

1922
@noinline function make_seed(rng::AbstractRNG=Random.RandomDevice())
@@ -25,44 +28,55 @@ using Random: Random, AbstractRNG
2528
return seed
2629
end
2730

28-
function Random.seed!(rng::TracedRNG, seed::Number)
31+
@noinline function Random.seed!(rng::TracedRNG, seed::Number)
2932
if seed isa TracedRNumber
3033
error("Passing in `TracedRNumber` as a seed is not supported. Please pass in a \
3134
`TracedRArray` of the appropriate size instead.")
3235
end
3336

3437
seed = reinterpret(UInt64, Random.hash_seed(seed))
35-
seed = if Reactant.within_reactant_interpreter()
36-
TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
37-
else
38-
ConcreteRArray(seed[1:length(rng.seed)])
39-
end
40-
return Random.seed!(rng, seed)
38+
return Random.seed!(
39+
rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed[1:length(rng.seed)])
40+
)
4141
end
4242

43-
function Random.seed!(rng::TracedRNG, seed::AbstractArray{<:Integer,1})
43+
@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{<:Integer})
4444
return Random.seed!(rng, UInt64.(seed))
4545
end
4646

47-
function Random.seed!(rng::TracedRNG, seed::AbstractArray{UInt64,1})
47+
@noinline function Random.seed!(rng::TracedRNG, seed::AbstractVector{UInt64})
4848
return Random.seed!(rng, TracedUtils.promote_to(TracedRArray{UInt64,1}, seed))
4949
end
5050

51-
function Random.seed!(
52-
rng::TracedRNG, seed::Union{ConcreteRArray{UInt64,1},TracedRArray{UInt64,1}}
53-
)
51+
@noinline function Random.seed!(rng::TracedRNG, seed::TracedRArray{UInt64,1})
5452
rng.seed = seed
5553
return rng
5654
end
5755

58-
@noinline TracedRNG() = TracedRNG(ConcreteRArray(make_seed()))
59-
@noinline TracedRNG(seed::ConcreteRArray{UInt64,1}) = TracedRNG(seed, "DEFAULT")
56+
@noinline function Random.seed!(rng::ConcreteRNG, seed::Number)
57+
seed isa ConcreteRNumber && (seed = unwrapped_eltype(seed)(seed))
58+
seed = reinterpret(UInt64, Random.hash_seed(seed))
59+
return Random.seed!(rng, ConcreteRArray(seed))
60+
end
61+
62+
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{<:Integer})
63+
return Random.seed!(rng, seed)
64+
end
6065

61-
@noinline function default_rng()
62-
Reactant.within_reactant_interpreter() || return TracedRNG()
63-
return TracedRNG(TracedUtils.promote_to(TracedRArray{UInt64,1}, make_seed()), "DEFAULT")
66+
@noinline function Random.seed!(rng::ConcreteRNG, seed::AbstractVector{UInt64})
67+
return Random.seed!(rng, ConcreteRArray(seed))
6468
end
6569

70+
@noinline function Random.seed!(rng::ConcreteRNG, seed::ConcreteRArray{UInt64,1})
71+
rng.seed = seed
72+
return rng
73+
end
74+
75+
@noinline ConcreteRNG() = ConcreteRNG(ConcreteRArray(make_seed()))
76+
@noinline ConcreteRNG(seed::ConcreteRArray{UInt64,1}) = ConcreteRNG(seed, "DEFAULT")
77+
78+
@noinline default_rng() = ConcreteRNG()
79+
6680
@noinline rng_algorithm(rng::TracedRNG) = rng.algorithm
6781
@noinline rng_algorithm(::AbstractRNG) = "DEFAULT"
6882

0 commit comments

Comments
 (0)