1818from collections .abc import Callable , Sequence
1919from datetime import datetime
2020from functools import partial
21+ from types import ModuleType
2122from typing import Any , Literal
2223
2324import arviz as az
2829
2930from arviz .data .base import make_attrs
3031from jax .lax import scan
32+ from numpy .typing import ArrayLike
3133from pytensor .compile import SharedVariable , Supervisor , mode
3234from pytensor .graph .basic import graph_inputs
3335from pytensor .graph .fg import FunctionGraph
@@ -120,7 +122,7 @@ def _replace_shared_variables(graph: list[TensorVariable]) -> list[TensorVariabl
120122def get_jaxified_graph (
121123 inputs : list [TensorVariable ] | None = None ,
122124 outputs : list [TensorVariable ] | None = None ,
123- ) -> list [TensorVariable ]:
125+ ) -> Callable [[ list [TensorVariable ]], list [ TensorVariable ] ]:
124126 """Compile a PyTensor graph into an optimized JAX function."""
125127 graph = _replace_shared_variables (outputs ) if outputs is not None else None
126128
@@ -143,13 +145,13 @@ def get_jaxified_graph(
143145 return jax_funcify (fgraph )
144146
145147
146- def get_jaxified_logp (model : Model , negative_logp = True ) -> Callable :
148+ def get_jaxified_logp (model : Model , negative_logp : bool = True ) -> Callable [[ ArrayLike ], jax . Array ] :
147149 model_logp = model .logp ()
148150 if not negative_logp :
149151 model_logp = - model_logp
150152 logp_fn = get_jaxified_graph (inputs = model .value_vars , outputs = [model_logp ])
151153
152- def logp_fn_wrap (x ) :
154+ def logp_fn_wrap (x : ArrayLike ) -> jax . Array :
153155 return logp_fn (* x )[0 ]
154156
155157 return logp_fn_wrap
@@ -210,23 +212,43 @@ def _get_batched_jittered_initial_points(
210212 chains : int ,
211213 initvals : StartDict | Sequence [StartDict | None ] | None ,
212214 random_seed : RandomSeed ,
215+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
213216 jitter : bool = True ,
214217 jitter_max_retries : int = 10 ,
215218) -> np .ndarray | list [np .ndarray ]:
216- """Get jittered initial point in format expected by NumPyro MCMC kernel.
219+ """Get jittered initial point in format expected by Jax MCMC kernel.
220+
221+ Parameters
222+ ----------
223+ logp_fn : Callable[Sequence[np.ndarray]], np.ndarray]
224+ Jaxified logp function
217225
218226 Returns
219227 -------
220228 out: list of ndarrays
221229 list with one item per variable and number of chains as batch dimension.
222230 Each item has shape `(chains, *var.shape)`
223231 """
232+ if logp_fn is None :
233+ eval_logp_initial_point = None
234+
235+ else :
236+
237+ def eval_logp_initial_point (point : dict [str , np .ndarray ]) -> jax .Array :
238+ """Wrap logp_fn to conform to _init_jitter logic.
239+
240+ Wraps jaxified logp function to accept a dict of
241+ {model_variable: np.array} key:value pairs.
242+ """
243+ return logp_fn (point .values ())
244+
224245 initial_points = _init_jitter (
225246 model ,
226247 initvals ,
227248 seeds = _get_seeds_per_chain (random_seed , chains ),
228249 jitter = jitter ,
229250 jitter_max_retries = jitter_max_retries ,
251+ logp_fn = eval_logp_initial_point ,
230252 )
231253 initial_points_values = [list (initial_point .values ()) for initial_point in initial_points ]
232254 if chains == 1 :
@@ -235,7 +257,7 @@ def _get_batched_jittered_initial_points(
235257
236258
237259def _blackjax_inference_loop (
238- seed , init_position , logprob_fn , draws , tune , target_accept , ** adaptation_kwargs
260+ seed , init_position , logp_fn , draws , tune , target_accept , ** adaptation_kwargs
239261):
240262 import blackjax
241263
@@ -251,13 +273,13 @@ def _blackjax_inference_loop(
251273
252274 adapt = blackjax .window_adaptation (
253275 algorithm = algorithm ,
254- logdensity_fn = logprob_fn ,
276+ logdensity_fn = logp_fn ,
255277 target_acceptance_rate = target_accept ,
256278 adaptation_info_fn = get_filter_adapt_info_fn (),
257279 ** adaptation_kwargs ,
258280 )
259281 (last_state , tuned_params ), _ = adapt .run (seed , init_position , num_steps = tune )
260- kernel = algorithm (logprob_fn , ** tuned_params ).step
282+ kernel = algorithm (logp_fn , ** tuned_params ).step
261283
262284 def _one_step (state , xs ):
263285 _ , rng_key = xs
@@ -288,67 +310,51 @@ def _sample_blackjax_nuts(
288310 tune : int ,
289311 draws : int ,
290312 chains : int ,
291- chain_method : str | None ,
313+ chain_method : Literal [ "parallel" , "vectorized" ] ,
292314 progressbar : bool ,
293315 random_seed : int ,
294- initial_points ,
295- nuts_kwargs ,
296- ) -> az .InferenceData :
316+ initial_points : np .ndarray | list [np .ndarray ],
317+ nuts_kwargs : dict [str , Any ],
318+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
319+ ) -> tuple [Any , dict [str , Any ], ModuleType ]:
297320 """
298321 Draw samples from the posterior using the NUTS method from the ``blackjax`` library.
299322
300323 Parameters
301324 ----------
302- draws : int, default 1000
303- The number of samples to draw. The number of tuned samples are discarded by
304- default.
305- tune : int, default 1000
325+ model : Model
326+ Model to sample from. The model needs to have free random variables.
327+ target_accept : float in [0, 1].
328+ The step size is tuned such that we approximate this acceptance rate. Higher
329+ values like 0.9 or 0.95 often work better for problematic posteriors.
330+ tune : int
306331 Number of iterations to tune. Samplers adjust the step sizes, scalings or
307332 similar during tuning. Tuning samples will be drawn in addition to the number
308333 specified in the ``draws`` argument.
309- chains : int, default 4
334+ draws : int
335+ The number of samples to draw. The number of tuned samples are discarded by default.
336+ chains : int
310337 The number of chains to sample.
311- target_accept : float in [0, 1].
312- The step size is tuned such that we approximate this acceptance rate. Higher
313- values like 0.9 or 0.95 often work better for problematic posteriors.
314- random_seed : int, RandomState or Generator, optional
338+ chain_method : "parallel" or "vectorized"
339+ Specify how samples should be drawn.
340+ progressbar : bool
341+ Whether to show progressbar or not during sampling.
342+ random_seed : int, RandomState or Generator
315343 Random seed used by the sampling steps.
316- initvals: StartDict or Sequence[Optional[StartDict]], optional
317- Initial values for random variables provided as a dictionary (or sequence of
318- dictionaries) mapping the random variable (by name or reference) to desired
319- starting values.
320- jitter: bool, default True
321- If True, add jitter to initial points.
322- model : Model, optional
323- Model to sample from. The model needs to have free random variables. When inside
324- a ``with`` model context, it defaults to that model, otherwise the model must be
325- passed explicitly.
326- var_names : sequence of str, optional
327- Names of variables for which to compute the posterior samples. Defaults to all
328- variables in the posterior.
329- keep_untransformed : bool, default False
330- Include untransformed variables in the posterior samples. Defaults to False.
331- chain_method : str, default "parallel"
332- Specify how samples should be drawn. The choices include "parallel", and
333- "vectorized".
334- postprocessing_backend: Optional[Literal["cpu", "gpu"]], default None,
335- Specify how postprocessing should be computed. gpu or cpu
336- postprocessing_vectorize: Literal["vmap", "scan"], default "scan"
337- How to vectorize the postprocessing: vmap or sequential scan
338- idata_kwargs : dict, optional
339- Keyword arguments for :func:`arviz.from_dict`. It also accepts a boolean as
340- value for the ``log_likelihood`` key to indicate that the pointwise log
341- likelihood should not be included in the returned object. Values for
342- ``observed_data``, ``constant_data``, ``coords``, and ``dims`` are inferred from
343- the ``model`` argument if not provided in ``idata_kwargs``. If ``coords`` and
344- ``dims`` are provided, they are used to update the inferred dictionaries.
344+ initial_points : np.ndarray or list[np.ndarray]
345+ Initial point(s) for sampler to begin sampling from.
346+ nuts_kwargs : dict
347+ Keyword arguments for the blackjax nuts sampler
348+ logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
349+ jaxified logp function. If not passed in it will be created anew.
345350
346351 Returns
347352 -------
348- InferenceData
349- ArviZ ``InferenceData`` object that contains the posterior samples, together
350- with their respective sample stats and pointwise log likeihood values (unless
351- skipped with ``idata_kwargs``).
353+ raw_mcmc_samples
354+ Datastructure containing raw mcmc samples
355+ sample_stats : dict[str, Any]
356+ Dictionary containing sample stats
357+ blackjax : ModuleType["blackjax"]
352358 """
353359 import blackjax
354360
@@ -365,15 +371,16 @@ def _sample_blackjax_nuts(
365371 if chains == 1 :
366372 initial_points = [np .stack (init_state ) for init_state in zip (initial_points )]
367373
368- logprob_fn = get_jaxified_logp (model )
374+ if logp_fn is None :
375+ logp_fn = get_jaxified_logp (model )
369376
370377 seed = jax .random .PRNGKey (random_seed )
371378 keys = jax .random .split (seed , chains )
372379
373380 nuts_kwargs ["progress_bar" ] = progressbar
374381 get_posterior_samples = partial (
375382 _blackjax_inference_loop ,
376- logprob_fn = logprob_fn ,
383+ logp_fn = logp_fn ,
377384 tune = tune ,
378385 draws = draws ,
379386 target_accept = target_accept ,
@@ -385,7 +392,7 @@ def _sample_blackjax_nuts(
385392
386393
387394# Adopted from arviz numpyro extractor
388- def _numpyro_stats_to_dict (posterior ):
395+ def _numpyro_stats_to_dict (posterior ) -> dict [ str , Any ] :
389396 """Extract sample_stats from NumPyro posterior."""
390397 rename_key = {
391398 "potential_energy" : "lp" ,
@@ -411,17 +418,58 @@ def _sample_numpyro_nuts(
411418 tune : int ,
412419 draws : int ,
413420 chains : int ,
414- chain_method : str | None ,
421+ chain_method : Literal [ "parallel" , "vectorized" ] ,
415422 progressbar : bool ,
416423 random_seed : int ,
417- initial_points ,
424+ initial_points : np . ndarray | list [ np . ndarray ] ,
418425 nuts_kwargs : dict [str , Any ],
419- ):
426+ logp_fn : Callable [[ArrayLike ], jax .Array ] | None = None ,
427+ ) -> tuple [Any , dict [str , Any ], ModuleType ]:
428+ """
429+ Draw samples from the posterior using the NUTS method from the ``numpyro`` library.
430+
431+ Parameters
432+ ----------
433+ model : Model
434+ Model to sample from. The model needs to have free random variables.
435+ target_accept : float in [0, 1].
436+ The step size is tuned such that we approximate this acceptance rate. Higher
437+ values like 0.9 or 0.95 often work better for problematic posteriors.
438+ tune : int
439+ Number of iterations to tune. Samplers adjust the step sizes, scalings or
440+ similar during tuning. Tuning samples will be drawn in addition to the number
441+ specified in the ``draws`` argument.
442+ draws : int
443+ The number of samples to draw. The number of tuned samples are discarded by default.
444+ chains : int
445+ The number of chains to sample.
446+ chain_method : "parallel" or "vectorized"
447+ Specify how samples should be drawn.
448+ progressbar : bool
449+ Whether to show progressbar or not during sampling.
450+ random_seed : int, RandomState or Generator
451+ Random seed used by the sampling steps.
452+ initial_points : np.ndarray or list[np.ndarray]
453+ Initial point(s) for sampler to begin sampling from.
454+ nuts_kwargs : dict
455+ Keyword arguments for the underlying numpyro nuts sampler
456+ logp_fn : Callable[[ArrayLike], jax.Array], optional, default None
457+ jaxified logp function. If not passed in it will be created anew.
458+
459+ Returns
460+ -------
461+ raw_mcmc_samples
462+ Datastructure containing raw mcmc samples
463+ sample_stats : dict[str, Any]
464+ Dictionary containing sample stats
465+ numpyro : ModuleType["numpyro"]
466+ """
420467 import numpyro
421468
422469 from numpyro .infer import MCMC , NUTS
423470
424- logp_fn = get_jaxified_logp (model , negative_logp = False )
471+ if logp_fn is None :
472+ logp_fn = get_jaxified_logp (model , negative_logp = False )
425473
426474 nuts_kwargs .setdefault ("adapt_step_size" , True )
427475 nuts_kwargs .setdefault ("adapt_mass_matrix" , True )
@@ -479,7 +527,7 @@ def sample_jax_nuts(
479527 nuts_kwargs : dict | None = None ,
480528 progressbar : bool = True ,
481529 keep_untransformed : bool = False ,
482- chain_method : str = "parallel" ,
530+ chain_method : Literal [ "parallel" , "vectorized" ] = "parallel" ,
483531 postprocessing_backend : Literal ["cpu" , "gpu" ] | None = None ,
484532 postprocessing_vectorize : Literal ["vmap" , "scan" ] | None = None ,
485533 postprocessing_chunks = None ,
@@ -525,7 +573,7 @@ def sample_jax_nuts(
525573 If True, display a progressbar while sampling
526574 keep_untransformed : bool, default False
527575 Include untransformed variables in the posterior samples.
528- chain_method : str , default "parallel"
576+ chain_method : Literal["parallel", "vectorized"] , default "parallel"
529577 Specify how samples should be drawn. The choices include "parallel", and
530578 "vectorized".
531579 postprocessing_backend : Optional[Literal["cpu", "gpu"]], default None,
@@ -589,6 +637,15 @@ def sample_jax_nuts(
589637 get_default_varnames (filtered_var_names , include_transformed = keep_untransformed )
590638 )
591639
640+ if nuts_sampler == "numpyro" :
641+ sampler_fn = _sample_numpyro_nuts
642+ logp_fn = get_jaxified_logp (model , negative_logp = False )
643+ elif nuts_sampler == "blackjax" :
644+ sampler_fn = _sample_blackjax_nuts
645+ logp_fn = get_jaxified_logp (model )
646+ else :
647+ raise ValueError (f"{ nuts_sampler = } not recognized" )
648+
592649 (random_seed ,) = _get_seeds_per_chain (random_seed , 1 )
593650
594651 initial_points = _get_batched_jittered_initial_points (
@@ -597,15 +654,9 @@ def sample_jax_nuts(
597654 initvals = initvals ,
598655 random_seed = random_seed ,
599656 jitter = jitter ,
657+ logp_fn = logp_fn ,
600658 )
601659
602- if nuts_sampler == "numpyro" :
603- sampler_fn = _sample_numpyro_nuts
604- elif nuts_sampler == "blackjax" :
605- sampler_fn = _sample_blackjax_nuts
606- else :
607- raise ValueError (f"{ nuts_sampler = } not recognized" )
608-
609660 tic1 = datetime .now ()
610661 raw_mcmc_samples , sample_stats , library = sampler_fn (
611662 model = model ,
0 commit comments