@@ -37,7 +37,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca
3737 return typing .cast (Callable , getattr (tfp_jax_math , jax_op_name ))
3838
3939
40- def check_if_inputs_scalars (node ):
40+ def all_inputs_are_scalar (node ):
4141 """Check whether all the inputs of an `Elemwise` are scalar values.
4242
4343 `jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
@@ -62,54 +62,68 @@ def check_if_inputs_scalars(node):
6262
6363@jax_funcify .register (ScalarOp )
6464def jax_funcify_ScalarOp (op , node , ** kwargs ):
65+ """Return JAX function that implements the same computation as the Scalar Op.
66+
67+ This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does,
68+ even though it's dispatched on the Scalar Op.
69+ """
70+
6571 # We dispatch some PyTensor operators to Python operators
6672 # whenever the inputs are all scalars.
67- are_inputs_scalars = check_if_inputs_scalars (node )
68- if are_inputs_scalars :
69- elemwise = elemwise_scalar (op )
70- if elemwise is not None :
71- return elemwise
72- func_name = op .nfunc_spec [0 ]
73+ if all_inputs_are_scalar (node ):
74+ jax_func = jax_funcify_scalar_op_via_py_operators (op )
75+ if jax_func is not None :
76+ return jax_func
77+
78+ nfunc_spec = getattr (op , "nfunc_spec" , None )
79+ if nfunc_spec is None :
80+ raise NotImplementedError (f"Dispatch not implemented for Scalar Op { op } " )
81+
82+ func_name = nfunc_spec [0 ]
7383 if "." in func_name :
74- jnp_func = functools .reduce (getattr , [jax ] + func_name .split ("." ))
75- else :
76- jnp_func = getattr (jnp , func_name )
77-
78- if hasattr (op , "nfunc_variadic" ):
79- # These are special cases that handle invalid arities due to the broken
80- # PyTensor `Op` type contract (e.g. binary `Op`s that also function as
81- # their own variadic counterparts--even when those counterparts already
82- # exist as independent `Op`s).
83- jax_variadic_func = getattr (jnp , op .nfunc_variadic )
84-
85- def elemwise (* args ):
86- if len (args ) > op .nfunc_spec [1 ]:
87- return jax_variadic_func (
88- jnp .stack (jnp .broadcast_arrays (* args ), axis = 0 ), axis = 0
89- )
90- else :
91- return jnp_func (* args )
92-
93- return elemwise
84+ jax_func = functools .reduce (getattr , [jax ] + func_name .split ("." ))
9485 else :
95- return jnp_func
86+ jax_func = getattr (jnp , func_name )
87+
88+ if len (node .inputs ) > op .nfunc_spec [1 ]:
89+ # Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
90+ # even though the base Op from `func_name` is specified as a binary Op.
91+ # This happens with `Add`, which can work as a `Sum` for multiple scalars.
92+ jax_variadic_func = getattr (jnp , op .nfunc_variadic , None )
93+ if not jax_variadic_func :
94+ raise NotImplementedError (
95+ f"Dispatch not implemented for Scalar Op { op } with { len (node .inputs )} inputs"
96+ )
97+
98+ def jax_func (* args ):
99+ return jax_variadic_func (
100+ jnp .stack (jnp .broadcast_arrays (* args ), axis = 0 ), axis = 0
101+ )
102+
103+ return jax_func
96104
97105
98106@functools .singledispatch
99- def elemwise_scalar (op ):
107+ def jax_funcify_scalar_op_via_py_operators (op ):
108+ """Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays.
109+
110+ Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats),
111+ which can perform better with Python operators, and more importantly, avoid upcasting to array types
112+ not supported by some JAX functions.
113+ """
100114 return None
101115
102116
103- @elemwise_scalar .register (Add )
104- def elemwise_scalar_add (op ):
117+ @jax_funcify_scalar_op_via_py_operators .register (Add )
118+ def jax_funcify_scalar_Add (op ):
105119 def elemwise (* inputs ):
106120 return sum (inputs )
107121
108122 return elemwise
109123
110124
111- @elemwise_scalar .register (Mul )
112- def elemwise_scalar_mul (op ):
125+ @jax_funcify_scalar_op_via_py_operators .register (Mul )
126+ def jax_funcify_scalar_Mul (op ):
113127 import operator
114128 from functools import reduce
115129
@@ -119,24 +133,24 @@ def elemwise(*inputs):
119133 return elemwise
120134
121135
122- @elemwise_scalar .register (Sub )
123- def elemwise_scalar_sub (op ):
136+ @jax_funcify_scalar_op_via_py_operators .register (Sub )
137+ def jax_funcify_scalar_Sub (op ):
124138 def elemwise (x , y ):
125139 return x - y
126140
127141 return elemwise
128142
129143
130- @elemwise_scalar .register (IntDiv )
131- def elemwise_scalar_intdiv (op ):
144+ @jax_funcify_scalar_op_via_py_operators .register (IntDiv )
145+ def jax_funcify_scalar_IntDiv (op ):
132146 def elemwise (x , y ):
133147 return x // y
134148
135149 return elemwise
136150
137151
138- @elemwise_scalar .register (Mod )
139- def elemwise_scalar_mod (op ):
152+ @jax_funcify_scalar_op_via_py_operators .register (Mod )
153+ def jax_funcify_scalar_Mod (op ):
140154 def elemwise (x , y ):
141155 return x % y
142156
0 commit comments