Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions pytensor/link/jax/dispatch/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
)

import pytensor.tensor.random.basic as ptr
from pytensor.graph import Constant
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
from pytensor.link.jax.dispatch.shape import JAXShapeTuple
from pytensor.tensor.shape import Shape, Shape_i
Expand Down Expand Up @@ -91,15 +92,26 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
"""JAX implementation of random variables."""
rv = node.outputs[1]
out_dtype = rv.type.dtype
out_size = rv.type.shape
static_shape = rv.type.shape

batch_ndim = op.batch_ndim(node)
out_size = node.default_output().type.shape[:batch_ndim]

# Try to pass static size directly to JAX
static_size = static_shape[:batch_ndim]
if None in static_size:
# Sometimes size can be constant folded during rewrites,
# without the RandomVariable node being updated with new static types
size_param = node.inputs[1]
if isinstance(size_param, Constant):
size_tuple = tuple(size_param.data)
# PyTensor uses empty size to represent size = None
if len(size_tuple):
static_size = tuple(size_param.data)

# If one dimension has unknown size, either the size is determined
# by a `Shape` operator in which case JAX will compile, or it is
# not and we fail gracefully.
if None in out_size:
if None in static_size:
assert_size_argument_jax_compatible(node)

def sample_fn(rng, size, dtype, *parameters):
Expand All @@ -111,7 +123,9 @@ def sample_fn(rng, size, dtype, *parameters):
else:

def sample_fn(rng, size, dtype, *parameters):
return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters)
return jax_sample_fn(op, node=node)(
rng, static_size, out_dtype, *parameters
)

return sample_fn

Expand Down
26 changes: 24 additions & 2 deletions tests/link/jax/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytensor
import pytensor.tensor as pt
import pytensor.tensor.random.basic as ptr
from pytensor import clone_replace
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.graph.basic import Constant
Expand All @@ -26,11 +27,11 @@
from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402


def compile_random_function(*args, **kwargs):
def compile_random_function(*args, mode="JAX", **kwargs):
with pytest.warns(
UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used"
):
return function(*args, **kwargs)
return function(*args, mode=mode, **kwargs)


def test_random_RandomStream():
Expand Down Expand Up @@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input():
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
jax_fn = compile_random_function([size_pt], out, mode=jax_mode)
assert jax_fn(10).shape == (10,)


def test_constant_shape_after_graph_rewriting():
size = pt.vector("size", shape=(2,), dtype=int)
x = pt.random.normal(size=size)
assert x.type.shape == (None, None)

with pytest.raises(TypeError):
compile_random_function([size], x)([2, 5])

# Rebuild with strict=False so output type is not updated
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
assert new_x.type.shape == (None, None)
assert compile_random_function([], new_x)().shape == (2, 5)

# Rebuild with strict=True, so output type is updated
# This uses a different path in the dispatch implementation
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
assert new_x.type.shape == (2, 5)
assert compile_random_function([], new_x)().shape == (2, 5)