Skip to content

Derivative of set_subtensor with leading broadcastable dims fails #266

@AlexanderFengler

Description

@AlexanderFengler

Reproducable code example:

import pytensor.tensor as pt

x = pt.vector("x")
b = pt.zeros((1, 2))  # Fine if b is (n, 2), as long as n != 1
b = pt.set_subtensor(b[:], x)
pt.grad(b.sum(), x)

Error message:

Traceback (most recent call last):
  File "/home/ricardo/miniconda3/envs/pytensor/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-8-d0c2f306f2bb>", line 7, in <module>
    pt.grad(b.sum(), x)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 617, in grad
    _rval: Sequence[Variable] = _populate_grad_dict(
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in _populate_grad_dict
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in <listcomp>
    rval = [access_grad_cache(elem) for elem in wrt]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1375, in access_grad_cache
    term = access_term_cache(node)[idx]
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1205, in access_term_cache
    input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/op.py", line 392, in L_op
    return self.grad(inputs, output_grads)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/subtensor.py", line 1869, in grad
    gy = _sum_grad_over_bcasted_dims(y, gy)
  File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/subtensor.py", line 1905, in _sum_grad_over_bcasted_dims
    assert sum(gx.broadcastable) < sum(x_broad)
AssertionError

PyTensor version information:

2.10.1

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions