-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Closed
pymc-devs/pytensor
#1057Labels
Description
Describe the issue:
Someone opened an issue in the Bambi repo showing that trying to fit certain model resulted in an AssertionError. I reproduced the model in PyMC and found the issue. When we pass something to dims
and use transform=pm.distributions.transforms.ordered
, it causes the error.
Reproduceable code example:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
coords = {
"threshold_dim": [0, 1],
"to_predict_dim": [0, 1, 2],
"__obs__": [0, 1, 2],
}
predictor = np.array([1, 0, 1])
observed = np.array([0, 1, 2])
with pm.Model(coords=coords) as model:
b_predictor = pm.Normal("b_predictor")
threshold = pm.Normal(
"threshold",
mu=[-2, 2],
sigma=1,
transform=pm.distributions.transforms.ordered,
# dims="threshold_dim" # If this is commented out, we get the assertion error
)
eta = b_predictor * np.array([1, 0, 1])
eta_shifted = threshold - pt.shape_padright(eta)
p = pm.math.sigmoid(eta_shifted)
p = pt.concatenate(
[
pt.shape_padright(p[..., 0]),
p[..., 1:] - p[..., :-1],
pt.shape_padright(1 - p[..., -1]),
],
axis=-1,
)
p = pm.Deterministic("p", p, dims=("__obs__", "to_predict_dim"))
pm.Categorical("to_predict", p=p, observed=observed, dims="__obs__")
with model:
idata = pm.sample()
Error message:
When you comment out the line highlighted above, you'll see the following error message:
AssertionError Traceback (most recent call last)
Cell In[7], line 2
1 with model:
----> 2 idata = pm.sample()
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:718, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, **kwargs)
715 auto_nuts_init = False
717 initial_points = None
--> 718 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
720 if nuts_sampler != "pymc":
721 if not isinstance(step, NUTS):
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pymc/sampling/mcmc.py:223, in assign_step_methods(model, step, methods, step_kwargs)
221 if has_gradient:
222 try:
--> 223 tg.grad(model_logp, var) # type: ignore
224 except (NotImplementedError, tg.NullTypeGradError):
225 has_gradient = False
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:633, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
630 if hasattr(g.type, "dtype"):
631 assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 633 _rval: Sequence[Variable] = _populate_grad_dict(
634 var_to_app_to_idx, grad_dict, _wrt, cost_name
635 )
637 rval: MutableSequence[Variable | None] = list(_rval)
639 for i in range(len(_rval)):
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
1422 # end if cache miss
1423 return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
1427 return rval
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1425, in <listcomp>(.0)
1422 # end if cache miss
1423 return grad_dict[var]
-> 1425 rval = [access_grad_cache(elem) for elem in wrt]
1427 return rval
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
1378 for node in node_to_idx:
1379 for idx in node_to_idx[node]:
-> 1380 term = access_term_cache(node)[idx]
1382 if not isinstance(term, Variable):
1383 raise TypeError(
1384 f"{node.op}.grad returned {type(term)}, expected"
1385 " Variable instance."
1386 )
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
1054 if node not in term_dict:
1055 inputs = node.inputs
-> 1057 output_grads = [access_grad_cache(var) for var in node.outputs]
1059 # list of bools indicating if each output is connected to the cost
1060 outputs_connected = [
1061 not isinstance(g.type, DisconnectedType) for g in output_grads
1062 ]
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
1054 if node not in term_dict:
1055 inputs = node.inputs
-> 1057 output_grads = [access_grad_cache(var) for var in node.outputs]
1059 # list of bools indicating if each output is connected to the cost
1060 outputs_connected = [
1061 not isinstance(g.type, DisconnectedType) for g in output_grads
1062 ]
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
1378 for node in node_to_idx:
1379 for idx in node_to_idx[node]:
-> 1380 term = access_term_cache(node)[idx]
1382 if not isinstance(term, Variable):
1383 raise TypeError(
1384 f"{node.op}.grad returned {type(term)}, expected"
1385 " Variable instance."
1386 )
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in _populate_grad_dict.<locals>.access_term_cache(node)
1054 if node not in term_dict:
1055 inputs = node.inputs
-> 1057 output_grads = [access_grad_cache(var) for var in node.outputs]
1059 # list of bools indicating if each output is connected to the cost
1060 outputs_connected = [
1061 not isinstance(g.type, DisconnectedType) for g in output_grads
1062 ]
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1057, in <listcomp>(.0)
1054 if node not in term_dict:
1055 inputs = node.inputs
-> 1057 output_grads = [access_grad_cache(var) for var in node.outputs]
1059 # list of bools indicating if each output is connected to the cost
1060 outputs_connected = [
1061 not isinstance(g.type, DisconnectedType) for g in output_grads
1062 ]
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1380, in _populate_grad_dict.<locals>.access_grad_cache(var)
1378 for node in node_to_idx:
1379 for idx in node_to_idx[node]:
-> 1380 term = access_term_cache(node)[idx]
1382 if not isinstance(term, Variable):
1383 raise TypeError(
1384 f"{node.op}.grad returned {type(term)}, expected"
1385 " Variable instance."
1386 )
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/gradient.py:1210, in _populate_grad_dict.<locals>.access_term_cache(node)
1202 if o_shape != g_shape:
1203 raise ValueError(
1204 "Got a gradient of shape "
1205 + str(o_shape)
1206 + " on an output of shape "
1207 + str(g_shape)
1208 )
-> 1210 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
1212 if input_grads is None:
1213 raise TypeError(
1214 f"{node.op}.grad returned NoneType, expected iterable."
1215 )
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/graph/op.py:398, in Op.L_op(self, inputs, outputs, output_grads)
371 def L_op(
372 self,
373 inputs: Sequence[Variable],
374 outputs: Sequence[Variable],
375 output_grads: Sequence[Variable],
376 ) -> list[Variable]:
377 r"""Construct a graph for the L-operator.
378
379 The L-operator computes a row vector times the Jacobian.
(...)
396
397 """
--> 398 return self.grad(inputs, output_grads)
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:1995, in IncSubtensor.grad(self, inputs, grads)
1993 gx = g_output
1994 gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
-> 1995 gy = _sum_grad_over_bcasted_dims(y, gy)
1997 return [gx, gy] + [DisconnectedType()()] * len(idx_list)
File ~/miniconda3/envs/bambi-dev/lib/python3.11/site-packages/pytensor/tensor/subtensor.py:2031, in _sum_grad_over_bcasted_dims(x, gx)
2029 x_dim_added = gx.ndim - x.ndim
2030 x_broad = (True,) * x_dim_added + x.broadcastable
-> 2031 assert sum(gx.broadcastable) <= sum(x_broad)
2032 axis_to_sum = []
2033 for i in range(gx.ndim):
AssertionError:
But if you leave it commented, it works.
PyMC version information:
PyMC 5.17.0
PyTensor 2.25.5
Context for the issue:
No response