diff --git a/.gitignore b/.gitignore index 73b5b4eb..6e96e8fa 100644 --- a/.gitignore +++ b/.gitignore @@ -151,3 +151,6 @@ _html/ # Parsl log files run_logs/ + +# Emacs +*~ \ No newline at end of file diff --git a/src/kbmod_wf/measure_uncertainties_workflow.py b/src/kbmod_wf/measure_uncertainties_workflow.py new file mode 100644 index 00000000..2b7694c5 --- /dev/null +++ b/src/kbmod_wf/measure_uncertainties_workflow.py @@ -0,0 +1,471 @@ +import logging + +from kbmod_wf.utilities import ( + LOGGING_CONFIG, + apply_runtime_updates, + get_resource_config, + get_executors, + get_configured_logger, + ErrorLogger, + parse_logdir, + plot_campaign +) + +logging.config.dictConfig(LOGGING_CONFIG) + +import argparse +import os +import glob + +import toml +import parsl +from parsl import python_app, File +import parsl.executors + +from astropy.table import Table + + +# "esci_48_8cpus" "astro_48_8cpus" +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_96gb_8cpus"]), + ignore_for_cache=["logging_file"], +) +def step1(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Create WorkUnit out of an ImageCollection and resample it. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. Keys ``butler_config_filepath``, + ``search_config_filepath`` and ``n_workers`` will be consumed + if they exist. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + ic_file : `File` + Parsl File object pointing to the ImageCollection. + res_file : `File` + Parsl File object pointing to the Results file associated with the image collection. + + Outputs + ------- + workunit_path : `File` + Parsl File object poiting to the resampled WorkUnit. + """ + import numpy as np + from astropy.table import Table + from astropy.wcs import WCS + + from kbmod import ImageCollection + from kbmod.configuration import SearchConfiguration + import kbmod.reprojection as reprojection + + from lsst.daf.butler import Butler + + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.step1", logging_file.filepath) + + with ErrorLogger(logger): + logger.info("Starting step 1.") + + ic_filename = inputs[0].filename + ic_file = inputs[0].filepath + + pg_name = ic_filename.split(".collection")[0] + meas_path = f"uncert_meas/{pg_name}.meas" + if os.path.exists(meas_path): + if not os.path.exists(outputs[0].filepath): + # touch resampled.wu file so that Parsl + # understands its' been cached. + open(outputs[0].filepath, 'a').close() + logger.info("Finished step 1. Measurements exist.") + return outputs + + if os.path.exists(outputs[0].filepath): + logger.info("Finished step 1. Resampled WU exists.") + return outputs + + # Unravell inputs + repo_root = runtime_config["butler_config_filepath"] + search_conf_path = runtime_config.get("search_config_filepath", None) + ic_file = inputs[0].filepath + + #### + # Run core tasks + ### + ic = ImageCollection.read(ic_file) + ic.data.sort("mjd_mid") + search_conf = SearchConfiguration.from_file(search_conf_path) + + # The "optimal" WCS is the one we used in the initial search + # So pick that up from the results: + results = Table.read(inputs[1].filepath) + opt_wcs = WCS(json.loads(results.meta["wcs"])) + + butler = Butler(repo_root) + wu = ic.toWorkUnit(search_config=search_conf, butler=butler) + del ic # we're done with IC and results + del results # clean them up for memory + + resampled_wu = reprojection.reproject_work_unit( + wu, + opt_wcs, + parallelize=True, + max_parallel_processes=runtime_config.get("n_workers", 8), + ) + resampled_wu.to_fits(outputs[0].filepath, overwrite=True) + + logger.info("Finished step 1.") + return outputs + + +# "esci_48_8cpus" "astro_48_8cpus" +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "esci_32gb_2cpu_1gpu"]), + ignore_for_cache=["logging_file"], +) +def step2(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Create WorkUnit out of an ImageCollection and resample it. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. Keys ``butler_config_filepath``, + ``search_config_filepath`` and ``n_workers`` will be consumed + if they exist. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + wu_file : `File` + Parsl File object pointing to the WorkUnit. + ic_file : `File` + Parsl File object poiting to the associated ImageCollection. + res_file : `File` + Parsl File object poiting to the associated ImageCollection. + uuids : `list` + List of UUID hex representations corresponding to results we want to + measure uncertainties for. + + Outputs + ------- + workunit_path : `File` + Parsl File object poiting to the resampled WorkUnit. + """ + import json + + import numpy as np + import astropy.units as u + from astropy.table import Table + + import lsst.daf.butler as dafButler + + from kbmod.work_unit import WorkUnit + from kbmod.trajectory_explorer import TrajectoryExplorer + + from kbmod_wf.task_impls import calc_skypos_uncerts + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.step2", logging_file.filepath) + + with ErrorLogger(logger): + logger.info("Starting step 2.") + + if os.path.exists(outputs[0].filepath): + logger.info("Finished step 2. Measurements exist.") + return outputs + + wu_path = inputs[0][0].filepath + coll_path = inputs[1].filepath + res_path = inputs[2].filepath + uuids = inputs[3] + + # Run the search + wu = WorkUnit.from_fits(wu_path) + results = Table.read(res_path) + explorer = TrajectoryExplorer(wu.im_stack) + + mjds = results.meta["mjd_mid"] + mjd_start = np.min(mjds) + mjd_end = np.max(mjds) + + wcs = wu.wcs + + uuids2, pgs, startt, endt = [], [], [], [] + p1ra, p1dec, sigma_p1ra, sigma_p1dec = [], [], [], [] + p2ra, p2dec, sigma_p2ra, sigma_p2dec = [], [], [], [] + likelihoods, uncerts = [], [] + for uuid in uuids: + r = results[results["uuid"] == uuid] + samples = explorer.evaluate_around_linear_trajectory( + r["x"][0], + r["y"][0], + r["vx"][0], + r["vy"][0], + pixel_radius=10, + max_ang_offset=0.785397999997775, # np.pi/4 + ang_step=1.5*0.0174533, # deg2rad + max_vel_offset=45, + vel_step=0.55, + ) + + maxl = samples["likelihood"].max() + bestfit = samples[samples["likelihood"] == maxl] + # happens when oversampling + if len(bestfit) > 1: + bestfit = bestfit[:1] + + start_coord, end_coord, uncert = calc_skypos_uncerts( + samples, + mjd_start, + mjd_end, + wcs + ) + + uuids2.append(uuid) + startt.append(mjd_start) + endt.append(mjd_end) + likelihoods.append(maxl) + p1ra.append(start_coord.ra.deg) + p1dec.append(start_coord.dec.deg) + p2ra.append(end_coord.ra.deg) + p2dec.append(end_coord.dec.deg) + sigma_p1ra.append(uncert[0,0]) + sigma_p1dec.append(uncert[1,1]) + sigma_p2ra.append(uncert[2,2]) + sigma_p2dec.append(uncert[3,3]) + uncerts.append(uncert) + + t = Table({ + "likelihood": likelihoods, + "p1ra": p1ra, + "p1dec": p1dec, + "p2ra": p2ra, + "p2dec": p2dec, + "sigma_p1ra": np.sqrt(sigma_p1ra), + "sigma_p1dec": np.sqrt(sigma_p1dec), + "sigma_p2ra": np.sqrt(sigma_p2ra), + "sigma_p2dec": np.sqrt(sigma_p2dec), + "uncertainty": uncerts, + "uuid": uuids2, + "t0": startt, + "t1": endt + }) + t.write(outputs[0].filepath, format="ascii.ecsv", overwrite=True) + logger.info("Finished step 2.") + + return outputs + + +def workflow_runner(env=None, runtime_config={}): + """Find all image collections in the given directory and run KBMOD + search on them. + + Running the Workflow is a multi-step process which includes + additional preparation and cleanup work that executes at the + submit location: + - Run prep + - Load runtime config + - find all files in ``staging_directory`` that match ``pattern`` + - filter out unwanted files + - Run KBMOD Search for each remaining collection + - Create a workflow Gantt chart. + + Running a KBMOD search is a 3 step process: + - step 1, executed on CPUs + - load ImageCollection + - filter unwanted rows of data from it + - load SearchConfiguration + - update search config values based on the IC metadata + - materialize a WorkUnit, requires the Rubin Data Butler + - resample a WorkUnit, targets the largest common footprint WCS + - writes the WorkUnit to file + - step 2, executed on GPUs + - loads the WorkUnit + - runs KBMOD search + - adds relevant metadata to the Results Table + - writes Results to file + - step 3, executed on CPUs + - loads Results file + - makes an analysis plot + + Parameters + ---------- + env : str, optional + Environment string used to define which resource configuration to use, + by default None + runtime_config : dict, optional + Dictionary of assorted runtime configuration parameters, by default {} + """ + resource_config = get_resource_config(env=env) + resource_config = apply_runtime_updates(resource_config, runtime_config) + workflow_config = runtime_config.get("workflow", {}) + app_configs = runtime_config.get("apps", {}) + + dfk = parsl.load(resource_config) + logger = get_configured_logger("workflow.workflow_runner") + + if dfk: + if runtime_config is not None: + logger.info(f"Using runtime configuration definition:\n{toml.dumps(runtime_config)}") + + logger.info("Starting workflow") + + directory_path = workflow_config.get("staging_directory", "collections") + file_pattern = workflow_config.get("ic_filename_pattern", "*.collection") + pattern = os.path.join(directory_path, file_pattern) + entries = glob.glob(pattern) + logger.info(f"Found {len(entries)} files in {directory_path}") + + result_ic_lookup = Table.read("resources/uuid-pg-lookup.ecsv") + result_ic_lookup = result_ic_lookup.group_by("pg") + + # bookeping, used to build future output filenames + resfiles, collfiles, uuids_per_pg, collnames, resampled_wus = [], [], [], [], [] + for g in result_ic_lookup.groups: + resfname = g["pg"][0] + results_file = File(f"results/{resfname}") + resfiles.append(results_file) + + collname = resfname.replace(".results.ecsv", "") + collnames.append(collname) + + collection = f"{collname}.collection" + collection_file = File(os.path.join(directory_path, collection)) + collfiles.append(collection_file) + uuids_per_pg.append(list(g["uuid"])) + + logger.info(f"Registering {collname} for step1 of {collection}") + logging_file = File(f"logs/{collname}.resample.log") + + resampled_wus.append( + step1( + inputs=[collection_file, results_file], + outputs=[File(f"resampled_wus/{collname}.resampled.wu")], + runtime_config=app_configs.get("step1", {}), + logging_file=logging_file, + ) + ) + + results = [] + for resampledwu, collname, collfile, resfile, uuids in zip(resampled_wus, collnames, collfiles, resfiles, uuids_per_pg): + logger.info(f"Registering {collname} for step2 of {collfile.filepath}") + logging_file = File(f"logs/{collname}.search.log") + + results.append( + step2( + inputs=[resampledwu, collfile, resfile, uuids], + outputs=[File(f"uncert_meas/{collname}.meas"),], + runtime_config=app_configs.get("step2", {}), + logging_file=logging_file, + ) + ) + + [f.result() for f in results] + dfk.wait_for_current_tasks() + logger.info("Workflow complete") + + + # Create the Workflow Gantt chart + logs = parse_logdir("logs") + + success, fail = [], [] + for l in logs: + successfull_steps = l.stepnames[l.success] + if not all([ + "resample" in successfull_steps, + "search" in successfull_steps, + "analysis" in successfull_steps + ]): + fail.append(l) + else: + success.append(l) + + print(f"N success: {len(success)}") + print(f"N fail: {len(fail)}") + + with open("failed_runs.list", "w") as f: + for l in fail: + f.write(l.name) + f.write("\n") + + with open("success_runs.list", "w") as f: + for l in success: + f.write(l.name) + f.write("\n") + + try: + import matplotlib.pyplot as plt + except ImportError: + logger.warning("Matplotlib not installed, skipping creating " + "workflow Gantt chart") + else: + fig, ax = plt.subplots(figsize=(15, 15)) + ax = plot_campaign( + ax, + logs, + relative_to_launch=True, + units="hour", + name_pos="right+column" + ) + plt.tight_layout() + plt.savefig("exec_gantt.png") + finally: + parsl.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + choices=["dev", "klone"], + help="The environment to run the workflow in.", + ) + + parser.add_argument( + "--runtime-config", + type=str, + help="The complete runtime configuration filepath to use for the workflow.", + ) + + args = parser.parse_args() + + # if a runtime_config file was provided and exists, load the toml as a dict. + runtime_config = {} + if args.runtime_config is not None and os.path.exists(args.runtime_config): + with open(args.runtime_config, "r") as toml_runtime_config: + runtime_config = toml.load(toml_runtime_config) + + workflow_runner(env=args.env, runtime_config=runtime_config) + + + + + + + + + diff --git a/src/kbmod_wf/resource_configs/klone_configuration.py b/src/kbmod_wf/resource_configs/klone_configuration.py index 953622b7..6fbad604 100644 --- a/src/kbmod_wf/resource_configs/klone_configuration.py +++ b/src/kbmod_wf/resource_configs/klone_configuration.py @@ -18,81 +18,235 @@ def klone_resource_config(): app_cache=True, checkpoint_mode="task_exit", checkpoint_files=get_all_checkpoints( - os.path.join("/gscratch/dirac/kbmod/workflow/run_logs", datetime.date.today().isoformat()) + os.path.join(os.path.abspath(os.curdir), "parsl_rundir") ), - run_dir=os.path.join("/gscratch/dirac/kbmod/workflow/run_logs", datetime.date.today().isoformat()), - retries=1, + run_dir=os.path.join(os.path.abspath(os.curdir), "parsl_rundir"), + retries=4, executors=[ + #################### + # Resample resources + #################### HighThroughputExecutor( - label="small_cpu", + label="astro_96gb_8cpus", max_workers=1, provider=SlurmProvider( - partition="ckpt-g2", + partition="compute-bigmem", + account="astro", + min_blocks=0, + max_blocks=4, # Low block count for shared resource + init_blocks=0, + parallelism=1, + nodes_per_block=1, + mem_per_node=96, # 96 GB for >100, 48 for < 100 + cores_per_node=8, + exclusive=False, + walltime=walltimes["sharded_reproject"], + worker_init="", + ), + ), + HighThroughputExecutor( + label="astro_48gb_8cpus", + max_workers=1, + provider=SlurmProvider( + partition="compute-bigmem", account="astro", min_blocks=0, - max_blocks=4, + max_blocks=4, # Low block count for shared resource init_blocks=0, parallelism=1, nodes_per_block=1, - cores_per_node=1, # perhaps should be 8??? - mem_per_node=256, # In GB + mem_per_node=48, # 96 GB for >100, 48 for < 100 + cores_per_node=8, exclusive=False, - walltime=walltimes["compute_bigmem"], - # Command to run before starting worker - i.e. conda activate + walltime=walltimes["sharded_reproject"], worker_init="", ), ), HighThroughputExecutor( - label="large_mem", + label="esci_96gb_8cpus", max_workers=1, provider=SlurmProvider( - partition="ckpt-g2", + partition="gpu-a40", + account="escience", + min_blocks=0, + max_blocks=4, # low block count for shared resources + init_blocks=0, + parallelism=1, + nodes_per_block=1, + mem_per_node=96, # 96 GB for >100, 48 for < 100 + cores_per_node=8, + exclusive=False, + walltime=walltimes["sharded_reproject"], + worker_init="", + ), + ), + HighThroughputExecutor( + label="esci_48gb_8cpus", + max_workers=1, + provider=SlurmProvider( + partition="gpu-a40", + account="escience", + min_blocks=0, + max_blocks=4, # low block count for shared resources + init_blocks=0, + parallelism=1, + nodes_per_block=1, + mem_per_node=48, # 96 GB for >100, 48 for < 100 + cores_per_node=8, + exclusive=False, + walltime=walltimes["sharded_reproject"], + worker_init="", + ), + ), + HighThroughputExecutor( + label="ckpt_96gb_8cpus", + max_workers=1, + provider=SlurmProvider( + partition="ckpt-all", account="astro", min_blocks=0, - max_blocks=2, + max_blocks=50, # scale to the size of the GPU blocks, big number for low memory init_blocks=0, parallelism=1, nodes_per_block=1, - cores_per_node=32, - mem_per_node=512, + mem_per_node=96, # 96 GB for >100, 48 for < 100 + cores_per_node=8, exclusive=False, - walltime=walltimes["large_mem"], - # Command to run before starting worker - i.e. conda activate + walltime=walltimes["sharded_reproject"], worker_init="", ), ), HighThroughputExecutor( - label="sharded_reproject", + label="ckpt_48gb_8cpus", max_workers=1, provider=SlurmProvider( - partition="ckpt-g2", + partition="ckpt-all", account="astro", min_blocks=0, - max_blocks=2, + max_blocks=50, # scale to the size of the GPU blocks, big number for low memory init_blocks=0, parallelism=1, nodes_per_block=1, - cores_per_node=32, - mem_per_node=128, # ~2-4 GB per core + mem_per_node=48, # 96 GB for >100, 48 for < 100 + cores_per_node=8, exclusive=False, walltime=walltimes["sharded_reproject"], + worker_init="", + ), + ), + #################### + # Search resources + #################### + HighThroughputExecutor( + label="esci_96gb_2cpu_1gpu", + max_workers=1, + provider=SlurmProvider( + partition="gpu-a40", + account="escience", + min_blocks=0, + max_blocks=4, # low block count for shared resource + init_blocks=0, + parallelism=1, + nodes_per_block=1, + cores_per_node=2, # perhaps should be 8??? + mem_per_node=96, # 96 GB for >100, 48 for < 100 + exclusive=False, + walltime=walltimes["gpu_max"], + worker_init="", + scheduler_options="#SBATCH --gpus=1", + ), + ), + HighThroughputExecutor( + label="esci_48gb_2cpu_1gpu", + max_workers=1, + provider=SlurmProvider( + partition="gpu-a40", + account="escience", + min_blocks=0, + max_blocks=4, # low block count for shared resource + init_blocks=0, + parallelism=1, + nodes_per_block=1, + cores_per_node=2, # perhaps should be 8??? + mem_per_node=48, # 96 GB for >100, 48 for < 100 + exclusive=False, + walltime=walltimes["gpu_max"], + worker_init="", + scheduler_options="#SBATCH --gpus=1", + ), + ), + HighThroughputExecutor( + label="esci_32gb_2cpu_1gpu", + max_workers=1, + provider=SlurmProvider( + partition="gpu-a40", + account="escience", + min_blocks=0, + max_blocks=6, # low block count for shared resource + init_blocks=0, + parallelism=1, + nodes_per_block=1, + cores_per_node=2, # perhaps should be 8??? + mem_per_node=32, # 96 GB for >100, 48 for < 100 + exclusive=False, + walltime=walltimes["gpu_max"], + worker_init="", + scheduler_options="#SBATCH --gpus=1", + ), + ), + HighThroughputExecutor( + label="ckpt_96gb_2cpu_1gpu", + max_workers=1, + provider=SlurmProvider( + partition="ckpt-g2", + account="escience", + min_blocks=0, + max_blocks=50, # 20 for 96, 50 for 48 + init_blocks=0, + parallelism=1, + nodes_per_block=1, + cores_per_node=2, # perhaps should be 8??? + mem_per_node=96, # 96 GB for >100, 48 for < 100 + exclusive=False, + walltime=walltimes["gpu_max"], + # Command to run before starting worker - i.e. conda activate + worker_init="", + scheduler_options="#SBATCH --gpus=1", + ), + ), + HighThroughputExecutor( + label="ckpt_48gb_2cpu_1gpu", + max_workers=1, + provider=SlurmProvider( + partition="ckpt-g2", + account="escience", + min_blocks=0, + max_blocks=50, # 20 for 96, 50 for 48 + init_blocks=0, + parallelism=1, + nodes_per_block=1, + cores_per_node=2, # perhaps should be 8??? + mem_per_node=48, # 96 GB for >100, 48 for < 100 + exclusive=False, + walltime=walltimes["gpu_max"], # Command to run before starting worker - i.e. conda activate worker_init="", + scheduler_options="#SBATCH --gpus=1", ), ), HighThroughputExecutor( - label="gpu", + label="ckpt_32gb_2cpu_1gpu", max_workers=1, provider=SlurmProvider( partition="ckpt-g2", account="escience", min_blocks=0, - max_blocks=2, + max_blocks=50, # 20 for 96, 50 for 48 init_blocks=0, parallelism=1, nodes_per_block=1, cores_per_node=2, # perhaps should be 8??? - mem_per_node=512, # In GB + mem_per_node=32, # 96 GB for >100, 48 for < 100 exclusive=False, walltime=walltimes["gpu_max"], # Command to run before starting worker - i.e. conda activate @@ -100,11 +254,65 @@ def klone_resource_config(): scheduler_options="#SBATCH --gpus=1", ), ), + + #################### + # Analysis resource + #################### HighThroughputExecutor( - label="local_thread", - provider=LocalProvider( + label="astro_4gb_2cpus", + max_workers=1, # Do we mean max_workers_per_node here? + provider=SlurmProvider( + partition="compute-bigmem", # ckpt-all + account="astro", # astro + min_blocks=0, + max_blocks=12, # low block count for shared resource init_blocks=0, - max_blocks=1, + parallelism=1, + nodes_per_block=1, + mem_per_node=4, + cores_per_node=2, + exclusive=False, + walltime=walltimes["sharded_reproject"], + # Command to run before starting worker - i.e. conda activate + worker_init="", + ), + ), + HighThroughputExecutor( + label="esci_4gb_2cpus", + max_workers=1, # Do we mean max_workers_per_node here? + provider=SlurmProvider( + partition="gpu-a40", # ckpt-all + account="escience", # astro + min_blocks=0, + max_blocks=12, # low block count for shared resource + init_blocks=0, + parallelism=1, + nodes_per_block=1, + mem_per_node=4, + cores_per_node=2, + exclusive=False, + walltime=walltimes["sharded_reproject"], + # Command to run before starting worker - i.e. conda activate + worker_init="", + ), + ), + HighThroughputExecutor( + label="ckpt_4gb_2cpus", + max_workers=1, # Do we mean max_workers_per_node here? + provider=SlurmProvider( + partition="ckpt-all", # ckpt-all + account="astro", # astro + min_blocks=0, + max_blocks=100, # can leave large at all times + init_blocks=0, + parallelism=1, + nodes_per_block=1, + mem_per_node=4, + cores_per_node=2, + exclusive=False, + walltime=walltimes["sharded_reproject"], + # Command to run before starting worker - i.e. conda activate + worker_init="", ), ), ], diff --git a/src/kbmod_wf/single_chip_analysis.py b/src/kbmod_wf/single_chip_analysis.py new file mode 100644 index 00000000..611249ba --- /dev/null +++ b/src/kbmod_wf/single_chip_analysis.py @@ -0,0 +1,217 @@ +import logging + +from kbmod_wf.utilities import ( + LOGGING_CONFIG, + apply_runtime_updates, + get_resource_config, + get_executors, + get_configured_logger, + ErrorLogger, + parse_logdir, + plot_campaign +) + +logging.config.dictConfig(LOGGING_CONFIG) + +import argparse +import os +import glob + +import toml +import parsl +from parsl import python_app, File +import parsl.executors +import time + + +#"ckpt_2gb_2cpus", "ckpt_2gb_2cpus", "astro_2gb_2cpus"]), +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_4gb_2cpus"]), + ignore_for_cache=["logging_file"], +) +def postscript(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Run postscript actions after each individual task. + + Generally consists of creating analysis plots for each result. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. No keys are consumed. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + result_file : `File` + Parsl File object poiting to the associated Results. + + Outputs + ------- + results : `File` + Parsl File object poiting to the results. + """ + import tempfile + import tarfile + import json + + from astropy.table import Table + from astropy.io import fits as fitsio + from astropy.wcs import WCS + import matplotlib.pyplot as plt + + from kbmod_wf.task_impls.deep_plots import ( + Figure, + configure_plot, + plot_result, + select_known_objects + ) + + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("kbmod", logging_file.filepath) + + with ErrorLogger(logger): + # Grab some names from the input so we know how to name + # our output plots etc. + results_path = inputs[0].filepath + collname = os.path.basename(results_path).split(".")[0] + results = Table.read(results_path) + + # Grab external resources required + # - wcs, times, visitids, fakes, known objects and so on + obstimes = results.meta["mjd_mid"] + wcs = WCS(json.loads(results.meta["wcs"])) + + fakes = fitsio.open(runtime_config.get( + "fake_object_catalog", "fakes_catalog.fits" + )) + allknowns = Table.read(runtime_config.get( + "known_object_catalog", "known_objects_catalog.fits" + )) + + fakes, knowns = select_known_objects(fakes, allknowns, results) + fakes = fakes.group_by("ORBITID") + knowns = knowns.group_by("Name") + + # Make the plots, write them to tmpdir and tar them up + allplots = [] + tmpdir = tempfile.mkdtemp() + logger.info(f"Creating analysis plots for results of length: {len(results)}") + for i, res in enumerate(results): + figure = configure_plot(wcs, fig_kwargs={"figsize": (24, 12)}) + figure.fig.suptitle(f"{collname}, {res['uuid']}") + figure = plot_result(figure, res, fakes, knowns, wcs, obstimes) + + pltname = f"{collname}_L{int(res['likelihood']):0>4}_idx{i:0>4}.jpg" + pltpath = os.path.join(tmpdir, pltname) + allplots.append(pltpath) + logger.info(f"Saving {pltpath}") + plt.savefig(pltpath) + plt.close(figure.fig) + + with tarfile.open(outputs[0].filepath, "w|bz2") as tar: + for f in allplots: + tar.add(f) + + return outputs + + +def workflow_runner(env=None, runtime_config={}): + """This function will load and configure Parsl, and run the workflow. + + Parameters + ---------- + env : str, optional + Environment string used to define which resource configuration to use, + by default None + runtime_config : dict, optional + Dictionary of assorted runtime configuration parameters, by default {} + """ + resource_config = get_resource_config(env=env) + resource_config = apply_runtime_updates(resource_config, runtime_config) + app_configs = runtime_config.get("apps", {}) + + dfk = parsl.load(resource_config) + logger = get_configured_logger("workflow.workflow_runner") + + if dfk: + if runtime_config is not None: + logger.info(f"Using runtime configuration definition:\n{toml.dumps(runtime_config)}") + + results_dirpath = "results" + pattern = os.path.join(results_dirpath, "*results*") + results = glob.glob(pattern) + + resampledwus_dirpath = "resampled_wus" + imgcolls_dirpath = "collections" + + collnames, collfiles, wufiles = [], [], [] + for respth in results: + resname = os.path.basename(respth).split(".")[0] + collnames.append(resname) + + pattern = os.path.join(imgcolls_dirpath, resname) + "*" + collfiles.extend(glob.glob(pattern)) + + pattern = os.path.join(resampledwus_dirpath, resname) + "*" + wufiles.extend(glob.glob(pattern)) + + logger.info("Starting workflow") + logger.info(f"Found {len(results)} files in {results_dirpath}") + + # Register postscript for each output of step 2 + analysis = [] + for result, collname in zip(results, collnames): + logger.info(f"Registering {collname} for postscript") + logging_file = File(f"analysis_logs/{collname}.analysis.log") + plots_archive = File(f"analysis_plots/{collname}.plots.tar.bz2") + analysis.append( + postscript( + inputs=[File(result)], + outputs=[plots_archive, ], + runtime_config=app_configs.get("postscript", {}), + logging_file=logging_file + ) + ) + + [f.result() for f in analysis] + dfk.wait_for_current_tasks() + logger.info("Workflow complete") + parsl.clear() + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + choices=["dev", "klone"], + help="The environment to run the workflow in.", + ) + + parser.add_argument( + "--runtime-config", + type=str, + help="The complete runtime configuration filepath to use for the workflow.", + ) + + args = parser.parse_args() + + # if a runtime_config file was provided and exists, load the toml as a dict. + runtime_config = {} + if args.runtime_config is not None and os.path.exists(args.runtime_config): + with open(args.runtime_config, "r") as toml_runtime_config: + runtime_config = toml.load(toml_runtime_config) + + workflow_runner(env=args.env, runtime_config=runtime_config) diff --git a/src/kbmod_wf/single_chip_step2.py b/src/kbmod_wf/single_chip_step2.py new file mode 100644 index 00000000..c1a9052f --- /dev/null +++ b/src/kbmod_wf/single_chip_step2.py @@ -0,0 +1,183 @@ +import logging + +from kbmod_wf.utilities import ( + LOGGING_CONFIG, + apply_runtime_updates, + get_resource_config, + get_executors, + get_configured_logger, + ErrorLogger +) + +logging.config.dictConfig(LOGGING_CONFIG) + +import argparse +import os +import glob + +import toml +import parsl +from parsl import python_app, File +import parsl.executors +import time + + +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_32gb_2cpu_1gpu"]), + ignore_for_cache=["logging_file"], +) +def step2(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Load an resampled WorkUnit and search through it. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. No values are consumed. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + wu_file : `File` + Parsl File object pointing to the WorkUnit. + ic_file : `File` + Parsl File object poiting to the associated ImageCollection. + + Outputs + ------- + results : `File` + Parsl File object poiting to the results. + """ + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.step2", logging_file.filepath) + + import json + + from kbmod import ImageCollection + from kbmod.work_unit import WorkUnit + from kbmod.run_search import SearchRunner + + with ErrorLogger(logger): + wu_path = inputs[0][0].filepath + coll_path = inputs[1].filepath + + # Run the search + ic = ImageCollection.read(coll_path) + ic.data.sort("mjd_mid") + wu = WorkUnit.from_fits(wu_path) + res = SearchRunner().run_search_from_work_unit(wu) + + # add useful metadata to the results + header = wu.wcs.to_header(relax=True) + header["NAXIS1"], header["NAXIS2"] = wu.wcs.pixel_shape + res.table.meta["wcs"] = json.dumps(dict(header)) + res.table.meta["visits"] = list(ic["visit"].data) + res.table.meta["detector"] = ic["detector"][0] + res.table.meta["mjd_mid"] = list(ic["mjd_mid"].data) + res.table["uuid"] = [uuid.uuid4().hex for i in range(len(res.table))] + + # write results + res.write_table(outputs[0].filepath, overwrite=True) + + return outputs + + +def workflow_runner(env=None, runtime_config={}): + """Find all WorkUnits in the given directory and run KBMOD + search on them. + + Requires matching image collections directory path. + + Parameters + ---------- + env : str, optional + Environment string used to define which resource configuration to use, + by default None + runtime_config : dict, optional + Dictionary of assorted runtime configuration parameters, by default {} + """ + resource_config = get_resource_config(env=env) + resource_config = apply_runtime_updates(resource_config, runtime_config) + workflow_config = runtime_config.get("workflow", {}) + app_configs = runtime_config.get("apps", {}) + + dfk = parsl.load(resource_config) + logger = get_configured_logger("workflow.workflow_runner") + + if dfk: + if runtime_config is not None: + logger.info(f"Using runtime configuration definition:\n{toml.dumps(runtime_config)}") + + logger.info("Starting workflow") + + resampledwus_dirpath = "resampled_wus" + imgcolls_dirpath = "collections" + + wufile_pattern = "*.wu" + pattern = os.path.join(resampledwus_dirpath, wufile_pattern) + wus = glob.glob(pattern) + + collnames, collfiles = [], [] + for wupth in wus: + wuname = os.path.basename(wupth) + wuname = wuname.split(".")[0] + collnames.append(wuname) + pattern = os.path.join(imgcolls_dirpath, wuname) + "*" + collfiles.extend(glob.glob(pattern)) + + logger.info(f"Found {len(wus)} WorkUnits in {resampledwus_path}") + + # Register step 2 for each output of step 1 + results = [] + for resample, collname, collfile in zip(wus, collnames, collfiles): + logger.info(f"Registering {collname} for step2 of {collfile.filepath}") + logging_file = File(f"logs/{collname}.search.log") + results.append( + step2( + inputs=[resample, collfile], + outputs=[File(f"results/{collname}.results.ecsv"),], + runtime_config=app_configs.get("step2", {}), + logging_file=logging_file, + ) + ) + + [f.result() for f in search_futures] + logger.info("Workflow complete") + + parsl.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + choices=["dev", "klone"], + help="The environment to run the workflow in.", + ) + + parser.add_argument( + "--runtime-config", + type=str, + help="The complete runtime configuration filepath to use for the workflow.", + ) + + args = parser.parse_args() + + # if a runtime_config file was provided and exists, load the toml as a dict. + runtime_config = {} + if args.runtime_config is not None and os.path.exists(args.runtime_config): + with open(args.runtime_config, "r") as toml_runtime_config: + runtime_config = toml.load(toml_runtime_config) + + workflow_runner(env=args.env, runtime_config=runtime_config) diff --git a/src/kbmod_wf/single_chip_workflow2.py b/src/kbmod_wf/single_chip_workflow2.py new file mode 100644 index 00000000..b25b3e11 --- /dev/null +++ b/src/kbmod_wf/single_chip_workflow2.py @@ -0,0 +1,482 @@ +import logging + +from kbmod_wf.utilities import ( + LOGGING_CONFIG, + apply_runtime_updates, + get_resource_config, + get_executors, + get_configured_logger, + ErrorLogger, + parse_logdir, + plot_campaign +) + +logging.config.dictConfig(LOGGING_CONFIG) + +import argparse +import os +import glob + +import toml +import parsl +from parsl import python_app, File +import parsl.executors + + +# "esci_48_8cpus" "astro_48_8cpus" +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_48gb_8cpus"]), + ignore_for_cache=["logging_file"], +) +def step1(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Create WorkUnit out of an ImageCollection and resample it. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. Keys ``butler_config_filepath``, + ``search_config_filepath`` and ``n_workers`` will be consumed + if they exist. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + ic_file : `File` + Parsl File object pointing to the ImageCollection. + + Outputs + ------- + workunit_path : `File` + Parsl File object poiting to the resampled WorkUnit. + """ + import numpy as np + from reproject.mosaicking import find_optimal_celestial_wcs + + from kbmod import ImageCollection + from kbmod.configuration import SearchConfiguration + import kbmod.reprojection as reprojection + + from lsst.daf.butler import Butler + + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.step1", logging_file.filepath) + + with ErrorLogger(logger): + # Unravell inputs + repo_root = runtime_config["butler_config_filepath"] + search_conf_path = runtime_config.get("search_config_filepath", None) + ic_file = inputs[0].filepath + + #### + # Run core tasks + ### + ic = ImageCollection.read(ic_file) + + ### Mask out images that we don't want or can not search through + # mask out poor weather images + #mask_zp = np.logical_and(ic["zeroPoint"] > 27 , ic["zeroPoint"] < 32) + #ic = ic[np.logical_and(mask_zp, mask_wcs_err)] + + # mask out images with WCS error more than 0.1 arcseconds because we + # can't trust their resampling can be correct + mask_good_wcs_err = ic["wcs_err"] < 1e-04 + if not all(mask_good_wcs_err): + logger.warning("Image collection contains large WCS errors!") + #ic = ic[mask_good_wcs_err] + #ic.reset_lazy_loading_indices() + ic.data.sort("mjd_mid") + + ### Adjust the search parameters based on remaining metadata + search_conf = SearchConfiguration.from_file(search_conf_path) + if len(ic)//2 < 25: + n_obs = 15 + else: + n_obs = len(ic)//2 + search_conf._params["n_obs"] = n_obs + + ### Resampling + # Fit the optimal WCS + opt_wcs, shape = find_optimal_celestial_wcs(list(ic.wcs)) + opt_wcs.array_shape = shape + + butler = Butler(repo_root) + wu = ic.toWorkUnit(search_config=search_conf, butler=butler) + del ic # we're done with IC, clean it up for memory + + resampled_wu = reprojection.reproject_work_unit( + wu, + opt_wcs, + parallelize=True, + max_parallel_processes=runtime_config.get("n_workers", 8), + ) + resampled_wu.to_fits(outputs[0].filepath, overwrite=True) + + return outputs + + +# "esci_48_2cpu_1gpu", "esci_48_2cpu_1gpu" +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_32gb_2cpu_1gpu"]), + ignore_for_cache=["logging_file"], +) +def step2(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Load an resampled WorkUnit and search through it. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. No values are consumed. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + wu_file : `File` + Parsl File object pointing to the WorkUnit. + ic_file : `File` + Parsl File object poiting to the associated ImageCollection. + + Outputs + ------- + results : `File` + Parsl File object poiting to the results. + """ + import json + + from kbmod import ImageCollection + from kbmod.work_unit import WorkUnit + from kbmod.run_search import SearchRunner + + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.step2", logging_file.filepath) + + with ErrorLogger(logger): + wu_path = inputs[0][0].filepath + coll_path = inputs[1].filepath + + # Run the search + ic = ImageCollection.read(coll_path) + ic.data.sort("mjd_mid") + wu = WorkUnit.from_fits(wu_path) + res = SearchRunner().run_search_from_work_unit(wu) + + # add useful metadata to the results + header = wu.wcs.to_header(relax=True) + header["NAXIS1"], header["NAXIS2"] = wu.wcs.pixel_shape + res.table.meta["wcs"] = json.dumps(dict(header)) + res.table.meta["visits"] = list(ic["visit"].data) + res.table.meta["detector"] = ic["detector"][0] + res.table.meta["mjd_mid"] = list(ic["mjd_mid"].data) + res.table["uuid"] = [uuid.uuid4().hex for i in range(len(res.table))] + + # write results + res.write_table(outputs[0].filepath, overwrite=True) + + return outputs + + +#"ckpt_2gb_2cpus", "ckpt_2gb_2cpus", "astro_2gb_2cpus"]), +@python_app( + cache=True, + executors=get_executors(["local_dev_testing", "ckpt_4gb_2cpus"]), + ignore_for_cache=["logging_file"], +) +def postscript(inputs=(), outputs=(), runtime_config={}, logging_file=None): + """Run postscript actions after each individual task. + + Generally consists of creating analysis plots for each result. + + Parameters + ---------- + inputs : `tuple` or `list` + Order sensitive input to the Python App. + outputs : `tuple` or `list` + Order sensitive output of the Python App. + runtime_config : `dict`, optional + Runtime configuration values. No keys are consumed. + logging_file : `File` or `None`, optional + Parsl File object poiting to the output logging file. + + Returns + ------- + outputs : `tuple` or `list` + Order sensitive output of the Python App. + + Inputs + ---------- + result_file : `File` + Parsl File object poiting to the associated Results. + + Outputs + ------- + results : `File` + Parsl File object poiting to the results. + """ + import tempfile + import tarfile + import json + + from astropy.table import Table + from astropy.io import fits as fitsio + from astropy.wcs import WCS + import matplotlib.pyplot as plt + + from kbmod_wf.task_impls.deep_plots import ( + Figure, + configure_plot, + plot_result, + select_known_objects + ) + + from kbmod_wf.utilities.logger_utilities import get_configured_logger, ErrorLogger + logger = get_configured_logger("workflow.postscript", logging_file.filepath) + + with ErrorLogger(logger): + # Grab some names from the input so we know how to name + # our output plots etc. + results_path = inputs[0][0].filepath + collname = os.path.basename(results_path).split(".")[0] + results = Table.read(results_path) + + # Grab external resources required + # - wcs, times, visitids, fakes, known objects and so on + obstimes = results.meta["mjd_mid"] + wcs = WCS(json.loads(results.meta["wcs"])) + + fakes = fitsio.open(runtime_config.get( + "fake_object_catalog", "fakes_catalog.fits" + )) + allknowns = Table.read(runtime_config.get( + "known_object_catalog", "known_objects_catalog.fits" + )) + + fakes, knowns = select_known_objects(fakes, allknowns, results) + fakes = fakes.group_by("ORBITID") + knowns = knowns.group_by("Name") + + # Make the plots, write them to tmpdir and tar them up + allplots = [] + tmpdir = tempfile.mkdtemp() + logger.info(f"Creating analysis plots for results of length: {len(results)}") + for i, res in enumerate(results): + figure = configure_plot(wcs, fig_kwargs={"figsize": (24, 12)}) + figure.fig.suptitle(f"{collname}, {res['uuid']}") + figure = plot_result(figure, res, fakes, knowns, wcs, obstimes) + + pltname = f"{collname}_L{int(res['likelihood']):0>4}_idx{i:0>4}.jpg" + pltpath = os.path.join(tmpdir, pltname) + allplots.append(pltpath) + logger.info(f"Saving {pltpath}") + plt.savefig(pltpath) + plt.close(figure.fig) + + with tarfile.open(outputs[0].filepath, "w|bz2") as tar: + for f in allplots: + tar.add(f) + + return outputs + + +def workflow_runner(env=None, runtime_config={}): + """Find all image collections in the given directory and run KBMOD + search on them. + + Running the Workflow is a multi-step process which includes + additional preparation and cleanup work that executes at the + submit location: + - Run prep + - Load runtime config + - find all files in ``staging_directory`` that match ``pattern`` + - filter out unwanted files + - Run KBMOD Search for each remaining collection + - Create a workflow Gantt chart. + + Running a KBMOD search is a 3 step process: + - step 1, executed on CPUs + - load ImageCollection + - filter unwanted rows of data from it + - load SearchConfiguration + - update search config values based on the IC metadata + - materialize a WorkUnit, requires the Rubin Data Butler + - resample a WorkUnit, targets the largest common footprint WCS + - writes the WorkUnit to file + - step 2, executed on GPUs + - loads the WorkUnit + - runs KBMOD search + - adds relevant metadata to the Results Table + - writes Results to file + - step 3, executed on CPUs + - loads Results file + - makes an analysis plot + + Parameters + ---------- + env : str, optional + Environment string used to define which resource configuration to use, + by default None + runtime_config : dict, optional + Dictionary of assorted runtime configuration parameters, by default {} + """ + resource_config = get_resource_config(env=env) + resource_config = apply_runtime_updates(resource_config, runtime_config) + workflow_config = runtime_config.get("workflow", {}) + app_configs = runtime_config.get("apps", {}) + + dfk = parsl.load(resource_config) + logger = get_configured_logger("workflow.workflow_runner") + + if dfk: + if runtime_config is not None: + logger.info(f"Using runtime configuration definition:\n{toml.dumps(runtime_config)}") + + logger.info("Starting workflow") + + directory_path = workflow_config.get("staging_directory", "collections") + file_pattern = workflow_config.get("ic_filename_pattern", "*.collection") + pattern = os.path.join(directory_path, file_pattern) + entries = glob.glob(pattern) + logger.info(f"Found {len(entries)} files in {directory_path}") + + skip_ccds = workflow_config.get("skip_ccds", ["002", "031", "061"]) + + # bookeping, used to build future output filenames + collfiles, collnames, resampled_wus = [], [], [] + for collection in entries: + if any([ccd in collection for ccd in skip_ccds]): + logger.warning(f"Skipping {collection} bad detector.") + continue + + # bookeeping for future tasks + collname = os.path.basename(collection).split(".")[0] + collnames.append(collname) + + # Register step 1 for each of the collection file + logger.info(f"Registering {collname} for step1 of {collection}") + logging_file = File(f"logs/{collname}.resample.log") + collection_file = File(collection) + collfiles.append(collection_file) + resampled_wus.append( + step1( + inputs=[collection_file], + outputs=[File(f"resampled_wus/{collname}.resampled.wu")], + runtime_config=app_configs.get("step1", {}), + logging_file=logging_file, + ) + ) + + # Register step 2 for each output of step 1 + results = [] + for resample, collname, collfile in zip(resampled_wus, collnames, collfiles): + logger.info(f"Registering {collname} for step2 of {collfile.filepath}") + logging_file = File(f"logs/{collname}.search.log") + results.append( + step2( + inputs=[resample, collfile], + outputs=[File(f"results/{collname}.results.ecsv"),], + runtime_config=app_configs.get("step2", {}), + logging_file=logging_file, + ) + ) + + # Register postscript for each output of step 2 + analysis = [] + for result, collname in zip(results, collnames): + logger.info(f"Registering {collname} for postscript") + logging_file = File(f"logs/{collname}.analysis.log") + plots_archive = File(f"plots/{collname}.plots.tar.bz2") + analysis.append( + postscript( + inputs=[result], + outputs=[plots_archive, ], + runtime_config=app_configs.get("postscript", {}), + logging_file=logging_file + ) + ) + + [f.result() for f in analysis] + dfk.wait_for_current_tasks() + logger.info("Workflow complete") + + # Create the Workflow Gantt chart + logs = parse_logdir("logs") + + success = [l for l in logs if l.success] + failed = [l for l in logs if not l.success] + print(f"N success: {len(success)}") + print(f"N fail: {len(fail)}") + + with open("failed_runs.list", "w") as f: + for l in fail: + f.write(l.name) + f.write("\n") + + with open("success_runs.list", "w") as f: + for l in success: + f.write(l.name) + f.write("\n") + + try: + import matplotlib.pyplot as plt + except ImportError: + logger.warning("Matplotlib not installed, skipping creating " + "workflow Gantt chart") + else: + fig, ax = plt.subplots(figsize=(15, 15)) + ax = plot_campaign( + ax, + logs, + relative_to_launch=True, + units="hour", + name_pos="right+column" + ) + plt.tight_layout() + plt.savefig("exec_gantt.png") + finally: + parsl.clear() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--env", + type=str, + choices=["dev", "klone"], + help="The environment to run the workflow in.", + ) + + parser.add_argument( + "--runtime-config", + type=str, + help="The complete runtime configuration filepath to use for the workflow.", + ) + + args = parser.parse_args() + + # if a runtime_config file was provided and exists, load the toml as a dict. + runtime_config = {} + if args.runtime_config is not None and os.path.exists(args.runtime_config): + with open(args.runtime_config, "r") as toml_runtime_config: + runtime_config = toml.load(toml_runtime_config) + + workflow_runner(env=args.env, runtime_config=runtime_config) diff --git a/src/kbmod_wf/task_impls/__init__.py b/src/kbmod_wf/task_impls/__init__.py index eef1ad26..ed66aefc 100644 --- a/src/kbmod_wf/task_impls/__init__.py +++ b/src/kbmod_wf/task_impls/__init__.py @@ -2,6 +2,10 @@ from .kbmod_search import kbmod_search from .uri_to_ic import uri_to_ic +from .deep_plots import * +from .uncertainty_propagation import * + + __all__ = [ "ic_to_wu", "kbmod_search", diff --git a/src/kbmod_wf/task_impls/deep_plots.py b/src/kbmod_wf/task_impls/deep_plots.py new file mode 100644 index 00000000..e763c7ae --- /dev/null +++ b/src/kbmod_wf/task_impls/deep_plots.py @@ -0,0 +1,418 @@ +import dataclasses +import json + +import matplotlib.pyplot as plt +from matplotlib import gridspec +from matplotlib.gridspec import GridSpec + +import numpy as np +from astropy.table import Table +from astropy.time import Time +from astropy.coordinates import SkyCoord +from astropy.wcs import WCS + + +__all__ = [ + "Figure", + "configure_plot", + "plot_result", + "result_to_skycoord", + "select_known_objects", + "plot_objects", + "plot_result" +] + + +KNOWN_OBJECTS_PLTSTYLE = { + "fakes": { + "tno": { + "color": "purple", + "label": "Fake TNO", + "linewidth": 1, + "markersize": 2, + "marker": "o", + "start_marker": "^", + "start_color": "green" + }, + "asteroid": { + "color": "red", + "label": "Fake Asteroid", + "linewidth": 1, + "markersize": 2, + "marker": "o", + "start_marker": "^", + "start_color": "green" + } + }, + "knowns": { + "KBO": { + "color": "darkorange", + "label": "Known KBO", + "linewidth": 1, + "markersize": 2, + "marker": "o", + "start_marker": "^", + "start_color": "green" + }, + "*": { + "color": "chocolate", + "label": "Known object", + "linewidth": 1, + "markersize": 2, + "marker": "o", + "start_marker": "^", + "start_color": "green" + } + } +} +"""Default plot style for known objects.""" + + +@dataclasses.dataclass +class Figure: + """Figure area containing Axes named ``likelihood``, ``sky``, + ``stamps`` and ``normed_stamps`` and ``psiphi`` axis, twinned to + ``likelihood``. + + The class does not define a layout, nor data, for these axes, + just their content. + """ + fig: plt.Figure + stamps: list[plt.Axes] + normed_stamps: plt.Axes + likelihood: plt.Axes + psiphi: plt.Axes + sky: plt.Axes + + +def configure_plot( + wcs, + fig_kwargs=None, + gs_kwargs=None, + layout="tight" +): + """Configure a `Figure` and place `Axes` within that figure. + + The returned plot area is a 2x2 layout, with axes named + ``likelihood``, ``sky``, ``stamps`` and ``normed_stamps``, going + in clockwise direction. The top left axis has a twinned y axis + named ``psiphi``. The stamps are 1x4 Axes with no axis labels or + ticks. + + This function only provides this layout and it does not plot data. + + Parameters + ---------- + wcs : `WCS` + WCS class to added to ``sky`` axis. + fig_kwargs : `dict` or `None`, optional + Keyword arguments passed forwards to `plt.figure`. + gs_kwargs : `dict` or `None`, optional + Keyword arguments passed forwards to `GridSpec`. + layout: `str`, optional + Figure layout is by default ``tight``. + + Returns + ------- + Figure : `obj` + Dataclass containing all of the created Axes, see `Figure` + """ + fig_kwargs = {} if fig_kwargs is None else fig_kwargs + lk = fig_kwargs.pop("layout", None) + layout = "tight" if lk is None else lk + gs_kwargs = {} if gs_kwargs is None else gs_kwargs + + fig = plt.figure(layout=layout, **fig_kwargs) + + fig_gs = GridSpec(2, 2, figure=fig, **gs_kwargs) + stamp_gs = gridspec.GridSpecFromSubplotSpec(1, 4, hspace=0.01, wspace=0.01, subplot_spec=fig_gs[1, 0]) + stamp_gs2 = gridspec.GridSpecFromSubplotSpec(1, 4, hspace=0.01, wspace=0.01, subplot_spec=fig_gs[1, 1]) + + ax_left = fig.add_subplot(stamp_gs[:]) + ax_left.axis('off') + ax_left.set_title('Coadded cutouts') + + ax_right = fig.add_subplot(stamp_gs2[:]) + ax_right.axis('off') + ax_right.set_title('Coadded cutouts normalized to mean values.') + + stamps = np.array([fig.add_subplot(stamp_gs[i]) for i in range(4)]) + + for ax in stamps[1:]: + ax.sharey(stamps[0]) + plt.setp(ax.get_yticklabels(), visible=False) + + normed = np.array([fig.add_subplot(stamp_gs2[i]) for i in range(4)]) + for ax in normed[1:]: + ax.sharey(normed[0]) + plt.setp(ax.get_yticklabels(), visible=False) + + likelihood = fig.add_subplot(fig_gs[0, 0]) + psiphi = likelihood.twinx() + likelihood.set_ylabel("Likelihood") + psiphi.set_ylabel("Psi, Phi value") + likelihood.set_xlabel("i-th image in stack") + + sky = fig.add_subplot(fig_gs[0, 1], projection=wcs) + overlay = sky.get_coords_overlay('geocentricmeanecliptic') + overlay.grid(color='black', ls='dotted') + sky.coords[0].set_major_formatter('d.dd') + sky.coords[1].set_major_formatter('d.dd') + + return Figure(fig, stamps, normed, likelihood, psiphi, sky) + + +def result_to_skycoord(result, times, obs_valid, wcs): + """Return a collection of on-sky coordinates that match the result. + + Take a result entry and return its SkyCoord positions on the sky. + + Parameters + ---------- + results : `Row` + Result + times : `np.array` + Array of MJD timestamps as floats. + obs_valid : `list[bool]` + A list of which observations are valid. + wcs : `WCS` + WCS + + Returns + ------- + coords : `SkyCoord` + World coordinates. + pos_valid : `list[bool]` + List of valid observations. + """ + pos, pos_valid = [], [] + times = Time(times, format="mjd") + dt = (times - times[0]).value + + newx = result["x"]+dt*result["vx"] + newy = result["y"]+dt*result["vy"] + coord = wcs.pixel_to_world(newx, newy) + #pos.append(list(zip(coord.ra.deg, coord.dec.deg))) + #pos_valid.append(obs_valid) # NOTE TO SELF: FIX LATER + + return coord, obs_valid #SkyCoord(pos), pos_valid + + +def select_known_objects(fakes, known_objs, results): + """Select known objects and known inserted fake objects. + + Parameters + ---------- + fakes : `Table` + Table containing visit and detector columns to match on. + known_objs : `Table` + SkyBot results containing ephemeris of all known objects at + the same timestamps. + results: `Results` + Results + + Returns + ------- + fakes : `Table` + Filtered ingoing table of fakes. + knowns : `Table` + Filetered ingoing table of knwon objects. + """ + visitids = results.meta["visits"] + detector = results.meta["detector"] + obstimes = results.meta["mjd_mid"] + wcs = WCS(json.loads(results.meta["wcs"])) + + mask = fakes[1].data["CCDNUM"] == detector + visitmask = fakes[1].data["EXPNUM"][mask] == visitids[0] + for vid in visitids[1:]: + visitmask = np.logical_or( + visitmask, + fakes[1].data["EXPNUM"][mask] == vid + ) + fakes = Table(fakes[1].data[mask][visitmask]) + fakes = fakes.group_by("ORBITID") + + (blra, bldec), (tlra, tldec), (trra, trdec), (brra, brdec) = wcs.calc_footprint() + padding = 0.005 + mask = ( + (known_objs["RA"] > tlra-padding) & + (known_objs["RA"] < blra+padding) & + (known_objs["DEC"] > bldec-padding) & + (known_objs["DEC"] < trdec+padding) + ) + knowns = known_objs[mask].group_by("Name") + + return fakes, knowns + + +def plot_objects(ax, objs, type_key, plot_kwargs, sort_on="mjd_mid"): + """Plots objects onto the given WCSAxes. + + Objects are a table containing ``RA``, ``DEC``, `type_key` and + `sort_on` columns, grouped by the individual object. The object + positions are plotted as a scattered plot, with each object getting + a different visual formatting based on the object type value. + The `type_key` names a column that selects the visual formatting + via `plot_kwargs`. The plot keyword arguments must match the name + of the type of the object, f.e. "tno", "kbo", "asteroid" etc. or + the ``"*"`` literal to match any not-specified object type. + + The `plot_kwargs` is a dictionary with keys matching the desired + object type. Value of each key is a dictionary with appropriate + key-value pairs passed onto `ax.plot_coord`. Additionally, the + dictionary may contain keys ``start_marker`` and ``start_color`` + keys that will be used to differently visualize the first position + of the object on the plot, to visualize the direction of motion. + By default this is a green triangle. + + Parameters + ---------- + ax : `WCSAxes` + Axis + objs : `Table` + Catalog of object positions, grouped by individual object. + type_key : `str` + Name of the column that contains the object type, f.e. ``tno``, + ``kbo``, ``asteroid`` etc. + plot_kwargs : `dict` + Dictionary containing the names of the object types, or ``*``, + and their formatting parameters. Optionally ``start_marker`` + and ``start_color`` may be provided for each object class to + mark the first object position. + sort_on : `str`, optional + Name of the column on which to sort each object on. By default + ``mjd_mid``. + + Returns + ------- + ax : `WCSAxes` + Axis containing the artists. + """ + plt_kwargs = plot_kwargs.copy() + legend_entries = [] + for group in objs.groups: + if sort_on is not None: + group.sort(sort_on) + + kind = np.unique(group[type_key]) + if len(kind) > 1: + raise ValueError( + "Object can only be classified into a single type. " + f"Got {kind} instead" + ) + + obj_type = group[type_key][0] + if obj_type not in plt_kwargs.keys(): + if "*" in plt_kwargs.keys(): + obj_type = "*" + else: + raise ValueError( + f"Object type `{obj_type}` not found in the plot " + f"arguments `{plt_kwargs.keys()}`" + ) + + sm = plt_kwargs[obj_type].pop("start_marker", "^") + sms = plt_kwargs[obj_type].pop("start_markersize", 1) + sc = plt_kwargs[obj_type].pop("start_color", "green") + pos = SkyCoord(group["RA"], group["DEC"], unit="degree", + frame="icrs") + + ax.plot_coord(pos, **plt_kwargs[obj_type]) + ax.scatter_coord(pos[0], marker=sm, color=sc) + + ax.legend(legend_entries) + return ax + + +def plot_result(figure, res, fakes, knowns, wcs, obstimes): + """Plot a single result onto the `Figure`. + + Four axes are plotted for each result, the likelihood and psi and + phi, the footprint of the WCS with positions of the result and + known objects within it, and two sets of postage stamp cutouts + centered on the results positions. One shares the normalization + range and the other does not. + + Parameters + ---------- + figure : `Figure` + Dataclass containing all the axes of the plot. + res : `Row` + Result. + fakes : `Table` + Catalog of all simulated objects positions. Must contain ``RA`` + ``DEC``, ``mjd_mid`` and ``type`` columns. Must be grouped by + individual object. + knowns : `Table` + Catalog of all known real objects. Must contain ``RA``, ``DEC`` + ``Type`` and ``mjd_mid`` columns. Must be grouped by individual + object. + + Returns + ------- + figure : `Figure` + Figure containing all the axes and their artists. + """ + # Top Left plot + # - Phi, Psi and Likelihood values + figure.psiphi.plot(res["psi_curve"], alpha=0.25, marker="o", label="psi") + figure.psiphi.plot(res["phi_curve"], alpha=0.25, marker="o", label="phi") + figure.psiphi.legend(loc="upper right") + + figure.likelihood.plot(res["psi_curve"]/res["phi_curve"], marker="o", label="L", color="red") + figure.likelihood.set_title( + f"Likelihood: {res['likelihood']:.5}, obs_count: {res['obs_count']}, \n " + f"(x, y): ({res['x']}, {res['y']}), (vx, vy): ({res['vx']:.6}, {res['vy']:.6})" + ) + figure.likelihood.legend(loc="upper left") + + # Top right + # - footprint of the CCD + # - known fake objects + # - known real objects + # - trajectory of the result + # Order is important because of the z-level of the plotted artists + (blra, bldec), (tlra, tldec), (trra, trdec), (brra, brdec) = wcs.calc_footprint() + figure.sky.plot( + [blra, tlra, trra, brra, blra], + [bldec, tldec, trdec, brdec, bldec], + transform=figure.sky.get_transform("world"), + color="black", label="Footprint" + ) + + if len(fakes) > 0: + figure.sky = plot_objects(figure.sky, fakes, "type", KNOWN_OBJECTS_PLTSTYLE["fakes"]) + + if len(knowns) > 0: + figure.sky = plot_objects(figure.sky, knowns, "Type", KNOWN_OBJECTS_PLTSTYLE["knowns"]) + + pos, pos_valid = result_to_skycoord(res, obstimes, res["obs_valid"], wcs) + figure.sky.plot_coord(pos, marker="o", markersize=1, linewidth=1, label="Search trj.", color="C0") + if sum(pos_valid) > 0: + figure.sky.scatter_coord(pos[pos_valid], marker="+", alpha=0.25, label="Obs. valid", color="C0") + figure.sky.scatter_coord(pos[0], marker="^", color="green", label="Starting point") + + # de-duplicate the axis entries + bb = {name : handle for handle, name in zip(*figure.sky.get_legend_handles_labels())} + handles = list(bb.values()) + names = list(bb.keys()) + figure.sky.legend(bb.values(), bb.keys(), loc="upper left", ncols=7) + + # Bottom left + # - individually scaled coadd stamps + stamp_types = ("coadd_mean", "coadd_median", + "coadd_weighted", "coadd_sum") + for ax, kind in zip(figure.stamps.ravel(), stamp_types): + ax.imshow(res[kind], interpolation="none") + ax.set_title(kind) + + # Bottom right + # - postage stamps scaled to the mean stamp + ntype = stamp_types[0] + for ax, kind in zip(figure.normed_stamps.ravel(), stamp_types): + ax.imshow(res[kind], vmin=res[ntype].min(), + vmax=res[ntype].max(), interpolation="none") + ax.set_title(kind) + + return figure diff --git a/src/kbmod_wf/task_impls/uncertainty_propagation.py b/src/kbmod_wf/task_impls/uncertainty_propagation.py new file mode 100644 index 00000000..22b30729 --- /dev/null +++ b/src/kbmod_wf/task_impls/uncertainty_propagation.py @@ -0,0 +1,283 @@ +import numpy as np +import astropy.units as u + + +__all__ = [ + "calc_means_covariance", + "kbmod2pix", + "pix2sky", + "jac_deproject_rad", + "calc_wcs_jacobian", + "calc_skypos_uncerts" +] + + +# x y vx vy +# x +# y +# vx +# vy +def calc_means_covariance(likelihood, x, y, vx, vy): + lexp = np.exp(likelihood) + lexp_sum = lexp.sum() + + cov = np.nan * np.empty((4, 4)) + + x_hat = (x * lexp).sum() / lexp_sum + y_hat = (y * lexp).sum() / lexp_sum + vx_hat = (vx * lexp).sum() / lexp_sum + vy_hat = (vy * lexp).sum() / lexp_sum + + # diagonals + xx_hat = (x**2 * lexp).sum() / lexp_sum - x_hat**2 + yy_hat = (y**2 * lexp).sum() / lexp_sum - y_hat**2 + vxvx_hat = (vx**2 * lexp).sum() / lexp_sum - vx_hat**2 + vyvy_hat = (vy**2 * lexp).sum() / lexp_sum - vy_hat**2 + + # Mixed elements + xy_hat = (x*y * lexp).sum() / lexp_sum - x_hat*y_hat + xvx_hat = (x*vx * lexp).sum() / lexp_sum - x_hat*vx_hat + xvy_hat = (x*vy * lexp).sum() / lexp_sum - x_hat*vy_hat + + yvx_hat = (y*vx * lexp).sum() / lexp_sum - y_hat*vx_hat + yvy_hat = (y*vy * lexp).sum() / lexp_sum - y_hat*vy_hat + + vxvy_hat = (vx*vy * lexp).sum() / lexp_sum - vx_hat*vy_hat + + cov = np.array([ + [ xx_hat, xy_hat, xvx_hat, xvy_hat ], + [ xy_hat, yy_hat, yvx_hat, yvy_hat ], + [ xvx_hat, yvx_hat, vxvx_hat, vxvy_hat ], + [ xvy_hat, yvy_hat, vxvy_hat, vyvy_hat ] + ]) + + return (x_hat, y_hat, vx_hat, vy_hat), cov + + +# must be mjd because the v is implicitly MJD via search +# config selection +def kbmod2pix(eigenv, cov, t1, t2, t0="start"): + t0 = t1 if t0=="start" else t0 + + dt1 = t1 - t0 + xinit = eigenv[0] + eigenv[2]*dt1 + yinit = eigenv[1] + eigenv[3]*dt1 + + dt2 = t2 - t0 + xend = eigenv[0] + eigenv[2]*dt2 + yend = eigenv[1] + eigenv[3]*dt2 + + jac = np.array([ + [1, 0, dt1, 0], + [0, 1, 0, dt1], + [1, 0, dt2, 0], + [0, 1, 0, dt2] + ]) + uncert = jac @ cov @ jac.T + + return (xinit, yinit), (xend, yend), uncert + + +def jac_deproject_rad(center_coord, u, v, projection): + # sin(dec) = cos(c) sin(dec0) + v sin(c)/r cos(dec0) + # tan(ra-ra0) = u sin(c)/r / (cos(dec0) cos(c) - v sin(dec0) sin(c)/r) + # + # d(sin(dec)) = cos(dec) ddec = s0 dc + (v ds + s dv) c0 + # dtan(ra-ra0) = sec^2(ra-ra0) dra + # = ( (u ds + s du) A - u s (dc c0 - (v ds + s dv) s0 ) )/A^2 + # where s = sin(c) / r + # c = cos(c) + # s0 = sin(dec0) + # c0 = cos(dec0) + # A = c c0 - v s s0 + + rsq = u*u + v*v + rsq1 = (u+1.e-4)**2 + v**2 + rsq2 = u**2 + (v+1.e-4)**2 + if projection is None or projection[0] == 'g': + c = s = 1./np.sqrt(1.+rsq) + s3 = s*s*s + dcdu = dsdu = -u*s3 + dcdv = dsdv = -v*s3 + elif projection[0] == 's': + s = 4. / (4.+rsq) + c = 2.*s-1. + ssq = s*s + dcdu = -u * ssq + dcdv = -v * ssq + dsdu = 0.5*dcdu + dsdv = 0.5*dcdv + elif projection[0] == 'l': + c = 1. - rsq/2. + s = np.sqrt(4.-rsq) / 2. + dcdu = -u + dcdv = -v + dsdu = -u/(4.*s) + dsdv = -v/(4.*s) + else: + r = np.sqrt(rsq) + if r == 0.: + c = s = 1 + dcdu = -u + dcdv = -v + dsdu = dsdv = 0 + else: + c = np.cos(r) + s = np.sin(r)/r + dcdu = -s*u + dcdv = -s*v + dsdu = (c-s)*u/rsq + dsdv = (c-s)*v/rsq + + # u, v, projection + # in Celestial Coordinates + ra, dec = center_coord + + _sinra, _cosra = np.sin(ra), np.cos(ra) + _sindec, _cosdec = np.sin(dec), np.cos(dec) + + _x = _cosdec * _cosra + _y = _cosdec * _sinra + _z = _sindec + + s0 = _sindec + c0 = _cosdec + sindec = c * s0 + v * s * c0 + cosdec = np.sqrt(1.-sindec*sindec) + dddu = ( s0 * dcdu + v * dsdu * c0 ) / cosdec + dddv = ( s0 * dcdv + (v * dsdv + s) * c0 ) / cosdec + + tandra_num = u * s + tandra_denom = c * c0 - v * s * s0 + # Note: A^2 sec^2(dra) = denom^2 (1 + tan^2(dra) = denom^2 + num^2 + A2sec2dra = tandra_denom**2 + tandra_num**2 + drdu = ((u * dsdu + s) * tandra_denom - u * s * ( dcdu * c0 - v * dsdu * s0 ))/A2sec2dra + drdv = (u * dsdv * tandra_denom - u * s * ( dcdv * c0 - (v * dsdv + s) * s0 ))/A2sec2dra + + drdu *= cosdec + drdv *= cosdec + + return np.array([[drdu, drdv], [dddu, dddv]]) + + +def calc_wcs_jacobian(wcs, x0, y0): + if wcs.wcs.has_cd(): + cd = wcs.wcs.cd + elif wcs.wcs.has_pc(): + cdelt1, cdelt2 = wcs.wcs.cdelt + cd11 = wcs.wcs.pc[0, 0]*cdelt1 + cd12 = wcs.wcs.pc[0, 1]*cdelt1 + cd21 = wcs.wcs.pc[1, 0]*cdelt2 + cd22 = wcs.wcs.pc[1, 1]*cdelt2 + cd = np.array([[cd11, cd12], [cd21, cd22]]) + else: + raise AttributeError("No CD or PC in WCS?") + + ctype = wcs.wcs.ctype[0] + ctype = ctype.replace("RA", "") + ctype = ctype.replace("---", "") + if ctype in ('TAN', 'TPV', 'TNX', 'TAN-SIP'): + projection = 'gnomonic' + elif ctype in ('STG', 'STG-SIP'): + projection = 'stereographic' + elif ctype in ('ZEA', 'ZEA-SIP'): + projection = 'lambert' + elif ctype in ('ARC', 'ARC-SIP'): + projection = 'postel' + else: + raise AttributeError("unsuported projection") + + pixc = np.array([x0, y0]) + p1 = pixc - wcs.wcs.crpix + + jac = np.diag([1, 1]) + + if wcs.sip is not None: + a_order = wcs.sip.a_order + b_order = wcs.sip.b_order + # Use the same order for both + order = max(a_order, b_order) + a, b = wcs.sip.a, wcs.sip.b + + # the calculation for SIP is differential + # relative to CRVAL in the + # https://fits.gsfc.nasa.gov/registry/sip/SIP_distortion_v1_0.pdf + # but it's easier to make the first two elements identities instead + # and not worry about translation later + a[1,0] += 1 + b[0,1] += 1 + ab = np.array([a, b]) + + x = p1[0] + y = p1[1] + # order = len(self.ab[0])-1 + xpow = x ** np.arange(order+1) + ypow = y ** np.arange(order+1) + p1 = np.dot(np.dot(ab, ypow), xpow) + + dxpow = np.zeros(order+1) + dypow = np.zeros(order+1) + dxpow[1:] = (np.arange(order)+1.) * xpow[:-1] + dypow[1:] = (np.arange(order)+1.) * ypow[:-1] + j1 = np.transpose([ np.dot(np.dot(ab, ypow), dxpow) , + np.dot(np.dot(ab, dypow), xpow) ]) + jac = np.dot(j1, jac) + + # With no distorsion the jacobian is just the + # affine part of the WCS transform, evaluated at center + # then shifted to the given new center coordinate and + # scaled by units + p2 = np.dot(cd, p1) + jac = np.dot(cd, jac) + + unit_convert = [ -u.degree.to(u.radian), u.degree.to(u.radian) ] + p2 *= unit_convert + + jac = jac * np.transpose( [ unit_convert ] ) + + # Convert from (u,v) to (ra, dec) + center_coord = wcs.wcs.crval * u.degree.to(u.rad) + j2 = jac_deproject_rad(center_coord, p2[0], p2[1], projection=projection) + + # rad/pix --> arcsec/pixel. + jac = np.dot(j2, jac) + jac *= u.radian.to(u.arcsec) + + return jac + + + +def pix2sky(p1, p2, uncertainties, wcs): + # The mean values are just directly convertable + s1 = wcs.pixel_to_world(*p1) + s2 = wcs.pixel_to_world(*p2) + + # the uncertainties need to be transformed + # Fit an affine transform to a small neighborhood centered + # on the tracklet. Purposfully overestimate the box size + midx = (p1[0] + p2[0])/2. + midy = (p1[1] + p2[1])/2. + jac = calc_wcs_jacobian(wcs, midx, midy) + + J = np.array([ + [jac[0, 0], jac[0, 1], 0, 0], + [jac[0, 1], jac[1, 1], 0, 0], + [ 0, 0, jac[0, 0], jac[0, 1]], + [ 0, 0, jac[0, 1], jac[1, 1]] + ]) + uncert = J @ uncertainties @ J.T + + return s1, s2, uncert + + +def calc_skypos_uncerts(trajectories, mjd_start, mjd_end, wcs): + eigenv, cov = calc_means_covariance( + trajectories["likelihood"], + trajectories["x"], + trajectories["y"], + trajectories["vx"], + trajectories["vy"] + ) + p1, p2, uncert = kbmod2pix(eigenv, cov, mjd_start, mjd_end) + s1, s2, uncert = pix2sky(p1, p2, uncert, wcs) + return s1, s2, uncert diff --git a/src/kbmod_wf/utilities/logger_utilities.py b/src/kbmod_wf/utilities/logger_utilities.py index 85780969..a041f40d 100644 --- a/src/kbmod_wf/utilities/logger_utilities.py +++ b/src/kbmod_wf/utilities/logger_utilities.py @@ -1,8 +1,27 @@ +import time import traceback import logging from logging import config -__all__ = ["LOGGING_CONFIG", "get_configured_logger", "ErrorLogger"] +import os +import re +import glob + +import numpy as np +import matplotlib.pyplot as plt +from astropy.table import Table, vstack +from astropy.time import Time, TimeDelta + + +__all__ = [ + "LOGGING_CONFIG", + "get_configured_logger", + "ErrorLogger", + "Log", + "parse_logfile", + "parse_logdir", + "plot_campaign" +] LOGGING_CONFIG = { @@ -17,31 +36,28 @@ }, "handlers": { "stdout": { - "level": "INFO", + "level": "DEBUG", "formatter": "standard", "class": "logging.StreamHandler", "stream": "ext://sys.stdout", }, "stderr": { - "level": "INFO", + "level": "DEBUG", "formatter": "standard", "class": "logging.StreamHandler", "stream": "ext://sys.stderr", }, "file": { - "level": "INFO", + "level": "DEBUG", "formatter": "standard", "class": "logging.FileHandler", - "filename": "parsl.log", + "filename": "kbmod.log", }, }, "loggers": { - "task": {"level": "INFO", "handlers": ["file", "stdout"], "propagate": False}, - "task.create_manifest": {}, - "task.ic_to_wu": {}, - "task.reproject_wu": {}, - "task.kbmod_search": {}, - "kbmod": {"level": "INFO", "handlers": ["file", "stdout"], "propagate": False}, + "parsl": {"level": "INFO", "handlers": ["file", "stdout"], "propagate": False}, + "workflow": {"level": "INFO", "handlers": ["file", "stdout"], "propagate": False}, + "kbmod": {"level": "DEBUG", "handlers": ["file", "stdout"], "propagate": False}, }, } """Default logging configuration for Parsl.""" @@ -61,13 +77,14 @@ def get_configured_logger(logger_name, file_path=None): if file_path is not None: logconf["handlers"]["file"]["filename"] = file_path config.dictConfig(logconf) + logging.Formatter.converter = time.gmtime logger = logging.getLogger() - return logging.getLogger(logger_name) class ErrorLogger: - """Logs received errors before re-raising them. + """Context manager that logs received errors before re-raising + them. Parameters ---------- @@ -91,3 +108,676 @@ def __exit__(self, exc, value, tb): msg = "".join(msg) self.logger.error(msg) return self.silence_errors + + +class Log: + """A log of a task. + + Note that each Log can consist of multiple log files. For example + the complete log of a Task ``20210101_A`` can consist of multiple + steps called ``20210101_A_step1``, ``20210101_A_step2`` etc. thus + making the full Log of the event a union of every step. + + Individual steps are parsed, extracting individual log entries and + context; assigning to each log step a success, error and completion + status. + A log is succesfull if all individual steps were succesfull. + A step is not completed if the log does not contain the expected + final log entry. + A step is marked with an error if the log produces a traceback or + an error report. A not completed log step does not indicate a error. + Not-completed, but not-errored, steps may be safely re-run. + A log is a failure if all steps were not completed or any of the + steps has an error. A log may be safely re-run if no steps have an + error, but not all steps have completed. + + Each step may contain repeated sequences of messages. When + running a task with Parsl, and it is pre-empted or it fails, + it can be resubmitted - leading to repeated collections of + messages. + + To overwrite the default contextualization of the parsed logs + override the `_parse_single` method. + + Parameters + ---------- + logfiles : `list` + List of log files belonging to the same Log event. + name : `str` or `None`, optional + Name of the log event, f.e. ``20210101_A``. When not provided + it will be determined as the longest shared common prefix of + each step. + stepnames : `list` or `None`, optional + Name of the each step, f.e. ``20210101_A_step1`` etc. When not + provided, it will be determined from the log itself. + + Attributes + ---------- + logfiles : `list` + Paths to log files. + name : `str` + Log group name + stepnames : `list[str]` + Names of steps involved in the log. + nsteps : `int` + Number of steps in the log. + started : `list` + List of step start times. + completed : `list` + List of step end times. + success : `list[bool]` + List of whether or not a step was determined to be succesfull. + error : `bool` + `True` when log contains a step with an error. + data : `Table` + Table of all log entries. Table is grouped by stepname and + step index. + """ + + fmt: str = r"\[(?P[\w|-]*) (?P[\w|-]*) (?P