2222
2323from abc import ABC
2424from typing import (
25+ Any ,
2526 Dict ,
2627 List ,
28+ Mapping ,
2729 Optional ,
2830 Sequence ,
2931 Set ,
@@ -47,7 +49,87 @@ class BackendError(Exception):
4749 pass
4850
4951
50- class BaseTrace (ABC ):
52+ class IBaseTrace (ABC , Sized ):
53+ """Minimal interface needed to record and access draws and stats for one MCMC chain."""
54+
55+ chain : int
56+ """Chain number."""
57+
58+ varnames : List [str ]
59+ """Names of tracked variables."""
60+
61+ sampler_vars : List [Dict [str , type ]]
62+ """Sampler stats for each sampler."""
63+
64+ def __len__ (self ):
65+ raise NotImplementedError ()
66+
67+ def get_values (self , varname : str , burn = 0 , thin = 1 ) -> np .ndarray :
68+ """Get values from trace.
69+
70+ Parameters
71+ ----------
72+ varname: str
73+ burn: int
74+ thin: int
75+
76+ Returns
77+ -------
78+ A NumPy array
79+ """
80+ raise NotImplementedError ()
81+
82+ def get_sampler_stats (self , stat_name : str , sampler_idx : Optional [int ] = None , burn = 0 , thin = 1 ):
83+ """Get sampler statistics from the trace.
84+
85+ Parameters
86+ ----------
87+ stat_name: str
88+ sampler_idx: int or None
89+ burn: int
90+ thin: int
91+
92+ Returns
93+ -------
94+ If the `sampler_idx` is specified, return the statistic with
95+ the given name in a numpy array. If it is not specified and there
96+ is more than one sampler that provides this statistic, return
97+ a numpy array of shape (m, n), where `m` is the number of
98+ such samplers, and `n` is the number of samples.
99+ """
100+ raise NotImplementedError ()
101+
102+ def _slice (self , idx : slice ) -> "IBaseTrace" :
103+ """Slice trace object."""
104+ raise NotImplementedError ()
105+
106+ def point (self , idx : int ) -> Dict [str , np .ndarray ]:
107+ """Return dictionary of point values at `idx` for current chain
108+ with variables names as keys.
109+ """
110+ raise NotImplementedError ()
111+
112+ def record (self , draw : Mapping [str , np .ndarray ], stats : Sequence [Mapping [str , Any ]]):
113+ """Record results of a sampling iteration.
114+
115+ Parameters
116+ ----------
117+ draw: dict
118+ Values mapped to variable names
119+ stats: list of dicts
120+ The diagnostic values for each sampler
121+ """
122+ raise NotImplementedError ()
123+
124+ def close (self ):
125+ """Close the backend.
126+
127+ This is called after sampling has finished.
128+ """
129+ pass
130+
131+
132+ class BaseTrace (IBaseTrace ):
51133 """Base trace object
52134
53135 Parameters
@@ -127,25 +209,6 @@ def setup(self, draws, chain, sampler_vars=None) -> None:
127209 self ._set_sampler_vars (sampler_vars )
128210 self ._is_base_setup = True
129211
130- def record (self , point , sampler_states = None ):
131- """Record results of a sampling iteration.
132-
133- Parameters
134- ----------
135- point: dict
136- Values mapped to variable names
137- sampler_states: list of dicts
138- The diagnostic values for each sampler
139- """
140- raise NotImplementedError
141-
142- def close (self ):
143- """Close the database backend.
144-
145- This is called after sampling has finished.
146- """
147- pass
148-
149212 # Selection methods
150213
151214 def __getitem__ (self , idx ):
@@ -157,24 +220,6 @@ def __getitem__(self, idx):
157220 except (ValueError , TypeError ): # Passed variable or variable name.
158221 raise ValueError ("Can only index with slice or integer" )
159222
160- def __len__ (self ):
161- raise NotImplementedError
162-
163- def get_values (self , varname , burn = 0 , thin = 1 ):
164- """Get values from trace.
165-
166- Parameters
167- ----------
168- varname: str
169- burn: int
170- thin: int
171-
172- Returns
173- -------
174- A NumPy array
175- """
176- raise NotImplementedError
177-
178223 def get_sampler_stats (self , stat_name , sampler_idx = None , burn = 0 , thin = 1 ):
179224 """Get sampler statistics from the trace.
180225
@@ -220,19 +265,9 @@ def _get_sampler_stats(self, stat_name, sampler_idx, burn, thin):
220265 """Get sampler statistics."""
221266 raise NotImplementedError ()
222267
223- def _slice (self , idx : Union [int , slice ]):
224- """Slice trace object."""
225- raise NotImplementedError ()
226-
227- def point (self , idx : int ) -> Dict [str , np .ndarray ]:
228- """Return dictionary of point values at `idx` for current chain
229- with variables names as keys.
230- """
231- raise NotImplementedError ()
232-
233268 @property
234269 def stat_names (self ) -> Set [str ]:
235- names = set ()
270+ names : Set [ str ] = set ()
236271 for vars in self .sampler_vars or []:
237272 names .update (vars .keys ())
238273
@@ -290,7 +325,7 @@ class MultiTrace:
290325 List of variable names in the trace(s)
291326 """
292327
293- def __init__ (self , straces : Sequence [BaseTrace ]):
328+ def __init__ (self , straces : Sequence [IBaseTrace ]):
294329 if len ({t .chain for t in straces }) != len (straces ):
295330 raise ValueError ("Chains are not unique." )
296331 self ._straces = {t .chain : t for t in straces }
@@ -386,7 +421,7 @@ def stat_names(self) -> Set[str]:
386421 sampler_vars = [s .sampler_vars for s in self ._straces .values ()]
387422 if not all (svars == sampler_vars [0 ] for svars in sampler_vars ):
388423 raise ValueError ("Inividual chains contain different sampler stats" )
389- names = set ()
424+ names : Set [ str ] = set ()
390425 for trace in self ._straces .values ():
391426 if trace .sampler_vars is None :
392427 continue
@@ -472,7 +507,7 @@ def get_sampler_stats(
472507 ]
473508 return _squeeze_cat (results , combine , squeeze )
474509
475- def _slice (self , slice ):
510+ def _slice (self , slice : slice ):
476511 """Return a new MultiTrace object sliced according to `slice`."""
477512 new_traces = [trace ._slice (slice ) for trace in self ._straces .values ()]
478513 trace = MultiTrace (new_traces )
0 commit comments