diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 3c82e6d6f4..8709633da1 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -20,7 +20,17 @@ Second, Sub, ) -from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi +from pytensor.scalar.math import ( + Erf, + Erfc, + Erfcinv, + Erfcx, + Erfinv, + Iv, + Log1mexp, + Psi, + TriGamma, +) def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable: @@ -37,7 +47,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name)) -def check_if_inputs_scalars(node): +def all_inputs_are_scalar(node): """Check whether all the inputs of an `Elemwise` are scalar values. `jax.lax` or `jax.numpy` functions systematically return `TracedArrays`, @@ -62,54 +72,68 @@ def check_if_inputs_scalars(node): @jax_funcify.register(ScalarOp) def jax_funcify_ScalarOp(op, node, **kwargs): + """Return JAX function that implements the same computation as the Scalar Op. + + This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does, + even though it's dispatched on the Scalar Op. + """ + # We dispatch some PyTensor operators to Python operators # whenever the inputs are all scalars. - are_inputs_scalars = check_if_inputs_scalars(node) - if are_inputs_scalars: - elemwise = elemwise_scalar(op) - if elemwise is not None: - return elemwise - func_name = op.nfunc_spec[0] + if all_inputs_are_scalar(node): + jax_func = jax_funcify_scalar_op_via_py_operators(op) + if jax_func is not None: + return jax_func + + nfunc_spec = getattr(op, "nfunc_spec", None) + if nfunc_spec is None: + raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}") + + func_name = nfunc_spec[0] if "." in func_name: - jnp_func = functools.reduce(getattr, [jax] + func_name.split(".")) + jax_func = functools.reduce(getattr, [jax] + func_name.split(".")) else: - jnp_func = getattr(jnp, func_name) - - if hasattr(op, "nfunc_variadic"): - # These are special cases that handle invalid arities due to the broken - # PyTensor `Op` type contract (e.g. binary `Op`s that also function as - # their own variadic counterparts--even when those counterparts already - # exist as independent `Op`s). - jax_variadic_func = getattr(jnp, op.nfunc_variadic) - - def elemwise(*args): - if len(args) > op.nfunc_spec[1]: - return jax_variadic_func( - jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0 - ) - else: - return jnp_func(*args) - - return elemwise - else: - return jnp_func + jax_func = getattr(jnp, func_name) + + if len(node.inputs) > op.nfunc_spec[1]: + # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function, + # even though the base Op from `func_name` is specified as a binary Op. + # This happens with `Add`, which can work as a `Sum` for multiple scalars. + jax_variadic_func = getattr(jnp, op.nfunc_variadic, None) + if not jax_variadic_func: + raise NotImplementedError( + f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs" + ) + + def jax_func(*args): + return jax_variadic_func( + jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0 + ) + + return jax_func @functools.singledispatch -def elemwise_scalar(op): +def jax_funcify_scalar_op_via_py_operators(op): + """Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays. + + Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats), + which can perform better with Python operators, and more importantly, avoid upcasting to array types + not supported by some JAX functions. + """ return None -@elemwise_scalar.register(Add) -def elemwise_scalar_add(op): +@jax_funcify_scalar_op_via_py_operators.register(Add) +def jax_funcify_scalar_Add(op): def elemwise(*inputs): return sum(inputs) return elemwise -@elemwise_scalar.register(Mul) -def elemwise_scalar_mul(op): +@jax_funcify_scalar_op_via_py_operators.register(Mul) +def jax_funcify_scalar_Mul(op): import operator from functools import reduce @@ -119,24 +143,24 @@ def elemwise(*inputs): return elemwise -@elemwise_scalar.register(Sub) -def elemwise_scalar_sub(op): +@jax_funcify_scalar_op_via_py_operators.register(Sub) +def jax_funcify_scalar_Sub(op): def elemwise(x, y): return x - y return elemwise -@elemwise_scalar.register(IntDiv) -def elemwise_scalar_intdiv(op): +@jax_funcify_scalar_op_via_py_operators.register(IntDiv) +def jax_funcify_scalar_IntDiv(op): def elemwise(x, y): return x // y return elemwise -@elemwise_scalar.register(Mod) -def elemwise_scalar_mod(op): +@jax_funcify_scalar_op_via_py_operators.register(Mod) +def jax_funcify_scalar_Mod(op): def elemwise(x, y): return x % y @@ -261,6 +285,14 @@ def psi(x): return psi +@jax_funcify.register(TriGamma) +def jax_funcify_TriGamma(op, node, **kwargs): + def tri_gamma(x): + return jax.scipy.special.polygamma(1, x) + + return tri_gamma + + @jax_funcify.register(Softplus) def jax_funcify_Softplus(op, **kwargs): def softplus(x): diff --git a/tests/link/jax/test_scalar.py b/tests/link/jax/test_scalar.py index 9691dd535c..837c07085f 100644 --- a/tests/link/jax/test_scalar.py +++ b/tests/link/jax/test_scalar.py @@ -23,6 +23,7 @@ psi, sigmoid, softplus, + tri_gamma, ) from pytensor.tensor.type import matrix, scalar, vector from tests.link.jax.test_basic import compare_jax_and_py @@ -170,6 +171,13 @@ def test_psi(): compare_jax_and_py(fg, [3.0]) +def test_tri_gamma(): + x = vector("x", dtype="float64") + out = tri_gamma(x) + fg = FunctionGraph([x], [out]) + compare_jax_and_py(fg, [np.array([3.0, 5.0])]) + + def test_log1mexp(): x = vector("x") out = log1mexp(x)