Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 12, 2025

This PR speedups the python (default) implementation of multivariate_normal, by at least 10x - 100x (the latter in the case of batch parameters).

It also allows specifying the method of decomposition, which defaults to cholesky for performance. This is used also in the JAX dispatch, but not the Numba impl (because we don't have those yet right @jessegrabowski ?)

Also removed the dumb defaults, which closes #833

import pytensor
import pytensor.tensor as pt
import numpy as np

rv = pt.random.multivariate_normal([0, 0, 0], cov=np.eye(3))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 335 μs ± 78.5 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
# After PR:
# 32 μs ± 3.32 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Compared to numpy:

zeros = np.zeros(3)
eye = np.eye(3)
rng = np.random.default_rng()

# Default method uses SVD, so unsurprisingly slower
%timeit rng.multivariate_normal(zeros, eye)
# 90.7 μs ± 140 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

%timeit rng.multivariate_normal(zeros, eye, method="cholesky")
# 19.1 μs ± 240 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

With batch mean (numpy doesn't support it):

rv = pt.random.multivariate_normal(np.random.normal(size=(100, 3)), cov=np.eye(3))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 54.1 ms ± 3.88 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# After PR:
# 42.5 μs ± 3.29 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

With batch covariance - a trivial one though (numpy doesn't support it):

rv = pt.random.multivariate_normal(np.zeros(3), cov=np.broadcast_to(np.eye(3), (100, 3, 3)))
rng = rv.owner.inputs[0]
next_rng = rv.owner.outputs[0]

fn = pytensor.function([], rv, updates={rng: next_rng})

%timeit fn()
# Before PR:
# 30.1 ms ± 4.39 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
# After PR:
# 60.2 μs ± 763 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In the long term we should still go with a symbolic graph as discussed in #1115 so that other rewrites can happen on top of the graph, such as avoiding a useless cholesky, if the covariance is built symbolically from a cholesky to begin with (as in most PyMC models).

However that requires some coordination with PyMC as that object couldn't be a RandomVariable Op anymore. Also some nice rewrites we have now may not work with the symbolic representation.


📚 Documentation preview 📚: https://pytensor--1203.org.readthedocs.build/en/1203/

Comment on lines -899 to -902
if mean is None:
mean = np.array([0.0], dtype=dtype)
if cov is None:
cov = np.array([[1.0]], dtype=dtype)
Copy link
Member Author

@ricardoV94 ricardoV94 Feb 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These were dumb defaults, just removed them: #833

It made sense when both were None perhaps, but not just one of them. Anyway, numpy doesn't provide defaults either. In PyMC we do, because we are not trying to mimick numpy API there.

@jessegrabowski
Copy link
Member

Numba supports SVD via np.linalg.svd but you're not allowed to set compute_uv = False. We have to set it to True for this application, so I think it should work. Cholesky and eig is also supported.

@ricardoV94
Copy link
Member Author

Numba supports SVD via np.linalg.svd but you're not allowed to set compute_uv = False. We have to set it to True for this application, so I think it should work. Cholesky and eig is also supported.

I thought it didn't my bad. Numba now also supports the different modes

@ricardoV94 ricardoV94 changed the title Faster python implementation of multivariate_normal Speedup implementation of multivariate_normal and allow method of covariance decomposition Feb 12, 2025
@ricardoV94 ricardoV94 force-pushed the speedup_mvnormal branch 2 times, most recently from e295c66 to ceecfb0 Compare February 12, 2025 16:58
Copy link

codecov bot commented Feb 12, 2025

Codecov Report

Attention: Patch coverage is 74.00000% with 13 lines in your changes missing coverage. Please review.

Project coverage is 82.25%. Comparing base (7411a08) to head (e2fb8d1).
Report is 178 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/numba/dispatch/random.py 8.33% 11 Missing ⚠️
pytensor/tensor/random/basic.py 93.10% 1 Missing and 1 partial ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1203      +/-   ##
==========================================
- Coverage   82.26%   82.25%   -0.02%     
==========================================
  Files         186      186              
  Lines       47962    47981      +19     
  Branches     8630     8630              
==========================================
+ Hits        39456    39465       +9     
- Misses       6347     6356       +9     
- Partials     2159     2160       +1     
Files with missing lines Coverage Δ
pytensor/link/jax/dispatch/random.py 93.70% <100.00%> (+0.20%) ⬆️
pytensor/tensor/random/basic.py 98.84% <93.10%> (-0.39%) ⬇️
pytensor/link/numba/dispatch/random.py 57.20% <8.33%> (-1.78%) ⬇️
🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great, just one random docstring looked wrong.

@ricardoV94 ricardoV94 merged commit 2aecb95 into pymc-devs:main Feb 13, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Default MvNormal covariance doesn't make sense

2 participants