Skip to content

Commit de1e859

Browse files
fixup again, moving home to test decamimdiff gen
1 parent 0b88c8a commit de1e859

File tree

4 files changed

+107
-58
lines changed

4 files changed

+107
-58
lines changed

src/kbmod/mocking/fits.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -184,13 +184,14 @@ class DECamImdiff:
184184
def from_defaults(
185185
cls,
186186
with_data=False,
187+
override_original=False,
188+
shape=(100, 100),
189+
start_mjd=60310,
190+
step_mjd=0.001,
191+
with_noise=False,
187192
noise="simplistic",
188193
src_cat=None,
189-
obj_cat=None,
190-
editable_images=False,
191-
separate_masks=False,
192-
writeable_masks=False,
193-
editable_masks=False,
194+
obj_cat=None
194195
):
195196
if obj_cat.config.type == "progressive":
196197
raise ValueError(
@@ -202,15 +203,8 @@ def from_defaults(
202203
hdr_factory = ArchivedHeader("headers_archive.tar.bz2", "decam_imdiff_headers.ecsv")
203204

204205
hdu_types = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
205-
hdu_types.extend(
206-
[
207-
BinTableHDU,
208-
]
209-
* 12
210-
)
211-
data = [
212-
NoneFactory(),
213-
] * 16
206+
hdu_types.extend([BinTableHDU] * 12)
207+
data = [NoneFactory()] * 16
214208

215209
if with_data:
216210
headers = hdr_factory.get(0)
@@ -236,12 +230,7 @@ def __init__(self, header_factory, data_factories=None, obj_cat=None):
236230
self.hdr_factory = header_factory
237231
self.data_factories = data_factories
238232
self.hdu_layout = [PrimaryHDU, CompImageHDU, CompImageHDU, CompImageHDU]
239-
self.hdu_layout.extend(
240-
[
241-
BinTableHDU,
242-
]
243-
* 12
244-
)
233+
self.hdu_layout.extend([BinTableHDU] * 12)
245234

246235
def mock(self, n=1):
247236
obj_cats = None

src/kbmod/mocking/fits_data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -464,12 +464,13 @@ class SimpleImage(DataFactory):
464464

465465
default_config = SimpleImageConfig
466466

467-
def __init__(self, image=None, src_cat=None, obj_cat=None, config=None, **kwargs):
467+
def __init__(self, image=None, src_cat=None, obj_cat=None, config=None,
468+
dtype=np.float32, **kwargs):
468469
self.config = self.default_config(config=config, **kwargs)
469470
super().__init__(image, self.config, **kwargs)
470471

471472
if image is None:
472-
image = np.zeros(self.config.shape, dtype=np.float32)
473+
image = np.zeros(self.config.shape, dtype=dtype)
473474
else:
474475
image = image
475476
self.config.shape = image.shape
@@ -794,7 +795,7 @@ def add_noise(cls, images, config):
794795
return images
795796

796797
@classmethod
797-
def gen_base_image(cls, config=None, src_cat=None):
798+
def gen_base_image(cls, config=None, src_cat=None, dtype=np.float32):
798799
"""Generate base image from configuration.
799800
800801
Parameters
@@ -812,7 +813,7 @@ def gen_base_image(cls, config=None, src_cat=None):
812813
config = cls.default_config(config)
813814

814815
# empty image
815-
base = np.zeros(config.shape, dtype=np.float32)
816+
base = np.zeros(config.shape, dtype=dtype)
816817
base += config.bias
817818
base = cls.add_hot_pixels(base, config)
818819
base = cls.add_bad_cols(base, config)
@@ -821,7 +822,8 @@ def gen_base_image(cls, config=None, src_cat=None):
821822

822823
return base
823824

824-
def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, **kwargs):
825+
def __init__(self, image=None, config=None, src_cat=None, obj_cat=None, dtype=np.float32,**kwargs):
825826
conf = self.default_config(config=config, **kwargs)
826827
# static objects are added in SimpleImage init
827-
super().__init__(image=self.gen_base_image(conf), config=conf, src_cat=src_cat, obj_cat=obj_cat)
828+
super().__init__(image=self.gen_base_image(conf, dtype=dtype),
829+
config=conf, src_cat=src_cat, obj_cat=obj_cat)

src/kbmod/mocking/headers.py

Lines changed: 64 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import warnings
22

3-
from astropy.utils.exceptions import AstropyUserWarning
3+
import numpy as np
4+
45
from astropy.wcs import WCS
56
from astropy.io.fits import Header
67

@@ -14,6 +15,54 @@
1415
]
1516

1617

18+
def make_wcs(center_coords=(351., -5.), rotation=0, pixscale=0.2, shape=None):
19+
"""
20+
Create a simple celestial `~astropy.wcs.WCS` object in ICRS
21+
coordinate system.
22+
23+
Parameters
24+
----------
25+
shape : tuple[int]
26+
Two-tuple, dimensions of the WCS footprint
27+
center_coords : tuple[int]
28+
Two-tuple of on-sky coordinates of the center of the WCS in
29+
decimal degrees, in ICRS.
30+
rotation : float, optional
31+
Rotation in degrees, from ICRS equator. In decimal degrees.
32+
pixscale : float
33+
Pixel scale in arcsec/pixel.
34+
35+
Returns
36+
-------
37+
wcs : `astropy.wcs.WCS`
38+
The world coordinate system.
39+
40+
Examples
41+
--------
42+
>>> from kbmod.mocking import make_wcs
43+
>>> shape = (100, 100)
44+
>>> wcs = make_wcs(shape)
45+
>>> wcs = make_wcs(shape, (115, 5), 45, 0.1)
46+
"""
47+
wcs = WCS(naxis=2)
48+
rho = rotation*0.0174533 # deg to rad
49+
scale = 0.1 / 3600.0 # arcsec/pixel to deg/pix
50+
51+
if shape is not None:
52+
wcs.pixel_shape = shape
53+
wcs.wcs.crpix = [shape[1] / 2, shape[0] / 2]
54+
else:
55+
wcs.wcs.crpix = [0, 0]
56+
wcs.wcs.crval = center_coords
57+
wcs.wcs.cunit = ['deg', 'deg']
58+
wcs.wcs.cd = [[-scale * np.cos(rho), scale * np.sin(rho)],
59+
[scale * np.sin(rho), scale * np.cos(rho)]]
60+
wcs.wcs.radesys = 'ICRS'
61+
wcs.wcs.ctype = ['RA---TAN', 'DEC--TAN']
62+
63+
return wcs
64+
65+
1766
class HeaderFactory:
1867
primary_template = {
1968
"EXTNAME": "PRIMARY",
@@ -29,14 +78,6 @@ class HeaderFactory:
2978

3079
ext_template = {"NAXIS": 2, "NAXIS1": 2048, "NAXIS2": 4096, "CRPIX1": 1024, "CPRIX2": 2048, "BITPIX": 32}
3180

32-
wcs_template = {
33-
"ctype": ["RA---TAN", "DEC--TAN"],
34-
"crval": [351, -5],
35-
"cunit": ["deg", "deg"],
36-
"radesys": "ICRS",
37-
"cd": [[-1.44e-07, 7.32e-05], [7.32e-05, 1.44e-05]],
38-
}
39-
4081
def __validate_mutables(self):
4182
# !xor
4283
if bool(self.mutables) != bool(self.callbacks):
@@ -87,23 +128,16 @@ def mock(self, n=1):
87128
return headers
88129

89130
@classmethod
90-
def gen_wcs(cls, metadata):
91-
wcs = WCS(naxis=2)
92-
for k, v in metadata.items():
93-
setattr(wcs.wcs, k, v)
94-
return wcs.to_header()
95-
96-
@classmethod
97-
def gen_header(cls, base, overrides, wcs_base=None):
131+
def gen_header(cls, base, overrides, wcs=None):
98132
header = Header(base)
99133
header.update(overrides)
100134

101-
if wcs_base is not None:
102-
naxis1 = header.get("NAXIS1", False)
103-
naxis2 = header.get("NAXIS2", False)
104-
if not all((naxis1, naxis2)):
105-
raise ValueError("Adding a WCS to the header requires " "NAXIS1 and NAXIS2 keys.")
106-
header.update(cls.gen_wcs(wcs_base))
135+
if wcs is not None:
136+
# Sync WCS with header + overwrites
137+
wcs_header = wcs.to_header()
138+
wcs_header.update(header)
139+
# then merge back to mocked header template
140+
header.update(wcs_header)
107141

108142
return header
109143

@@ -122,9 +156,13 @@ def from_ext_template(cls, overrides=None, mutables=None, callbacks=None, wcs=No
122156
ext_template["CRPIX1"] = shape[0] // 2
123157
ext_template["CRPIX2"] = shape[1] // 2
124158

125-
hdr = cls.gen_header(
126-
base=ext_template, overrides=overrides, wcs_base=cls.wcs_template if wcs is None else wcs
127-
)
159+
if wcs is None:
160+
wcs = make_wcs(
161+
shape=(ext_template["NAXIS1"], ext_template["NAXIS2"]),
162+
163+
)
164+
165+
hdr = cls.gen_header(base=ext_template, overrides=overrides, wcs=wcs)
128166
return cls(hdr, mutables, callbacks)
129167

130168

tests/test_end_to_end.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,15 @@ def test_static_objects(self):
6060
self.assertTrue(len(results) == 0)
6161

6262

63-
class TestLinearSearch(unittest.TestCase):
63+
class TestRandomLinearSearch(unittest.TestCase):
6464
def setUp(self):
6565
# Set up shared search values
6666
self.n_imgs = 10
6767
self.repeat_n_times = 10
68-
self.shape = (500, 500)
69-
self.start_pos = (10, 50)
70-
self.vxs = [10, 30]
71-
self.vys = [10, 30]
68+
self.shape = (300, 300)
69+
self.start_pos = (125, 175)
70+
self.vxs = [-10, 10]
71+
self.vys = [-10, 10]
7272

7373
# Set up configs for mocking and search
7474
# These don't change from test to test
@@ -100,7 +100,7 @@ def setUp(self):
100100
}
101101
)
102102

103-
def test_search(self):
103+
def test_simple_search(self):
104104
# Mock the data and repeat tests. The random catalog
105105
# creation guarantees a diverse set of changing test values
106106
for i in range(self.repeat_n_times):
@@ -123,6 +123,26 @@ def test_search(self):
123123
self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5)
124124
self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5)
125125

126+
def test_diffim_mocks(self):
127+
src_cat = kbmock.SourceCatalog.from_defaults()
128+
obj_cat = kbmock.ObjectCatalog.from_defaults(self.param_ranges, n=1)
129+
factory = kbmock.DECamImdiff.from_defaults(with_data=True, src_cat=src_cat, obj_cat=obj_cat)
130+
hduls = factory.mock(n=self.n_imgs)
131+
132+
ic = ImageCollection.fromTargets(hduls, force="TestDataStd")
133+
wu = ic.toWorkUnit(search_config=self.config)
134+
results = SearchRunner().run_search_from_work_unit(wu)
135+
136+
# Run tests
137+
self.assertGreaterEqual(len(results), 1)
138+
for res in results:
139+
diff = abs(obj_cat.table["y_mean"] - res["y"])
140+
obj = obj_cat.table[diff == diff.min()]
141+
self.assertLessEqual(abs(obj["x_mean"] - res["x"]), 5)
142+
self.assertLessEqual(abs(obj["y_mean"] - res["y"]), 5)
143+
self.assertLessEqual(abs(obj["vx"] - res["vx"]), 5)
144+
self.assertLessEqual(abs(obj["vy"] - res["vy"]), 5)
145+
126146

127147
####
128148

0 commit comments

Comments
 (0)