diff --git a/pytensor/link/jax/dispatch/scalar.py b/pytensor/link/jax/dispatch/scalar.py index 6f63474de4..3c82e6d6f4 100644 --- a/pytensor/link/jax/dispatch/scalar.py +++ b/pytensor/link/jax/dispatch/scalar.py @@ -62,8 +62,6 @@ def check_if_inputs_scalars(node): @jax_funcify.register(ScalarOp) def jax_funcify_ScalarOp(op, node, **kwargs): - func_name = op.nfunc_spec[0] - # We dispatch some PyTensor operators to Python operators # whenever the inputs are all scalars. are_inputs_scalars = check_if_inputs_scalars(node) @@ -71,7 +69,7 @@ def jax_funcify_ScalarOp(op, node, **kwargs): elemwise = elemwise_scalar(op) if elemwise is not None: return elemwise - + func_name = op.nfunc_spec[0] if "." in func_name: jnp_func = functools.reduce(getattr, [jax] + func_name.split(".")) else: