diff --git a/notebooks/KBMOD_Demo.ipynb b/notebooks/KBMOD_Demo.ipynb index dcfa68cbd..91b0eb04e 100644 --- a/notebooks/KBMOD_Demo.ipynb +++ b/notebooks/KBMOD_Demo.ipynb @@ -1,7 +1,6 @@ { "cells": [ { - "attachments": {}, "cell_type": "markdown", "metadata": {}, "source": [ @@ -185,8 +184,8 @@ " \"average_angle\": 0.0,\n", "}\n", "\n", - "rs = run_search(input_parameters)\n", - "rs.run_search()" + "rs = SearchRunner()\n", + "rs.run_search(input_parameters)" ] }, { diff --git a/src/kbmod/configuration.py b/src/kbmod/configuration.py index 07ca2ae2d..0b2871bc4 100644 --- a/src/kbmod/configuration.py +++ b/src/kbmod/configuration.py @@ -123,6 +123,24 @@ def set(self, param, value, strict=True): else: self._params[param] = value + def set_multiple(self, overrides, strict=True): + """Sets multiple parameters from a dictionary. + + Parameters + ---------- + overrides : `dict` + A dictionary of parameter->value to overwrite. + strict : `bool` + Raise an exception on unknown parameters. + + Raises + ------ + Raises a ``KeyError`` if any parameter is not part on the list of known parameters + and ``strict`` is False. + """ + for key, value in overrides.items(): + self.set(key, value) + def validate(self): """Check that the configuration has the necessary parameters. diff --git a/src/kbmod/run_search.py b/src/kbmod/run_search.py index d9476f40a..8870ccfdb 100644 --- a/src/kbmod/run_search.py +++ b/src/kbmod/run_search.py @@ -24,182 +24,159 @@ apply_mask_operations, ) from .result_list import * +from .work_unit import WorkUnit -class run_search: - """ - Run the KBMOD grid search. - - Parameters - ---------- - input_parameters : ``dict`` - Additional parameters. Merged with (and checked against) the loaded input file and - the defaults provided in the SearchConfiguration class. - config_file : ``str`` (optional) - The name and path of the configuration file. - - Attributes - ---------- - config : ``SearchConfiguration`` - Search parameters. - """ - - def __init__(self, input_parameters, config_file=None): - # Load parameters from a file. - if config_file != None: - self.config = SearchConfiguration.from_file(config_file) - else: - self.config = SearchConfiguration() - - # Load any additional parameters (overwriting what is there). - if len(input_parameters) > 0: - for key, value in input_parameters.items(): - self.config.set(key, value) +class SearchRunner: + """A class to run the KBMOD grid search.""" - # Validate the configuration. - self.config.validate() + def __init__(self): + pass - def do_masking(self, stack): + def do_masking(self, config, stack): """Perform the masking based on the search's configuration parameters. Parameters ---------- - stack : `kbmod.ImageStack` - The stack before the masks have been applied. + config : `SearchConfiguration` + The configuration parameters + stack : `ImageStack` + The stack before the masks have been applied. Modified in-place. + + Returns + ------- + stack : `ImageStack` + The stack after the masks have been applied. """ mask_steps = [] # Prioritize the mask_bit_vector over the dictionary based version. - if self.config["mask_bit_vector"]: - mask_steps.append(BitVectorMasker(self.config["mask_bit_vector"], [0])) - elif self.config["flag_keys"] and len(self.config["flag_keys"]) > 0: - mask_steps.append(DictionaryMasker(self.config["mask_bits_dict"], self.config["flag_keys"])) + if config["mask_bit_vector"]: + mask_steps.append(BitVectorMasker(config["mask_bit_vector"], [0])) + elif config["flag_keys"] and len(config["flag_keys"]) > 0: + mask_steps.append(DictionaryMasker(config["mask_bits_dict"], config["flag_keys"])) # Add the threshold mask if it is set. - if self.config["mask_threshold"]: - mask_steps.append(ThresholdMask(self.config["mask_threshold"])) + if config["mask_threshold"]: + mask_steps.append(ThresholdMask(config["mask_threshold"])) # Add the global masking if it is set. - if self.config["repeated_flag_keys"] and len(self.config["repeated_flag_keys"]) > 0: + if config["repeated_flag_keys"] and len(config["repeated_flag_keys"]) > 0: mask_steps.append( GlobalDictionaryMasker( - self.config["mask_bits_dict"], - self.config["repeated_flag_keys"], - self.config["mask_num_images"], + config["mask_bits_dict"], + config["repeated_flag_keys"], + config["mask_num_images"], ) ) # Grow the mask. - if self.config["mask_grow"] and self.config["mask_grow"] > 0: - mask_steps.append(GrowMask(self.config["mask_grow"])) + if config["mask_grow"] and config["mask_grow"] > 0: + mask_steps.append(GrowMask(config["mask_grow"])) # Apply the masks. stack = apply_mask_operations(stack, mask_steps) return stack - def do_gpu_search(self, search): + def get_angle_limits(self, config): + """Compute the angle limits based on the configuration information. + + Parameters + ---------- + config : `SearchConfiguration` + The configuration parameters + + Returns + ------- + res : `list` + A list with the minimum and maximum angle to search (in pixel space). + """ + ang_min = config["average_angle"] - config["ang_arr"][0] + ang_max = config["average_angle"] + config["ang_arr"][1] + return [ang_min, ang_max] + + def do_gpu_search(self, config, search): """ Performs search on the GPU. Parameters ---------- - search : ``~kbmod.search.Search`` - Search object. + config : `SearchConfiguration` + The configuration parameters + search : `StackSearch` + The C++ object that holds data and does searching. + + Returns + ------- + search : `StackSearch` + The C++ object holding the data and results. """ width = search.get_image_width() height = search.get_image_height() - search_params = {} - - # Run the grid search - # Set min and max values for angle and velocity - ang_min = self.config["average_angle"] - self.config["ang_arr"][0] - ang_max = self.config["average_angle"] + self.config["ang_arr"][1] - vel_min = self.config["v_arr"][0] - vel_max = self.config["v_arr"][1] - search_params["ang_lims"] = [ang_min, ang_max] - search_params["vel_lims"] = [vel_min, vel_max] + ang_lim = self.get_angle_limits(config) # Set the search bounds. - if self.config["x_pixel_bounds"] and len(self.config["x_pixel_bounds"]) == 2: - search.set_start_bounds_x(self.config["x_pixel_bounds"][0], self.config["x_pixel_bounds"][1]) - elif self.config["x_pixel_buffer"] and self.config["x_pixel_buffer"] > 0: - search.set_start_bounds_x(-self.config["x_pixel_buffer"], width + self.config["x_pixel_buffer"]) + if config["x_pixel_bounds"] and len(config["x_pixel_bounds"]) == 2: + search.set_start_bounds_x(config["x_pixel_bounds"][0], config["x_pixel_bounds"][1]) + elif config["x_pixel_buffer"] and config["x_pixel_buffer"] > 0: + search.set_start_bounds_x(-config["x_pixel_buffer"], width + config["x_pixel_buffer"]) - if self.config["y_pixel_bounds"] and len(self.config["y_pixel_bounds"]) == 2: - search.set_start_bounds_y(self.config["y_pixel_bounds"][0], self.config["y_pixel_bounds"][1]) - elif self.config["y_pixel_buffer"] and self.config["y_pixel_buffer"] > 0: - search.set_start_bounds_y(-self.config["y_pixel_buffer"], height + self.config["y_pixel_buffer"]) + if config["y_pixel_bounds"] and len(config["y_pixel_bounds"]) == 2: + search.set_start_bounds_y(config["y_pixel_bounds"][0], config["y_pixel_bounds"][1]) + elif config["y_pixel_buffer"] and config["y_pixel_buffer"] > 0: + search.set_start_bounds_y(-config["y_pixel_buffer"], height + config["y_pixel_buffer"]) search_start = time.time() print("Starting Search") print("---------------------------------------") - print(f"Average Angle = {self.config['average_angle']}") - print(f"Search Angle Limits = {search_params['ang_lims']}") - print(f"Velocity Limits = {search_params['vel_lims']}") + print(f"Average Angle = {config['average_angle']}") + print(f"Search Angle Limits = {ang_lim}") + print(f"Velocity Limits = {config['v_arr']}") # If we are using gpu_filtering, enable it and set the parameters. - if self.config["gpu_filter"]: + if config["gpu_filter"]: print("Using in-line GPU sigmaG filtering methods", flush=True) - coeff = find_sigmaG_coeff(self.config["sigmaG_lims"]) + coeff = find_sigmaG_coeff(config["sigmaG_lims"]) search.enable_gpu_sigmag_filter( - np.array(self.config["sigmaG_lims"]) / 100.0, + np.array(config["sigmaG_lims"]) / 100.0, coeff, - self.config["lh_level"], + config["lh_level"], ) # If we are using an encoded image representation on GPU, enable it and # set the parameters. - if self.config["encode_psi_bytes"] > 0 or self.config["encode_phi_bytes"] > 0: - search.enable_gpu_encoding(self.config["encode_psi_bytes"], self.config["encode_phi_bytes"]) + if config["encode_psi_bytes"] > 0 or config["encode_phi_bytes"] > 0: + search.enable_gpu_encoding(config["encode_psi_bytes"], config["encode_phi_bytes"]) # Enable debugging. - if self.config["debug"]: - search.set_debug(self.config["debug"]) + if config["debug"]: + search.set_debug(config["debug"]) search.search( - int(self.config["ang_arr"][2]), - int(self.config["v_arr"][2]), - *search_params["ang_lims"], - *search_params["vel_lims"], - int(self.config["num_obs"]), + int(config["ang_arr"][2]), + int(config["v_arr"][2]), + ang_lim[0], + ang_lim[1], + config["v_arr"][0], + config["v_arr"][1], + int(config["num_obs"]), ) + print("Search finished in {0:.3f}s".format(time.time() - search_start), flush=True) - return (search, search_params) + return search - def run_search(self): + def run_search(self, config, stack): """This function serves as the highest-level python interface for starting - a KBMOD search. - - The `config` attribute requires the following key value pairs. + a KBMOD search given an ImageStack and SearchConfiguration. Parameters ---------- - self.config.im_filepath : string - Path to the folder containing the images to be ingested into - KBMOD and searched over. - self.config.res_filepath : string - Path to the folder that will contain the results from the search. - If ``None`` the program skips outputting the files. - self.config.out_suffix : string - Suffix to append to the output files. Used to differentiate - between different searches over the same stack of images. - self.config.time_file : string - Path to the file containing the image times (or None to use - values from the FITS files). - self.config.psf_file : string - Path to the file containing the image PSFs (or None to use default). - self.config.lh_level : float - Minimum acceptable likelihood level for a trajectory. - Trajectories with likelihoods below this value will be discarded. - self.config.psf_val : float - The value of the variance of the default PSF to use. - self.config.mjd_lims : numpy array - Limits the search to images taken within the limits input by - mjd_lims (or None for no filtering). - self.config.average_angle : float - Overrides the ecliptic angle calculation and instead centers - the average search around average_angle. + config : `SearchConfiguration` + The configuration parameters + stack : `ImageStack` + The stack before the masks have been applied. Modified in-place. + Returns ------- @@ -208,79 +185,148 @@ def run_search(self): """ start = time.time() - # Load images to search - stack, wcs_list, mjds = load_input_from_config(self.config, verbose=self.config["debug"]) - - # Compute the suggested search angle from the images. This is a 12 arcsecond - # segment parallel to the ecliptic is seen under from the image origin. - if self.config["average_angle"] == None: - center_pixel = (stack.get_width() / 2, stack.get_height() / 2) - self.config.set("average_angle", self._calc_suggested_angle(wcs_list[0], center_pixel)) + # Collect the MJDs. + mjds = [] + for i in range(stack.img_count()): + mjds.append(stack.get_obstime(i)) # Set up the post processing data structure. - kb_post_process = PostProcess(self.config, mjds) + kb_post_process = PostProcess(config, mjds) # Apply the mask to the images. - if self.config["do_mask"]: - stack = self.do_masking(stack) + if config["do_mask"]: + stack = self.do_masking(config, stack) # Perform the actual search. search = kb.StackSearch(stack) - search, search_params = self.do_gpu_search(search) + search = self.do_gpu_search(config, search) # Load the KBMOD results into Python and apply a filter based on - # 'filter_type. - mjds = np.array(mjds) + # 'filter_type'. keep = kb_post_process.load_and_filter_results( search, - self.config["lh_level"], - chunk_size=self.config["chunk_size"], - max_lh=self.config["max_lh"], + config["lh_level"], + chunk_size=config["chunk_size"], + max_lh=config["max_lh"], ) - if self.config["do_stamp_filter"]: + if config["do_stamp_filter"]: kb_post_process.apply_stamp_filter( keep, search, - center_thresh=self.config["center_thresh"], - peak_offset=self.config["peak_offset"], - mom_lims=self.config["mom_lims"], - stamp_type=self.config["stamp_type"], - stamp_radius=self.config["stamp_radius"], + center_thresh=config["center_thresh"], + peak_offset=config["peak_offset"], + mom_lims=config["mom_lims"], + stamp_type=config["stamp_type"], + stamp_radius=config["stamp_radius"], ) - if self.config["do_clustering"]: + if config["do_clustering"]: cluster_params = {} cluster_params["x_size"] = stack.get_width() cluster_params["y_size"] = stack.get_height() - cluster_params["vel_lims"] = search_params["vel_lims"] - cluster_params["ang_lims"] = search_params["ang_lims"] - cluster_params["mjd"] = mjds + cluster_params["vel_lims"] = config["v_arr"] + cluster_params["ang_lims"] = self.get_angle_limits(config) + cluster_params["mjd"] = np.array(mjds) kb_post_process.apply_clustering(keep, cluster_params) # Extract all the stamps. - kb_post_process.get_all_stamps(keep, search, self.config["stamp_radius"]) + kb_post_process.get_all_stamps(keep, search, config["stamp_radius"]) + # TODO - Re-enable the known object counting once we have a way to pass + # A WCS into the WorkUnit. # Count how many known objects we found. - if self.config["known_obj_thresh"]: - self._count_known_matches(keep, search) - - del search + # if config["known_obj_thresh"]: + # _count_known_matches(keep, search) # Save the results and the configuration information used. print(f"Found {keep.num_results()} potential trajectories.") - if self.config["res_filepath"] is not None: - keep.save_to_files(self.config["res_filepath"], self.config["output_suffix"]) + if config["res_filepath"] is not None: + keep.save_to_files(config["res_filepath"], config["output_suffix"]) - config_filename = os.path.join( - self.config["res_filepath"], f"config_{self.config['output_suffix']}.yml" - ) - self.config.to_file(config_filename, overwrite=True) + config_filename = os.path.join(config["res_filepath"], f"config_{config['output_suffix']}.yml") + config.to_file(config_filename, overwrite=True) end = time.time() print("Time taken for patch: ", end - start) return keep + def run_search_from_config(self, config): + """Run a KBMOD search from a SearchConfiguration object. + + Parameters + ---------- + config : `SearchConfiguration` or `dict` + The configuration object with all the information for the run. + + Returns + ------- + keep : ResultList + The results. + """ + if type(config) is dict: + config = SearchConfiguration.from_dict(config) + + # Load the image files. + stack, wcs_list, _ = load_input_from_config(config, verbose=config["debug"]) + + # Compute the suggested search angle from the images. This is a 12 arcsecond + # segment parallel to the ecliptic is seen under from the image origin. + if config["average_angle"] == None: + center_pixel = (stack.get_width() / 2, stack.get_height() / 2) + config.set("average_angle", self._calc_suggested_angle(wcs_list[0], center_pixel)) + + return self.run_search(config, stack) + + def run_search_from_config_file(self, filename, overrides=None): + """Run a KBMOD search from a configuration file. + + Parameters + ---------- + filename : `str` + The name of the configuration file. + overrides : `dict`, optional + A dictionary of configuration parameters to override. + + Returns + ------- + keep : ResultList + The results. + """ + config = SearchConfiguration.from_file(filename) + if overrides is not None: + config.set_multiple(overrides) + + return self.run_search_from_config(config) + + def run_search_from_work_unit_file(self, filename, overrides=None): + """Run a KBMOD search from a WorkUnit file. + + Parameters + ---------- + filename : `str` + The name of the WorkUnit file. + overrides : `dict`, optional + A dictionary of configuration parameters to override. + + Returns + ------- + keep : ResultList + The results. + """ + work = WorkUnit.from_fits(filename) + + if overrides is not None: + work.config.set_multiple(overrides) + + if work.config["average_angle"] == None: + print("WARNING: average_angle is unset. WorkUnit currently uses a default of 0.0") + + # TODO: Support the correct setting of the angle. + work.config.set("average_angle", 0.0) + + return self.run_search(work.config, work.im_stack) + def _count_known_matches(self, result_list, search): """Look up the known objects that overlap the images and count how many are found among the results. @@ -293,7 +339,7 @@ def _count_known_matches(self, result_list, search): A StackSearch object containing information about the search. """ # Get the image metadata - im_filepath = self.config["im_filepath"] + im_filepath = config["im_filepath"] filenames = sorted(os.listdir(im_filepath)) image_list = [os.path.join(im_filepath, im_name) for im_name in filenames] metadata = koffi.ImageMetadataStack(image_list) @@ -310,9 +356,9 @@ def _count_known_matches(self, result_list, search): print("-----------------") matches = {} - known_obj_thresh = self.config["known_obj_thresh"] - min_obs = self.config["known_obj_obs"] - if self.config["known_obj_jpl"]: + known_obj_thresh = config["known_obj_thresh"] + min_obs = config["known_obj_obs"] + if config["known_obj_jpl"]: print("Quering known objects from JPL") matches = koffi.jpl_query_known_objects_stack( potential_sources=ps_list, @@ -335,9 +381,7 @@ def _count_known_matches(self, result_list, search): if len(matches[ps_id]) > 0: num_found += 1 matches_string += f"result id {ps_id}:" + str(matches[ps_id])[1:-1] + "\n" - print( - "Found %i objects with at least %i potential observations." % (num_found, self.config["num_obs"]) - ) + print("Found %i objects with at least %i potential observations." % (num_found, config["num_obs"])) if num_found > 0: print(matches_string) @@ -370,10 +414,6 @@ def _calc_suggested_angle(self, wcs, center_pixel=(1000, 2000), step=12): ---- It is not neccessary to calculate this angle for each image in an image set if they have all been warped to a common WCS. - - See Also - -------- - run_search.do_gpu_search """ # pick a starting pixel approximately near the center of the image # convert it to ecliptic coordinates diff --git a/tests/diff_test.py b/tests/diff_test.py index 98c960bc7..502ad9f99 100644 --- a/tests/diff_test.py +++ b/tests/diff_test.py @@ -6,7 +6,7 @@ import numpy as np -from kbmod.run_search import run_search +from kbmod.run_search import SearchRunner def check_and_create_goldens_dir(): @@ -224,8 +224,8 @@ def perform_search(im_filepath, time_file, psf_file, res_filepath, res_suffix, s "encode_phi_bytes": -1, } - rs = run_search(input_parameters) - rs.run_search() + rs = SearchRunner() + rs.run_search_from_config(input_parameters) if __name__ == "__main__": diff --git a/tests/regression_test.py b/tests/regression_test.py index fa204b3ed..51e15e3d5 100644 --- a/tests/regression_test.py +++ b/tests/regression_test.py @@ -14,7 +14,7 @@ from kbmod.fake_data_creator import add_fake_object from kbmod.file_utils import * -from kbmod.run_search import run_search +from kbmod.run_search import SearchRunner from kbmod.search import * @@ -404,8 +404,8 @@ def perform_search(im_filepath, time_file, psf_file, res_filepath, results_suffi "debug": True, } - rs = run_search(input_parameters) - rs.run_search() + rs = SearchRunner() + rs.run_search_from_config(input_parameters) if __name__ == "__main__": diff --git a/tests/test_configuration.py b/tests/test_configuration.py index ae8f8ee66..52185a700 100644 --- a/tests/test_configuration.py +++ b/tests/test_configuration.py @@ -32,6 +32,16 @@ def test_set(self): # The set should fail when using unknown parameters and strict checking. self.assertRaises(KeyError, config.set, "My_new_param", 100, strict=True) + def set_multiple(self): + config = SearchConfiguration() + self.assertIsNone(config["im_filepath"]) + self.assertEqual(config["encode_psi_bytes"], -1) + + d = {"im_filepath": "Here", "encode_psi_bytes": 2} + config.set_multiple(d) + self.assertEqual(config["im_filepath"], "Here") + self.assertEqual(config["encode_psi_bytes"], 2) + def test_from_dict(self): d = {"im_filepath": "Here2", "num_obs": 5} config = SearchConfiguration.from_dict(d) diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index fe6ad7c10..834400023 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -1,9 +1,12 @@ -import unittest - +import math import numpy as np +import tempfile +import unittest +from kbmod.fake_data_creator import * from kbmod.run_search import * from kbmod.search import * +from kbmod.work_unit import WorkUnit # from .utils_for_tests import get_absolute_demo_data_path # import utils_for_tests @@ -51,8 +54,8 @@ def setUp(self): @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") def test_demo_defaults(self): - rs = run_search(self.input_parameters) - keep = rs.run_search() + rs = SearchRunner() + keep = rs.run_search_from_config(self.input_parameters) self.assertGreaterEqual(keep.num_results(), 1) self.assertEqual(keep.results[0].stamp.size, 441) @@ -60,8 +63,11 @@ def test_demo_defaults(self): def test_demo_config_file(self): im_filepath = get_absolute_demo_data_path("demo") config_file = get_absolute_demo_data_path("demo_config.yml") - rs = run_search({"im_filepath": im_filepath}, config_file=config_file) - keep = rs.run_search() + rs = SearchRunner() + keep = rs.run_search_from_config_file( + config_file, + overrides={"im_filepath": im_filepath}, + ) self.assertGreaterEqual(keep.num_results(), 1) self.assertEqual(keep.results[0].stamp.size, 441) @@ -70,8 +76,8 @@ def test_demo_stamp_size(self): self.input_parameters["stamp_radius"] = 15 self.input_parameters["mom_lims"] = [80.0, 80.0, 50.0, 20.0, 20.0] - rs = run_search(self.input_parameters) - keep = rs.run_search() + rs = SearchRunner() + keep = rs.run_search_from_config(self.input_parameters) self.assertGreaterEqual(keep.num_results(), 1) self.assertIsNotNone(keep.results[0].stamp) @@ -81,6 +87,35 @@ def test_demo_stamp_size(self): for s in keep.results[0].all_stamps: self.assertEqual(s.size, 961) + @unittest.skipIf(not HAS_GPU, "Skipping test (no GPU detected)") + def test_e2e_work_unit(self): + num_images = 10 + + # Create a fake data set with a single bright fake object. + ds = FakeDataSet(128, 128, num_images, obs_per_day=10, use_seed=True) + trj = Trajectory() + trj.x = 50 + trj.y = 60 + trj.vx = 5.0 + trj.vy = 0.0 + trj.flux = 500.0 + ds.insert_object(trj) + + # Set the configuration to pick up the fake object. + config = SearchConfiguration() + config.set("ang_arr", [math.pi, math.pi, 16]) + config.set("v_arr", [0, 10.0, 20]) + + work = WorkUnit(im_stack=ds.stack, config=config) + + with tempfile.TemporaryDirectory() as dir_name: + file_path = f"{dir_name}/test_workunit.fits" + work.to_fits(file_path) + + rs = SearchRunner() + keep = rs.run_search_from_work_unit_file(file_path) + self.assertGreaterEqual(keep.num_results(), 1) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_masking.py b/tests/test_masking.py index 49e22760e..e5c941499 100644 --- a/tests/test_masking.py +++ b/tests/test_masking.py @@ -1,5 +1,6 @@ import unittest +from kbmod.configuration import SearchConfiguration from kbmod.masking import ( BitVectorMasker, DictionaryMasker, @@ -8,7 +9,7 @@ ThresholdMask, apply_mask_operations, ) -from kbmod.run_search import * +from kbmod.run_search import SearchRunner from kbmod.search import * @@ -214,9 +215,12 @@ def test_apply_masks(self): bad_set = set(bad_pixels) + config = SearchConfiguration() + config.set_multiple(overrides) + # Do the actual masking. - rs = run_search(overrides) - self.stack = rs.do_masking(self.stack) + rs = SearchRunner() + self.stack = rs.do_masking(config, self.stack) # Test the the correct pixels have been masked. for i in range(self.img_count):