-
Notifications
You must be signed in to change notification settings - Fork 145
Closed
Labels
Description
Reproducable code example:
import pytensor.tensor as pt
x = pt.vector("x")
b = pt.zeros((1, 2)) # Fine if b is (n, 2), as long as n != 1
b = pt.set_subtensor(b[:], x)
pt.grad(b.sum(), x)
Error message:
Traceback (most recent call last):
File "/home/ricardo/miniconda3/envs/pytensor/lib/python3.10/site-packages/IPython/core/interactiveshell.py", line 3433, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-8-d0c2f306f2bb>", line 7, in <module>
pt.grad(b.sum(), x)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 617, in grad
_rval: Sequence[Variable] = _populate_grad_dict(
File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in _populate_grad_dict
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1420, in <listcomp>
rval = [access_grad_cache(elem) for elem in wrt]
File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1375, in access_grad_cache
term = access_term_cache(node)[idx]
File "/home/ricardo/Documents/Projects/pytensor/pytensor/gradient.py", line 1205, in access_term_cache
input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/graph/op.py", line 392, in L_op
return self.grad(inputs, output_grads)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/subtensor.py", line 1869, in grad
gy = _sum_grad_over_bcasted_dims(y, gy)
File "/home/ricardo/Documents/Projects/pytensor/pytensor/tensor/subtensor.py", line 1905, in _sum_grad_over_bcasted_dims
assert sum(gx.broadcastable) < sum(x_broad)
AssertionError
PyTensor version information:
2.10.1