|
39 | 39 | from collections import deque
|
40 | 40 | from typing import Dict, List, Optional, Sequence, Union
|
41 | 41 |
|
| 42 | +import numpy as np |
42 | 43 | import pytensor
|
43 | 44 | import pytensor.tensor as pt
|
44 | 45 |
|
45 | 46 | from pytensor import config
|
46 |
| -from pytensor.graph.basic import graph_inputs, io_toposort |
| 47 | +from pytensor.graph.basic import Variable, graph_inputs, io_toposort |
47 | 48 | from pytensor.graph.op import compute_test_value
|
48 | 49 | from pytensor.graph.rewriting.basic import GraphRewriter, NodeRewriter
|
49 | 50 | from pytensor.tensor.random.op import RandomVariable
|
50 | 51 | from pytensor.tensor.var import TensorVariable
|
51 |
| - |
52 |
| -from pymc.logprob.abstract import _logprob, get_measurable_outputs |
53 |
| -from pymc.logprob.abstract import logprob as logp_logprob |
| 52 | +from typing_extensions import TypeAlias |
| 53 | + |
| 54 | +from pymc.logprob.abstract import ( |
| 55 | + _icdf_helper, |
| 56 | + _logcdf_helper, |
| 57 | + _logprob, |
| 58 | + _logprob_helper, |
| 59 | + get_measurable_outputs, |
| 60 | +) |
54 | 61 | from pymc.logprob.rewriting import construct_ir_fgraph
|
55 | 62 | from pymc.logprob.transforms import RVTransform, TransformValuesRewrite
|
56 | 63 | from pymc.logprob.utils import rvs_to_value_vars
|
57 | 64 |
|
| 65 | +TensorLike: TypeAlias = Union[Variable, float, np.ndarray] |
| 66 | + |
58 | 67 |
|
59 |
| -def logp(rv: TensorVariable, value: TensorVariable, **kwargs) -> TensorVariable: |
| 68 | +def logp(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
60 | 69 | """Return the log-probability graph of a Random Variable"""
|
61 | 70 |
|
62 | 71 | value = pt.as_tensor_variable(value, dtype=rv.dtype)
|
63 | 72 | try:
|
64 |
| - return logp_logprob(rv, value, **kwargs) |
| 73 | + return _logprob_helper(rv, value, **kwargs) |
65 | 74 | except NotImplementedError:
|
66 | 75 | fgraph, _, _ = construct_ir_fgraph({rv: value})
|
67 | 76 | [(ir_rv, ir_value)] = fgraph.preserve_rv_mappings.rv_values.items()
|
68 |
| - return logp_logprob(ir_rv, ir_value, **kwargs) |
| 77 | + return _logprob_helper(ir_rv, ir_value, **kwargs) |
| 78 | + |
| 79 | + |
| 80 | +def logcdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
| 81 | + """Create a graph for the log-CDF of a Random Variable.""" |
| 82 | + value = pt.as_tensor_variable(value, dtype=rv.dtype) |
| 83 | + return _logcdf_helper(rv, value, **kwargs) |
| 84 | + |
| 85 | + |
| 86 | +def icdf(rv: TensorVariable, value: TensorLike, **kwargs) -> TensorVariable: |
| 87 | + """Create a graph for the inverse CDF of a Random Variable.""" |
| 88 | + value = pt.as_tensor_variable(value) |
| 89 | + return _icdf_helper(rv, value, **kwargs) |
69 | 90 |
|
70 | 91 |
|
71 | 92 | def factorized_joint_logprob(
|
|
0 commit comments