@@ -402,7 +402,9 @@ def {careduce_fn_name}({input_name}):
402402 return careduce_fn
403403
404404
405- def jit_compile_reducer (node , fn , * , reduce_to_scalar = False , ** kwds ):
405+ def jit_compile_reducer (
406+ node , fn , * , reduce_to_scalar = False , infer_signature = True , ** kwds
407+ ):
406408 """Compile Python source for reduction loops using additional optimizations.
407409
408410 Parameters
@@ -411,6 +413,10 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
411413 An node from which the signature can be derived.
412414 fn
413415 The Python function object to compile.
416+ reduce_to_scalar: bool, default False
417+ Whether to reduce output to a scalar (instead of 0d array)
418+ infer_signature: bool: default True
419+ Whether to try and infer the function signature from the Apply node.
414420 kwds
415421 Extra keywords to be added to the :func:`numba.njit` function.
416422
@@ -419,13 +425,17 @@ def jit_compile_reducer(node, fn, *, reduce_to_scalar=False, **kwds):
419425 A :func:`numba.njit`-compiled function.
420426
421427 """
422- signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
428+ if infer_signature :
429+ signature = create_numba_signature (node , reduce_to_scalar = reduce_to_scalar )
430+ args = (signature ,)
431+ else :
432+ args = ()
423433
424434 # Eagerly compile the function using increased optimizations. This should
425435 # help improve nested loop reductions.
426436 with use_optimized_cheap_pass ():
427437 res = numba_basic .numba_njit (
428- signature ,
438+ * args ,
429439 boundscheck = False ,
430440 fastmath = config .numba__fastmath ,
431441 ** kwds ,
@@ -926,11 +936,7 @@ def softmax_grad_py_fn(dy, sm):
926936 return dx
927937
928938 # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True)
929- # softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn)
930- softmax_grad = numba_njit (
931- boundscheck = False ,
932- fastmath = config .numba__fastmath ,
933- )(softmax_grad_py_fn )
939+ softmax_grad = jit_compile_reducer (node , softmax_grad_py_fn , infer_signature = False )
934940
935941 return softmax_grad
936942
0 commit comments