Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions pytensor/link/numba/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}):
return careduce_fn


def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
def jit_compile_reducer(
node, fn, *, reduce_to_scalar=False, infer_signature=True, **kwds
):
"""Compile Python source for reduction loops using additional optimizations.

Parameters
Expand All @@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
An node from which the signature can be derived.
fn
The Python function object to compile.
reduce_to_scalar: bool, default False
Whether to reduce output to a scalar (instead of 0d array)
infer_signature: bool: default True
Whether to try and infer the function signature from the Apply node.
kwds
Extra keywords to be added to the :func:`numba.njit` function.

Expand All @@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
A :func:`numba.njit`-compiled function.

"""
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
if infer_signature:
signature = create_numba_signature(node, reduce_to_scalar=reduce_to_scalar)
args = (signature,)
else:
args = ()

# Eagerly compile the function using increased optimizations. This should
# help improve nested loop reductions.
with use_optimized_cheap_pass():
res = numba_basic.numba_njit(
signature,
*args,
boundscheck=False,
fastmath=config.numba__fastmath,
**kwds,
Expand Down Expand Up @@ -926,11 +936,7 @@ def softmax_grad_py_fn(dy, sm):
return dx

# The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
# softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
softmax_grad = numba_njit(
boundscheck=False,
fastmath=config.numba__fastmath,
)(softmax_grad_py_fn)
softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn, infer_signature=False)

return softmax_grad

Expand Down