Skip to content

Commit

Permalink
functionalize and remove the StampCreator class (#761)
Browse files Browse the repository at this point in the history
* functionalize and remove the StampCreator class

* black formatting

* address comments from #761
  • Loading branch information
maxwest-uw authored Dec 18, 2024
1 parent 438766f commit 3361d2a
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 177 deletions.
4 changes: 1 addition & 3 deletions benchmarks/bench_filter_stamps.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
StampParameters,
StampType,
Trajectory,
StampCreator,
)


Expand Down Expand Up @@ -55,10 +54,9 @@ def run_search_benchmark(params):

# Create an empty search stack.
# im_stack = ImageStack([])
sc = StampCreator()

# Do the timing runs.
tmr = timeit.Timer(stmt="sc.filter_stamp(stamp, params)", globals=locals())
tmr = timeit.Timer(stmt="filter_stamp(stamp, params)", globals=locals())
res_time = np.mean(tmr.repeat(repeat=10, number=20))
return res_time

Expand Down
2 changes: 1 addition & 1 deletion notebooks/kbmod_search_results_for_fakes.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
"import os\n",
"\n",
"from kbmod.analysis.plotting import *\n",
"from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n",
"from kbmod.search import ImageStack, PSF, Trajectory\n",
"from kbmod.results import Results\n",
"from kbmod.work_unit import WorkUnit\n",
"\n",
Expand Down
20 changes: 14 additions & 6 deletions notebooks/kbmod_visualize.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,15 @@
"\n",
"from kbmod.analysis.plotting import *\n",
"from kbmod.util_functions import load_deccam_layered_image\n",
"from kbmod.search import ImageStack, PSF, StampCreator, Trajectory\n",
"from kbmod.search import (\n",
" ImageStack,\n",
" PSF,\n",
" Trajectory,\n",
" get_stamps,\n",
" get_median_stamp,\n",
" get_mean_stamp,\n",
" get_summed_stamp,\n",
")\n",
"from kbmod.results import Results\n",
"\n",
"# Data paths\n",
Expand Down Expand Up @@ -113,7 +121,7 @@
"source": [
"# Create and Visualize Stamps\n",
"\n",
"Stamps are a critical tool for analyzing and debugging proposed detections. They can be created automatically using the `StampCreator` class. It requires a few pieces of data:\n",
"Stamps are a critical tool for analyzing and debugging proposed detections. They can be created automatically using the stamp creation utilities. It requires a few pieces of data:\n",
"* search_stack - provides the machinery for making predictions on the image (needed to handle the various corrections).\n",
"* trajectory - Contains the information about where to place the stamps (the underlying trajectory).\n",
"* stamp_radius - The radius in pixels."
Expand All @@ -133,7 +141,7 @@
"trj.vy = 3.3\n",
"\n",
"# Create the stamps around this trajectory.\n",
"stamps = StampCreator.get_stamps(stack, trj, 20)"
"stamps = get_stamps(stack, trj, 20)"
]
},
{
Expand Down Expand Up @@ -174,9 +182,9 @@
"\n",
"plot_multiple_images(\n",
" [\n",
" StampCreator.get_summed_stamp(stack, trj, 10, inds),\n",
" StampCreator.get_mean_stamp(stack, trj, 10, inds),\n",
" StampCreator.get_median_stamp(stack, trj, 10, inds),\n",
" get_summed_stamp(stack, trj, 10, inds),\n",
" get_mean_stamp(stack, trj, 10, inds),\n",
" get_median_stamp(stack, trj, 10, inds),\n",
" ],\n",
" labels=[\"Summed\", \"Mean\", \"Median\"],\n",
" norm=True,\n",
Expand Down
4 changes: 2 additions & 2 deletions src/kbmod/analysis/visualizer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from kbmod.analysis.plotting import plot_multiple_images
from kbmod.search import StampCreator
from kbmod.search import get_stamps
from kbmod.util_functions import mjd_to_day

import numpy as np
Expand Down Expand Up @@ -33,7 +33,7 @@ def generate_all_stamps(self, radius=10):
radius of the stamp.
"""
self.results.table["all_stamps"] = [
StampCreator.get_stamps(self.im_stack, trj, radius) for trj in self.trajectories
get_stamps(self.im_stack, trj, radius) for trj in self.trajectories
]

def count_num_days(self):
Expand Down
7 changes: 4 additions & 3 deletions src/kbmod/filters/stamp_filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
DebugTimer,
ImageStack,
RawImage,
StampCreator,
StampParameters,
StampType,
Logging,
get_stamps,
get_coadded_stamps,
)


Expand Down Expand Up @@ -141,7 +142,7 @@ def get_coadds_and_filter_results(result_data, im_stack, stamp_params, chunk_siz

# Create and filter the results, using the GPU if there is one and enough
# trajectories to make it worthwhile.
stamps_slice = StampCreator.get_coadded_stamps(
stamps_slice = get_coadded_stamps(
im_stack,
trj_slice,
bool_slice,
Expand Down Expand Up @@ -240,7 +241,7 @@ def append_all_stamps(result_data, im_stack, stamp_radius):

all_stamps = []
for trj in result_data.make_trajectory_list():
stamps = StampCreator.get_stamps(im_stack, trj, stamp_radius)
stamps = get_stamps(im_stack, trj, stamp_radius)
all_stamps.append(np.array([stamp.image for stamp in stamps]))

# We add the column even if it is empty so we can have consistent
Expand Down
5 changes: 0 additions & 5 deletions src/kbmod/search/pydocs/stamp_creator_docs.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,6 @@
#define STAMP_CREATOR_DOCS

namespace pydocs {
static const auto DOC_StampCreator = R"doc(
A class for creating a set of stamps or a co-added stamp
from an ImageStack and Trajectory.
)doc";

static const auto DOC_StampCreator_create_stamps = R"doc(
Create a vector of stamps centered on the predicted position
of an Trajectory at different times.
Expand Down
80 changes: 37 additions & 43 deletions src/kbmod/search/stamp_creator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@ void deviceGetCoadds(const uint64_t num_images, const uint64_t width, const uint
GPUArray<int>& use_index_vect, GPUArray<float>& results);
#endif

StampCreator::StampCreator() {}

std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
bool keep_no_data, const std::vector<bool>& use_index) {
std::vector<RawImage> create_stamps(ImageStack& stack, const Trajectory& trj, int radius,
bool keep_no_data, const std::vector<bool>& use_index) {
if (use_index.size() > 0)
assert_sizes_equal(use_index.size(), stack.img_count(), "create_stamps() use_index");
bool use_all_stamps = (use_index.size() == 0);
Expand All @@ -35,33 +33,33 @@ std::vector<RawImage> StampCreator::create_stamps(ImageStack& stack, const Traje
// For stamps used for visualization we replace invalid pixels with zeros
// and return all the stamps (regardless of whether individual timesteps
// have been filtered).
std::vector<RawImage> StampCreator::get_stamps(ImageStack& stack, const Trajectory& t, int radius) {
std::vector<RawImage> get_stamps(ImageStack& stack, const Trajectory& t, int radius) {
std::vector<bool> empty_vect;
return create_stamps(stack, t, radius, false /*=keep_no_data*/, empty_vect);
}

// For creating coadded stamps, we keep invalid pixels tagged (so we can filter it out of mean/median).
RawImage StampCreator::get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
RawImage get_median_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
return create_median_image(create_stamps(stack, trj, radius, true /*=keep_no_data*/, use_index));
}

// For creating coadded stamps, we keep invalid pixels tagged (so we can filter it out of mean/median).
RawImage StampCreator::get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
RawImage get_mean_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
return create_mean_image(create_stamps(stack, trj, radius, true /*=keep_no_data*/, use_index));
}

// For creating summed stamps, we replace invalid pixels with zero (which is the same as
// filtering it out for the sum).
RawImage StampCreator::get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
RawImage get_summed_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
return create_summed_image(create_stamps(stack, trj, radius, false /*=keep_no_data*/, use_index));
}

std::vector<RawImage> StampCreator::get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params, bool use_gpu) {
std::vector<RawImage> get_coadded_stamps(ImageStack& stack, std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params, bool use_gpu) {
logging::Logger* rs_logger = logging::getLogger("kbmod.search.stamp_creator");
rs_logger->info("Generating co_added stamps on " + std::to_string(t_array.size()) + " trajectories.");
DebugTimer timer = DebugTimer("coadd generating", rs_logger);
Expand All @@ -78,10 +76,10 @@ std::vector<RawImage> StampCreator::get_coadded_stamps(ImageStack& stack, std::v
return get_coadded_stamps_cpu(stack, t_array, use_index_vect, params);
}

std::vector<RawImage> StampCreator::get_coadded_stamps_cpu(ImageStack& stack,
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
std::vector<RawImage> get_coadded_stamps_cpu(ImageStack& stack,
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
const uint64_t num_trajectories = t_array.size();
std::vector<RawImage> results(num_trajectories);

Expand Down Expand Up @@ -115,7 +113,7 @@ std::vector<RawImage> StampCreator::get_coadded_stamps_cpu(ImageStack& stack,
return results;
}

bool StampCreator::filter_stamp(const RawImage& img, const StampParameters& params) {
bool filter_stamp(const RawImage& img, const StampParameters& params) {
if (params.radius <= 0) throw std::runtime_error("Invalid stamp radius=" + std::to_string(params.radius));

// Allocate space for the coadd information and initialize to zero.
Expand Down Expand Up @@ -157,10 +155,10 @@ bool StampCreator::filter_stamp(const RawImage& img, const StampParameters& para
return false;
}

std::vector<RawImage> StampCreator::get_coadded_stamps_gpu(ImageStack& stack,
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
std::vector<RawImage> get_coadded_stamps_gpu(ImageStack& stack,
std::vector<Trajectory>& t_array,
std::vector<std::vector<bool>>& use_index_vect,
const StampParameters& params) {
logging::Logger* rs_logger = logging::getLogger("kbmod.search.stamp_creator");

// Right now only limited stamp sizes are allowed.
Expand Down Expand Up @@ -271,8 +269,8 @@ std::vector<RawImage> StampCreator::get_coadded_stamps_gpu(ImageStack& stack,
return results;
}

std::vector<RawImage> StampCreator::create_variance_stamps(ImageStack& stack, const Trajectory& trj,
int radius, const std::vector<bool>& use_index) {
std::vector<RawImage> create_variance_stamps(ImageStack& stack, const Trajectory& trj,
int radius, const std::vector<bool>& use_index) {
if (use_index.size() > 0)
assert_sizes_equal(use_index.size(), stack.img_count(), "create_stamps() use_index");
bool use_all_stamps = (use_index.size() == 0);
Expand All @@ -293,8 +291,8 @@ std::vector<RawImage> StampCreator::create_variance_stamps(ImageStack& stack, co
return stamps;
}

RawImage StampCreator::get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
RawImage get_variance_weighted_stamp(ImageStack& stack, const Trajectory& trj, int radius,
const std::vector<bool>& use_index) {
if (radius < 0) throw std::runtime_error("Invalid stamp radius. Must be >= 0.");
unsigned int num_images = stack.img_count();
if (num_images == 0) throw std::runtime_error("Unable to create mean image given 0 images.");
Expand Down Expand Up @@ -335,22 +333,18 @@ RawImage StampCreator::get_variance_weighted_stamp(ImageStack& stack, const Traj

#ifdef Py_PYTHON_H
static void stamp_creator_bindings(py::module& m) {
using sc = search::StampCreator;

py::class_<sc>(m, "StampCreator", pydocs::DOC_StampCreator)
.def(py::init<>())
.def_static("get_stamps", &sc::get_stamps, pydocs::DOC_StampCreator_get_stamps)
.def_static("get_median_stamp", &sc::get_median_stamp, pydocs::DOC_StampCreator_get_median_stamp)
.def_static("get_mean_stamp", &sc::get_mean_stamp, pydocs::DOC_StampCreator_get_mean_stamp)
.def_static("get_summed_stamp", &sc::get_summed_stamp, pydocs::DOC_StampCreator_get_summed_stamp)
.def_static("get_coadded_stamps", &sc::get_coadded_stamps,
pydocs::DOC_StampCreator_get_coadded_stamps)
.def_static("get_variance_weighted_stamp", &sc::get_variance_weighted_stamp,
pydocs::DOC_StampCreator_get_variance_weighted_stamp)
.def_static("create_stamps", &sc::create_stamps, pydocs::DOC_StampCreator_create_stamps)
.def_static("create_variance_stamps", &sc::create_variance_stamps,
pydocs::DOC_StampCreator_create_variance_stamps)
.def_static("filter_stamp", &sc::filter_stamp, pydocs::DOC_StampCreator_filter_stamp);
m.def("get_stamps", &search::get_stamps, pydocs::DOC_StampCreator_get_stamps);
m.def("get_median_stamp", &search::get_median_stamp, pydocs::DOC_StampCreator_get_median_stamp);
m.def("get_mean_stamp", &search::get_mean_stamp, pydocs::DOC_StampCreator_get_mean_stamp);
m.def("get_summed_stamp", &search::get_summed_stamp, pydocs::DOC_StampCreator_get_summed_stamp);
m.def("get_coadded_stamps", &search::get_coadded_stamps,
pydocs::DOC_StampCreator_get_coadded_stamps);
m.def("get_variance_weighted_stamp", &search::get_variance_weighted_stamp,
pydocs::DOC_StampCreator_get_variance_weighted_stamp);
m.def("create_stamps", &search::create_stamps, pydocs::DOC_StampCreator_create_stamps);
m.def("create_variance_stamps", &search::create_variance_stamps,
pydocs::DOC_StampCreator_create_variance_stamps);
m.def("filter_stamp", &search::filter_stamp, pydocs::DOC_StampCreator_filter_stamp);
}
#endif /* Py_PYTHON_H */

Expand Down
Loading

0 comments on commit 3361d2a

Please sign in to comment.