3333import pymc as pm
3434
3535from pymc .backends import _init_trace
36- from pymc .backends .base import BaseTrace , MultiTrace , _choose_chains
36+ from pymc .backends .base import BaseTrace , IBaseTrace , MultiTrace , _choose_chains
3737from pymc .blocking import DictToArrayBijection
3838from pymc .exceptions import SamplingError
3939from pymc .initial_point import PointType , StartDict , make_initial_point_fns_per_chain
7171class SamplingIteratorCallback (Protocol ):
7272 """Signature of the callable that may be passed to `pm.sample(callable=...)`."""
7373
74- def __call__ (self , trace : BaseTrace , draw : Draw ):
74+ def __call__ (self , trace : IBaseTrace , draw : Draw ):
7575 pass
7676
7777
@@ -657,7 +657,7 @@ def _sample_many(
657657 * ,
658658 draws : int ,
659659 chains : int ,
660- traces : Sequence [BaseTrace ],
660+ traces : Sequence [IBaseTrace ],
661661 start : Sequence [PointType ],
662662 random_seed : Optional [Sequence [RandomSeed ]],
663663 step : Step ,
@@ -701,7 +701,7 @@ def _sample(
701701 start : PointType ,
702702 draws : int ,
703703 step : Step ,
704- trace : BaseTrace ,
704+ trace : IBaseTrace ,
705705 tune : int ,
706706 model : Optional [Model ] = None ,
707707 callback = None ,
@@ -726,8 +726,8 @@ def _sample(
726726 The number of samples to draw
727727 step : function
728728 Step function
729- trace : backend, optional
730- A backend instance .
729+ trace
730+ A chain backend to record draws and stats .
731731 tune : int
732732 Number of iterations to tune.
733733 model : Model (optional if in ``with`` context)
@@ -767,7 +767,7 @@ def _iter_sample(
767767 draws : int ,
768768 step : Step ,
769769 start : PointType ,
770- trace : BaseTrace ,
770+ trace : IBaseTrace ,
771771 chain : int = 0 ,
772772 tune : int = 0 ,
773773 model : Optional [Model ] = None ,
@@ -785,8 +785,8 @@ def _iter_sample(
785785 start : dict
786786 Starting point in parameter space (or partial point).
787787 Must contain numeric (transformed) initial values for all (transformed) free variables.
788- trace : backend
789- A backend instance .
788+ trace
789+ A chain backend to record draws and stats .
790790 chain : int, optional
791791 Chain number used to store sample in backend.
792792 tune : int, optional
@@ -852,7 +852,7 @@ def _mp_sample(
852852 random_seed : Sequence [RandomSeed ],
853853 start : Sequence [PointType ],
854854 progressbar : bool = True ,
855- traces : Sequence [BaseTrace ],
855+ traces : Sequence [IBaseTrace ],
856856 model : Optional [Model ] = None ,
857857 callback : Optional [SamplingIteratorCallback ] = None ,
858858 mp_ctx = None ,
@@ -879,9 +879,8 @@ def _mp_sample(
879879 Dicts must contain numeric (transformed) initial values for all (transformed) free variables.
880880 progressbar : bool
881881 Whether or not to display a progress bar in the command line.
882- trace : BaseTrace, optional
883- A backend instance, or None.
884- If None, the NDArray backend is used.
882+ traces
883+ Recording backends for each chain.
885884 model : Model (optional if in ``with`` context)
886885 callback
887886 A function which gets called for every sample from the trace of a chain. The function is
0 commit comments