2020
2121from abc import ABC , abstractmethod
2222from enum import IntEnum , unique
23- from typing import Dict , List , Sequence , Tuple , Union
23+ from typing import Any , Dict , List , Mapping , Sequence , Tuple , Union
2424
2525import numpy as np
2626
@@ -181,14 +181,14 @@ def flatten_steps(step: Union[BlockedStep, CompoundStep]) -> List[BlockedStep]:
181181class StatsBijection :
182182 """Map between a `list` of stats to `dict` of stats."""
183183
184- def __init__ (self , sampler_stats_dtypes : Sequence [Dict [str , type ]]) -> None :
184+ def __init__ (self , sampler_stats_dtypes : Sequence [Mapping [str , type ]]) -> None :
185185 # Keep a list of flat vs. original stat names
186186 self ._stat_groups : List [List [Tuple [str , str ]]] = [
187187 [(f"sampler_{ s } __{ statname } " , statname ) for statname , _ in names_dtypes .items ()]
188188 for s , names_dtypes in enumerate (sampler_stats_dtypes )
189189 ]
190190
191- def map (self , stats_list : StatsType ) -> StatsDict :
191+ def map (self , stats_list : Sequence [ Mapping [ str , Any ]] ) -> StatsDict :
192192 """Combine stats dicts of multiple samplers into one dict."""
193193 stats_dict = {}
194194 for s , sts in enumerate (stats_list ):
@@ -197,7 +197,7 @@ def map(self, stats_list: StatsType) -> StatsDict:
197197 stats_dict [sname ] = sval
198198 return stats_dict
199199
200- def rmap (self , stats_dict : StatsDict ) -> StatsType :
200+ def rmap (self , stats_dict : Mapping [ str , Any ] ) -> StatsType :
201201 """Split a global stats dict into a list of sampler-wise stats dicts."""
202202 stats_list = []
203203 for namemap in self ._stat_groups :
0 commit comments