1717import pytensor
1818import pytensor .tensor as pt
1919
20- from pytensor import scan
20+ from pytensor import config , graph_replace , scan
2121from pytensor .graph import Op
2222from pytensor .graph .basic import Node
2323from pytensor .raise_op import CheckAndRaise
2424from pytensor .scan import until
2525from pytensor .tensor import TensorConstant , TensorVariable
2626from pytensor .tensor .random .basic import NormalRV
2727from pytensor .tensor .random .op import RandomVariable
28+ from pytensor .tensor .random .type import RandomType
2829
2930from pymc .distributions .continuous import TruncatedNormal , bounded_cont_transform
3031from pymc .distributions .dist_math import check_parameters
3132from pymc .distributions .distribution import (
33+ CustomSymbolicDistRV ,
3234 Distribution ,
3335 SymbolicRandomVariable ,
3436 _support_point ,
3840from pymc .distributions .transforms import _default_transform
3941from pymc .exceptions import TruncationError
4042from pymc .logprob .abstract import _logcdf , _logprob
41- from pymc .logprob .basic import icdf , logcdf
43+ from pymc .logprob .basic import icdf , logcdf , logp
4244from pymc .math import logdiffexp
45+ from pymc .pytensorf import collect_default_updates
4346from pymc .util import check_dist_not_registered
4447
4548
@@ -49,11 +52,17 @@ class TruncatedRV(SymbolicRandomVariable):
4952 that represents a truncated univariate random variable.
5053 """
5154
52- default_output = 1
53- base_rv_op = None
54- max_n_steps = None
55-
56- def __init__ (self , * args , base_rv_op : Op , max_n_steps : int , ** kwargs ):
55+ default_output : int = 0
56+ base_rv_op : Op
57+ max_n_steps : int
58+
59+ def __init__ (
60+ self ,
61+ * args ,
62+ base_rv_op : Op ,
63+ max_n_steps : int ,
64+ ** kwargs ,
65+ ):
5766 self .base_rv_op = base_rv_op
5867 self .max_n_steps = max_n_steps
5968 self ._print_name = (
@@ -63,8 +72,13 @@ def __init__(self, *args, base_rv_op: Op, max_n_steps: int, **kwargs):
6372 super ().__init__ (* args , ** kwargs )
6473
6574 def update (self , node : Node ):
66- """Return the update mapping for the internal RNG."""
67- return {node .inputs [- 1 ]: node .outputs [0 ]}
75+ """Return the update mapping for the internal RNGs.
76+
77+ TruncatedRVs are created in a way that the rng updates follow the same order as the input RNGs.
78+ """
79+ rngs = [inp for inp in node .inputs if isinstance (inp .type , RandomType )]
80+ next_rngs = [out for out in node .outputs if isinstance (out .type , RandomType )]
81+ return dict (zip (rngs , next_rngs ))
6882
6983
7084@singledispatch
@@ -141,10 +155,14 @@ class Truncated(Distribution):
141155
142156 @classmethod
143157 def dist (cls , dist , lower = None , upper = None , max_n_steps : int = 10_000 , ** kwargs ):
144- if not (isinstance (dist , TensorVariable ) and isinstance (dist .owner .op , RandomVariable )):
158+ if not (
159+ isinstance (dist , TensorVariable )
160+ and isinstance (dist .owner .op , RandomVariable | CustomSymbolicDistRV )
161+ ):
145162 if isinstance (dist .owner .op , SymbolicRandomVariable ):
146163 raise NotImplementedError (
147- f"Truncation not implemented for SymbolicRandomVariable { dist .owner .op } "
164+ f"Truncation not implemented for SymbolicRandomVariable { dist .owner .op } .\n "
165+ f"You can try wrapping the distribution inside a CustomDist instead."
148166 )
149167 raise ValueError (
150168 f"Truncation dist must be a distribution created via the `.dist()` API, got { type (dist )} "
@@ -174,46 +192,54 @@ def rv_op(cls, dist, lower, upper, max_n_steps, size=None):
174192 if size is None :
175193 size = pt .broadcast_shape (dist , lower , upper )
176194 dist = change_dist_size (dist , new_size = size )
195+ rv_inputs = [
196+ inp
197+ if not isinstance (inp .type , RandomType )
198+ else pytensor .shared (np .random .default_rng ())
199+ for inp in dist .owner .inputs
200+ ]
201+ graph_inputs = [* rv_inputs , lower , upper ]
177202
178203 # Variables with `_` suffix identify dummy inputs for the OpFromGraph
179- graph_inputs = [* dist .owner .inputs [1 :], lower , upper ]
180- graph_inputs_ = [inp .type () for inp in graph_inputs ]
204+ graph_inputs_ = [
205+ inp .type () if not isinstance (inp .type , RandomType ) else inp for inp in graph_inputs
206+ ]
181207 * rv_inputs_ , lower_ , upper_ = graph_inputs_
182208
183- # We will use a Shared RNG variable because Scan demands it, even though it
184- # would not be necessary for the OpFromGraph inverse cdf.
185- rng = pytensor .shared (np .random .default_rng ())
186- rv_ = dist .owner .op .make_node (rng , * rv_inputs_ ).default_output ()
209+ rv_ = dist .owner .op .make_node (* rv_inputs_ ).default_output ()
187210
188211 # Try to use inverted cdf sampling
212+ # truncated_rv = icdf(rv, draw(uniform(cdf(lower), cdf(upper))))
189213 try :
190- # For left truncated discrete RVs, we need to include the whole lower bound.
191- # This may result in draws below the truncation range, if any uniform == 0
192- lower_value = lower_ - 1 if dist .owner .op .dtype .startswith ("int" ) else lower_
193- cdf_lower_ = pt .exp (logcdf (rv_ , lower_value ))
194- cdf_upper_ = pt .exp (logcdf (rv_ , upper_ ))
195- # It's okay to reuse the same rng here, because the rng in rv_ will not be
196- # used by either the logcdf of icdf functions
214+ logcdf_lower_ , logcdf_upper_ = Truncated ._create_logcdf_exprs (rv_ , rv_ , lower_ , upper_ )
215+ # We use the first RNG from the base RV, so we don't have to introduce a new one
216+ # This is not problematic because the RNG won't be used in the RV logcdf graph
217+ uniform_rng_ = next (inp_ for inp_ in rv_inputs_ if isinstance (inp_ .type , RandomType ))
197218 uniform_next_rng_ , uniform_ = pt .random .uniform (
198- cdf_lower_ ,
199- cdf_upper_ ,
200- rng = rng ,
201- size = rv_inputs_ [ 0 ] ,
219+ pt . exp ( logcdf_lower_ ) ,
220+ pt . exp ( logcdf_upper_ ) ,
221+ rng = uniform_rng_ ,
222+ size = rv_ . shape ,
202223 ).owner .outputs
203- truncated_rv_ = icdf (rv_ , uniform_ )
224+ truncated_rv_ = icdf (rv_ , uniform_ , warn_rvs = False )
204225 return TruncatedRV (
205226 base_rv_op = dist .owner .op ,
206- inputs = [ * graph_inputs_ , rng ] ,
207- outputs = [uniform_next_rng_ , truncated_rv_ ],
227+ inputs = graph_inputs_ ,
228+ outputs = [truncated_rv_ , uniform_next_rng_ ],
208229 ndim_supp = 0 ,
209230 max_n_steps = max_n_steps ,
210- )(* graph_inputs , rng )
231+ )(* graph_inputs )
211232 except NotImplementedError :
212233 pass
213234
214235 # Fallback to rejection sampling
215- def loop_fn (truncated_rv , reject_draws , lower , upper , rng , * rv_inputs ):
216- next_rng , new_truncated_rv = dist .owner .op .make_node (rng , * rv_inputs ).outputs
236+ # truncated_rv = zeros(rv.shape)
237+ # reject_draws = ones(rv.shape, dtype=bool)
238+ # while any(reject_draws):
239+ # truncated_rv[reject_draws] = draw(rv)[reject_draws]
240+ # reject_draws = (truncated_rv < lower) | (truncated_rv > upper)
241+ def loop_fn (truncated_rv , reject_draws , lower , upper , * rv_inputs ):
242+ new_truncated_rv = dist .owner .op .make_node (* rv_inputs_ ).default_output ()
217243 # Avoid scalar boolean indexing
218244 if truncated_rv .type .ndim == 0 :
219245 truncated_rv = new_truncated_rv
@@ -226,7 +252,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
226252
227253 return (
228254 (truncated_rv , reject_draws ),
229- [( rng , next_rng )] ,
255+ collect_default_updates ( new_truncated_rv ) ,
230256 until (~ pt .any (reject_draws )),
231257 )
232258
@@ -236,7 +262,7 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
236262 pt .zeros_like (rv_ ),
237263 pt .ones_like (rv_ , dtype = bool ),
238264 ],
239- non_sequences = [lower_ , upper_ , rng , * rv_inputs_ ],
265+ non_sequences = [lower_ , upper_ , * rv_inputs_ ],
240266 n_steps = max_n_steps ,
241267 strict = True ,
242268 )
@@ -246,24 +272,49 @@ def loop_fn(truncated_rv, reject_draws, lower, upper, rng, *rv_inputs):
246272 truncated_rv_ = TruncationCheck (f"Truncation did not converge in { max_n_steps } steps" )(
247273 truncated_rv_ , convergence_
248274 )
275+ # Sort updates of each RNG so that they show in the same order as the input RNGs
276+
277+ def sort_updates (update ):
278+ rng , next_rng = update
279+ return graph_inputs .index (rng )
280+
281+ next_rngs = [next_rng for rng , next_rng in sorted (updates .items (), key = sort_updates )]
249282
250- [next_rng ] = updates .values ()
251283 return TruncatedRV (
252284 base_rv_op = dist .owner .op ,
253- inputs = [ * graph_inputs_ , rng ] ,
254- outputs = [next_rng , truncated_rv_ ],
285+ inputs = graph_inputs_ ,
286+ outputs = [truncated_rv_ , * next_rngs ],
255287 ndim_supp = 0 ,
256288 max_n_steps = max_n_steps ,
257- )(* graph_inputs , rng )
289+ )(* graph_inputs )
290+
291+ @staticmethod
292+ def _create_logcdf_exprs (
293+ base_rv : TensorVariable ,
294+ value : TensorVariable ,
295+ lower : TensorVariable ,
296+ upper : TensorVariable ,
297+ ) -> tuple [TensorVariable , TensorVariable ]:
298+ """Create lower and upper logcdf expressions for base_rv.
299+
300+ Uses `value` as a template for broadcasting.
301+ """
302+ # For left truncated discrete RVs, we need to include the whole lower bound.
303+ lower_value = lower - 1 if base_rv .type .dtype .startswith ("int" ) else lower
304+ lower_value = pt .full_like (value , lower_value , dtype = config .floatX )
305+ upper_value = pt .full_like (value , upper , dtype = config .floatX )
306+ lower_logcdf = logcdf (base_rv , lower_value , warn_rvs = False )
307+ upper_logcdf = graph_replace (lower_logcdf , {lower_value : upper_value })
308+ return lower_logcdf , upper_logcdf
258309
259310
260311@_change_dist_size .register (TruncatedRV )
261- def change_truncated_size (op , dist , new_size , expand ):
262- * rv_inputs , lower , upper , rng = dist .owner .inputs
263- # Recreate the original untruncated RV
264- untruncated_rv = op . base_rv_op . make_node ( rng , * rv_inputs ). default_output ()
312+ def change_truncated_size (op : TruncatedRV , truncated_rv , new_size , expand ):
313+ * rv_inputs , lower , upper = truncated_rv .owner .inputs
314+ untruncated_rv = op . base_rv_op . make_node ( * rv_inputs ). default_output ()
315+
265316 if expand :
266- new_size = to_tuple (new_size ) + tuple (dist .shape )
317+ new_size = to_tuple (new_size ) + tuple (truncated_rv .shape )
267318
268319 return Truncated .rv_op (
269320 untruncated_rv ,
@@ -275,11 +326,11 @@ def change_truncated_size(op, dist, new_size, expand):
275326
276327
277328@_support_point .register (TruncatedRV )
278- def truncated_support_point (op , rv , * inputs ):
279- * rv_inputs , lower , upper , rng = inputs
329+ def truncated_support_point (op : TruncatedRV , truncated_rv , * inputs ):
330+ * rv_inputs , lower , upper = inputs
280331
281332 # recreate untruncated rv and respective support_point
282- untruncated_rv = op .base_rv_op .make_node (rng , * rv_inputs ).default_output ()
333+ untruncated_rv = op .base_rv_op .make_node (* rv_inputs ).default_output ()
283334 untruncated_support_point = support_point (untruncated_rv )
284335
285336 fallback_support_point = pt .switch (
@@ -300,31 +351,25 @@ def truncated_support_point(op, rv, *inputs):
300351
301352
302353@_default_transform .register (TruncatedRV )
303- def truncated_default_transform (op , rv ):
354+ def truncated_default_transform (op , truncated_rv ):
304355 # Don't transform discrete truncated distributions
305- if op . base_rv_op .dtype .startswith ("int" ):
356+ if truncated_rv . type .dtype .startswith ("int" ):
306357 return None
307- # Lower and Upper are the arguments -3 and -2
308- return bounded_cont_transform (op , rv , bound_args_indices = (- 3 , - 2 ))
358+ # Lower and Upper are the arguments -2 and -1
359+ return bounded_cont_transform (op , truncated_rv , bound_args_indices = (- 2 , - 1 ))
309360
310361
311362@_logprob .register (TruncatedRV )
312363def truncated_logprob (op , values , * inputs , ** kwargs ):
313364 (value ,) = values
314-
315- * rv_inputs , lower , upper , rng = inputs
316- rv_inputs = [rng , * rv_inputs ]
365+ * rv_inputs , lower , upper = inputs
317366
318367 base_rv_op = op .base_rv_op
319- logp = _logprob (base_rv_op , (value ,), * rv_inputs , ** kwargs )
320- # For left truncated RVs, we don't want to include the lower bound in the
321- # normalization term
322- lower_value = lower - 1 if base_rv_op .dtype .startswith ("int" ) else lower
323- lower_logcdf = _logcdf (base_rv_op , lower_value , * rv_inputs , ** kwargs )
324- upper_logcdf = _logcdf (base_rv_op , upper , * rv_inputs , ** kwargs )
325-
368+ base_rv = base_rv_op .make_node (* rv_inputs ).default_output ()
369+ base_logp = logp (base_rv , value )
370+ lower_logcdf , upper_logcdf = Truncated ._create_logcdf_exprs (base_rv , value , lower , upper )
326371 if base_rv_op .name :
327- logp .name = f"{ base_rv_op } _logprob"
372+ base_logp .name = f"{ base_rv_op } _logprob"
328373 lower_logcdf .name = f"{ base_rv_op } _lower_logcdf"
329374 upper_logcdf .name = f"{ base_rv_op } _upper_logcdf"
330375
@@ -339,37 +384,31 @@ def truncated_logprob(op, values, *inputs, **kwargs):
339384 elif is_upper_bounded :
340385 lognorm = upper_logcdf
341386
342- logp = logp - lognorm
387+ truncated_logp = base_logp - lognorm
343388
344389 if is_lower_bounded :
345- logp = pt .switch (value < lower , - np .inf , logp )
390+ truncated_logp = pt .switch (value < lower , - np .inf , truncated_logp )
346391
347392 if is_upper_bounded :
348- logp = pt .switch (value <= upper , logp , - np .inf )
393+ truncated_logp = pt .switch (value <= upper , truncated_logp , - np .inf )
349394
350395 if is_lower_bounded and is_upper_bounded :
351- logp = check_parameters (
352- logp ,
396+ truncated_logp = check_parameters (
397+ truncated_logp ,
353398 pt .le (lower , upper ),
354399 msg = "lower_bound <= upper_bound" ,
355400 )
356401
357- return logp
402+ return truncated_logp
358403
359404
360405@_logcdf .register (TruncatedRV )
361- def truncated_logcdf (op , value , * inputs , ** kwargs ):
362- * rv_inputs , lower , upper , rng = inputs
363- rv_inputs = [rng , * rv_inputs ]
364-
365- base_rv_op = op .base_rv_op
366- logcdf = _logcdf (base_rv_op , value , * rv_inputs , ** kwargs )
406+ def truncated_logcdf (op : TruncatedRV , value , * inputs , ** kwargs ):
407+ * rv_inputs , lower , upper = inputs
367408
368- # For left truncated discrete RVs, we don't want to include the lower bound in the
369- # normalization term
370- lower_value = lower - 1 if base_rv_op .dtype .startswith ("int" ) else lower
371- lower_logcdf = _logcdf (base_rv_op , lower_value , * rv_inputs , ** kwargs )
372- upper_logcdf = _logcdf (base_rv_op , upper , * rv_inputs , ** kwargs )
409+ base_rv = op .base_rv_op .make_node (* rv_inputs ).default_output ()
410+ base_logcdf = logcdf (base_rv , value )
411+ lower_logcdf , upper_logcdf = Truncated ._create_logcdf_exprs (base_rv , value , lower , upper )
373412
374413 is_lower_bounded = not (isinstance (lower , TensorConstant ) and np .all (np .isneginf (lower .value )))
375414 is_upper_bounded = not (isinstance (upper , TensorConstant ) and np .all (np .isinf (upper .value )))
@@ -382,7 +421,7 @@ def truncated_logcdf(op, value, *inputs, **kwargs):
382421 elif is_upper_bounded :
383422 lognorm = upper_logcdf
384423
385- logcdf_numerator = logdiffexp (logcdf , lower_logcdf ) if is_lower_bounded else logcdf
424+ logcdf_numerator = logdiffexp (base_logcdf , lower_logcdf ) if is_lower_bounded else base_logcdf
386425 logcdf_trunc = logcdf_numerator - lognorm
387426
388427 if is_lower_bounded :
0 commit comments