Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

functionalize and remove the StampCreator class #761

Merged
merged 4 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions benchmarks/bench_filter_stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
StampParameters,
StampType,
Trajectory,
StampCreator,
)


Expand Down Expand Up @@ -55,7 +54,6 @@ def run_search_benchmark(params):

# Create an empty search stack.
# im_stack = ImageStack([])
sc = StampCreator()

# Do the timing runs.
tmr = timeit.Timer(stmt="sc.filter_stamp(stamp, params)", globals=locals())
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
2 changes: 1 addition & 1 deletion notebooks/kbmod_search_results_for_fakes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"import os\n",
"\n",
"from kbmod.analysis.plotting import *\n",
"from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n",
"from kbmod.search import ImageStack, PSF, Trajectory\n",
"from kbmod.results import Results\n",
"from kbmod.work_unit import WorkUnit\n",
"\n",
Expand Down
20 changes: 14 additions & 6 deletions notebooks/kbmod_visualize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
"\n",
"from kbmod.analysis.plotting import *\n",
"from kbmod.util_functions import load_deccam_layered_image\n",
"from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n",
"from kbmod.search import (\n",
" ImageStack,\n",
" PSF,\n",
" Trajectory,\n",
" get_stamps,\n",
" get_median_stamp,\n",
" get_mean_stamp,\n",
" get_summed_stamp,\n",
")\n",
"from kbmod.results import Results\n",
"\n",
"# Data paths\n",
Expand Down Expand Up @@ -113,7 +121,7 @@
"source": [
"# Create and Visualize Stamps\n",
"\n",
"Stamps are a critical tool for analyzing and debugging proposed detections. They can be created automatically using the `StampCreator` class. It requires a few pieces of data:\n",
"Stamps are a critical tool for analyzing and debugging proposed detections. They can be created automatically using the stamp creation utilities. It requires a few pieces of data:\n",
"* search_stack - provides the machinery for making predictions on the image (needed to handle the various corrections).\n",
"* trajectory - Contains the information about where to place the stamps (the underlying trajectory).\n",
"* stamp_radius - The radius in pixels."
Expand All @@ -133,7 +141,7 @@
"trj.vy = 3.3\n",
"\n",
"# Create the stamps around this trajectory.\n",
"stamps = StampCreator.get_stamps(stack, trj, 20)"
"stamps = get_stamps(stack, trj, 20)"
]
},
{
Expand Down Expand Up @@ -174,9 +182,9 @@
"\n",
"plot_multiple_images(\n",
" [\n",
" StampCreator.get_summed_stamp(stack, trj, 10, inds),\n",
" StampCreator.get_mean_stamp(stack, trj, 10, inds),\n",
" StampCreator.get_median_stamp(stack, trj, 10, inds),\n",
" get_summed_stamp(stack, trj, 10, inds),\n",
" get_mean_stamp(stack, trj, 10, inds),\n",
" get_median_stamp(stack, trj, 10, inds),\n",
" ],\n",
" labels=[\"Summed\", \"Mean\", \"Median\"],\n",
" norm=True,\n",
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/analysis/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kbmod.analysis.plotting import plot_multiple_images
from kbmod.search import StampCreator
from kbmod.search import get_stamps
from kbmod.util_functions import mjd_to_day

import numpy as np
Expand Down Expand Up @@ -33,7 +33,7 @@ def generate_all_stamps(self, radius=10):
radius of the stamp.
"""
self.results.table["all_stamps"] = [
StampCreator.get_stamps(self.im_stack, trj, radius) for trj in self.trajectories
get_stamps(self.im_stack, trj, radius) for trj in self.trajectories
]

def count_num_days(self):
Expand Down
7 changes: 4 additions & 3 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
DebugTimer,
ImageStack,
RawImage,
StampCreator,
StampParameters,
StampType,
Logging,
get_stamps,
get_coadded_stamps,
)


Expand Down Expand Up @@ -141,7 +142,7 @@ def get_coadds_and_filter_results(result_data, im_stack, stamp_params, chunk_siz

# Create and filter the results, using the GPU if there is one and enough
# trajectories to make it worthwhile.
stamps_slice = StampCreator.get_coadded_stamps(
stamps_slice = get_coadded_stamps(
im_stack,
trj_slice,
bool_slice,
Expand Down Expand Up @@ -240,7 +241,7 @@ def append_all_stamps(result_data, im_stack, stamp_radius):

all_stamps = []
for trj in result_data.make_trajectory_list():
stamps = StampCreator.get_stamps(im_stack, trj, stamp_radius)
stamps = get_stamps(im_stack, trj, stamp_radius)
all_stamps.append(np.array([stamp.image for stamp in stamps]))

# We add the column even if it is empty so we can have consistent
Expand Down
5 changes: 0 additions & 5 deletions src/kbmod/search/pydocs/stamp_creator_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
#define STAMP_CREATOR_DOCS

namespace pydocs {
static const auto DOC_StampCreator = R"doc(
A class for creating a set of stamps or a co-added stamp
from an ImageStack and Trajectory.
)doc";

static const auto DOC_StampCreator_create_stamps = R"doc(
Create a vector of stamps centered on the predicted position
of an Trajectory at different times.
Expand Down
52 changes: 23 additions & 29 deletions src/kbmod/search/stamp_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ void deviceGetCoadds(const uint64_t num_images, const uint64_t width, const uint
GPUArray<int>& use_index_vect, GPUArray<float>& results);
#endif

StampCreator::StampCreator() {}

std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
std::vector<RawImage> create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
bool keep_no_data, const std::vector<bool>& use_index) {
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
if (use_index.size() > 0)
assert_sizes_equal(use_index.size(), stack.img_count(), "create_stamps() use_index");
Expand All @@ -35,33 +33,33 @@ std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Traje
// For stamps used for visualization we replace invalid pixels with zeros
// and return all the stamps (regardless of whether individual timesteps
// have been filtered).
std::vector<RawImage> StampCreator::get_stamps(ImageStack& stack, const Trajectory& t, int radius) {
std::vector<RawImage> get_stamps(ImageStack& stack, const Trajectory& t, int radius) {
std::vector<bool> empty_vect;
return create_stamps(stack, t, radius, false /*=keep_no_data*/, empty_vect);
}

// For creating coadded stamps, we do not interpolate the pixel values and keep
// invalid pixels tagged (so we can filter it out of mean/median).
RawImage StampCreator::get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
RawImage get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
return create_median_image(create_stamps(stack, trj, radius, true /*=keep_no_data*/, use_index));
}

// For creating coadded stamps, we do not interpolate the pixel values and keep
// invalid pixels tagged (so we can filter it out of mean/median).
RawImage StampCreator::get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
RawImage get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
return create_mean_image(create_stamps(stack, trj, radius, true /*=keep_no_data*/, use_index));
}

// For creating summed stamps, we do not interpolate the pixel values and replace
// invalid pixels with zero (which is the same as filtering it out for the sum).
RawImage StampCreator::get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
RawImage get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
return create_summed_image(create_stamps(stack, trj, radius, false /*=keep_no_data*/, use_index));
}

std::vector<RawImage> StampCreator::get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<RawImage> get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
const StampParameters& params, bool use_gpu) {
logging::Logger* rs_logger = logging::getLogger("kbmod.search.stamp_creator");
Expand All @@ -80,7 +78,7 @@ std::vector<RawImage> StampCreator::get_coadded_stamps(ImageStack& stack, std::v
return get_coadded_stamps_cpu(stack, t_array, use_index_vect, params);
}

std::vector<RawImage> StampCreator::get_coadded_stamps_cpu(ImageStack& stack,
std::vector<RawImage> get_coadded_stamps_cpu(ImageStack& stack,
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
Expand Down Expand Up @@ -117,7 +115,7 @@ std::vector<RawImage> StampCreator::get_coadded_stamps_cpu(ImageStack& stack,
return results;
}

bool StampCreator::filter_stamp(const RawImage& img, const StampParameters& params) {
bool filter_stamp(const RawImage& img, const StampParameters& params) {
if (params.radius <= 0) throw std::runtime_error("Invalid stamp radius=" + std::to_string(params.radius));

// Allocate space for the coadd information and initialize to zero.
Expand Down Expand Up @@ -159,7 +157,7 @@ bool StampCreator::filter_stamp(const RawImage& img, const StampParameters& para
return false;
}

std::vector<RawImage> StampCreator::get_coadded_stamps_gpu(ImageStack& stack,
std::vector<RawImage> get_coadded_stamps_gpu(ImageStack& stack,
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
Expand Down Expand Up @@ -273,7 +271,7 @@ std::vector<RawImage> StampCreator::get_coadded_stamps_gpu(ImageStack& stack,
return results;
}

std::vector<RawImage> StampCreator::create_variance_stamps(ImageStack& stack, const Trajectory& trj,
std::vector<RawImage> create_variance_stamps(ImageStack& stack, const Trajectory& trj,
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
int radius, const std::vector<bool>& use_index) {
if (use_index.size() > 0)
assert_sizes_equal(use_index.size(), stack.img_count(), "create_stamps() use_index");
Expand All @@ -295,7 +293,7 @@ std::vector<RawImage> StampCreator::create_variance_stamps(ImageStack& stack, co
return stamps;
}

RawImage StampCreator::get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
RawImage get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
maxwest-uw marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<bool>& use_index) {
if (radius < 0) throw std::runtime_error("Invalid stamp radius. Must be >= 0.");
unsigned int num_images = stack.img_count();
Expand Down Expand Up @@ -337,22 +335,18 @@ RawImage StampCreator::get_variance_weighted_stamp(ImageStack& stack, const Traj

#ifdef Py_PYTHON_H
static void stamp_creator_bindings(py::module& m) {
using sc = search::StampCreator;

py::class_<sc>(m, "StampCreator", pydocs::DOC_StampCreator)
.def(py::init<>())
.def_static("get_stamps", &sc::get_stamps, pydocs::DOC_StampCreator_get_stamps)
.def_static("get_median_stamp", &sc::get_median_stamp, pydocs::DOC_StampCreator_get_median_stamp)
.def_static("get_mean_stamp", &sc::get_mean_stamp, pydocs::DOC_StampCreator_get_mean_stamp)
.def_static("get_summed_stamp", &sc::get_summed_stamp, pydocs::DOC_StampCreator_get_summed_stamp)
.def_static("get_coadded_stamps", &sc::get_coadded_stamps,
pydocs::DOC_StampCreator_get_coadded_stamps)
.def_static("get_variance_weighted_stamp", &sc::get_variance_weighted_stamp,
pydocs::DOC_StampCreator_get_variance_weighted_stamp)
.def_static("create_stamps", &sc::create_stamps, pydocs::DOC_StampCreator_create_stamps)
.def_static("create_variance_stamps", &sc::create_variance_stamps,
pydocs::DOC_StampCreator_create_variance_stamps)
.def_static("filter_stamp", &sc::filter_stamp, pydocs::DOC_StampCreator_filter_stamp);
m.def("get_stamps", &search::get_stamps, pydocs::DOC_StampCreator_get_stamps);
m.def("get_median_stamp", &search::get_median_stamp, pydocs::DOC_StampCreator_get_median_stamp);
m.def("get_mean_stamp", &search::get_mean_stamp, pydocs::DOC_StampCreator_get_mean_stamp);
m.def("get_summed_stamp", &search::get_summed_stamp, pydocs::DOC_StampCreator_get_summed_stamp);
m.def("get_coadded_stamps", &search::get_coadded_stamps,
pydocs::DOC_StampCreator_get_coadded_stamps);
m.def("get_variance_weighted_stamp", &search::get_variance_weighted_stamp,
pydocs::DOC_StampCreator_get_variance_weighted_stamp);
m.def("create_stamps", &search::create_stamps, pydocs::DOC_StampCreator_create_stamps);
m.def("create_variance_stamps", &search::create_variance_stamps,
pydocs::DOC_StampCreator_create_variance_stamps);
m.def("filter_stamp", &search::filter_stamp, pydocs::DOC_StampCreator_filter_stamp);
}
#endif /* Py_PYTHON_H */

Expand Down
87 changes: 40 additions & 47 deletions src/kbmod/search/stamp_creator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,58 +11,51 @@ namespace search {
* Utility class for functions used for creating science stamps for
* filtering, visualization, etc.
*/
class StampCreator {
public:
StampCreator();

// Functions science stamps for filtering, visualization, etc. User can specify
// the radius of the stamp, whether to keep no data values (e.g. NaN) or replace
// them with zero, and what indices to use.
// The indices to use are indicated by use_index: a vector<bool> indicating whether to use
// each time step. An empty (size=0) vector will use all time steps.
static std::vector<RawImage> create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
bool keep_no_data, const std::vector<bool>& use_index);

static std::vector<RawImage> get_stamps(ImageStack& stack, const Trajectory& t, int radius);

static RawImage get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

static RawImage get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

static RawImage get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

// Compute a mean or summed stamp for each trajectory on the GPU or CPU.
// The GPU implementation is slower for small numbers of trajectories (< 500), but performs
// relatively better as the number of trajectories increases. If filtering is applied then
// the code will return a 1x1 image with NO_DATA to represent each filtered image.
static std::vector<RawImage> get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
// Functions science stamps for filtering, visualization, etc. User can specify
// the radius of the stamp, whether to keep no data values (e.g. NaN) or replace
// them with zero, and what indices to use.
// The indices to use are indicated by use_index: a vector<bool> indicating whether to use
// each time step. An empty (size=0) vector will use all time steps.
static std::vector<RawImage> create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
bool keep_no_data, const std::vector<bool>& use_index);

static std::vector<RawImage> get_stamps(ImageStack& stack, const Trajectory& t, int radius);

static RawImage get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

static RawImage get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

static RawImage get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

// Compute a mean or summed stamp for each trajectory on the GPU or CPU.
// The GPU implementation is slower for small numbers of trajectories (< 500), but performs
// relatively better as the number of trajectories increases. If filtering is applied then
// the code will return a 1x1 image with NO_DATA to represent each filtered image.
static std::vector<RawImage> get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool> >& use_index_vect,
const StampParameters& params, bool use_gpu);

static std::vector<RawImage> get_coadded_stamps_gpu(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool> >& use_index_vect,
const StampParameters& params, bool use_gpu);

static std::vector<RawImage> get_coadded_stamps_gpu(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool> >& use_index_vect,
const StampParameters& params);
const StampParameters& params);

static std::vector<RawImage> get_coadded_stamps_cpu(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool> >& use_index_vect,
const StampParameters& params);

// Function to do the actual stamp filtering.
static bool filter_stamp(const RawImage& img, const StampParameters& params);
static std::vector<RawImage> get_coadded_stamps_cpu(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool> >& use_index_vect,
const StampParameters& params);

// Function for generating variance stamps. All times are returned and NO_DATA values are preserved.
static std::vector<RawImage> create_variance_stamps(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);
// Function to do the actual stamp filtering.
static bool filter_stamp(const RawImage& img, const StampParameters& params);

// Function for generating variance weighted stamps. All times are used and NO_DATA values are skipped.
static RawImage get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);
// Function for generating variance stamps. All times are returned and NO_DATA values are preserved.
static std::vector<RawImage> create_variance_stamps(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

virtual ~StampCreator(){};
};
// Function for generating variance weighted stamps. All times are used and NO_DATA values are skipped.
static RawImage get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index);

} /* namespace search */

Expand Down
2 changes: 1 addition & 1 deletion src/kbmod/trajectory_explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from kbmod.configuration import SearchConfiguration
from kbmod.filters.sigma_g_filter import apply_clipped_sigma_g, SigmaGClipping
from kbmod.results import Results
from kbmod.search import StackSearch, StampCreator, Logging
from kbmod.search import StackSearch, Logging
from kbmod.filters.stamp_filters import append_all_stamps, append_coadds
from kbmod.trajectory_utils import make_trajectory_from_ra_dec

Expand Down
Loading