diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 95534770ab..b14fa242b2 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -1902,7 +1902,7 @@ def _sum_grad_over_bcasted_dims(x, gx): if gx.broadcastable != x.broadcastable: x_dim_added = gx.ndim - x.ndim x_broad = (True,) * x_dim_added + x.broadcastable - assert sum(gx.broadcastable) < sum(x_broad) + assert sum(gx.broadcastable) <= sum(x_broad) axis_to_sum = [] for i in range(gx.ndim): if gx.broadcastable[i] is False and x_broad[i] is True: diff --git a/tests/tensor/test_subtensor.py b/tests/tensor/test_subtensor.py index 86573805c1..2dec05b63b 100644 --- a/tests/tensor/test_subtensor.py +++ b/tests/tensor/test_subtensor.py @@ -1593,6 +1593,15 @@ def just_numeric_args(a, b): ), ) + # Broadcastable leading dim + utt.verify_grad( + f_slice(slice(None, None), slice(1, 3)), + ( + np.asarray([0, 1, 2, 3, 4, 5.0])[None, ...], + np.asarray([9, 9.0]), + ), + ) + class TestIncSubtensor1: def setup_method(self):