Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorize match_trajectory_sets #764

Merged
merged 2 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from astropy.table import Table, vstack

from kbmod.trajectory_utils import trajectory_from_np_object
from kbmod.trajectory_utils import trajectory_from_np_object, trajectories_to_dict
from kbmod.search import Trajectory
from kbmod.wcs_utils import deserialize_wcs, serialize_wcs

Expand Down Expand Up @@ -148,20 +148,8 @@ def from_trajectories(cls, trajectories, track_filtered=False):
track_filtered : `bool`
Indicates whether to track future filtered points.
"""
# Create dictionaries for the required columns.
input_d = {}
for col in cls.required_cols:
input_d[col[0]] = []

# Add the trajectories to the table.
for trj in trajectories:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
input_d["vx"].append(trj.vx)
input_d["vy"].append(trj.vy)
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
# Create dictionaries from the Trajectories.
input_d = trajectories_to_dict(trajectories)

# Check for any missing columns and fill in the default value.
for col in cls.required_cols:
Expand Down
67 changes: 62 additions & 5 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,53 @@ def trajectory_from_np_object(result):
return trj


def trajectories_to_dict(trj_list):
"""Create a dictionary of trajectory related information
from a list of Trajectory objects.

Parameters
----------
trj_list : `list`
The list of Trajectory objects.

Returns
-------
trj_dict : `Trajectory`
The corresponding trajectory object.
"""
# Create the lists to fill.
num_trjs = len(trj_list)
x0 = [0] * num_trjs
y0 = [0] * num_trjs
vx = [0.0] * num_trjs
vy = [0.0] * num_trjs
lh = [0.0] * num_trjs
flux = [0.0] * num_trjs
obs_count = [0] * num_trjs

# Extract the values from each Trajectory object.
for idx, trj in enumerate(trj_list):
x0[idx] = trj.x
y0[idx] = trj.y
vx[idx] = trj.vx
vy[idx] = trj.vy
lh[idx] = trj.lh
flux[idx] = trj.flux
obs_count[idx] = trj.obs_count

# Store the lists in a dictionary and return that.
trj_dict = {
"x": x0,
"y": y0,
"vx": vx,
"vy": vy,
"likelihood": lh,
"flux": flux,
"obs_count": obs_count,
}
return trj_dict


def trajectory_from_dict(trj_dict):
"""Create a trajectory from a dictionary of the parameters.

Expand Down Expand Up @@ -342,12 +389,22 @@ def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]):
num_query = len(traj_query)
num_base = len(traj_base)

# Compute the matrix of distances between each pair. If this double FOR loop
# becomes a bottleneck, we can vectorize.
# Predict the x and y positions for the base trajectories at each time (using the vectorized functions).
base_info = trajectories_to_dict(traj_base)
base_px = predict_pixel_locations(times, base_info["x"], base_info["vx"], centered=False, as_int=False)
base_py = predict_pixel_locations(times, base_info["y"], base_info["vy"], centered=False, as_int=False)

# Compute the matrix of distances between each pair.
dists = np.zeros((num_query, num_base))
for q_idx in range(num_query):
for b_idx in range(num_base):
dists[q_idx][b_idx] = avg_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times)
for q_idx, q_trj in enumerate(traj_query):
# Compute the query point locations at all times.
q_px = q_trj.x + times * q_trj.vx
q_py = q_trj.y + times * q_trj.vy

# Compute the average distance with each of the base predictions.
dx = q_px[np.newaxis, :] - base_px
dy = q_py[np.newaxis, :] - base_py
dists[q_idx, :] = np.mean(np.sqrt(dx**2 + dy**2), axis=1)

# Use scipy to solve the optimal bipartite matching problem.
row_inds, col_inds = linear_sum_assignment(dists)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def test_fit_trajectory_from_pixels(self):
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0, 2.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0], [1.0, 2.0])

def test_trajectories_to_dict(self):
trj_list = [
Trajectory(x=0, y=1, vx=2.0, vy=3.0, lh=4.0, flux=5.0, obs_count=6),
Trajectory(x=10, y=11, vx=12.0, vy=13.0, lh=14.0, flux=15.0, obs_count=16),
Trajectory(x=20, y=21, vx=22.0, vy=23.0, lh=24.0, flux=25.0, obs_count=26),
]

trj_dict = trajectories_to_dict(trj_list)
self.assertTrue(np.array_equal(trj_dict["x"], [0, 10, 20]))
self.assertTrue(np.array_equal(trj_dict["y"], [1, 11, 21]))
self.assertTrue(np.array_equal(trj_dict["vx"], [2.0, 12.0, 22.0]))
self.assertTrue(np.array_equal(trj_dict["vy"], [3.0, 13.0, 23.0]))
self.assertTrue(np.array_equal(trj_dict["likelihood"], [4.0, 14.0, 24.0]))
self.assertTrue(np.array_equal(trj_dict["flux"], [5.0, 15.0, 25.0]))
self.assertTrue(np.array_equal(trj_dict["obs_count"], [6, 16, 26]))

def test_evaluate_trajectory_mse(self):
trj = Trajectory(x=5, y=4, vx=2.0, vy=-1.0)
x_vals = np.array([5.5, 7.5, 9.7, 11.5])
Expand Down