diff --git a/src/kbmod/results.py b/src/kbmod/results.py index dd13df1f..1ac11d45 100644 --- a/src/kbmod/results.py +++ b/src/kbmod/results.py @@ -11,7 +11,7 @@ from astropy.table import Table, vstack -from kbmod.trajectory_utils import trajectory_from_np_object +from kbmod.trajectory_utils import trajectory_from_np_object, trajectories_to_dict from kbmod.search import Trajectory from kbmod.wcs_utils import deserialize_wcs, serialize_wcs @@ -148,20 +148,8 @@ def from_trajectories(cls, trajectories, track_filtered=False): track_filtered : `bool` Indicates whether to track future filtered points. """ - # Create dictionaries for the required columns. - input_d = {} - for col in cls.required_cols: - input_d[col[0]] = [] - - # Add the trajectories to the table. - for trj in trajectories: - input_d["x"].append(trj.x) - input_d["y"].append(trj.y) - input_d["vx"].append(trj.vx) - input_d["vy"].append(trj.vy) - input_d["likelihood"].append(trj.lh) - input_d["flux"].append(trj.flux) - input_d["obs_count"].append(trj.obs_count) + # Create dictionaries from the Trajectories. + input_d = trajectories_to_dict(trajectories) # Check for any missing columns and fill in the default value. for col in cls.required_cols: diff --git a/src/kbmod/trajectory_utils.py b/src/kbmod/trajectory_utils.py index 3b0265fc..195e153b 100644 --- a/src/kbmod/trajectory_utils.py +++ b/src/kbmod/trajectory_utils.py @@ -159,6 +159,53 @@ def trajectory_from_np_object(result): return trj +def trajectories_to_dict(trj_list): + """Create a dictionary of trajectory related information + from a list of Trajectory objects. + + Parameters + ---------- + trj_list : `list` + The list of Trajectory objects. + + Returns + ------- + trj_dict : `Trajectory` + The corresponding trajectory object. + """ + # Create the lists to fill. + num_trjs = len(trj_list) + x0 = [0] * num_trjs + y0 = [0] * num_trjs + vx = [0.0] * num_trjs + vy = [0.0] * num_trjs + lh = [0.0] * num_trjs + flux = [0.0] * num_trjs + obs_count = [0] * num_trjs + + # Extract the values from each Trajectory object. + for idx, trj in enumerate(trj_list): + x0[idx] = trj.x + y0[idx] = trj.y + vx[idx] = trj.vx + vy[idx] = trj.vy + lh[idx] = trj.lh + flux[idx] = trj.flux + obs_count[idx] = trj.obs_count + + # Store the lists in a dictionary and return that. + trj_dict = { + "x": x0, + "y": y0, + "vx": vx, + "vy": vy, + "likelihood": lh, + "flux": flux, + "obs_count": obs_count, + } + return trj_dict + + def trajectory_from_dict(trj_dict): """Create a trajectory from a dictionary of the parameters. @@ -342,12 +389,22 @@ def match_trajectory_sets(traj_query, traj_base, threshold, times=[0.0]): num_query = len(traj_query) num_base = len(traj_base) - # Compute the matrix of distances between each pair. If this double FOR loop - # becomes a bottleneck, we can vectorize. + # Predict the x and y positions for the base trajectories at each time (using the vectorized functions). + base_info = trajectories_to_dict(traj_base) + base_px = predict_pixel_locations(times, base_info["x"], base_info["vx"], centered=False, as_int=False) + base_py = predict_pixel_locations(times, base_info["y"], base_info["vy"], centered=False, as_int=False) + + # Compute the matrix of distances between each pair. dists = np.zeros((num_query, num_base)) - for q_idx in range(num_query): - for b_idx in range(num_base): - dists[q_idx][b_idx] = avg_trajectory_distance(traj_query[q_idx], traj_base[b_idx], times) + for q_idx, q_trj in enumerate(traj_query): + # Compute the query point locations at all times. + q_px = q_trj.x + times * q_trj.vx + q_py = q_trj.y + times * q_trj.vy + + # Compute the average distance with each of the base predictions. + dx = q_px[np.newaxis, :] - base_px + dy = q_py[np.newaxis, :] - base_py + dists[q_idx, :] = np.mean(np.sqrt(dx**2 + dy**2), axis=1) # Use scipy to solve the optimal bipartite matching problem. row_inds, col_inds = linear_sum_assignment(dists) diff --git a/tests/test_trajectory_utils.py b/tests/test_trajectory_utils.py index 88410e2c..9a1d6f91 100644 --- a/tests/test_trajectory_utils.py +++ b/tests/test_trajectory_utils.py @@ -139,6 +139,22 @@ def test_fit_trajectory_from_pixels(self): self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0, 2.0], [1.0]) self.assertRaises(ValueError, fit_trajectory_from_pixels, [1.0, 2.0], [1.0], [1.0, 2.0]) + def test_trajectories_to_dict(self): + trj_list = [ + Trajectory(x=0, y=1, vx=2.0, vy=3.0, lh=4.0, flux=5.0, obs_count=6), + Trajectory(x=10, y=11, vx=12.0, vy=13.0, lh=14.0, flux=15.0, obs_count=16), + Trajectory(x=20, y=21, vx=22.0, vy=23.0, lh=24.0, flux=25.0, obs_count=26), + ] + + trj_dict = trajectories_to_dict(trj_list) + self.assertTrue(np.array_equal(trj_dict["x"], [0, 10, 20])) + self.assertTrue(np.array_equal(trj_dict["y"], [1, 11, 21])) + self.assertTrue(np.array_equal(trj_dict["vx"], [2.0, 12.0, 22.0])) + self.assertTrue(np.array_equal(trj_dict["vy"], [3.0, 13.0, 23.0])) + self.assertTrue(np.array_equal(trj_dict["likelihood"], [4.0, 14.0, 24.0])) + self.assertTrue(np.array_equal(trj_dict["flux"], [5.0, 15.0, 25.0])) + self.assertTrue(np.array_equal(trj_dict["obs_count"], [6, 16, 26])) + def test_evaluate_trajectory_mse(self): trj = Trajectory(x=5, y=4, vx=2.0, vy=-1.0) x_vals = np.array([5.5, 7.5, 9.7, 11.5])