Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
def __set_compiler_flags():
# Workarounds for Theano compiler problems on various platforms
import theano

current = theano.config.gcc.cxxflags
theano.config.gcc.cxxflags = f"{current} -Wno-c++11-narrowing"

Expand Down
11 changes: 5 additions & 6 deletions pymc3/backends/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,8 @@
from ..backends.sqlite import SQLite
from ..backends.hdf5 import HDF5

_shortcuts = {'text': {'backend': Text,
'name': 'mcmc'},
'sqlite': {'backend': SQLite,
'name': 'mcmc.sqlite'},
'hdf5': {'backend': HDF5,
'name': 'mcmc.hdf5'}}
_shortcuts = {
"text": {"backend": Text, "name": "mcmc"},
"sqlite": {"backend": SQLite, "name": "mcmc.sqlite"},
"hdf5": {"backend": HDF5, "name": "mcmc.hdf5"},
}
78 changes: 39 additions & 39 deletions pymc3/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from .report import SamplerReport, merge_reports
from ..util import get_var_name

logger = logging.getLogger('pymc3')
logger = logging.getLogger("pymc3")


class BackendError(Exception):
Expand Down Expand Up @@ -75,10 +75,8 @@ def __init__(self, name, model=None, vars=None, test_point=None):
test_point_.update(test_point)
test_point = test_point_
var_values = list(zip(self.varnames, self.fn(test_point)))
self.var_shapes = {var: value.shape
for var, value in var_values}
self.var_dtypes = {var: value.dtype
for var, value in var_values}
self.var_shapes = {var: value.shape for var, value in var_values}
self.var_dtypes = {var: value.dtype for var, value in var_values}
self.chain = None
self._is_base_setup = False
self.sampler_vars = None
Expand All @@ -104,8 +102,7 @@ def _set_sampler_vars(self, sampler_vars):
for stats in sampler_vars:
for key, dtype in stats.items():
if dtypes.setdefault(key, dtype) != dtype:
raise ValueError("Sampler statistic %s appears with "
"different types." % key)
raise ValueError("Sampler statistic %s appears with " "different types." % key)

self.sampler_vars = sampler_vars

Expand Down Expand Up @@ -155,7 +152,7 @@ def __getitem__(self, idx):
try:
return self.point(int(idx))
except (ValueError, TypeError): # Passed variable or variable name.
raise ValueError('Can only index with slice or integer')
raise ValueError("Can only index with slice or integer")

def __len__(self):
raise NotImplementedError
Expand Down Expand Up @@ -199,13 +196,13 @@ def get_sampler_stats(self, stat_name, sampler_idx=None, burn=0, thin=1):
if sampler_idx is not None:
return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)

sampler_idxs = [i for i, s in enumerate(self.sampler_vars)
if stat_name in s]
sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if stat_name in s]
if not sampler_idxs:
raise KeyError("Unknown sampler stat %s" % stat_name)

vals = np.stack([self._get_sampler_stats(stat_name, i, burn, thin)
for i in sampler_idxs], axis=-1)
vals = np.stack(
[self._get_sampler_stats(stat_name, i, burn, thin) for i in sampler_idxs], axis=-1
)
if vals.shape[-1] == 1:
return vals[..., 0]
else:
Expand Down Expand Up @@ -296,13 +293,12 @@ def __init__(self, straces):

self._report = SamplerReport()
for strace in straces:
if hasattr(strace, '_warnings'):
if hasattr(strace, "_warnings"):
self._report._add_warnings(strace._warnings, strace.chain)

def __repr__(self):
template = '<{}: {} chains, {} iterations, {} variables>'
return template.format(self.__class__.__name__,
self.nchains, len(self), len(self.varnames))
template = "<{}: {} chains, {} iterations, {} variables>"
return template.format(self.__class__.__name__, self.nchains, len(self), len(self.varnames))

