diff --git a/benchmarks/benchmarks.py b/benchmarks/benchmarks.py index 382f8f7a..2a23c223 100644 --- a/benchmarks/benchmarks.py +++ b/benchmarks/benchmarks.py @@ -5,6 +5,7 @@ import os import lsdb +from lsdb.core.crossmatch.kdtree_gnomonic_match import KdTreeGnomonicCrossmatch TEST_DIR = os.path.join(os.path.dirname(__file__), "..", "tests") DATA_DIR_NAME = "data" @@ -22,8 +23,15 @@ def load_small_sky_xmatch(): return lsdb.read_hipscat(path, catalog_type=lsdb.Catalog) -def time_crossmatch(): +def time_kdtree_crossmatch(): """Time computations are prefixed with 'time'.""" small_sky = load_small_sky() small_sky_xmatch = load_small_sky_xmatch() small_sky.crossmatch(small_sky_xmatch).compute() + + +def time_kdtree_gnomonic_crossmatch(): + """Time computations are prefixed with 'time'.""" + small_sky = load_small_sky() + small_sky_xmatch = load_small_sky_xmatch() + small_sky.crossmatch(small_sky_xmatch, algorithm=KdTreeGnomonicCrossmatch).compute() diff --git a/pyproject.toml b/pyproject.toml index f801e035..e69d3104 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,6 +21,7 @@ dependencies = [ "deprecated", "ipykernel", # Support for Jupyter notebooks "scikit-learn", + "scipy", # kdtree ] # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) diff --git a/src/lsdb/core/crossmatch/kdtree_gnomonic_match.py b/src/lsdb/core/crossmatch/kdtree_gnomonic_match.py new file mode 100644 index 00000000..69e8497e --- /dev/null +++ b/src/lsdb/core/crossmatch/kdtree_gnomonic_match.py @@ -0,0 +1,162 @@ +import healpy as hp +import numpy as np +import pandas as pd +from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN +from sklearn.neighbors import KDTree + +from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm + + +class KdTreeGnomonicCrossmatch(AbstractCrossmatchAlgorithm): + """Nearest neighbor crossmatch using a K-D Tree""" + + def crossmatch( + self, + n_neighbors: int = 1, + d_thresh: float = 0.01, + ) -> pd.DataFrame: + """Perform a cross-match between the data from two HEALPix pixels + + Finds the n closest neighbors in the right catalog for each point in the left catalog that + are within a threshold distance by using a K-D Tree. + + Args: + n_neighbors (int): The number of neighbors to find within each point + d_thresh (float): The threshold distance in degrees beyond which neighbors are not added + + Returns: + A DataFrame from the left and right tables merged with one row for each pair of + neighbors found from cross-matching. The resulting table contains the columns from the + left table with the first suffix appended, the right columns with the second suffix, and + a column with the name {AbstractCrossmatchAlgorithm.DISTANCE_COLUMN_NAME} with the + great circle separation between the points. + """ + + # get matching indices for cross-matched rows + left_idx, right_idx = self._find_crossmatch_indices(n_neighbors) + + # filter indexes to only include rows with points within the distance threshold + ( + distances, + left_ids_filtered, + right_ids_filtered, + ) = self._filter_indexes_to_threshold(left_idx, right_idx, d_thresh) + + # rename columns so no same names during merging + self._rename_columns_with_suffix(self.left, self.suffixes[0]) + self._rename_columns_with_suffix(self.right, self.suffixes[1]) + + # concat dataframes together + self.left.index.name = HIPSCAT_ID_COLUMN + left_join_part = self.left.iloc[left_ids_filtered].reset_index() + right_join_part = self.right.iloc[right_ids_filtered].reset_index(drop=True) + out = pd.concat( + [ + left_join_part, + right_join_part, + ], + axis=1, + ) + out.set_index(HIPSCAT_ID_COLUMN, inplace=True) + out[self.DISTANCE_COLUMN_NAME] = distances + + return out + + def _find_crossmatch_indices(self, n_neighbors): + # calculate the gnomic distances to use with the KDTree + clon, clat = hp.pix2ang(hp.order2nside(self.left_order), self.left_pixel, nest=True, lonlat=True) + xy1 = _frame_gnomonic(self.left, self.left_metadata.catalog_info, clon, clat) + xy2 = _frame_gnomonic(self.right, self.right_metadata.catalog_info, clon, clat) + # construct the KDTree from the right catalog + tree = KDTree(xy2, leaf_size=2) + # find the indices for the nearest neighbors + # this is the cross-match calculation + n_neighbors = min(n_neighbors, len(xy2)) + _, inds = tree.query(xy1, k=n_neighbors) + # numpy indexing to join the two catalogs + # index of each row in the output table # (0... number of output rows) + out_idx = np.arange(len(self.left) * n_neighbors) + # index of the corresponding row in the left table (0, 0, 0, 1, 1, 1, 2, 2, 2, ...) + left_idx = out_idx // n_neighbors + # index of the corresponding row in the right table (22, 33, 44, 55, 66, ...) + right_idx = inds.ravel() + return left_idx, right_idx + + def _filter_indexes_to_threshold(self, left_idx, right_idx, d_thresh): + """ + Filters indexes to merge dataframes to the points separated by distances within the + threshold + + Returns: + A tuple of (distances, filtered_left_indices, filtered_right_indices) + """ + left_catalog_info = self.left_metadata.catalog_info + right_catalog_info = self.right_metadata.catalog_info + # align radec to indices + left_radec = self.left[[left_catalog_info.ra_column, left_catalog_info.dec_column]] + left_radec_aligned = left_radec.iloc[left_idx] + right_radec = self.right[[right_catalog_info.ra_column, right_catalog_info.dec_column]] + right_radec_aligned = right_radec.iloc[right_idx] + + # store the indices from each row + distances_df = pd.DataFrame.from_dict({"_left_idx": left_idx, "_right_idx": right_idx}) + + # calculate distances of each pair + distances_df[self.DISTANCE_COLUMN_NAME] = _great_circle_dist( + left_radec_aligned[left_catalog_info.ra_column].values, + left_radec_aligned[left_catalog_info.dec_column].values, + right_radec_aligned[right_catalog_info.ra_column].values, + right_radec_aligned[right_catalog_info.dec_column].values, + ) + # cull based on the distance threshold + distances_df = distances_df.loc[distances_df[self.DISTANCE_COLUMN_NAME] < d_thresh] + left_ids_filtered = distances_df["_left_idx"] + right_ids_filtered = distances_df["_right_idx"] + distances = distances_df[self.DISTANCE_COLUMN_NAME].to_numpy() + return distances, left_ids_filtered, right_ids_filtered + + +def _great_circle_dist(lon1, lat1, lon2, lat2): + """ + function that calculates the distance between two points + p1 (lon1, lat1) or (ra1, dec1) + p2 (lon2, lat2) or (ra2, dec2) + + can be np.array() + returns np.array() + """ + lon1 = np.radians(lon1) + lat1 = np.radians(lat1) + lon2 = np.radians(lon2) + lat2 = np.radians(lat2) + + return np.degrees( + 2 + * np.arcsin( + np.sqrt( + (np.sin((lat1 - lat2) * 0.5)) ** 2 + + np.cos(lat1) * np.cos(lat2) * (np.sin((lon1 - lon2) * 0.5)) ** 2 + ) + ) + ) + + +def _frame_gnomonic(data_frame, catalog_info, clon, clat): + """ + method taken from lsd1: + creates a np.array of gnomonic distances for each source in the dataframe + from the center of the ordered pixel. These values are passed into + the kdtree NN query during the xmach routine. + """ + phi = np.radians(data_frame[catalog_info.dec_column].values) + lam = np.radians(data_frame[catalog_info.ra_column].values) + phi1 = np.radians(clat) + lam0 = np.radians(clon) + + cosc = np.sin(phi1) * np.sin(phi) + np.cos(phi1) * np.cos(phi) * np.cos(lam - lam0) + x_projected = np.cos(phi) * np.sin(lam - lam0) / cosc + y_projected = (np.cos(phi1) * np.sin(phi) - np.sin(phi1) * np.cos(phi) * np.cos(lam - lam0)) / cosc + + ret = np.column_stack((np.degrees(x_projected), np.degrees(y_projected))) + del phi, lam, phi1, lam0, cosc, x_projected, y_projected + return ret diff --git a/src/lsdb/core/crossmatch/kdtree_match.py b/src/lsdb/core/crossmatch/kdtree_match.py index aef15e55..2afad980 100644 --- a/src/lsdb/core/crossmatch/kdtree_match.py +++ b/src/lsdb/core/crossmatch/kdtree_match.py @@ -1,14 +1,17 @@ -import healpy as hp +import math +from typing import Tuple + import numpy as np +import numpy.typing as npt import pandas as pd from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN -from sklearn.neighbors import KDTree +from scipy.spatial import KDTree from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm class KdTreeCrossmatch(AbstractCrossmatchAlgorithm): - """Nearest neighbor crossmatch using a K-D Tree""" + """Nearest neighbor crossmatch using a 3D k-D tree""" def crossmatch( self, @@ -28,19 +31,17 @@ def crossmatch( A DataFrame from the left and right tables merged with one row for each pair of neighbors found from cross-matching. The resulting table contains the columns from the left table with the first suffix appended, the right columns with the second suffix, and - a column with the name {AbstractCrossmatchAlgorithm.DISTANCE_COLUMN_NAME} with the - great circle separation between the points. + a column with the name {AbstractCrossmatchAlgorithm.DISTANCE_COLUMN_NAME} with the great + circle separation between the points. """ + # Distance in 3-D space for unit sphere + d_chord = 2.0 * math.sin(math.radians(0.5 * d_thresh)) # get matching indices for cross-matched rows - left_idx, right_idx = self._find_crossmatch_indices(n_neighbors) - - # filter indexes to only include rows with points within the distance threshold - ( - distances, - left_ids_filtered, - right_ids_filtered, - ) = self._filter_indexes_to_threshold(left_idx, right_idx, d_thresh) + chord_distances, left_idx, right_idx = self._find_crossmatch_indices( + n_neighbors=n_neighbors, max_distance=d_chord + ) + arc_distances = np.degrees(2.0 * np.arcsin(0.5 * chord_distances)) # rename columns so no same names during merging self._rename_columns_with_suffix(self.left, self.suffixes[0]) @@ -48,8 +49,8 @@ def crossmatch( # concat dataframes together self.left.index.name = HIPSCAT_ID_COLUMN - left_join_part = self.left.iloc[left_ids_filtered].reset_index() - right_join_part = self.right.iloc[right_ids_filtered].reset_index(drop=True) + left_join_part = self.left.iloc[left_idx].reset_index() + right_join_part = self.right.iloc[right_idx].reset_index(drop=True) out = pd.concat( [ left_join_part, @@ -58,104 +59,62 @@ def crossmatch( axis=1, ) out.set_index(HIPSCAT_ID_COLUMN, inplace=True) - out["_DIST"] = distances + out[self.DISTANCE_COLUMN_NAME] = pd.Series(arc_distances, index=out.index) return out - def _find_crossmatch_indices(self, n_neighbors): - # calculate the gnomic distances to use with the KDTree - clon, clat = hp.pix2ang(hp.order2nside(self.left_order), self.left_pixel, nest=True, lonlat=True) - xy1 = _frame_gnomonic(self.left, self.left_metadata.catalog_info, clon, clat) - xy2 = _frame_gnomonic(self.right, self.right_metadata.catalog_info, clon, clat) + def _find_crossmatch_indices( + self, n_neighbors: int, max_distance: float + ) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.int64], npt.NDArray[np.int64]]: + # calculate the cartesian coordinates of the points + left_xyz = _lon_lat_to_xyz( + lon=self.left[self.left_metadata.catalog_info.ra_column].values, + lat=self.left[self.left_metadata.catalog_info.dec_column].values, + ) + right_xyz = _lon_lat_to_xyz( + lon=self.right[self.right_metadata.catalog_info.ra_column].values, + lat=self.right[self.right_metadata.catalog_info.dec_column].values, + ) + + # Make sure we don't ask for more neighbors than there are points + n_neighbors = min(n_neighbors, len(right_xyz)) + # construct the KDTree from the right catalog - tree = KDTree(xy2, leaf_size=2) + tree = KDTree( + right_xyz, + leafsize=n_neighbors, + compact_nodes=True, + balanced_tree=True, + copy_data=False, + ) + # find the indices for the nearest neighbors # this is the cross-match calculation - _, inds = tree.query(xy1, k=min([n_neighbors, len(xy2)])) - # numpy indexing to join the two catalogs - # index of each row in the output table # (0... number of output rows) - out_idx = np.arange(len(self.left) * n_neighbors) - # index of the corresponding row in the left table (0, 0, 0, 1, 1, 1, 2, 2, 2, ...) - left_idx = out_idx // n_neighbors - # index of the corresponding row in the right table (22, 33, 44, 55, 66, ...) - right_idx = inds.ravel() - return left_idx, right_idx - - def _filter_indexes_to_threshold(self, left_idx, right_idx, d_thresh): - """ - Filters indexes to merge dataframes to the points separated by distances within the - threshold + distances, right_index = tree.query(left_xyz, k=n_neighbors, distance_upper_bound=max_distance) - Returns: - A tuple of (distances, filtered_left_indices, filtered_right_indices) - """ - left_catalog_info = self.left_metadata.catalog_info - right_catalog_info = self.right_metadata.catalog_info - # align radec to indices - left_radec = self.left[[left_catalog_info.ra_column, left_catalog_info.dec_column]] - left_radec_aligned = left_radec.iloc[left_idx] - right_radec = self.right[[right_catalog_info.ra_column, right_catalog_info.dec_column]] - right_radec_aligned = right_radec.iloc[right_idx] - - # store the indices from each row - distances_df = pd.DataFrame.from_dict({"_left_idx": left_idx, "_right_idx": right_idx}) - - # calculate distances of each pair - distances_df[self.DISTANCE_COLUMN_NAME] = _great_circle_dist( - left_radec_aligned[left_catalog_info.ra_column].values, - left_radec_aligned[left_catalog_info.dec_column].values, - right_radec_aligned[right_catalog_info.ra_column].values, - right_radec_aligned[right_catalog_info.dec_column].values, - ) - # cull based on the distance threshold - distances_df = distances_df.loc[distances_df[self.DISTANCE_COLUMN_NAME] < d_thresh] - left_ids_filtered = distances_df["_left_idx"] - right_ids_filtered = distances_df["_right_idx"] - distances = distances_df[self.DISTANCE_COLUMN_NAME].to_numpy() - return distances, left_ids_filtered, right_ids_filtered + # index of the corresponding row in the left table [[0, 0, 0], [1, 1, 1], [2, 2, 2], ...] + left_index = np.arange(left_xyz.shape[0]) + # We need make the shape the same as for right_index + if n_neighbors > 1: + left_index = np.stack([left_index] * n_neighbors, axis=1) + # Infinite distance means no match + match_mask = np.isfinite(distances) + return distances[match_mask], left_index[match_mask], right_index[match_mask] -def _great_circle_dist(lon1, lat1, lon2, lat2): - """ - function that calculates the distance between two points - p1 (lon1, lat1) or (ra1, dec1) - p2 (lon2, lat2) or (ra2, dec2) - can be np.array() - returns np.array() +def _lon_lat_to_xyz(lon: npt.NDArray[np.float64], lat: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]: + """Converts longitude and latitude to cartesian coordinates on the unit sphere + + Args: + lon (np.ndarray[np.float64]): longitude in radians + lat (np.ndarray[np.float64]): latitude in radians """ - lon1 = np.radians(lon1) - lat1 = np.radians(lat1) - lon2 = np.radians(lon2) - lat2 = np.radians(lat2) - - return np.degrees( - 2 - * np.arcsin( - np.sqrt( - (np.sin((lat1 - lat2) * 0.5)) ** 2 - + np.cos(lat1) * np.cos(lat2) * (np.sin((lon1 - lon2) * 0.5)) ** 2 - ) - ) - ) + lon = np.radians(lon) + lat = np.radians(lat) + x = np.cos(lat) * np.cos(lon) + y = np.cos(lat) * np.sin(lon) + z = np.sin(lat) -def _frame_gnomonic(data_frame, catalog_info, clon, clat): - """ - method taken from lsd1: - creates a np.array of gnomonic distances for each source in the dataframe - from the center of the ordered pixel. These values are passed into - the kdtree NN query during the xmach routine. - """ - phi = np.radians(data_frame[catalog_info.dec_column].values) - lam = np.radians(data_frame[catalog_info.ra_column].values) - phi1 = np.radians(clat) - lam0 = np.radians(clon) - - cosc = np.sin(phi1) * np.sin(phi) + np.cos(phi1) * np.cos(phi) * np.cos(lam - lam0) - x_projected = np.cos(phi) * np.sin(lam - lam0) / cosc - y_projected = (np.cos(phi1) * np.sin(phi) - np.sin(phi1) * np.cos(phi) * np.cos(lam - lam0)) / cosc - - ret = np.column_stack((np.degrees(x_projected), np.degrees(y_projected))) - del phi, lam, phi1, lam0, cosc, x_projected, y_projected - return ret + return np.stack([x, y, z], axis=1) diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 0f848a98..68beab63 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -3,41 +3,55 @@ from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm +from lsdb.core.crossmatch.kdtree_gnomonic_match import KdTreeGnomonicCrossmatch +from lsdb.core.crossmatch.kdtree_match import KdTreeCrossmatch -def test_kdtree_crossmatch(small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct): - xmatched = small_sky_catalog.crossmatch(small_sky_xmatch_catalog).compute() - assert len(xmatched) == len(xmatch_correct) - for _, correct_row in xmatch_correct.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values - xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) +@pytest.mark.parametrize("algo", [KdTreeCrossmatch, KdTreeGnomonicCrossmatch]) +class TestCrossmatch: + @staticmethod + def test_kdtree_crossmatch(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct): + xmatched = small_sky_catalog.crossmatch(small_sky_xmatch_catalog, algorithm=algo).compute() + assert len(xmatched) == len(xmatch_correct) + for _, correct_row in xmatch_correct.iterrows(): + assert correct_row["ss_id"] in xmatched["id_small_sky"].values + xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] + assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] + assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) + @staticmethod + def test_kdtree_crossmatch_thresh(algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_005): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, d_thresh=0.005, algorithm=algo + ).compute() + assert len(xmatched) == len(xmatch_correct_005) + for _, correct_row in xmatch_correct_005.iterrows(): + assert correct_row["ss_id"] in xmatched["id_small_sky"].values + xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] + assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] + assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) -def test_kdtree_crossmatch_thresh(small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_005): - xmatched = small_sky_catalog.crossmatch(small_sky_xmatch_catalog, d_thresh=0.005).compute() - assert len(xmatched) == len(xmatch_correct_005) - for _, correct_row in xmatch_correct_005.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values - xmatch_row = xmatched[xmatched["id_small_sky"] == correct_row["ss_id"]] - assert xmatch_row["id_small_sky_xmatch"].values == correct_row["xmatch_id"] - assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) - + @staticmethod + def test_kdtree_crossmatch_multiple_neighbors( + algo, small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_3n_2t_no_margin + ): + xmatched = small_sky_catalog.crossmatch( + small_sky_xmatch_catalog, n_neighbors=3, d_thresh=2, algorithm=algo + ).compute() + assert len(xmatched) == len(xmatch_correct_3n_2t_no_margin) + for _, correct_row in xmatch_correct_3n_2t_no_margin.iterrows(): + assert correct_row["ss_id"] in xmatched["id_small_sky"].values + xmatch_row = xmatched[ + (xmatched["id_small_sky"] == correct_row["ss_id"]) + & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) + ] + assert len(xmatch_row) == 1 + assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) -def test_kdtree_crossmatch_multiple_neighbors( - small_sky_catalog, small_sky_xmatch_catalog, xmatch_correct_3n_2t_no_margin -): - xmatched = small_sky_catalog.crossmatch(small_sky_xmatch_catalog, n_neighbors=3, d_thresh=2).compute() - assert len(xmatched) == len(xmatch_correct_3n_2t_no_margin) - for _, correct_row in xmatch_correct_3n_2t_no_margin.iterrows(): - assert correct_row["ss_id"] in xmatched["id_small_sky"].values - xmatch_row = xmatched[ - (xmatched["id_small_sky"] == correct_row["ss_id"]) - & (xmatched["id_small_sky_xmatch"] == correct_row["xmatch_id"]) - ] - assert len(xmatch_row) == 1 - assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) + @staticmethod + def test_wrong_suffixes(algo, small_sky_catalog, small_sky_xmatch_catalog): + with pytest.raises(ValueError): + small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffixes=("wrong",), algorithm=algo) def test_custom_crossmatch_algorithm(small_sky_catalog, small_sky_xmatch_catalog, xmatch_mock): @@ -52,11 +66,6 @@ def test_custom_crossmatch_algorithm(small_sky_catalog, small_sky_xmatch_catalog assert xmatch_row["_DIST"].values == pytest.approx(correct_row["dist"]) -def test_wrong_suffixes(small_sky_catalog, small_sky_xmatch_catalog): - with pytest.raises(ValueError): - small_sky_catalog.crossmatch(small_sky_xmatch_catalog, suffixes=("wrong",)) - - # pylint: disable=too-few-public-methods class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm): """Mock class used to test a crossmatch algorithm"""