From 6a8a8bf7a247af08ba8d5af99cbb320a04d64642 Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Wed, 18 Dec 2024 16:32:38 -0500 Subject: [PATCH] Vectorize match_trajectory_sets --- src/kbmod/trajectory_utils.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index 7ead6f60..195e153b 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -389,12 +389,22 @@ def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]): num_query = len(traj_query) num_base = len(traj_base) - # Compute the matrix of distances between each pair. If this double FOR loop - # becomes a bottleneck, we can vectorize. + # Predict the x and y positions for the base trajectories at each time (using the vectorized functions). + base_info = trajectories_to_dict(traj_base) + base_px = predict_pixel_locations(times, base_info["x"], base_info["vx"], centered=False, as_int=False) + base_py = predict_pixel_locations(times, base_info["y"], base_info["vy"], centered=False, as_int=False) + + # Compute the matrix of distances between each pair. dists = np.zeros((num_query, num_base)) - for q_idx in range(num_query): - for b_idx in range(num_base): - dists[q_idx][b_idx] = avg_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times) + for q_idx, q_trj in enumerate(traj_query): + # Compute the query point locations at all times. + q_px = q_trj.x + times * q_trj.vx + q_py = q_trj.y + times * q_trj.vy + + # Compute the average distance with each of the base predictions. + dx = q_px[np.newaxis, :] - base_px + dy = q_py[np.newaxis, :] - base_py + dists[q_idx, :] = np.mean(np.sqrt(dx**2 + dy**2), axis=1) # Use scipy to solve the optimal bipartite matching problem. row_inds, col_inds = linear_sum_assignment(dists)