@property
def nchains(self):
Expand Down Expand Up @@ -339,16 +335,17 @@ def __getitem__(self, idx):
var = get_var_name(var)
if var in self.varnames:
if var in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats.")
warnings.warn(
"Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats."
)
return self.get_values(var, burn=burn, thin=thin)
if var in self.stat_names:
return self.get_sampler_stats(var, burn=burn, thin=thin)
raise KeyError("Unknown variable %s" % var)

_attrs = {'_straces', 'varnames', 'chains', 'stat_names',
'supports_sampler_stats', '_report'}
_attrs = {"_straces", "varnames", "chains", "stat_names", "supports_sampler_stats", "_report"}

def __getattr__(self, name):
# Avoid infinite recursion when called before __init__
Expand All @@ -359,14 +356,15 @@ def __getattr__(self, name):
name = get_var_name(name)
if name in self.varnames:
if name in self.stat_names:
warnings.warn("Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats.")
warnings.warn(
"Attribute access on a trace object is ambigous. "
"Sampler statistic and model variable share a name. Use "
"trace.get_values or trace.get_sampler_stats."
)
return self.get_values(name)
if name in self.stat_names:
return self.get_sampler_stats(name)
raise AttributeError("'{}' object has no attribute '{}'".format(
type(self).__name__, name))
raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, name))

def __len__(self):
chain = self.chains[-1]
Expand Down Expand Up @@ -425,10 +423,12 @@ def add_values(self, vals, overwrite=False) -> None:
l_samples = len(self) * len(self.chains)
l_v = len(v)
if l_v != l_samples:
warnings.warn("The length of the values you are trying to "
"add ({}) does not match the number ({}) of "
"total samples in the trace "
"(chains * iterations)".format(l_v, l_samples))
warnings.warn(
"The length of the values you are trying to "
"add ({}) does not match the number ({}) of "
"total samples in the trace "
"(chains * iterations)".format(l_v, l_samples)
)

v = np.squeeze(v.reshape(len(chains), len(self), -1))

Expand Down Expand Up @@ -457,8 +457,7 @@ def remove_values(self, name):
chain.vars.remove(va)
del chain.samples[name]

def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
squeeze=True):
def get_values(self, varname, burn=0, thin=1, combine=True, chains=None, squeeze=True):
"""Get values from traces.

Parameters
Expand All @@ -485,14 +484,12 @@ def get_values(self, varname, burn=0, thin=1, combine=True, chains=None,
chains = self.chains
varname = get_var_name(varname)
try:
results = [self._straces[chain].get_values(varname, burn, thin)
for chain in chains]
results = [self._straces[chain].get_values(varname, burn, thin) for chain in chains]
except TypeError: # Single chain passed.
results = [self._straces[chains].get_values(varname, burn, thin)]
return _squeeze_cat(results, combine, squeeze)

def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
chains=None, squeeze=True):
def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True, chains=None, squeeze=True):
"""Get sampler statistics from the trace.

