Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 75 additions & 1 deletion src/kbmod_wf/task_impls/kbmod_search.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import kbmod
from kbmod.work_unit import WorkUnit

import os
import numpy as np
from logging import Logger
import os

from kbmod.analysis.plotting import plot_ic_image_bounds, plot_wcs_on_sky
from kbmod.analysis.visualizer import Visualizer

from kbmod.filters.known_object_filters import KnownObjsMatcher
from kbmod.filters.stamp_filters import filter_stamps_by_cnn

from astropy.table import Table


def kbmod_search(
Expand Down Expand Up @@ -89,7 +98,72 @@ def run_search(self):
self.logger.info("Search complete")
self.logger.info(f"Number of results found: {len(res)}")

# Match to known objects from the results
skybot_table_path = self.runtime_config.get("skybot_table_path", None)
if skybot_table_path is None:
self.logger.warning("No skybot table path provided, skipping filtering by known objects.")
else:
self.logger.info(f"Filtering results by known objects using table at {skybot_table_path}")
skytable = Table.read(skybot_table_path)
self.logger.info(f"Read {skybot_table_path}. There are {len(skytable)} rows.")
known_objs_matcher = KnownObjsMatcher(
skytable,
np.array(wu.get_all_obstimes()),
matcher_name="known_matcher",
sep_thresh=5.0, # Observations must be within 5 arcsecs.
time_thresh_s=30, # Observations must match within 30 seconds.
name_col="Name",
ra_col=f"ra_{wu.barycentric_distance}",
dec_col=f"dec_{wu.barycentric_distance}",
mjd_col="mjd_mid",
)

# Carry out initial matching to known objects and populate the matches column.
known_objs_matcher.match(res, wu.wcs)

# Filter the matches down to results with at least 10 observations.
min_obs = 5
known_objs_matcher.match_on_min_obs(res, min_obs)

# Filter results by CNN
ml_model_path = self.runtime_config.get("ml_model_path", None)
if ml_model_path is None:
self.logger.warning("No ML model path provided, skipping filtering by CNN.")
else:
self.logger.info(f"Filtering results by CNN using model at {ml_model_path}")
orig_res_len = len(res)
filter_stamps_by_cnn(
res,
ml_model_path,
coadd_type="weighted",
)
res.filter_rows(res["cnn_class"])
self.logger.info(
f"Filtered {orig_res_len - len(res)} results using CNN model at {ml_model_path}"
)

self.logger.info(f"Writing results to output file: {self.result_filepath}")
res.write_table(self.result_filepath)

self.logger.info(f"Writing daily_coadds to output file: {self.result_filepath}")
# Now add some convenience plots to the results.
# Plot the daily coadds for each result.
res.table["stamp"] = res.table[
"coadd_weighted"
] # We store the combined stamps in the results table.

viz = Visualizer(
wu.im_stack, res
) # The image data are derived from the reprojected WorkUnit.
viz.generate_all_stamps(radius=config["stamp_radius"])
viz.count_num_days() # This feature enables the daily image co-addition.

# For each result, in the results table, plot the daily coadds.
for res_idx in range(len(res)):
res_uuid = res[res_idx]["uuid"]
daily_coadds_filename = os.path.join(
self.results_directory, f"{res_uuid}_daily_coadds.png"
)
viz.plot_daily_coadds(res_idx, filename=daily_coadds_filename)

return self.result_filepath
Loading