diff --git a/pymc/logprob/transforms.py b/pymc/logprob/transforms.py index 82aac0fa88..46b3a163c4 100644 --- a/pymc/logprob/transforms.py +++ b/pymc/logprob/transforms.py @@ -374,12 +374,15 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa else: input_logprob = logprob(measurable_input, backward_value) - if input_logprob.ndim < value.ndim: - # Do we just need to sum the jacobian terms across the support dims? - raise NotImplementedError("Transform of multivariate RVs not implemented") - jacobian = op.transform_elemwise.log_jac_det(value, *other_inputs) + if input_logprob.ndim < value.ndim: + # For multivariate variables, the Jacobian is diagonal. + # We can get the right result by summing the last dimensions + # of `transform_elemwise.log_jac_det` + ndim_supp = value.ndim - input_logprob.ndim + jacobian = jacobian.sum(axis=tuple(range(-ndim_supp, 0))) + # The jacobian is used to ensure a value in the supported domain was provided return pt.switch(pt.isnan(jacobian), -np.inf, input_logprob + jacobian) @@ -674,7 +677,7 @@ def backward(self, value, *inputs): def log_jac_det(self, value, *inputs): scale = self.transform_args_fn(*inputs) - return -pt.log(pt.abs(scale)) + return -pt.log(pt.abs(pt.broadcast_to(scale, value.shape))) class LogTransform(RVTransform): @@ -892,7 +895,8 @@ def log_jac_det(self, value, *inputs): det = 0.0 for det_ in det_list: if det_.ndim > ndim0: - det += det_.sum(axis=-1) + ndim_diff = det_.ndim - ndim0 + det += det_.sum(axis=tuple(range(-ndim_diff, 0))) else: det += det_ return det diff --git a/tests/logprob/test_transforms.py b/tests/logprob/test_transforms.py index 29c8dc0ea7..5e224afabe 100644 --- a/tests/logprob/test_transforms.py +++ b/tests/logprob/test_transforms.py @@ -626,8 +626,8 @@ def test_chained_transform(): log_jac_det = ch.log_jac_det(x_val_forward, *x.owner.inputs, scale, loc) assert np.isclose( - log_jac_det.eval(), - -np.log(scale) - np.sum(np.log(x_val_forward - loc)), + pt.sum(log_jac_det).eval(), + np.sum(-np.log(scale) - np.log(x_val_forward - loc)), ) @@ -964,3 +964,28 @@ def scan_step(prev_innov): "innov": np.full((4,), -0.5), } np.testing.assert_allclose(logp_fn(**test_point), ref_logp_fn(**test_point)) + + +@pytest.mark.parametrize("shift", [1.5, np.array([-0.5, 1, 0.3])]) +@pytest.mark.parametrize("scale", [2.0, np.array([1.5, 3.3, 1.0])]) +def test_multivariate_transform(shift, scale): + mu = np.array([0, 0.9, -2.1]) + cov = np.array([[1, 0, 0.9], [0, 1, 0], [0.9, 0, 1]]) + x_rv_raw = pt.random.multivariate_normal(mu, cov=cov) + x_rv = shift + x_rv_raw * scale + x_rv.name = "x" + + x_vv = x_rv.clone() + logp = factorized_joint_logprob({x_rv: x_vv})[x_vv] + assert_no_rvs(logp) + + x_vv_test = np.array([5.0, 4.9, -6.3]) + scale_mat = scale * np.eye(x_vv_test.shape[0]) + np.testing.assert_almost_equal( + logp.eval({x_vv: x_vv_test}), + sp.stats.multivariate_normal.logpdf( + x_vv_test, + shift + mu * scale, + scale_mat @ cov @ scale_mat.T, + ), + )