Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,13 +452,11 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
JAXLinker(),
RewriteDatabaseQuery(
include=["fast_run", "jax"],
# TODO: "local_uint_constant_indices" can be reintroduced once https://github.com/google/jax/issues/16836 is fixed.
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
"local_uint_constant_indices",
],
),
)
Expand Down
48 changes: 34 additions & 14 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
assert_size_argument_jax_compatible(node)

def sample_fn(rng, size, dtype, *parameters):
# PyTensor uses empty size to represent size = None
if jax.numpy.asarray(size).shape == (0,):
size = None
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)

else:
Expand Down Expand Up @@ -161,6 +164,8 @@ def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters
if size is None:
size = jax.numpy.broadcast_arrays(loc, scale)[0].shape
sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
Expand All @@ -169,16 +174,31 @@ def sample_fn(rng, size, dtype, *parameters):


@jax_sample_fn.register(ptr.BernoulliRV)
def jax_sample_fn_bernoulli(op):
"""JAX implementation of `BernoulliRV`."""

# We need a separate dispatch, because there is no dtype argument for Bernoulli in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax.random.bernoulli(sampling_key, p, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)

return sample_fn


@jax_sample_fn.register(ptr.CategoricalRV)
def jax_sample_fn_no_dtype(op):
"""Generic JAX implementation of random variables."""
name = op.name
jax_op = getattr(jax.random, name)
def jax_sample_fn_categorical(op):
"""JAX implementation of `CategoricalRV`."""

def sample_fn(rng, size, dtype, *parameters):
# We need a separate dispatch because Categorical expects logits in JAX
def sample_fn(rng, size, dtype, p):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
sample = jax_op(sampling_key, *parameters, shape=size)

logits = jax.scipy.special.logit(p)
sample = jax.random.categorical(sampling_key, logits=logits, shape=size)
rng["jax_state"] = rng_key
return (rng, sample)

Expand Down Expand Up @@ -229,6 +249,8 @@ def jax_sample_fn_shape_scale(op):
def sample_fn(rng, size, dtype, shape, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
if size is None:
size = jax.numpy.broadcast_arrays(shape, scale)[0].shape
sample = jax_op(sampling_key, shape, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
Expand All @@ -240,10 +262,11 @@ def sample_fn(rng, size, dtype, shape, scale):
def jax_sample_fn_exponential(op):
"""JAX implementation of `ExponentialRV`."""

def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, dtype, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters
if size is None:
size = jax.numpy.asarray(scale).shape
sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
Expand All @@ -255,14 +278,11 @@ def sample_fn(rng, size, dtype, *parameters):
def jax_sample_fn_t(op):
"""JAX implementation of `StudentTRV`."""

def sample_fn(rng, size, dtype, *parameters):
def sample_fn(rng, size, dtype, df, loc, scale):
rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(
df,
loc,
scale,
) = parameters
if size is None:
size = jax.numpy.broadcast_arrays(df, loc, scale)[0].shape
sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = rng_key
return (rng, sample)
Expand Down
78 changes: 68 additions & 10 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,34 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
assert test_res.pvalue > 0.01


@pytest.mark.parametrize(
"rv_fn",
[
lambda param_that_implies_size: ptr.normal(
loc=0, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.exponential(
scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.gamma(
shape=1, scale=pt.exp(param_that_implies_size)
),
lambda param_that_implies_size: ptr.t(
df=3, loc=param_that_implies_size, scale=1
),
],
)
def test_size_implied_by_broadcasted_parameters(rv_fn):
# We need a parameter with untyped shapes to test broadcasting does not result in identical draws
param_that_implies_size = pt.matrix("param_that_implies_size", shape=(None, None))

rv = rv_fn(param_that_implies_size)
draws = rv.eval({param_that_implies_size: np.zeros((2, 2))}, mode=jax_mode)

assert draws.shape == (2, 2)
assert np.unique(draws).size == 4


@pytest.mark.parametrize("size", [(), (4,)])
def test_random_bernoulli(size):
rng = shared(np.random.RandomState(123))
Expand Down Expand Up @@ -545,36 +573,66 @@ def test_random_dirichlet(parameter, size):


def test_random_choice():
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was working fine, but not all branches were being tested so I added them

# Elements are picked at equal frequency
num_samples = 10000
# `replace=True` and `p is None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(4), size=num_samples, rng=rng)
g = pt.random.choice(np.arange(4), size=10_000, rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (10_000,)
# Elements are picked at equal frequency
np.testing.assert_allclose(np.mean(samples == 3), 0.25, 2)

# `replace=True` and `p is not None`
rng = shared(np.random.default_rng(123))
g = pt.random.choice(4, p=np.array([0.0, 0.5, 0.0, 0.5]), size=(5, 2), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(np.sum(samples == 3) / num_samples, 0.25, 2)
assert samples.shape == (5, 2)
# Only odd numbers are picked
assert np.all(samples % 2 == 1)

# `replace=False` produces unique results
# `replace=False` and `p is None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(100), replace=False, size=99, rng=rng)
g = pt.random.choice(np.arange(100), replace=False, size=(2, 49), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert len(np.unique(samples)) == 99
assert samples.shape == (2, 49)
# Elements are unique
assert len(np.unique(samples)) == 98

# We can pass an array with probabilities
# `replace=False` and `p is not None`
rng = shared(np.random.RandomState(123))
g = pt.random.choice(np.arange(3), p=np.array([1.0, 0.0, 0.0]), size=10, rng=rng)
g = pt.random.choice(
8,
p=np.array([0.25, 0, 0.25, 0, 0.25, 0, 0.25, 0]),
size=3,
rng=rng,
replace=False,
)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
np.testing.assert_allclose(samples, np.zeros(10))
assert samples.shape == (3,)
# Elements are unique
assert len(np.unique(samples)) == 3
# Only even numbers are picked
assert np.all(samples % 2 == 0)


def test_random_categorical():
rng = shared(np.random.RandomState(123))
g = pt.random.categorical(0.25 * np.ones(4), size=(10000, 4), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (10000, 4)
np.testing.assert_allclose(samples.mean(axis=0), 6 / 4, 1)

# Test zero probabilities
g = pt.random.categorical([0, 0.5, 0, 0.5], size=(1000,), rng=rng)
g_fn = compile_random_function([], g, mode=jax_mode)
samples = g_fn()
assert samples.shape == (1000,)
assert np.all(samples % 2 == 1)


def test_random_permutation():
array = np.arange(4)
Expand Down