Skip to content

Numba SoftmaxGrad fails when input has constant shape #434

@ricardoV94

Description

@ricardoV94

Description

import numpy as np
import pytensor
import pytensor.tensor as pt

x = pt.matrix("x", shape=(2, 5))  # fine if shape is (None, None)
y = pt.special.softmax(x, axis=-1)
y_grad = pt.grad(y.sum(), wrt=x)

fn = pytensor.function([x], y_grad, mode="NUMBA")
fn(np.ones((2, 5)))
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
Invalid use of type(CPUDispatcher(<function numba_funcify_SoftmaxGrad.<locals>.softmax_grad_py_fn at 0x7f9766c2e0c0>)) with parameters (readonly array(float64, 2d, C), array(float64, 2d, A))
Known signatures:
 * (Array(float64, 2, 'A', False, aligned=True), Array(float64, 2, 'A', False, aligned=True)) -> array(float64, 2d, A)
During: resolving callee type: type(CPUDispatcher(<function numba_funcify_SoftmaxGrad.<locals>.softmax_grad_py_fn at 0x7f9766c2e0c0>))
During: typing of call at /tmp/tmpbsg8wrx8 (5)


File "../../../../../../../tmp/tmpbsg8wrx8", line 5:
def numba_funcified_fgraph(x):
    <source elided>
    # SoftmaxGrad{axis=-1}([[1. 1. 1. ... 1. 1. 1.]], Softmax{axis=-1}.0)
    tensor_variable_1 = softmax_grad_py_fn(tensor_constant, tensor_variable)
    ^

Apply node that caused the error: SoftmaxGrad{axis=-1}([[1. 1. 1. ... 1. 1. 1.]], Softmax{axis=-1}.0)
Toposort index: 1
Inputs types: [TensorType(float64, shape=(2, 5)), TensorType(float64, shape=(2, 5))]
Inputs shapes: [(2, 5)]
Inputs strides: [(40, 8)]
Inputs values: ['not shown']
Outputs clients: [['output']]

Backtrace when the node is created (use PyTensor flag traceback__limit=N to make it longer):
  File "/home/ricardo/miniconda3/envs/colgate-shelf-sow2/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3448, in run_ast_nodes
    if await self.run_code(code, result, async_=asy):
  File "/home/ricardo/miniconda3/envs/colgate-shelf-sow2/lib/python3.11/site-packages/IPython/core/interactiveshell.py", line 3508, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "/tmp/ipykernel_76068/3094158511.py", line 7, in <module>
    y_grad = pt.grad(y.sum(), wrt=x)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 617, in grad
    _rval: Sequence[Variable] = _populate_grad_dict(
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1375, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1205, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)

HINT: Use the PyTensor flag `exception_verbosity=high` for a debug print-out and storage map footprint of this Apply node.

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingnumba

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions