@@ -58,7 +58,7 @@ class IBaseTrace(ABC, Sized):
5858 varnames : List [str ]
5959 """Names of tracked variables."""
6060
61- sampler_vars : List [Dict [str , type ]]
61+ sampler_vars : List [Dict [str , Union [ type , np . dtype ] ]]
6262 """Sampler stats for each sampler."""
6363
6464 def __len__ (self ):
@@ -79,23 +79,27 @@ def get_values(self, varname: str, burn=0, thin=1) -> np.ndarray:
7979 """
8080 raise NotImplementedError ()
8181
82- def get_sampler_stats (self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1 ):
82+ def get_sampler_stats (
83+ self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1
84+ ) -> np .ndarray :
8385 """Get sampler statistics from the trace.
8486
8587 Parameters
8688 ----------
87- stat_name: str
88- sampler_idx: int or None
89- burn: int
90- thin: int
89+ stat_name : str
90+ Name of the stat to fetch.
91+ sampler_idx : int or None
92+ Index of the sampler to get the stat from.
93+ burn : int
94+ Draws to skip from the start.
95+ thin : int
96+ Stepsize for the slice.
9197
9298 Returns
9399 -------
94- If the `sampler_idx` is specified, return the statistic with
95- the given name in a numpy array. If it is not specified and there
96- is more than one sampler that provides this statistic, return
97- a numpy array of shape (m, n), where `m` is the number of
98- such samplers, and `n` is the number of samples.
100+ stats : np.ndarray
101+ If `sampler_idx` was specified, the shape should be `(draws,)`.
102+ Otherwise, the shape should be `(draws, samplers)`.
99103 """
100104 raise NotImplementedError ()
101105
@@ -220,23 +224,31 @@ def __getitem__(self, idx):
220224 except (ValueError , TypeError ): # Passed variable or variable name.
221225 raise ValueError ("Can only index with slice or integer" )
222226
223- def get_sampler_stats (self , stat_name , sampler_idx = None , burn = 0 , thin = 1 ):
227+ def get_sampler_stats (
228+ self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1
229+ ) -> np .ndarray :
224230 """Get sampler statistics from the trace.
225231
232+ Note: This implementation attempts to squeeze object arrays into a consistent dtype,
233+ # which can change their shape in hard-to-predict ways.
234+ # See https://github.com/pymc-devs/pymc/issues/6207
235+
226236 Parameters
227237 ----------
228- stat_name: str
229- sampler_idx: int or None
230- burn: int
231- thin: int
238+ stat_name : str
239+ Name of the stat to fetch.
240+ sampler_idx : int or None
241+ Index of the sampler to get the stat from.
242+ burn : int
243+ Draws to skip from the start.
244+ thin : int
245+ Stepsize for the slice.
232246
233247 Returns
234248 -------
235- If the `sampler_idx` is specified, return the statistic with
236- the given name in a numpy array. If it is not specified and there
237- is more than one sampler that provides this statistic, return
238- a numpy array of shape (m, n), where `m` is the number of
239- such samplers, and `n` is the number of samples.
249+ stats : np.ndarray
250+ If `sampler_idx` was specified, the shape should be `(draws,)`.
251+ Otherwise, the shape should be `(draws, samplers)`.
240252 """
241253 if sampler_idx is not None :
242254 return self ._get_sampler_stats (stat_name , sampler_idx , burn , thin )
@@ -254,14 +266,16 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
254266
255267 if vals .dtype == np .dtype (object ):
256268 try :
257- vals = np .vstack (vals )
269+ vals = np .vstack (list ( vals ) )
258270 except ValueError :
259271 # Most likely due to non-identical shapes. Just stick with the object-array.
260272 pass
261273
262274 return vals
263275
264- def _get_sampler_stats (self , stat_name , sampler_idx , burn , thin ):
276+ def _get_sampler_stats (
277+ self , stat_name : str , sampler_idx : int , burn : int , thin : int
278+ ) -> np .ndarray :
265279 """Get sampler statistics."""
266280 raise NotImplementedError ()
267281
@@ -476,23 +490,34 @@ def get_sampler_stats(
476490 combine : bool = True ,
477491 chains : Optional [Union [int , Sequence [int ]]] = None ,
478492 squeeze : bool = True ,
479- ):
493+ ) -> Union [ List [ np . ndarray ], np . ndarray ] :
480494 """Get sampler statistics from the trace.
481495
496+ Note: This implementation attempts to squeeze object arrays into a consistent dtype,
497+ # which can change their shape in hard-to-predict ways.
498+ # See https://github.com/pymc-devs/pymc/issues/6207
499+
482500 Parameters
483501 ----------
484- stat_name: str
485- sampler_idx: int or None
486- burn: int
487- thin: int
502+ stat_name : str
503+ Name of the stat to fetch.
504+ sampler_idx : int or None
505+ Index of the sampler to get the stat from.
506+ burn : int
507+ Draws to skip from the start.
508+ thin : int
509+ Stepsize for the slice.
510+ combine : bool
511+ If True, results from `chains` will be concatenated.
512+ squeeze : bool
513+ Return a single array element if the resulting list of
514+ values only has one element. If False, the result will
515+ always be a list of arrays, even if `combine` is True.
488516
489517 Returns
490518 -------
491- If the `sampler_idx` is specified, return the statistic with
492- the given name in a numpy array. If it is not specified and there
493- is more than one sampler that provides this statistic, return
494- a numpy array of shape (m, n), where `m` is the number of
495- such samplers, and `n` is the number of samples.
519+ stats : np.ndarray
520+ List or ndarray depending on parameters.
496521 """
497522 if stat_name not in self .stat_names :
498523 raise KeyError ("Unknown sampler statistic %s" % stat_name )
@@ -543,7 +568,7 @@ def points(self, chains=None):
543568 return itl .chain .from_iterable (self ._straces [chain ] for chain in chains )
544569
545570
546- def _squeeze_cat (results , combine , squeeze ):
571+ def _squeeze_cat (results , combine : bool , squeeze : bool ):
547572 """Squeeze and concatenate the results depending on values of
548573 `combine` and `squeeze`."""
549574 if combine :
0 commit comments