| 
37 | 37 | import warnings  | 
38 | 38 | 
 
  | 
39 | 39 | from copy import copy  | 
40 |  | -from typing import Callable, Dict, Generator, Iterable, List, Optional, Set, Tuple  | 
 | 40 | +from typing import (  | 
 | 41 | +    Callable,  | 
 | 42 | +    Dict,  | 
 | 43 | +    Generator,  | 
 | 44 | +    Iterable,  | 
 | 45 | +    List,  | 
 | 46 | +    Optional,  | 
 | 47 | +    Sequence,  | 
 | 48 | +    Set,  | 
 | 49 | +    Tuple,  | 
 | 50 | +)  | 
41 | 51 | 
 
  | 
42 | 52 | import numpy as np  | 
43 | 53 | 
 
  | 
@@ -265,7 +275,7 @@ def diracdelta_logprob(op, values, *inputs, **kwargs):  | 
265 | 275 | def ignore_logprob(rv: TensorVariable) -> TensorVariable:  | 
266 | 276 |     """Return a duplicated variable that is ignored when creating logprob graphs  | 
267 | 277 | 
  | 
268 |  | -    This is used in SymbolicDistributions that use other RVs as inputs but account  | 
 | 278 | +    This is used in by MeasurableRVs that use other RVs as inputs but account  | 
269 | 279 |     for their logp terms explicitly.  | 
270 | 280 | 
  | 
271 | 281 |     If the variable is already ignored, it is returned directly.  | 
@@ -298,3 +308,32 @@ def reconsider_logprob(rv: TensorVariable) -> TensorVariable:  | 
298 | 308 |     new_node.op = copy(new_node.op)  | 
299 | 309 |     new_node.op.__class__ = original_op_type  | 
300 | 310 |     return new_node.outputs[node.outputs.index(rv)]  | 
 | 311 | + | 
 | 312 | + | 
 | 313 | +def ignore_logprob_multiple_vars(  | 
 | 314 | +    vars: Sequence[TensorVariable], rvs_to_values: Dict[TensorVariable, TensorVariable]  | 
 | 315 | +) -> List[TensorVariable]:  | 
 | 316 | +    """Return duplicated variables that are ignored when creating logprob graphs.  | 
 | 317 | +
  | 
 | 318 | +    This function keeps any interdependencies between variables intact, after  | 
 | 319 | +    making each "unmeasurable", whereas a sequential call to `ignore_logprob`  | 
 | 320 | +    would not do this correctly.  | 
 | 321 | +    """  | 
 | 322 | +    from pymc.pytensorf import _replace_rvs_in_graphs  | 
 | 323 | + | 
 | 324 | +    measurable_vars_to_unmeasurable_vars = {  | 
 | 325 | +        measurable_var: ignore_logprob(measurable_var) for measurable_var in vars  | 
 | 326 | +    }  | 
 | 327 | + | 
 | 328 | +    def replacement_fn(var, replacements):  | 
 | 329 | +        if var in measurable_vars_to_unmeasurable_vars:  | 
 | 330 | +            replacements[var] = measurable_vars_to_unmeasurable_vars[var]  | 
 | 331 | +        # We don't want to clone valued nodes. Assigning a var to itself in the  | 
 | 332 | +        # replacements prevents this  | 
 | 333 | +        elif var in rvs_to_values:  | 
 | 334 | +            replacements[var] = var  | 
 | 335 | + | 
 | 336 | +        return []  | 
 | 337 | + | 
 | 338 | +    unmeasurable_vars, _ = _replace_rvs_in_graphs(graphs=vars, replacement_fn=replacement_fn)  | 
 | 339 | +    return unmeasurable_vars  | 
0 commit comments