5555
5656import pymc as pm
5757
58- from pymc .logprob .basic import factorized_joint_logprob , icdf , joint_logp , logcdf , logp
58+ from pymc .logprob .basic import (
59+ conditional_logp ,
60+ icdf ,
61+ logcdf ,
62+ logp ,
63+ transformed_conditional_logp ,
64+ )
5965from pymc .logprob .transforms import LogTransform
6066from pymc .logprob .utils import rvs_to_value_vars , walk_model
6167from pymc .pytensorf import replace_rvs_by_values
@@ -68,7 +74,7 @@ def test_factorized_joint_logprob_basic():
6874 a .name = "a"
6975 a_value_var = a .clone ()
7076
71- a_logp = factorized_joint_logprob ({a : a_value_var })
77+ a_logp = conditional_logp ({a : a_value_var })
7278 a_logp_comb = tuple (a_logp .values ())[0 ]
7379 a_logp_exp = logp (a , a_value_var )
7480
@@ -81,7 +87,7 @@ def test_factorized_joint_logprob_basic():
8187 sigma_value_var = sigma .clone ()
8288 y_value_var = Y .clone ()
8389
84- total_ll = factorized_joint_logprob ({Y : y_value_var , sigma : sigma_value_var })
90+ total_ll = conditional_logp ({Y : y_value_var , sigma : sigma_value_var })
8591 total_ll_combined = pt .add (* total_ll .values ())
8692
8793 # We need to replace the reference to `sigma` in `Y` with its value
@@ -106,7 +112,7 @@ def test_factorized_joint_logprob_basic():
106112 b_value_var = b .clone ()
107113 c_value_var = c .clone ()
108114
109- b_logp = factorized_joint_logprob ({a : a_value_var , b : b_value_var , c : c_value_var })
115+ b_logp = conditional_logp ({a : a_value_var , b : b_value_var , c : c_value_var })
110116 b_logp_combined = pt .sum ([pt .sum (factor ) for factor in b_logp .values ()])
111117
112118 # There shouldn't be any `RandomVariable`s in the resulting graph
@@ -125,7 +131,7 @@ def test_factorized_joint_logprob_multi_obs():
125131 a_val = a .clone ()
126132 b_val = b .clone ()
127133
128- logp_res = factorized_joint_logprob ({a : a_val , b : b_val })
134+ logp_res = conditional_logp ({a : a_val , b : b_val })
129135 logp_res_combined = pt .add (* logp_res .values ())
130136 logp_exp = logp (a , a_val ) + logp (b , b_val )
131137
@@ -137,8 +143,8 @@ def test_factorized_joint_logprob_multi_obs():
137143 x_val = x .clone ()
138144 y_val = y .clone ()
139145
140- logp_res = factorized_joint_logprob ({x : x_val , y : y_val })
141- exp_logp = factorized_joint_logprob ({x : x_val , y : y_val })
146+ logp_res = conditional_logp ({x : x_val , y : y_val })
147+ exp_logp = conditional_logp ({x : x_val , y : y_val })
142148 logp_res_comb = pt .sum ([pt .sum (factor ) for factor in logp_res .values ()])
143149 exp_logp_comb = pt .sum ([pt .sum (factor ) for factor in exp_logp .values ()])
144150
@@ -155,7 +161,7 @@ def test_factorized_joint_logprob_diff_dims():
155161 y_vv = y .clone ()
156162 y_vv .name = "y"
157163
158- logp = factorized_joint_logprob ({x : x_vv , y : y_vv })
164+ logp = conditional_logp ({x : x_vv , y : y_vv })
159165 logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
160166
161167 M_val = np .random .normal (size = (10 , 3 ))
@@ -181,7 +187,7 @@ def test_incsubtensor_original_values_output_dict():
181187 rv = pt .set_subtensor (base_rv [0 ], 5 )
182188 vv = rv .clone ()
183189
184- logp_dict = factorized_joint_logprob ({rv : vv })
190+ logp_dict = conditional_logp ({rv : vv })
185191 assert vv in logp_dict
186192
187193
@@ -194,14 +200,14 @@ def test_persist_inputs():
194200 beta_vv = beta_rv .type ()
195201 y_vv = Y_rv .clone ()
196202
197- logp = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv })
203+ logp = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv })
198204 logp_combined = pt .sum ([pt .sum (factor ) for factor in logp .values ()])
199205
200206 assert x in ancestors ([logp_combined ])
201207
202208 # Make sure we don't clone value variables when they're graphs.
203209 y_vv_2 = y_vv * 2
204- logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
210+ logp_2 = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv_2 })
205211 logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
206212
207213 assert y_vv in ancestors ([logp_2_combined ])
@@ -210,7 +216,7 @@ def test_persist_inputs():
210216 # Even when they are random
211217 y_vv = pt .random .normal (name = "y_vv2" )
212218 y_vv_2 = y_vv * 2
213- logp_2 = factorized_joint_logprob ({beta_rv : beta_vv , Y_rv : y_vv_2 })
219+ logp_2 = conditional_logp ({beta_rv : beta_vv , Y_rv : y_vv_2 })
214220 logp_2_combined = pt .sum ([pt .sum (factor ) for factor in logp_2 .values ()])
215221
216222 assert y_vv in ancestors ([logp_2_combined ])
@@ -224,11 +230,11 @@ def test_warn_random_found_factorized_joint_logprob():
224230 y_vv = y_rv .clone ()
225231
226232 with pytest .warns (UserWarning , match = "Random variables detected in the logp graph: {x}" ):
227- factorized_joint_logprob ({y_rv : y_vv })
233+ conditional_logp ({y_rv : y_vv })
228234
229235 with warnings .catch_warnings ():
230236 warnings .simplefilter ("error" )
231- factorized_joint_logprob ({y_rv : y_vv }, warn_missing_rvs = False )
237+ conditional_logp ({y_rv : y_vv }, warn_missing_rvs = False )
232238
233239
234240def test_multiple_rvs_to_same_value_raises ():
@@ -237,9 +243,9 @@ def test_multiple_rvs_to_same_value_raises():
237243 x = x_rv1 .type ()
238244 x .name = "x"
239245
240- msg = "More than one logprob factor was assigned to the value var x"
246+ msg = "More than one logprob term was assigned to the value var x"
241247 with pytest .raises (ValueError , match = msg ):
242- factorized_joint_logprob ({x_rv1 : x , x_rv2 : x })
248+ conditional_logp ({x_rv1 : x , x_rv2 : x })
243249
244250
245251def test_joint_logp_basic ():
@@ -259,7 +265,7 @@ def test_joint_logp_basic():
259265
260266 c_value_var = m .rvs_to_values [c ]
261267
262- (b_logp ,) = joint_logp (
268+ (b_logp ,) = transformed_conditional_logp (
263269 (b ,),
264270 rvs_to_values = m .rvs_to_values ,
265271 rvs_to_transforms = m .rvs_to_transforms ,
@@ -304,7 +310,7 @@ def test_joint_logp_incsubtensor(indices, size):
304310 a_idx_value_var = a_idx .type ()
305311 a_idx_value_var .name = "a_idx_value"
306312
307- a_idx_logp = joint_logp (
313+ a_idx_logp = transformed_conditional_logp (
308314 (a_idx ,),
309315 rvs_to_values = {a_idx : a_value_var },
310316 rvs_to_transforms = {},
0 commit comments