Skip to content

Commit d797271

Browse files
Fixup the mocking code.
1 parent 2d99408 commit d797271

File tree

6 files changed

+906
-402
lines changed

6 files changed

+906
-402
lines changed

src/kbmod/mocking/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33
from .headers import *
44
from .fits_data import *
55
from .fits import *
6-
#from . import test_mocking

src/kbmod/mocking/catalogs.py

Lines changed: 132 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,18 @@
11
import abc
22

33
import numpy as np
4-
from astropy.time import Time
5-
from astropy.table import QTable, vstack
4+
from astropy.table import QTable
5+
from .config import Config
66

77

88
__all__ = [
99
"gen_catalog",
1010
"CatalogFactory",
11-
"SimpleSourceCatalog",
12-
"SimpleObjectCatalog",
11+
"SimpleCatalog",
12+
"SourceCatalogConfig",
13+
"SourceCatalog",
14+
"ObjectCatalogConfig",
15+
"ObjectCatalog",
1316
]
1417

1518

@@ -26,84 +29,161 @@ def gen_catalog(n, param_ranges, seed=None):
2629

2730
# conversion assumes a gaussian
2831
if "flux" in param_ranges and "amplitude" not in param_ranges:
29-
xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1
30-
ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1
32+
xstd = cat["x_stddev"] if "x_stddev" in cat.colnames else 1.0
33+
ystd = cat["y_stddev"] if "y_stddev" in cat.colnames else 1.0
3134

3235
cat["amplitude"] = cat["flux"] / (2.0 * np.pi * xstd * ystd)
3336

3437
return cat
3538

3639

37-
3840
class CatalogFactory(abc.ABC):
3941
@abc.abstractmethod
40-
def gen_realization(self, *args, t=None, dt=None, **kwargs):
42+
def mock(self, *args, **kwargs):
4143
raise NotImplementedError()
4244

43-
def mock(self, *args, **kwargs):
44-
return self.gen_realization(self, *args, **kwargs)
4545

46+
class SimpleCatalogConfig(Config):
47+
return_copy = False
48+
seed = None
49+
n = 100
50+
param_ranges = {}
51+
52+
53+
class SimpleCatalog(CatalogFactory):
54+
default_config = SimpleCatalogConfig
55+
56+
def __init_from_table(self, table, config=None, **kwargs):
57+
config = self.default_config(config=config, **kwargs)
58+
config.n = len(table)
59+
params = {}
60+
for col in table.keys():
61+
params[col] = (table[col].min(), table[col].max())
62+
config.param_ranges.update(params)
63+
return config, table
64+
65+
def __init_from_config(self, config, **kwargs):
66+
config = self.default_config(config=config, method="subset", **kwargs)
67+
table = gen_catalog(config.n, config.param_ranges, config.seed)
68+
return config, table
69+
70+
def __init_from_ranges(self, **kwargs):
71+
param_ranges = kwargs.pop("param_ranges", None)
72+
if param_ranges is None:
73+
param_ranges = {k: v for k, v in kwargs.items() if k in self.default_config.param_ranges}
74+
kwargs = {k: v for k, v in kwargs.items() if k not in self.default_config.param_ranges}
75+
76+
config = self.default_config(**kwargs, method="subset")
77+
config.param_ranges.update(param_ranges)
78+
return self.__init_from_config(config=config)
79+
80+
def __init__(self, table=None, config=None, **kwargs):
81+
if table is not None:
82+
config, table = self.__init_from_table(table, config=config, **kwargs)
83+
elif isinstance(config, Config):
84+
config, table = self.__init_from_config(config=config, **kwargs)
85+
elif isinstance(config, dict) or kwargs:
86+
config = {} if config is None else config
87+
config, table = self.__init_from_ranges(**{**config, **kwargs})
88+
else:
89+
raise ValueError(
90+
"Expected table or config, or keyword arguments of expected "
91+
f"catalog value ranges, got:\n table={table}\n config={config} "
92+
f"\n kwargs={kwargs}"
93+
)
94+
95+
self.config = config
96+
self.table = table
97+
self.current = 0
4698

47-
class SimpleSourceCatalog(CatalogFactory):
48-
base_param_ranges = {
49-
"amplitude": [500, 2000],
50-
"x_mean": [0, 4096],
51-
"y_mean": [0, 2048],
52-
"x_stddev": [1, 7],
53-
"y_stddev": [1, 7],
54-
"theta": [0, np.pi],
55-
}
99+
@classmethod
100+
def from_config(cls, config, **kwargs):
101+
config = cls.default_config(config=config, method="subset", **kwargs)
102+
return cls(gen_catalog(config.n, config.param_ranges, config.seed), config=config)
56103

57-
def __init__(self, table, return_copy=False):
58-
self.table = table
59-
self.return_copy = return_copy
104+
@classmethod
105+
def from_ranges(cls, n=None, config=None, **kwargs):
106+
config = cls.default_config(n=n, config=config, method="subset")
107+
config.param_ranges.update(**kwargs)
108+
return cls.from_config(config)
60109

