|
62 | 62 | joint_logp, |
63 | 63 | ) |
64 | 64 | from pymc.logprob.utils import rvs_to_value_vars, walk_model |
65 | | -from pymc.tests.helpers import assert_no_rvs, select_by_precision |
| 65 | +from pymc.tests.helpers import assert_no_rvs |
66 | 66 | from pymc.tests.logprob.utils import joint_logprob |
67 | 67 |
|
68 | 68 |
|
@@ -409,64 +409,6 @@ def test_joint_logp_incsubtensor(indices, size): |
409 | 409 | np.testing.assert_almost_equal(logp_vals, exp_obs_logps) |
410 | 410 |
|
411 | 411 |
|
412 | | -def test_joint_logp_subtensor(): |
413 | | - """Make sure we can compute a log-likelihood for ``Y[I]`` where ``Y`` and ``I`` are random variables.""" |
414 | | - |
415 | | - size = 5 |
416 | | - |
417 | | - mu_base = pm.floatX(np.power(10, np.arange(np.prod(size)))).reshape(size) |
418 | | - mu = np.stack([mu_base, -mu_base]) |
419 | | - sigma = 0.001 |
420 | | - rng = pytensor.shared(np.random.RandomState(232), borrow=True) |
421 | | - |
422 | | - A_rv = pm.Normal.dist(mu, sigma, rng=rng) |
423 | | - A_rv.name = "A" |
424 | | - |
425 | | - p = 0.5 |
426 | | - |
427 | | - I_rv = pm.Bernoulli.dist(p, size=size, rng=rng) |
428 | | - I_rv.name = "I" |
429 | | - |
430 | | - A_idx = A_rv[I_rv, at.ogrid[A_rv.shape[-1] :]] |
431 | | - |
432 | | - assert isinstance(A_idx.owner.op, (Subtensor, AdvancedSubtensor, AdvancedSubtensor1)) |
433 | | - |
434 | | - A_idx_value_var = A_idx.type() |
435 | | - A_idx_value_var.name = "A_idx_value" |
436 | | - |
437 | | - I_value_var = I_rv.type() |
438 | | - I_value_var.name = "I_value" |
439 | | - |
440 | | - A_idx_logps = joint_logp( |
441 | | - (A_idx, I_rv), |
442 | | - rvs_to_values={A_idx: A_idx_value_var, I_rv: I_value_var}, |
443 | | - rvs_to_transforms={}, |
444 | | - rvs_to_total_sizes={}, |
445 | | - ) |
446 | | - A_idx_logp = at.add(*A_idx_logps) |
447 | | - |
448 | | - logp_vals_fn = pytensor.function([A_idx_value_var, I_value_var], A_idx_logp) |
449 | | - |
450 | | - # The compiled graph should not contain any `RandomVariables` |
451 | | - assert_no_rvs(logp_vals_fn.maker.fgraph.outputs[0]) |
452 | | - |
453 | | - decimals = select_by_precision(float64=6, float32=4) |
454 | | - |
455 | | - for i in range(10): |
456 | | - bern_sp = sp.bernoulli(p) |
457 | | - I_value = bern_sp.rvs(size=size).astype(I_rv.dtype) |
458 | | - |
459 | | - norm_sp = sp.norm(mu[I_value, np.ogrid[mu.shape[1] :]], sigma) |
460 | | - A_idx_value = norm_sp.rvs().astype(A_idx.dtype) |
461 | | - |
462 | | - exp_obs_logps = norm_sp.logpdf(A_idx_value) |
463 | | - exp_obs_logps += bern_sp.logpmf(I_value) |
464 | | - |
465 | | - logp_vals = logp_vals_fn(A_idx_value, I_value) |
466 | | - |
467 | | - np.testing.assert_almost_equal(logp_vals, exp_obs_logps, decimal=decimals) |
468 | | - |
469 | | - |
470 | 412 | def test_logp_helper(): |
471 | 413 | value = at.vector("value") |
472 | 414 | x = pm.Normal.dist(0, 1) |
|
0 commit comments