diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 0b20e6e7ce..96200dff99 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}): return careduce_fn -def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds): +def jit_compile_reducer( + node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds +): """Compile Python source for reduction loops using additional optimizations. Parameters @@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds): An node from which the signature can be derived. fn The Python function object to compile. + reduce_to_scalar: bool, default False + Whether to reduce output to a scalar (instead of 0d array) + infer_signature: bool: default True + Whether to try and infer the function signature from the Apply node. kwds Extra keywords to be added to the :func:`numba.njit` function. @@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds): A :func:`numba.njit`-compiled function. """ - signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) + if infer_signature: + signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar) + args = (signature,) + else: + args = () # Eagerly compile the function using increased optimizations. This should # help improve nested loop reductions. with use_optimized_cheap_pass(): res = numba_basic.numba_njit( - signature, + *args, boundscheck=False, fastmath=config.numba__fastmath, **kwds, @@ -926,11 +936,7 @@ def softmax_grad_py_fn(dy, sm): return dx # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True) - # softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn) - softmax_grad = numba_njit( - boundscheck=False, - fastmath=config.numba__fastmath, - )(softmax_grad_py_fn) + softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False) return softmax_grad