|
28 | 28 | from pytensor.tensor.random.type import RandomType |
29 | 29 |
|
30 | 30 | from pymc.distributions.continuous import TruncatedNormal, bounded_cont_transform |
31 | | -from pymc.distributions.custom import CustomSymbolicDistRV |
32 | 31 | from pymc.distributions.dist_math import check_parameters |
33 | 32 | from pymc.distributions.distribution import ( |
34 | 33 | Distribution, |
@@ -302,17 +301,24 @@ class Truncated(Distribution): |
302 | 301 | def dist(cls, dist, lower=None, upper=None, max_n_steps: int = 10_000, **kwargs): |
303 | 302 | if not ( |
304 | 303 | isinstance(dist, TensorVariable) |
305 | | - and isinstance(dist.owner.op, RandomVariable | CustomSymbolicDistRV) |
| 304 | + and dist.owner is not None |
| 305 | + and isinstance(dist.owner.op, RandomVariable | SymbolicRandomVariable) |
306 | 306 | ): |
307 | | - if isinstance(dist.owner.op, SymbolicRandomVariable): |
308 | | - raise NotImplementedError( |
309 | | - f"Truncation not implemented for SymbolicRandomVariable {dist.owner.op}.\n" |
310 | | - f"You can try wrapping the distribution inside a CustomDist instead." |
311 | | - ) |
312 | 307 | raise ValueError( |
313 | 308 | f"Truncation dist must be a distribution created via the `.dist()` API, got {type(dist)}" |
314 | 309 | ) |
315 | 310 |
|
| 311 | + if ( |
| 312 | + isinstance(dist.owner.op, SymbolicRandomVariable) |
| 313 | + and "[size]" not in dist.owner.op.extended_signature |
| 314 | + ): |
| 315 | + # Truncation needs to wrap the underlying dist, but not all SymbolicRandomVariables encapsulate the whole |
| 316 | + # random graph and as such we don't know where the actual inputs begin. This happens mostly for |
| 317 | + # distribution factories like `Censored` and `Mixture` which would have a very complex signature if they |
| 318 | + # encapsulated the random components instead of taking them as inputs like they do now. |
| 319 | + # SymbolicRandomVariables that encapsulate the whole random graph can be identified for having a size parameter. |
| 320 | + raise NotImplementedError(f"Truncation not implemented for {dist.owner.op}") |
| 321 | + |
316 | 322 | if dist.owner.op.ndim_supp > 0: |
317 | 323 | raise NotImplementedError("Truncation not implemented for multivariate distributions") |
318 | 324 |
|
|
0 commit comments