1
1
import abc
2
2
3
3
import numpy as np
4
- from astropy .time import Time
5
- from astropy . table import QTable , vstack
4
+ from astropy .table import QTable
5
+ from . config import Config
6
6
7
7
8
8
__all__ = [
9
9
"gen_catalog" ,
10
10
"CatalogFactory" ,
11
- "SimpleSourceCatalog" ,
12
- "SimpleObjectCatalog" ,
11
+ "SimpleCatalog" ,
12
+ "SourceCatalogConfig" ,
13
+ "SourceCatalog" ,
14
+ "ObjectCatalogConfig" ,
15
+ "ObjectCatalog" ,
13
16
]
14
17
15
18
@@ -26,84 +29,161 @@ def gen_catalog(n, param_ranges, seed=None):
26
29
27
30
# conversion assumes a gaussian
28
31
if "flux" in param_ranges and "amplitude" not in param_ranges :
29
- xstd = cat ["x_stddev" ] if "x_stddev" in cat .colnames else 1
30
- ystd = cat ["y_stddev" ] if "y_stddev" in cat .colnames else 1
32
+ xstd = cat ["x_stddev" ] if "x_stddev" in cat .colnames else 1.0
33
+ ystd = cat ["y_stddev" ] if "y_stddev" in cat .colnames else 1.0
31
34
32
35
cat ["amplitude" ] = cat ["flux" ] / (2.0 * np .pi * xstd * ystd )
33
36
34
37
return cat
35
38
36
39
37
-
38
40
class CatalogFactory (abc .ABC ):
39
41
@abc .abstractmethod
40
- def gen_realization (self , * args , t = None , dt = None , ** kwargs ):
42
+ def mock (self , * args , ** kwargs ):
41
43
raise NotImplementedError ()
42
44
43
- def mock (self , * args , ** kwargs ):
44
- return self .gen_realization (self , * args , ** kwargs )
45
45
46
+ class SimpleCatalogConfig (Config ):
47
+ return_copy = False
48
+ seed = None
49
+ n = 100
50
+ param_ranges = {}
51
+
52
+
53
+ class SimpleCatalog (CatalogFactory ):
54
+ default_config = SimpleCatalogConfig
55
+
56
+ def __init_from_table (self , table , config = None , ** kwargs ):
57
+ config = self .default_config (config = config , ** kwargs )
58
+ config .n = len (table )
59
+ params = {}
60
+ for col in table .keys ():
61
+ params [col ] = (table [col ].min (), table [col ].max ())
62
+ config .param_ranges .update (params )
63
+ return config , table
64
+
65
+ def __init_from_config (self , config , ** kwargs ):
66
+ config = self .default_config (config = config , method = "subset" , ** kwargs )
67
+ table = gen_catalog (config .n , config .param_ranges , config .seed )
68
+ return config , table
69
+
70
+ def __init_from_ranges (self , ** kwargs ):
71
+ param_ranges = kwargs .pop ("param_ranges" , None )
72
+ if param_ranges is None :
73
+ param_ranges = {k : v for k , v in kwargs .items () if k in self .default_config .param_ranges }
74
+ kwargs = {k : v for k , v in kwargs .items () if k not in self .default_config .param_ranges }
75
+
76
+ config = self .default_config (** kwargs , method = "subset" )
77
+ config .param_ranges .update (param_ranges )
78
+ return self .__init_from_config (config = config )
79
+
80
+ def __init__ (self , table = None , config = None , ** kwargs ):
81
+ if table is not None :
82
+ config , table = self .__init_from_table (table , config = config , ** kwargs )
83
+ elif isinstance (config , Config ):
84
+ config , table = self .__init_from_config (config = config , ** kwargs )
85
+ elif isinstance (config , dict ) or kwargs :
86
+ config = {} if config is None else config
87
+ config , table = self .__init_from_ranges (** {** config , ** kwargs })
88
+ else :
89
+ raise ValueError (
90
+ "Expected table or config, or keyword arguments of expected "
91
+ f"catalog value ranges, got:\n table={ table } \n config={ config } "
92
+ f"\n kwargs={ kwargs } "
93
+ )
94
+
95
+ self .config = config
96
+ self .table = table
97
+ self .current = 0
46
98
47
- class SimpleSourceCatalog (CatalogFactory ):
48
- base_param_ranges = {
49
- "amplitude" : [500 , 2000 ],
50
- "x_mean" : [0 , 4096 ],
51
- "y_mean" : [0 , 2048 ],
52
- "x_stddev" : [1 , 7 ],
53
- "y_stddev" : [1 , 7 ],
54
- "theta" : [0 , np .pi ],
55
- }
99
+ @classmethod
100
+ def from_config (cls , config , ** kwargs ):
101
+ config = cls .default_config (config = config , method = "subset" , ** kwargs )
102
+ return cls (gen_catalog (config .n , config .param_ranges , config .seed ), config = config )
56
103
57
- def __init__ (self , table , return_copy = False ):
58
- self .table = table
59
- self .return_copy = return_copy
104
+ @classmethod
105
+ def from_ranges (cls , n = None , config = None , ** kwargs ):
106
+ config = cls .default_config (n = n , config = config , method = "subset" )
107
+ config .param_ranges .update (** kwargs )
108
+ return cls .from_config (config )
60
109
61
110
@classmethod
62
- def from_params (cls , n = 100 , param_ranges = None ):
63
- param_ranges = {} if param_ranges is None else param_ranges
64
- tmp = cls .base_param_ranges .copy ()
65
- tmp .update (param_ranges )
66
- return cls (gen_catalog (n , tmp ))
67
-
68
- def gen_realization (self , * args , t = None , dt = None , ** kwargs ):
69
- if self .return_copy :
111
+ def from_table (cls , table ):
112
+ config = cls .default_config ()
113
+ config .n = len (table )
114
+ params = {}
115
+ for col in table .keys ():
116
+ params [col ] = (table [col ].min (), table [col ].max ())
117
+ config ["param_ranges" ] = params
118
+ return cls (table , config = config )
119
+
120
+ def mock (self ):
121
+ self .current += 1
122
+ if self .config .return_copy :
70
123
return self .table .copy ()
71
124
return self .table
72
125
73
126
74
- class SimpleObjectCatalog (CatalogFactory ):
75
- base_param_ranges = {
76
- "amplitude" : [1 , 100 ],
77
- "x_mean" : [0 , 4096 ],
78
- "y_mean" : [0 , 2048 ],
79
- "vx" : [500 , 1000 ],
80
- "vy" : [500 , 1000 ],
81
- "stddev" : [1 , 1.8 ],
82
- "theta" : [0 , np .pi ],
127
+ class SourceCatalogConfig (SimpleCatalogConfig ):
128
+ param_ranges = {
129
+ "amplitude" : [1. , 10. ],
130
+ "x_mean" : [0. , 4096. ],
131
+ "y_mean" : [0. , 2048. ],
132
+ "x_stddev" : [1. , 3. ],
133
+ "y_stddev" : [1. , 3. ],
134
+ "theta" : [0. , np .pi ],
83
135
}
84
136
85
- def __init__ (self , table , obstime = None ):
86
- self .table = table
87
- self ._realization = table .copy ()
137
+
138
+ class SourceCatalog (SimpleCatalog ):
139
+ default_config = SourceCatalogConfig
140
+
141
+
142
+ class ObjectCatalogConfig (SimpleCatalogConfig ):
143
+ param_ranges = {
144
+ "amplitude" : [0.1 , 3.0 ],
145
+ "x_mean" : [0. , 4096. ],
146
+ "y_mean" : [0. , 2048. ],
147
+ "vx" : [500. , 1000. ],
148
+ "vy" : [500. , 1000. ],
149
+ "stddev" : [0.25 , 1.5 ],
150
+ "theta" : [0. , np .pi ],
151
+ }
152
+
153
+
154
+ class ObjectCatalog (SimpleCatalog ):
155
+ default_config = ObjectCatalogConfig
156
+
157
+ def __init__ (self , table = None , obstime = None , config = None , ** kwargs ):
158
+ # put return_copy into kwargs to override whatever user might have
159
+ # supplied, and to guarantee the default is overriden
160
+ kwargs ["return_copy" ] = True
161
+ super ().__init__ (table = table , config = config , ** kwargs )
162
+ self ._realization = self .table .copy ()
88
163
self .obstime = 0 if obstime is None else obstime
89
164
90
- @classmethod
91
- def from_params (cls , n = 100 , param_ranges = None ):
92
- param_ranges = {} if param_ranges is None else param_ranges
93
- tmp = cls .base_param_ranges .copy ()
94
- tmp .update (param_ranges )
95
- return cls (gen_catalog (n , tmp ))
165
+ def reset (self ):
166
+ self .current = 0
167
+ self ._realization = self .table .copy ()
96
168
97
169
def gen_realization (self , t = None , dt = None , ** kwargs ):
98
170
if t is None and dt is None :
99
171
return self ._realization
100
172
101
173
dt = dt if t is None else t - self .obstime
102
- self ._realization ["x_mean" ] += self ._realization ["vx" ] * dt
103
- self ._realization ["y_mean" ] += self ._realization ["vy" ] * dt
174
+ self ._realization ["x_mean" ] += self .table ["vx" ] * dt
175
+ self ._realization ["y_mean" ] += self .table ["vy" ] * dt
104
176
return self ._realization
105
177
106
178
def mock (self , n = 1 , ** kwargs ):
179
+ breakpoint ()
107
180
if n == 1 :
108
- return self .gen_realization (** kwargs )
109
- return [self .gen_realization (** kwargs ).copy () for i in range (n )]
181
+ data = self .gen_realization (** kwargs )
182
+ self .current += 1
183
+ else :
184
+ data = []
185
+ for i in range (n ):
186
+ data .append (self .gen_realization (** kwargs ).copy ())
187
+ self .current += 1
188
+
189
+ return data
0 commit comments