Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 72 additions & 40 deletions pytensor/link/jax/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@
Second,
Sub,
)
from pytensor.scalar.math import Erf, Erfc, Erfcinv, Erfcx, Erfinv, Iv, Log1mexp, Psi
from pytensor.scalar.math import (
Erf,
Erfc,
Erfcinv,
Erfcx,
Erfinv,
Iv,
Log1mexp,
Psi,
TriGamma,
)


def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Callable:
Expand All @@ -37,7 +47,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))


def check_if_inputs_scalars(node):
def all_inputs_are_scalar(node):
"""Check whether all the inputs of an `Elemwise` are scalar values.

`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
Expand All @@ -62,54 +72,68 @@ def check_if_inputs_scalars(node):

@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, node, **kwargs):
"""Return JAX function that implements the same computation as the Scalar Op.

This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""

# We dispatch some PyTensor operators to Python operators
# whenever the inputs are all scalars.
are_inputs_scalars = check_if_inputs_scalars(node)
if are_inputs_scalars:
elemwise = elemwise_scalar(op)
if elemwise is not None:
return elemwise
func_name = op.nfunc_spec[0]
if all_inputs_are_scalar(node):
jax_func = jax_funcify_scalar_op_via_py_operators(op)
if jax_func is not None:
return jax_func

nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")

func_name = nfunc_spec[0]
if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
jax_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)

if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# PyTensor `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)

def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)

return elemwise
else:
return jnp_func
jax_func = getattr(jnp, func_name)

if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
jax_variadic_func = getattr(jnp, op.nfunc_variadic, None)
if not jax_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)

def jax_func(*args):
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)

return jax_func


@functools.singledispatch
def elemwise_scalar(op):
def jax_funcify_scalar_op_via_py_operators(op):
"""Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays.

Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats),
which can perform better with Python operators, and more importantly, avoid upcasting to array types
not supported by some JAX functions.
"""
return None


@elemwise_scalar.register(Add)
def elemwise_scalar_add(op):
@jax_funcify_scalar_op_via_py_operators.register(Add)
def jax_funcify_scalar_Add(op):
def elemwise(*inputs):
return sum(inputs)

return elemwise


@elemwise_scalar.register(Mul)
def elemwise_scalar_mul(op):
@jax_funcify_scalar_op_via_py_operators.register(Mul)
def jax_funcify_scalar_Mul(op):
import operator
from functools import reduce

Expand All @@ -119,24 +143,24 @@ def elemwise(*inputs):
return elemwise


@elemwise_scalar.register(Sub)
def elemwise_scalar_sub(op):
@jax_funcify_scalar_op_via_py_operators.register(Sub)
def jax_funcify_scalar_Sub(op):
def elemwise(x, y):
return x - y

return elemwise


@elemwise_scalar.register(IntDiv)
def elemwise_scalar_intdiv(op):
@jax_funcify_scalar_op_via_py_operators.register(IntDiv)
def jax_funcify_scalar_IntDiv(op):
def elemwise(x, y):
return x // y

return elemwise


@elemwise_scalar.register(Mod)
def elemwise_scalar_mod(op):
@jax_funcify_scalar_op_via_py_operators.register(Mod)
def jax_funcify_scalar_Mod(op):
def elemwise(x, y):
return x % y

Expand Down Expand Up @@ -261,6 +285,14 @@ def psi(x):
return psi


@jax_funcify.register(TriGamma)
def jax_funcify_TriGamma(op, node, **kwargs):
def tri_gamma(x):
return jax.scipy.special.polygamma(1, x)

return tri_gamma


@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
Expand Down
8 changes: 8 additions & 0 deletions tests/link/jax/test_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
psi,
sigmoid,
softplus,
tri_gamma,
)
from pytensor.tensor.type import matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
Expand Down Expand Up @@ -170,6 +171,13 @@ def test_psi():
compare_jax_and_py(fg, [3.0])


def test_tri_gamma():
x = vector("x", dtype="float64")
out = tri_gamma(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [np.array([3.0, 5.0])])


def test_log1mexp():
x = vector("x")
out = log1mexp(x)
Expand Down