This is a proof of concept only.
Wraps JAX's stateless random number generation in a class implementing the
numpy.random.Generator interface.
>>> import rng_jax
>>> rng = rng_jax.Generator(42)  # same arguments as jax.random.key()
>>> rng.standard_normal(3)
Array([-0.5675502 ,  0.28439185, -0.9320608 ], dtype=float32)
>>> rng.standard_normal(3)
Array([ 0.67903334, -1.220606  ,  0.94670606], dtype=float32)The Array API makes it possible to write array-agnostic Python
libraries. The rng-jax package makes it easy to extend this to random number
generation in NumPy and JAX. End users only need to provide a rng object, as
usual, which can either be a NumPy one or a rng_jax.Generator instance
wrapping JAX's stateless random number generation.
The rng_jax.Generator class works in the obvious way: it keeps track of the
JAX key and calls jax.random.split() before every random operation.
The problem with a stateful RNG is that it cannot be passed into a compiled JAX
function. In practice, this is not usually an issue, since the goal of this
package is to work in tandem with the Array API: array-agnostic code is not
usually compiled at low level. Conversely, native JAX code usually expects a
key, anyway, not a rng_jax.Generator instance.
To interface with a native JAX function expecting a key, use the .split()
method to obtain a new random key and advance the internal state of the
generator:
>>> import jax
>>> rng = rng_jax.Generator(42)
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([-0.5675502 ,  0.28439185, -0.9320608 ], dtype=float32)
>>> key = rng.split()
>>> jax.random.normal(key, 3)
Array([ 0.67903334, -1.220606  ,  0.94670606], dtype=float32)Using the rng_jax.Generator class fully within a compiled JAX function
works without issue.