Parameters
Expand Down Expand Up @@ -520,8 +517,9 @@ def get_sampler_stats(self, stat_name, burn=0, thin=1, combine=True,
except TypeError:
chains = [chains]

results = [self._straces[chain].get_sampler_stats(stat_name, None, burn, thin)
for chain in chains]
results = [
self._straces[chain].get_sampler_stats(stat_name, None, burn, thin) for chain in chains
]
return _squeeze_cat(results, combine, squeeze)

def _slice(self, slice):
Expand Down Expand Up @@ -582,7 +580,9 @@ def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
base_mtrace = mtraces[0]
chain_len = len(base_mtrace)
# check base trace
if any(len(st) != chain_len for _, st in base_mtrace._straces.items()): # pylint: disable=line-too-long
if any(
len(st) != chain_len for _, st in base_mtrace._straces.items()
): # pylint: disable=line-too-long
raise ValueError("Chains are of different lengths.")
for new_mtrace in mtraces[1:]:
for new_chain, strace in new_mtrace._straces.items():
Expand Down
45 changes: 25 additions & 20 deletions pymc3/backends/hdf5.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
import h5py
from contextlib import contextmanager


@contextmanager
def activator(instance):
if isinstance(instance.hdf5_file, h5py.File):
if instance.hdf5_file.id: # if file is open, keep open
yield
return
# if file is closed/not referenced: open, do job, then close
instance.hdf5_file = h5py.File(instance.name, 'a')
instance.hdf5_file = h5py.File(instance.name, "a")
yield
instance.hdf5_file.close()
return
Expand All @@ -43,7 +44,7 @@ class HDF5(base.BaseTrace):
`model.unobserved_RVs` is used.
test_point: dict
use different test point that might be with changed variables shapes
"""
"""

supports_sampler_stats = True

Expand All @@ -64,21 +65,21 @@ def activate_file(self):
@property
def samples(self):
g = self.hdf5_file.require_group(str(self.chain))
if 'name' not in g.attrs:
g.attrs['name'] = self.chain
return g.require_group('samples')
if "name" not in g.attrs:
g.attrs["name"] = self.chain
return g.require_group("samples")

@property
def stats(self):
g = self.hdf5_file.require_group(str(self.chain))
if 'name' not in g.attrs:
g.attrs['name'] = self.chain
return g.require_group('stats')
if "name" not in g.attrs:
g.attrs["name"] = self.chain
return g.require_group("stats")

@property
def chains(self):
with self.activate_file:
return [v.attrs['name'] for v in self.hdf5_file.values()]
return [v.attrs["name"] for v in self.hdf5_file.values()]

@property
def is_new_file(self):
Expand All @@ -98,19 +99,19 @@ def nchains(self):
@property
def records_stats(self):
with self.activate_file:
return self.hdf5_file.attrs['records_stats']
return self.hdf5_file.attrs["records_stats"]

@records_stats.setter
def records_stats(self, v):
with self.activate_file:
self.hdf5_file.attrs['records_stats'] = bool(v)
self.hdf5_file.attrs["records_stats"] = bool(v)

def _resize(self, n):
for v in self.samples.values():
v.resize(n, axis=0)
for key, group in self.stats.items():
for statds in group.values():
statds.resize((n, ))
statds.resize((n,))

@property
def sampler_vars(self):
Expand All @@ -137,10 +138,13 @@ def sampler_vars(self, values):
if not data.keys(): # no pre-recorded stats
for varname, dtype in sampler.items():
if varname not in data:
data.create_dataset(varname, (self.draws,), dtype=dtype, maxshape=(None,))
data.create_dataset(
varname, (self.draws,), dtype=dtype, maxshape=(None,)
)
elif data.keys() != sampler.keys():
raise ValueError(
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}")
f"Sampler vars can't change, names incompatible: {data.keys()} != {sampler.keys()}"
)
self.records_stats = True

def setup(self, draws, chain, sampler_vars=None):
Expand All @@ -160,16 +164,18 @@ def setup(self, draws, chain, sampler_vars=None):
with self.activate_file:
for varname, shape in self.var_shapes.items():
if varname not in self.samples:
self.samples.create_dataset(name=varname, shape=(draws, ) + shape,
dtype=self.var_dtypes[varname],
maxshape=(None, ) + shape)
self.samples.create_dataset(
name=varname,
shape=(draws,) + shape,
dtype=self.var_dtypes[varname],
maxshape=(None,) + shape,
)
self.draw_idx = len(self)
self.draws = self.draw_idx + draws
self._set_sampler_vars(sampler_vars)
self._is_base_setup = True
self._resize(self.draws)


def close(self):
with self.activate_file:
if self.draw_idx == self.draws:
Expand Down Expand Up @@ -204,8 +210,7 @@ def _slice(self, idx):
start, stop, step = idx.indices(len(self))
sliced = ndarray.NDArray(model=self.model, vars=self.vars)
sliced.chain = self.chain
sliced.samples = {v: self.samples[v][start:stop:step]
for v in self.varnames}
sliced.samples = {v: self.samples[v][start:stop:step] for v in self.varnames}
sliced.draw_idx = (stop - start) // step
return sliced

Expand Down
Loading