@@ -1756,6 +1756,7 @@ class TestMatrixNormal(BaseTestDistributionRandom):
17561756 "check_pymc_params_match_rv_op" ,
17571757 "check_draws" ,
17581758 "check_errors" ,
1759+ "check_random_variable_prior" ,
17591760 ]
17601761
17611762 def check_draws (self ):
@@ -1824,6 +1825,28 @@ def check_errors(self):
18241825 shape = 15 ,
18251826 )
18261827
1828+ def check_random_variable_prior (self ):
1829+ """
1830+ This test checks for shape correctness when using MatrixNormal distribution
1831+ with parameters as random variables.
1832+ Originally reported - https://github.com/pymc-devs/pymc/issues/3585
1833+ """
1834+ K = 3
1835+ D = 15
1836+ mu_0 = np .zeros ((D , K ))
1837+ lambd = 1.0
1838+ with pm .Model () as model :
1839+ sd_dist = pm .HalfCauchy .dist (beta = 2.5 , size = D )
1840+ packedL = pm .LKJCholeskyCov ("packedL" , eta = 2 , n = D , sd_dist = sd_dist , compute_corr = False )
1841+ L = pm .expand_packed_triangular (D , packedL , lower = True )
1842+ Sigma = pm .Deterministic ("Sigma" , L .dot (L .T )) # D x D covariance
1843+ mu = pm .MatrixNormal (
1844+ "mu" , mu = mu_0 , rowcov = (1 / lambd ) * Sigma , colcov = np .eye (K ), shape = (D , K )
1845+ )
1846+ prior = pm .sample_prior_predictive (2 , return_inferencedata = False )
1847+
1848+ assert prior ["mu" ].shape == (2 , D , K )
1849+
18271850
18281851class TestInterpolated (BaseTestDistributionRandom ):
18291852 def interpolated_rng_fn (self , size , mu , sigma , rng ):
@@ -2435,30 +2458,6 @@ def generate_shapes(include_params=False):
24352458 return data
24362459
24372460
2438- @pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
2439- def test_matrix_normal_random_with_random_variables ():
2440- """
2441- This test checks for shape correctness when using MatrixNormal distribution
2442- with parameters as random variables.
2443- Originally reported - https://github.com/pymc-devs/pymc/issues/3585
2444- """
2445- K = 3
2446- D = 15
2447- mu_0 = np .zeros ((D , K ))
2448- lambd = 1.0
2449- with pm .Model () as model :
2450- sd_dist = pm .HalfCauchy .dist (beta = 2.5 )
2451- packedL = pm .LKJCholeskyCov ("packedL" , eta = 2 , n = D , sd_dist = sd_dist )
2452- L = pm .expand_packed_triangular (D , packedL , lower = True )
2453- Sigma = pm .Deterministic ("Sigma" , L .dot (L .T )) # D x D covariance
2454- mu = pm .MatrixNormal (
2455- "mu" , mu = mu_0 , rowcov = (1 / lambd ) * Sigma , colcov = np .eye (K ), shape = (D , K )
2456- )
2457- prior = pm .sample_prior_predictive (2 )
2458-
2459- assert prior ["mu" ].shape == (2 , D , K )
2460-
2461-
24622461@pytest .mark .xfail (reason = "This distribution has not been refactored for v4" )
24632462class TestMvGaussianRandomWalk (SeededTest ):
24642463 @pytest .mark .parametrize (
0 commit comments