61110
@classmethod
62-
def from_params(cls, n=100, param_ranges=None):
63-
param_ranges = {} if param_ranges is None else param_ranges
64-
tmp = cls.base_param_ranges.copy()
65-
tmp.update(param_ranges)
66-
return cls(gen_catalog(n, tmp))
67-
68-
def gen_realization(self, *args, t=None, dt=None, **kwargs):
69-
if self.return_copy:
111+
def from_table(cls, table):
112+
config = cls.default_config()
113+
config.n = len(table)
114+
params = {}
115+
for col in table.keys():
116+
params[col] = (table[col].min(), table[col].max())
117+
config["param_ranges"] = params
118+
return cls(table, config=config)
119+
120+
def mock(self):
121+
self.current += 1
122+
if self.config.return_copy:
70123
return self.table.copy()
71124
return self.table
72125

73126

74-
class SimpleObjectCatalog(CatalogFactory):
75-
base_param_ranges = {
76-
"amplitude": [1, 100],
77-
"x_mean": [0, 4096],
78-
"y_mean": [0, 2048],
79-
"vx": [500, 1000],
80-
"vy": [500, 1000],
81-
"stddev": [1, 1.8],
82-
"theta": [0, np.pi],
127+
class SourceCatalogConfig(SimpleCatalogConfig):
128+
param_ranges = {
129+
"amplitude": [1., 10.],
130+
"x_mean": [0., 4096.],
131+
"y_mean": [0., 2048.],
132+
"x_stddev": [1., 3.],
133+
"y_stddev": [1., 3.],
134+
"theta": [0., np.pi],
83135
}
84136

85-
def __init__(self, table, obstime=None):
86-
self.table = table
87-
self._realization = table.copy()
137+
138+
class SourceCatalog(SimpleCatalog):
139+
default_config = SourceCatalogConfig
140+
141+
142+
class ObjectCatalogConfig(SimpleCatalogConfig):
143+
param_ranges = {
144+
"amplitude": [0.1, 3.0],
145+
"x_mean": [0., 4096.],
146+
"y_mean": [0., 2048.],
147+
"vx": [500., 1000.],
148+
"vy": [500., 1000.],
149+
"stddev": [0.25, 1.5],
150+
"theta": [0., np.pi],
151+
}
152+
153+
154+
class ObjectCatalog(SimpleCatalog):
155+
default_config = ObjectCatalogConfig
156+
157+
def __init__(self, table=None, obstime=None, config=None, **kwargs):
158+
# put return_copy into kwargs to override whatever user might have
159+
# supplied, and to guarantee the default is overriden
160+
kwargs["return_copy"] = True
161+
super().__init__(table=table, config=config, **kwargs)
162+
self._realization = self.table.copy()
88163
self.obstime = 0 if obstime is None else obstime
89164

90-
@classmethod
91-
def from_params(cls, n=100, param_ranges=None):
92-
param_ranges = {} if param_ranges is None else param_ranges
93-
tmp = cls.base_param_ranges.copy()
94-
tmp.update(param_ranges)
95-
return cls(gen_catalog(n, tmp))
165+
def reset(self):
166+
self.current = 0
167+
self._realization = self.table.copy()
96168

97169
def gen_realization(self, t=None, dt=None, **kwargs):
98170
if t is None and dt is None:
99171
return self._realization
100172

101173
dt = dt if t is None else t - self.obstime
102-
self._realization["x_mean"] += self._realization["vx"] * dt
103-
self._realization["y_mean"] += self._realization["vy"] * dt
174+
self._realization["x_mean"] += self.table["vx"] * dt
175+
self._realization["y_mean"] += self.table["vy"] * dt
104176
return self._realization
105177

106178
def mock(self, n=1, **kwargs):
179+
breakpoint()
107180
if n == 1:
108-
return self.gen_realization(**kwargs)
109-
return [self.gen_realization(**kwargs).copy() for i in range(n)]
181+
data = self.gen_realization(**kwargs)
182+
self.current += 1
183+
else:
184+
data = []
185+
for i in range(n):
186+
data.append(self.gen_realization(**kwargs).copy())
187+
self.current += 1
188+
189+
return data

src/kbmod/mocking/config.py

Lines changed: 85 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import copy
2+
13
__all__ = ["Config", "ConfigurationError"]
24

35

@@ -22,37 +24,66 @@ class attributes. Particular attributes can be overriden on an per-instance
2224
Keyword arguments, assigned as configuration key-values.
2325
"""
2426

25-
def __init__(self, config=None, **kwargs):
27+
def __init__(self, config=None, method="default", **kwargs):
2628
# This is a bit hacky, but it makes life a lot easier because it
2729
# enables automatic loading of the default configuration and separation
2830
# of default config from instance bound config
2931
keys = list(set(dir(self.__class__)) - set(dir(Config)))
3032

3133
# First fill out all the defaults by copying cls attrs
32-
self._conf = {k: getattr(self, k) for k in keys}
34+
self._conf = {k: copy.copy(getattr(self, k)) for k in keys}
3335

3436
# Then override with any user-specified values
35-
conf = config
36-
if isinstance(config, Config):
37-
conf = config._conf
37+
self.update(config=config, method=method, **kwargs)
3838

39-
if conf is not None:
40-
self._conf.update(config)
41-
self._conf.update(kwargs)
39+
@classmethod
40+
def from_configs(cls, *args):
41+
config = cls()
42+
for conf in args:
43+
config.update(config=conf, method="extend")
44+
return config
4245

43-
# now just shortcut the most common dict operations
4446
def __getitem__(self, key):
4547
return self._conf[key]
4648

49+
# now just shortcut the most common dict operations
50+
def __getattribute__(self, key):
51+
hasconf = "_conf" in object.__getattribute__(self, "__dict__")
52+
if hasconf:
53+
conf = object.__getattribute__(self, "_conf")
54+
if key in conf:
55+
return conf[key]
56+
return object.__getattribute__(self, key)
57+
4758
def __setitem__(self, key, value):
4859
self._conf[key] = value
4960

61+
def __repr__(self):
62+
res = f"{self.__class__.__name__}("
63+
for k, v in self.items():
64+
res += f"{k}: {v}, "
65+
return res[:-2] + ")"
66+
5067
def __str__(self):
5168
res = f"{self.__class__.__name__}("
5269
for k, v in self.items():
5370
res += f"{k}: {v}, "
5471
return res[:-2] + ")"
5572

73+
def _repr_html_(self):
74+
repr = f"""
75+
<table style='tr:nth-child(even){{background-color: #dddddd;}};'>
76+
<caption>{self.__class__.__name__}</caption>
77+
<tr>
78+
<th>Key</th>
79+
<th>Value</th>
80+
</tr>
81+
"""
82+
for k, v in self.items():
83+
repr += f"<tr><td>{k}</td><td>{v}\n"
84+
repr += "</table>"
85+
return repr
86+
5687
def __len__(self):
5788
return len(self._conf)
5889

@@ -76,7 +107,7 @@ def __or__(self, other):
76107
elif isinstance(other, dict):
77108
return self.__class__(config=self._conf | other)
78109
else:
79-
raise TypeError("unsupported operand type(s) for |: {type(self)} " "and {type(other)}")
110+
raise TypeError("unsupported operand type(s) for |: {type(self)}and {type(other)}")
80111

81112
def keys(self):
82113
"""A set-like object providing a view on config's keys."""
@@ -90,7 +121,10 @@ def items(self):
90121
"""A set-like object providing a view on config's items."""
91122
return self._conf.items()
92123

93-
def update(self, conf=None, **kwargs):
124+
def copy(self):
125+
return self.__class__(config=self._conf.copy())
126+
127+
def update(self, config=None, method="default", **kwargs):
94128
"""Update this config from dict/other config/iterable and
95129
apply any explicit keyword overrides.
96130
@@ -107,9 +141,46 @@ def update(self, conf=None, **kwargs):
107141
108142
for k in kwargs: this[k] = kwargs[k]
109143
"""
110-
if conf is not None:
111-
self._conf.update(conf)
112-
self._conf.update(kwargs)
144+
# Python < 3.9 does not support set operations for dicts
145+
# [fixme]: Update this to: other = conf | kwargs
146+
# and remove current implementation when 3.9 gets too old. Order of
147+
# conf and kwargs matter to correctly apply explicit overrides
148+
149+
# Check if both conf and kwargs are given, just conf or just
150+
# kwargs. If none are given do nothing to comply with default
151+
# dict behavior
152+
if config is not None and kwargs:
153+
other = {**config, **kwargs}
154+
elif config is not None:
155+
other = config
156+
elif kwargs is not None:
157+
other = kwargs
158+
else:
159+
return
160+
161+
# then, see if we the given config and overrides are a subset of this
162+
# config or it's superset. Depending on the selected method then raise
163+
# errors, ignore or extend the current config if the given config is a
164+
# superset (or disjoint) from the current one.
165+
subset = {k: v for k, v in other.items() if k in self._conf}
166+
superset = {k: v for k, v in other.items() if k not in subset}
167+
168+
if method.lower() == "default":
169+
if superset:
170+
raise ConfigurationError(
171+
"Tried setting the following fields, not a part of "
172+
f"this configuration options: {superset}"
173+
)
174+
conf = other # == subset
175+
elif method.lower() == "subset":
176+
conf = subset
177+
elif method.lower() == "extend":
178+
conf = other
179+
else:
180+
raise ValueError("Method expected to be one of 'default', "
181+
f"'subset' or 'extend'. Got {method} instead.")
182+
183+
self._conf.update(conf)
113184

114185
def toDict(self):
115186
"""Return this config as a dict."""

0 commit comments

Comments
 (0)