diff --git a/pytensor/link/jax/dispatch/random.py b/pytensor/link/jax/dispatch/random.py index f3cde4f943..824d728faf 100644 --- a/pytensor/link/jax/dispatch/random.py +++ b/pytensor/link/jax/dispatch/random.py @@ -8,6 +8,7 @@ ) import pytensor.tensor.random.basic as ptr +from pytensor.graph import Constant from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.shape import JAXShapeTuple from pytensor.tensor.shape import Shape, Shape_i @@ -91,15 +92,26 @@ def jax_funcify_RandomVariable(op, node, **kwargs): """JAX implementation of random variables.""" rv = node.outputs[1] out_dtype = rv.type.dtype - out_size = rv.type.shape + static_shape = rv.type.shape batch_ndim = op.batch_ndim(node) - out_size = node.default_output().type.shape[:batch_ndim] + + # Try to pass static size directly to JAX + static_size = static_shape[:batch_ndim] + if None in static_size: + # Sometimes size can be constant folded during rewrites, + # without the RandomVariable node being updated with new static types + size_param = node.inputs[1] + if isinstance(size_param, Constant): + size_tuple = tuple(size_param.data) + # PyTensor uses empty size to represent size = None + if len(size_tuple): + static_size = tuple(size_param.data) # If one dimension has unknown size, either the size is determined # by a `Shape` operator in which case JAX will compile, or it is # not and we fail gracefully. - if None in out_size: + if None in static_size: assert_size_argument_jax_compatible(node) def sample_fn(rng, size, dtype, *parameters): @@ -111,7 +123,9 @@ def sample_fn(rng, size, dtype, *parameters): else: def sample_fn(rng, size, dtype, *parameters): - return jax_sample_fn(op, node=node)(rng, out_size, out_dtype, *parameters) + return jax_sample_fn(op, node=node)( + rng, static_size, out_dtype, *parameters + ) return sample_fn diff --git a/tests/link/jax/test_random.py b/tests/link/jax/test_random.py index 22c3af7d83..c4d0dacdf7 100644 --- a/tests/link/jax/test_random.py +++ b/tests/link/jax/test_random.py @@ -5,6 +5,7 @@ import pytensor import pytensor.tensor as pt import pytensor.tensor.random.basic as ptr +from pytensor import clone_replace from pytensor.compile.function import function from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.graph.basic import Constant @@ -26,11 +27,11 @@ from pytensor.link.jax.dispatch.random import numpyro_available # noqa: E402 -def compile_random_function(*args, **kwargs): +def compile_random_function(*args, mode="JAX", **kwargs): with pytest.warns( UserWarning, match=r"The RandomType SharedVariables \[.+\] will not be used" ): - return function(*args, **kwargs) + return function(*args, mode=mode, **kwargs) def test_random_RandomStream(): @@ -896,3 +897,24 @@ def test_random_concrete_shape_graph_input(): out = pt.random.normal(0, 1, size=size_pt, rng=rng) jax_fn = compile_random_function([size_pt], out, mode=jax_mode) assert jax_fn(10).shape == (10,) + + +def test_constant_shape_after_graph_rewriting(): + size = pt.vector("size", shape=(2,), dtype=int) + x = pt.random.normal(size=size) + assert x.type.shape == (None, None) + + with pytest.raises(TypeError): + compile_random_function([size], x)([2, 5]) + + # Rebuild with strict=False so output type is not updated + # This reflects cases where size is constant folded during rewrites but the RV node is not recreated + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True) + assert new_x.type.shape == (None, None) + assert compile_random_function([], new_x)().shape == (2, 5) + + # Rebuild with strict=True, so output type is updated + # This uses a different path in the dispatch implementation + new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False) + assert new_x.type.shape == (2, 5) + assert compile_random_function([], new_x)().shape == (2, 5)