Skip to content

Commit

Permalink
Vectorize match_trajectory_sets
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremykubica committed Dec 18, 2024
1 parent 38adedf commit 6a8a8bf
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/kbmod/trajectory_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,12 +389,22 @@ def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]):
num_query = len(traj_query)
num_base = len(traj_base)

# Compute the matrix of distances between each pair. If this double FOR loop
# becomes a bottleneck, we can vectorize.
# Predict the x and y positions for the base trajectories at each time (using the vectorized functions).
base_info = trajectories_to_dict(traj_base)
base_px = predict_pixel_locations(times, base_info["x"], base_info["vx"], centered=False, as_int=False)
base_py = predict_pixel_locations(times, base_info["y"], base_info["vy"], centered=False, as_int=False)

# Compute the matrix of distances between each pair.
dists = np.zeros((num_query, num_base))
for q_idx in range(num_query):
for b_idx in range(num_base):
dists[q_idx][b_idx] = avg_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times)
for q_idx, q_trj in enumerate(traj_query):
# Compute the query point locations at all times.
q_px = q_trj.x + times * q_trj.vx
q_py = q_trj.y + times * q_trj.vy

# Compute the average distance with each of the base predictions.
dx = q_px[np.newaxis, :] - base_px
dy = q_py[np.newaxis, :] - base_py
dists[q_idx, :] = np.mean(np.sqrt(dx**2 + dy**2), axis=1)

# Use scipy to solve the optimal bipartite matching problem.
row_inds, col_inds = linear_sum_assignment(dists)
Expand Down

0 comments on commit 6a8a8bf

Please sign in to comment.