Skip to content

BUG: Using transform=pm.distributions.transforms.ordered together with non-null dims results in AssertionError #7554

@tomicapretto

Description

@tomicapretto

Describe the issue:

Someone opened an issue in the Bambi repo showing that trying to fit certain model resulted in an AssertionError. I reproduced the model in PyMC and found the issue. When we pass something to dims and use transform=pm.distributions.transforms.ordered, it causes the error.

Reproduceable code example:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "threshold_dim": [0, 1],
    "to_predict_dim": [0, 1, 2],
    "__obs__": [0, 1, 2],
}

predictor = np.array([1, 0, 1])
observed = np.array([0, 1, 2])

with pm.Model(coords=coords) as model:
    b_predictor = pm.Normal("b_predictor")
    threshold = pm.Normal(
        "threshold",
        mu=[-2, 2],
        sigma=1,
        transform=pm.distributions.transforms.ordered,
        # dims="threshold_dim" # If this is commented out, we get the assertion error
    )

    eta = b_predictor * np.array([1, 0, 1])
    eta_shifted = threshold - pt.shape_padright(eta)
    p = pm.math.sigmoid(eta_shifted)
    p = pt.concatenate(
        [
            pt.shape_padright(p[..., 0]),
            p[..., 1:] - p[..., :-1],
            pt.shape_padright(1 - p[..., -1]),
        ],
        axis=-1,
    )

    p = pm.Deterministic("p", p, dims=("__obs__", "to_predict_dim"))

    pm.Categorical("to_predict", p=p, observed=observed, dims="__obs__")

with model:
    idata = pm.sample()

Error message:

When you comment out the line highlighted above, you'll see the following error message:

AssertionError                            Traceback (most recent call last)
Cell In[7], line 2
      1 with model:
----> 2     idata = pm.sample()

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:718, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
    715         auto_nuts_init = False
    717 initial_points = None
--> 718 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
    720 if nuts_sampler != "pymc":
    721     if not isinstance(step, NUTS):

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:223, in assign_step_methods(model, step, methods, step_kwargs)
    221 if has_gradient:
    222     try:
--> 223         tg.grad(model_logp, var)  # type: ignore
    224     except (NotImplementedError, tg.NullTypeGradError):
    225         has_gradient = False

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:633, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    630     if hasattr(g.type, "dtype"):
    631         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 633 _rval: Sequence[Variable] = _populate_grad_dict(
    634     var_to_app_to_idx, grad_dict, _wrt, cost_name
    635 )
    637 rval: MutableSequence[Variable | None] = list(_rval)
    639 for i in range(len(_rval)):

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1422     # end if cache miss
   1423     return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
   1427 return rval

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in <listcomp>(.0)
   1422     # end if cache miss
   1423     return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
   1427 return rval

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
   1054 if node not in term_dict:
   1055     inputs = node.inputs
-> 1057     output_grads = [access_grad_cache(var) for var in node.outputs]
   1059     # list of bools indicating if each output is connected to the cost
   1060     outputs_connected = [
   1061         not isinstance(g.type, DisconnectedType) for g in output_grads
   1062     ]

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1378 for node in node_to_idx:
   1379     for idx in node_to_idx[node]:
-> 1380         term = access_term_cache(node)[idx]
   1382         if not isinstance(term, Variable):
   1383             raise TypeError(
   1384                 f"{node.op}.grad returned {type(term)}, expected"
   1385                 " Variable instance."
   1386             )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1210, in _populate_grad_dict.<locals>.access_term_cache(node)
   1202         if o_shape != g_shape:
   1203             raise ValueError(
   1204                 "Got a gradient of shape "
   1205                 + str(o_shape)
   1206                 + " on an output of shape "
   1207                 + str(g_shape)
   1208             )
-> 1210 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1212 if input_grads is None:
   1213     raise TypeError(
   1214         f"{node.op}.grad returned NoneType, expected iterable."
   1215     )

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/op.py:398, in Op.L_op(self, inputs, outputs, output_grads)
    371 def L_op(
    372     self,
    373     inputs: Sequence[Variable],
    374     outputs: Sequence[Variable],
    375     output_grads: Sequence[Variable],
    376 ) -> list[Variable]:
    377     r"""Construct a graph for the L-operator.
    378 
    379     The L-operator computes a row vector times the Jacobian.
   (...)
    396 
    397     """
--> 398     return self.grad(inputs, output_grads)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:1995, in IncSubtensor.grad(self, inputs, grads)
   1993         gx = g_output
   1994     gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
-> 1995     gy = _sum_grad_over_bcasted_dims(y, gy)
   1997 return [gx, gy] + [DisconnectedType()()] * len(idx_list)

File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:2031, in _sum_grad_over_bcasted_dims(x, gx)
   2029 x_dim_added = gx.ndim - x.ndim
   2030 x_broad = (True,) * x_dim_added + x.broadcastable
-> 2031 assert sum(gx.broadcastable) <= sum(x_broad)
   2032 axis_to_sum = []
   2033 for i in range(gx.ndim):

AssertionError: 

But if you leave it commented, it works.

PyMC version information:

PyMC 5.17.0
PyTensor 2.25.5

Context for the issue:

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions