Skip to content

Commit

Permalink
Merge pull request #764 from dirac-institute/testing_helpers
Browse files Browse the repository at this point in the history
Vectorize match_trajectory_sets
  • Loading branch information
jeremykubica authored Dec 18, 2024
2 parents 3361d2a + 6a8a8bf commit 03395e1
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 20 deletions.
18 changes: 3 additions & 15 deletions src/kbmod/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from astropy.table import Table, vstack

from kbmod.trajectory_utils import trajectory_from_np_object
from kbmod.trajectory_utils import trajectory_from_np_object, trajectories_to_dict
from kbmod.search import Trajectory
from kbmod.wcs_utils import deserialize_wcs, serialize_wcs

Expand Down Expand Up @@ -148,20 +148,8 @@ def from_trajectories(cls, trajectories, track_filtered=False):
track_filtered : `bool`
Indicates whether to track future filtered points.
"""
# Create dictionaries for the required columns.
input_d = {}
for col in cls.required_cols:
input_d[col[0]] = []

# Add the trajectories to the table.
for trj in trajectories:
input_d["x"].append(trj.x)
input_d["y"].append(trj.y)
input_d["vx"].append(trj.vx)
input_d["vy"].append(trj.vy)
input_d["likelihood"].append(trj.lh)
input_d["flux"].append(trj.flux)
input_d["obs_count"].append(trj.obs_count)
# Create dictionaries from the Trajectories.
input_d = trajectories_to_dict(trajectories)

# Check for any missing columns and fill in the default value.
for col in cls.required_cols:
Expand Down
67 changes: 62 additions & 5 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,53 @@ def trajectory_from_np_object(result):
return trj


def trajectories_to_dict(trj_list):
"""Create a dictionary of trajectory related information
from a list of Trajectory objects.
Parameters
----------
trj_list : `list`
The list of Trajectory objects.
Returns
-------
trj_dict : `Trajectory`
The corresponding trajectory object.
"""
# Create the lists to fill.
num_trjs = len(trj_list)
x0 = [0] * num_trjs
y0 = [0] * num_trjs
vx = [0.0] * num_trjs
vy = [0.0] * num_trjs
lh = [0.0] * num_trjs
flux = [0.0] * num_trjs
obs_count = [0] * num_trjs

# Extract the values from each Trajectory object.
for idx, trj in enumerate(trj_list):
x0[idx] = trj.x
y0[idx] = trj.y
vx[idx] = trj.vx
vy[idx] = trj.vy
lh[idx] = trj.lh
flux[idx] = trj.flux
obs_count[idx] = trj.obs_count

# Store the lists in a dictionary and return that.
trj_dict = {
"x": x0,
"y": y0,
"vx": vx,
"vy": vy,
"likelihood": lh,
"flux": flux,
"obs_count": obs_count,
}
return trj_dict


def trajectory_from_dict(trj_dict):
"""Create a trajectory from a dictionary of the parameters.
Expand Down Expand Up @@ -342,12 +389,22 @@ def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]):
num_query = len(traj_query)
num_base = len(traj_base)

# Compute the matrix of distances between each pair. If this double FOR loop
# becomes a bottleneck, we can vectorize.
# Predict the x and y positions for the base trajectories at each time (using the vectorized functions).
base_info = trajectories_to_dict(traj_base)
base_px = predict_pixel_locations(times, base_info["x"], base_info["vx"], centered=False, as_int=False)
base_py = predict_pixel_locations(times, base_info["y"], base_info["vy"], centered=False, as_int=False)

# Compute the matrix of distances between each pair.
dists = np.zeros((num_query, num_base))
for q_idx in range(num_query):
for b_idx in range(num_base):
dists[q_idx][b_idx] = avg_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times)
for q_idx, q_trj in enumerate(traj_query):
# Compute the query point locations at all times.
q_px = q_trj.x + times * q_trj.vx
q_py = q_trj.y + times * q_trj.vy

# Compute the average distance with each of the base predictions.
dx = q_px[np.newaxis, :] - base_px
dy = q_py[np.newaxis, :] - base_py
dists[q_idx, :] = np.mean(np.sqrt(dx**2 + dy**2), axis=1)

# Use scipy to solve the optimal bipartite matching problem.
row_inds, col_inds = linear_sum_assignment(dists)
Expand Down
16 changes: 16 additions & 0 deletions tests/test_trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,22 @@ def test_fit_trajectory_from_pixels(self):
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0, 2.0], [1.0])
self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0], [1.0, 2.0])

def test_trajectories_to_dict(self):
trj_list = [
Trajectory(x=0, y=1, vx=2.0, vy=3.0, lh=4.0, flux=5.0, obs_count=6),
Trajectory(x=10, y=11, vx=12.0, vy=13.0, lh=14.0, flux=15.0, obs_count=16),
Trajectory(x=20, y=21, vx=22.0, vy=23.0, lh=24.0, flux=25.0, obs_count=26),
]

trj_dict = trajectories_to_dict(trj_list)
self.assertTrue(np.array_equal(trj_dict["x"], [0, 10, 20]))
self.assertTrue(np.array_equal(trj_dict["y"], [1, 11, 21]))
self.assertTrue(np.array_equal(trj_dict["vx"], [2.0, 12.0, 22.0]))
self.assertTrue(np.array_equal(trj_dict["vy"], [3.0, 13.0, 23.0]))
self.assertTrue(np.array_equal(trj_dict["likelihood"], [4.0, 14.0, 24.0]))
self.assertTrue(np.array_equal(trj_dict["flux"], [5.0, 15.0, 25.0]))
self.assertTrue(np.array_equal(trj_dict["obs_count"], [6, 16, 26]))

def test_evaluate_trajectory_mse(self):
trj = Trajectory(x=5, y=4, vx=2.0, vy=-1.0)
x_vals = np.array([5.5, 7.5, 9.7, 11.5])
Expand Down

0 comments on commit 03395e1

Please sign in to comment.