Skip to content

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 24, 2023

This PR extends the number of RandomVariables supported in the JAX backend, with two strategies:

  1. Rewrites that convert RVS to equivalent expressions that use RVs supported natively by JAX
  2. Optional dependency on NumPyro routines for more complex expressions

The rewrites from 1. may be reused with other backends later (e.g, Numba). This seems like a cleaner solution than implementing the same logic in the different dispatched functions of each backend.

)
assert test_res.pvalue > 0.1
assert not np.isnan(test_res.statistic)
assert test_res.pvalue > 0.01
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The old criteria was a bit too loose. With two tests per RV it had a 81% under the null hypothesis

@ricardoV94 ricardoV94 force-pushed the jax_rvs branch 3 times, most recently from 8853a41 to ad3b16a Compare January 24, 2023 17:37
@codecov-commenter
Copy link

Codecov Report

Merging #200 (caebeff) into main (8051ffb) will increase coverage by 0.11%.
The diff coverage is 93.33%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #200      +/-   ##
==========================================
+ Coverage   79.95%   80.06%   +0.11%     
==========================================
  Files         170      170              
  Lines       44856    45185     +329     
  Branches     9498     9608     +110     
==========================================
+ Hits        35863    36176     +313     
- Misses       6780     6794      +14     
- Partials     2213     2215       +2     
Impacted Files Coverage Δ
pytensor/link/jax/dispatch/random.py 96.07% <81.81%> (-3.93%) ⬇️
pytensor/link/jax/dispatch/extra_ops.py 95.77% <100.00%> (+0.25%) ⬆️
pytensor/tensor/random/rewriting/jax.py 100.00% <100.00%> (ø)
pytensor/tensor/nlinalg.py 97.84% <0.00%> (-0.73%) ⬇️
pytensor/link/c/cmodule.py 51.54% <0.00%> (-0.34%) ⬇️
pytensor/scalar/math.py 85.00% <0.00%> (-0.30%) ⬇️
pytensor/graph/rewriting/basic.py 64.36% <0.00%> (-0.14%) ⬇️
pytensor/tensor/inplace.py 100.00% <0.00%> (ø)
pytensor/compile/compiledir.py 0.00% <0.00%> (ø)
pytensor/link/numba/dispatch/nlinalg.py 100.00% <0.00%> (ø)
... and 33 more

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, but don't trust my judgment on these distribution rewrites that I'm unfamiliar with:

  • GeometricRV
  • NegBinomialRV
  • ChiSquaredRV
  • GenGammaRV
  • WaldRV
  • BetaBinomialRV

@ricardoV94
Copy link
Member Author

LGTM, but don't trust my judgment on these distribution rewrites that I'm unfamiliar with:

* GeometricRV

* NegBinomialRV

* ChiSquaredRV

* GenGammaRV

* WaldRV

* BetaBinomialRV

Fair enough. Those are just straightforward equivalences one can find on Wikipedia / direct translations from the Scipy and Numpy implementations.

The samples are tested against the scipy CDF for the continuous distributions and I tested the first two moments (meand and std) of the discrete distributions against the moment references found on Wikipedia

@ricardoV94 ricardoV94 merged commit 51210c3 into pymc-devs:main Jan 26, 2023
@ricardoV94 ricardoV94 changed the title Extend support of RandomVariables in JAX backend Extend support of RandomVariables in JAX backend Feb 15, 2023
@ricardoV94 ricardoV94 deleted the jax_rvs branch June 21, 2023 08:56
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants