diff --git a/src/kbmod_wf/task_impls/kbmod_search.py b/src/kbmod_wf/task_impls/kbmod_search.py index c2552e19..ca71227a 100644 --- a/src/kbmod_wf/task_impls/kbmod_search.py +++ b/src/kbmod_wf/task_impls/kbmod_search.py @@ -1,8 +1,17 @@ import kbmod from kbmod.work_unit import WorkUnit -import os +import numpy as np from logging import Logger +import os + +from kbmod.analysis.plotting import plot_ic_image_bounds, plot_wcs_on_sky +from kbmod.analysis.visualizer import Visualizer + +from kbmod.filters.known_object_filters import KnownObjsMatcher +from kbmod.filters.stamp_filters import filter_stamps_by_cnn + +from astropy.table import Table def kbmod_search( @@ -89,7 +98,72 @@ def run_search(self): self.logger.info("Search complete") self.logger.info(f"Number of results found: {len(res)}") + # Match to known objects from the results + skybot_table_path = self.runtime_config.get("skybot_table_path", None) + if skybot_table_path is None: + self.logger.warning("No skybot table path provided, skipping filtering by known objects.") + else: + self.logger.info(f"Filtering results by known objects using table at {skybot_table_path}") + skytable = Table.read(skybot_table_path) + self.logger.info(f"Read {skybot_table_path}. There are {len(skytable)} rows.") + known_objs_matcher = KnownObjsMatcher( + skytable, + np.array(wu.get_all_obstimes()), + matcher_name="known_matcher", + sep_thresh=5.0, # Observations must be within 5 arcsecs. + time_thresh_s=30, # Observations must match within 30 seconds. + name_col="Name", + ra_col=f"ra_{wu.barycentric_distance}", + dec_col=f"dec_{wu.barycentric_distance}", + mjd_col="mjd_mid", + ) + + # Carry out initial matching to known objects and populate the matches column. + known_objs_matcher.match(res, wu.wcs) + + # Filter the matches down to results with at least 10 observations. + min_obs = 5 + known_objs_matcher.match_on_min_obs(res, min_obs) + + # Filter results by CNN + ml_model_path = self.runtime_config.get("ml_model_path", None) + if ml_model_path is None: + self.logger.warning("No ML model path provided, skipping filtering by CNN.") + else: + self.logger.info(f"Filtering results by CNN using model at {ml_model_path}") + orig_res_len = len(res) + filter_stamps_by_cnn( + res, + ml_model_path, + coadd_type="weighted", + ) + res.filter_rows(res["cnn_class"]) + self.logger.info( + f"Filtered {orig_res_len - len(res)} results using CNN model at {ml_model_path}" + ) + self.logger.info(f"Writing results to output file: {self.result_filepath}") res.write_table(self.result_filepath) + self.logger.info(f"Writing daily_coadds to output file: {self.result_filepath}") + # Now add some convenience plots to the results. + # Plot the daily coadds for each result. + res.table["stamp"] = res.table[ + "coadd_weighted" + ] # We store the combined stamps in the results table. + + viz = Visualizer( + wu.im_stack, res + ) # The image data are derived from the reprojected WorkUnit. + viz.generate_all_stamps(radius=config["stamp_radius"]) + viz.count_num_days() # This feature enables the daily image co-addition. + + # For each result, in the results table, plot the daily coadds. + for res_idx in range(len(res)): + res_uuid = res[res_idx]["uuid"] + daily_coadds_filename = os.path.join( + self.results_directory, f"{res_uuid}_daily_coadds.png" + ) + viz.plot_daily_coadds(res_idx, filename=daily_coadds_filename) + return self.result_filepath