From ea65855ceae6c06e14e29c7e1e02db4c5f9e9a47 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 11 Sep 2023 18:21:17 +0200 Subject: [PATCH] Fix SoftmaxGrad failure with constant dy in numba backend --- pytensor/link/numba/dispatch/elemwise.py | 7 ++++++- tests/link/numba/test_elemwise.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pytensor/link/numba/dispatch/elemwise.py b/pytensor/link/numba/dispatch/elemwise.py index 49f7d6e8a2..0b20e6e7ce 100644 --- a/pytensor/link/numba/dispatch/elemwise.py +++ b/pytensor/link/numba/dispatch/elemwise.py @@ -925,7 +925,12 @@ def softmax_grad_py_fn(dy, sm): dx = dy_times_sm - sum_dy_times_sm * sm return dx - softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn) + # The signature inferred by jit_compile_reducer is wrong when dy is a constant (readonly=True) + # softmax_grad = jit_compile_reducer(node, softmax_grad_py_fn) + softmax_grad = numba_njit( + boundscheck=False, + fastmath=config.numba__fastmath, + )(softmax_grad_py_fn) return softmax_grad diff --git a/tests/link/numba/test_elemwise.py b/tests/link/numba/test_elemwise.py index ffc7967c83..35b0ae9857 100644 --- a/tests/link/numba/test_elemwise.py +++ b/tests/link/numba/test_elemwise.py @@ -445,6 +445,16 @@ def test_SoftmaxGrad(dy, sm, axis, exc): ) +def test_SoftMaxGrad_constant_dy(): + dy = at.constant(np.zeros((3,), dtype=config.floatX)) + sm = at.vector(shape=(3,)) + + g = SoftmaxGrad(axis=None)(dy, sm) + g_fg = FunctionGraph(outputs=[g]) + + compare_numba_and_py(g_fg, [np.ones((3,), dtype=config.floatX)]) + + @pytest.mark.parametrize( "x, axis, exc", [