Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 32 additions & 20 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = loc + jax_op(rng_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
minval, maxval = parameters
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = jax_op(
sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(shape, rate) = parameters
sample = jax_op(rng_key, shape, size, dtype) / rate
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = jax_op(sampling_key, shape, size, dtype) / rate
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters
sample = jax.random.exponential(rng_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -239,13 +247,14 @@ def jax_sample_fn_t(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
df,
loc,
scale,
) = parameters
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -257,9 +266,10 @@ def jax_funcify_choice(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(a, p, replace) = parameters
smpl_value = jax.random.choice(rng_key, a, size, replace, p)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
rng["jax_state"] = rng_key
return (rng, smpl_value)

return sample_fn
Expand All @@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters
sample = jax.random.permutation(rng_key, x)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
sample = jax.random.permutation(sampling_key, x)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn
Expand All @@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):

def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
sample = loc + jax.random.normal(rng_key, size, dtype) * scale
sample = loc + jax.random.normal(sampling_key, size, dtype) * scale
sample_exp = jax.numpy.exp(sample)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0]
rng["jax_state"] = rng_key
return (rng, sample_exp)

return sample_fn
2 changes: 1 addition & 1 deletion tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def test_random_updates(rng_ctor):
[
set_test_value(
at.dvector(),
np.array([1000.0, 2000.0], dtype=np.float64),
np.array([100000.0, 200000.0], dtype=np.float64),
),
],
(2,),
Expand Down