From f630822c238c979bed24d8abac56f06a77f925b9 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Sat, 11 Nov 2023 17:06:00 -0700 Subject: [PATCH 1/7] add catalog joining --- src/lsdb/catalog/catalog.py | 27 +++++++++++ src/lsdb/dask/join_catalog_data.py | 78 ++++++++++++++++++++++++++++++ 2 files changed, 105 insertions(+) create mode 100644 src/lsdb/dask/join_catalog_data.py diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 2aa86be5..57f2f6cd 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -13,6 +13,9 @@ from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data +from lsdb.dask.join_catalog_data import join_catalog_data_on + +DaskDFPixelMap = Dict[HealpixPixel, int] from lsdb.types import DaskDFPixelMap @@ -317,3 +320,27 @@ def to_hipscat( **kwargs: Arguments to pass to the parquet write operations """ io.to_hipscat(self, base_catalog_path, catalog_name, storage_options, **kwargs) + + def join( + self, + other: Catalog, + left_on: str = None, + right_on: str = None, + suffixes: Tuple[str, str] | None = None, + output_catalog_name: str | None = None + ) -> Catalog: + if suffixes is None: + suffixes = ("", "") + + ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes=suffixes) + + if output_catalog_name is None: + output_catalog_name = self.hc_structure.catalog_info.catalog_name + new_catalog_info = dataclasses.replace( + self.hc_structure.catalog_info, + catalog_name=output_catalog_name, + ra_column=self.hc_structure.catalog_info.ra_column + suffixes[0], + dec_column=self.hc_structure.catalog_info.dec_column + suffixes[0], + ) + hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree) + return Catalog(ddf, ddf_map, hc_catalog) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py new file mode 100644 index 00000000..8254d914 --- /dev/null +++ b/src/lsdb/dask/join_catalog_data.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Tuple, cast + +import dask +import dask.dataframe as dd +import pandas as pd +from hipscat.pixel_math import HealpixPixel +from hipscat.pixel_tree import PixelAlignmentType, PixelAlignment, align_trees + +if TYPE_CHECKING: + from lsdb.catalog.catalog import Catalog, DaskDFPixelMap + + +def align_catalog_to_partitions( + catalog: Catalog, + pixels: pd.DataFrame, + order_col: str = "Norder", + pixel_col: str = "Npix" +) -> dd.core.DataFrame: + dfs = catalog._ddf.to_delayed() + partitions = pixels.apply(lambda row: dfs[ + catalog.get_partition_index(row[order_col], row[pixel_col])], axis=1) + partitions_list = partitions.to_list() + return partitions_list + +@dask.delayed +def perform_join_on(left: pd.DataFrame, right: pd.DataFrame, left_on: str, right_on: str, suffixes: Tuple[str, str]): + left_columns_renamed = {name: name + suffixes[0] for name in left.columns} + left = left.rename(columns=left_columns_renamed) + right_columns_renamed = {name: name + suffixes[1] for name in right.columns} + right = right.rename(columns=right_columns_renamed) + merged = left.reset_index().merge(right, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1]) + merged.set_index("_hipscat_index", inplace=True) + return merged + + +def join_catalog_data_on( + left: Catalog, + right: Catalog, + left_on: str = None, + right_on: str = None, + suffixes: Tuple[str, str] | None = None +) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]: + alignment = align_trees( + left.hc_structure.pixel_tree, + right.hc_structure.pixel_tree, + alignment_type=PixelAlignmentType.INNER + ) + join_pixels = alignment.pixel_mapping + left_aligned_to_join_partitions = align_catalog_to_partitions( + left, + join_pixels, + order_col=PixelAlignment.PRIMARY_ORDER_COLUMN_NAME, + pixel_col=PixelAlignment.PRIMARY_PIXEL_COLUMN_NAME, + ) + right_aligned_to_join_partitions = align_catalog_to_partitions( + right, + join_pixels, + order_col=PixelAlignment.JOIN_ORDER_COLUMN_NAME, + pixel_col=PixelAlignment.JOIN_PIXEL_COLUMN_NAME, + ) + joined_partitions = [perform_join_on(left_df, right_df, left_on, right_on, suffixes) for left_df, right_df in zip(left_aligned_to_join_partitions, right_aligned_to_join_partitions)] + partition_map = {} + for i, (_, row) in enumerate(join_pixels.iterrows()): + pixel = HealpixPixel(order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME], + pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME]) + partition_map[pixel] = i + meta = {} + for name, t in left._ddf.dtypes.items(): + meta[name + suffixes[0]] = pd.Series(dtype=t) + for name, t in right._ddf.dtypes.items(): + meta[name + suffixes[1]] = pd.Series(dtype=t) + meta_df = pd.DataFrame(meta) + meta_df.index.name = "_hipscat_index" + ddf = dd.from_delayed(joined_partitions, meta=meta_df) + ddf = cast(dd.DataFrame, ddf) + return ddf, partition_map, alignment From 6fcd84ef0a7a29e4146f15906a876091009e8b61 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Thu, 16 Nov 2023 16:19:57 -0500 Subject: [PATCH 2/7] wip --- src/lsdb/dask/join_catalog_data.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 8254d914..12bc3748 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -6,26 +6,21 @@ import dask.dataframe as dd import pandas as pd from hipscat.pixel_math import HealpixPixel +from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN from hipscat.pixel_tree import PixelAlignmentType, PixelAlignment, align_trees +from lsdb.dask.crossmatch_catalog_data import align_catalog_to_partitions + if TYPE_CHECKING: from lsdb.catalog.catalog import Catalog, DaskDFPixelMap -def align_catalog_to_partitions( - catalog: Catalog, - pixels: pd.DataFrame, - order_col: str = "Norder", - pixel_col: str = "Npix" -) -> dd.core.DataFrame: - dfs = catalog._ddf.to_delayed() - partitions = pixels.apply(lambda row: dfs[ - catalog.get_partition_index(row[order_col], row[pixel_col])], axis=1) - partitions_list = partitions.to_list() - return partitions_list - @dask.delayed -def perform_join_on(left: pd.DataFrame, right: pd.DataFrame, left_on: str, right_on: str, suffixes: Tuple[str, str]): +def perform_join_on(left: pd.DataFrame, right: pd.DataFrame, left_on: str, right_on: str, left_pixel: HealpixPixel, right_pixel: HealpixPixel, suffixes: Tuple[str, str]): + if right_pixel.order > left_pixel.order: + lower_bound = healpix_to_hipscat_id(right_pixel.order, right_pixel.pixel) + upper_bound = healpix_to_hipscat_id(right_pixel.order, right_pixel.pixel + 1) + left = left[(left.index >= lower_bound) & (left.index < upper_bound)] left_columns_renamed = {name: name + suffixes[0] for name in left.columns} left = left.rename(columns=left_columns_renamed) right_columns_renamed = {name: name + suffixes[1] for name in right.columns} @@ -72,7 +67,7 @@ def join_catalog_data_on( for name, t in right._ddf.dtypes.items(): meta[name + suffixes[1]] = pd.Series(dtype=t) meta_df = pd.DataFrame(meta) - meta_df.index.name = "_hipscat_index" + meta_df.index.name = HIPSCAT_ID_COLUMN ddf = dd.from_delayed(joined_partitions, meta=meta_df) ddf = cast(dd.DataFrame, ddf) return ddf, partition_map, alignment From dc906a677ecc18f452a4189b469c57051c2d8c3d Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 20 Nov 2023 18:19:18 -0500 Subject: [PATCH 3/7] refactor join and crossmatch shared code --- src/lsdb/dask/crossmatch_catalog_data.py | 63 +++++++++++++++++------- src/lsdb/dask/join_catalog_data.py | 57 ++++++++++++++------- 2 files changed, 82 insertions(+), 38 deletions(-) diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index fbf4b069..b491ca91 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Tuple, Type, cast +from typing import TYPE_CHECKING, List, Tuple, Type, cast, Dict, Sequence import dask import dask.dataframe as dd @@ -43,9 +43,7 @@ def perform_crossmatch( the result. """ if right_order > left_order: - lower_bound = healpix_to_hipscat_id(right_order, right_pixel) - upper_bound = healpix_to_hipscat_id(right_order, right_pixel + 1) - left_df = left_df[(left_df.index >= lower_bound) & (left_df.index < upper_bound)] + left_df = filter_by_hipscat_index_to_pixel(left_df, right_order, right_pixel) return algorithm( left_df, right_df, @@ -59,6 +57,13 @@ def perform_crossmatch( ).crossmatch(**kwargs) +def filter_by_hipscat_index_to_pixel(dataframe: pd.DataFrame, order: int, pixel: int) -> pd.DataFrame: + lower_bound = healpix_to_hipscat_id(order, pixel) + upper_bound = healpix_to_hipscat_id(order, pixel + 1) + filtered_df = dataframe[(dataframe.index >= lower_bound) & (dataframe.index < upper_bound)] + return filtered_df + + # pylint: disable=too-many-locals def crossmatch_catalog_data( left: Catalog, @@ -141,23 +146,13 @@ def crossmatch_catalog_data( ] # generate dask df partition map from alignment - partition_map = {} - for i, (_, row) in enumerate(join_pixels.iterrows()): - pixel = HealpixPixel( - order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME], - pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME], - ) - partition_map[pixel] = i + partition_map = get_partition_map_from_alignment_pixels(join_pixels) # generate meta table structure for dask df - meta = {} - for name, col_type in left.dtypes.items(): - meta[name + suffixes[0]] = pd.Series(dtype=col_type) - for name, col_type in right.dtypes.items(): - meta[name + suffixes[1]] = pd.Series(dtype=col_type) - meta[crossmatch_algorithm.DISTANCE_COLUMN_NAME] = pd.Series(dtype=np.dtype("float64")) - meta_df = pd.DataFrame(meta) - meta_df.index.name = HIPSCAT_ID_COLUMN + extra_columns = { + crossmatch_algorithm.DISTANCE_COLUMN_NAME: pd.Series(dtype=np.dtype("float64")) + } + meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, extra_columns=extra_columns) # create dask df from delayed partitions ddf = dd.from_delayed(joined_partitions, meta=meta_df) @@ -166,6 +161,36 @@ def crossmatch_catalog_data( return ddf, partition_map, alignment +def generate_meta_df_for_joined_tables( + tables: Sequence[Catalog], + suffixes: Sequence[str], + extra_columns: Dict[str, pd.Series] | None = None, + index_name: str = HIPSCAT_ID_COLUMN, +): + if len(tables) != len(suffixes): + raise ValueError("tables and suffixes must have the same length") + meta = {} + for table, suffix in zip(tables, suffixes): + for name, col_type in table.dtypes.items(): + meta[name + suffix] = pd.Series(dtype=col_type) + if extra_columns is not None: + meta.update(extra_columns) + meta_df = pd.DataFrame(meta) + meta_df.index.name = index_name + return meta_df + + +def get_partition_map_from_alignment_pixels(join_pixels): + partition_map = {} + for i, (_, row) in enumerate(join_pixels.iterrows()): + pixel = HealpixPixel( + order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME], + pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME], + ) + partition_map[pixel] = i + return partition_map + + def get_crossmatch_algorithm( algorithm: Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm, ) -> Type[AbstractCrossmatchAlgorithm]: diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index 12bc3748..b256f34b 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -9,24 +9,31 @@ from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN from hipscat.pixel_tree import PixelAlignmentType, PixelAlignment, align_trees -from lsdb.dask.crossmatch_catalog_data import align_catalog_to_partitions +from lsdb.dask.crossmatch_catalog_data import align_catalog_to_partitions, filter_by_hipscat_index_to_pixel, \ + get_partition_map_from_alignment_pixels, generate_meta_df_for_joined_tables if TYPE_CHECKING: from lsdb.catalog.catalog import Catalog, DaskDFPixelMap @dask.delayed -def perform_join_on(left: pd.DataFrame, right: pd.DataFrame, left_on: str, right_on: str, left_pixel: HealpixPixel, right_pixel: HealpixPixel, suffixes: Tuple[str, str]): +def perform_join_on( + left: pd.DataFrame, + right: pd.DataFrame, + left_on: str, + right_on: str, + left_pixel: HealpixPixel, + right_pixel: HealpixPixel, + suffixes: Tuple[str, str] +): if right_pixel.order > left_pixel.order: - lower_bound = healpix_to_hipscat_id(right_pixel.order, right_pixel.pixel) - upper_bound = healpix_to_hipscat_id(right_pixel.order, right_pixel.pixel + 1) - left = left[(left.index >= lower_bound) & (left.index < upper_bound)] + left = filter_by_hipscat_index_to_pixel(left, right_pixel.order, right_pixel.pixel) left_columns_renamed = {name: name + suffixes[0] for name in left.columns} left = left.rename(columns=left_columns_renamed) right_columns_renamed = {name: name + suffixes[1] for name in right.columns} right = right.rename(columns=right_columns_renamed) merged = left.reset_index().merge(right, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1]) - merged.set_index("_hipscat_index", inplace=True) + merged.set_index(HIPSCAT_ID_COLUMN, inplace=True) return merged @@ -55,19 +62,31 @@ def join_catalog_data_on( order_col=PixelAlignment.JOIN_ORDER_COLUMN_NAME, pixel_col=PixelAlignment.JOIN_PIXEL_COLUMN_NAME, ) - joined_partitions = [perform_join_on(left_df, right_df, left_on, right_on, suffixes) for left_df, right_df in zip(left_aligned_to_join_partitions, right_aligned_to_join_partitions)] - partition_map = {} - for i, (_, row) in enumerate(join_pixels.iterrows()): - pixel = HealpixPixel(order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME], - pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME]) - partition_map[pixel] = i - meta = {} - for name, t in left._ddf.dtypes.items(): - meta[name + suffixes[0]] = pd.Series(dtype=t) - for name, t in right._ddf.dtypes.items(): - meta[name + suffixes[1]] = pd.Series(dtype=t) - meta_df = pd.DataFrame(meta) - meta_df.index.name = HIPSCAT_ID_COLUMN + + left_pixels = [ + HealpixPixel( + row[PixelAlignment.PRIMARY_ORDER_COLUMN_NAME], + row[PixelAlignment.PRIMARY_PIXEL_COLUMN_NAME] + ) + for _, row in join_pixels.iterrows() + ] + + right_pixels = [ + HealpixPixel( + row[PixelAlignment.JOIN_ORDER_COLUMN_NAME], + row[PixelAlignment.JOIN_PIXEL_COLUMN_NAME] + ) + for _, row in join_pixels.iterrows() + ] + + joined_partitions = [ + perform_join_on(left_df, right_df, left_on, right_on, left_pixel, right_pixel, suffixes) + for left_df, right_df, left_pixel, right_pixel + in zip(left_aligned_to_join_partitions, right_aligned_to_join_partitions, left_pixels, right_pixels) + ] + + partition_map = get_partition_map_from_alignment_pixels(join_pixels) + meta_df = generate_meta_df_for_joined_tables([left, right], suffixes) ddf = dd.from_delayed(joined_partitions, meta=meta_df) ddf = cast(dd.DataFrame, ddf) return ddf, partition_map, alignment From 1eb74d6eb595128e036049762f225338cd3ae365 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 20 Nov 2023 19:07:30 -0500 Subject: [PATCH 4/7] add unit tests --- src/lsdb/catalog/catalog.py | 16 +++++++++++++--- tests/conftest.py | 3 ++- tests/lsdb/catalog/test_join.py | 31 +++++++++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 4 deletions(-) create mode 100644 tests/lsdb/catalog/test_join.py diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 57f2f6cd..cd287594 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -324,18 +324,28 @@ def to_hipscat( def join( self, other: Catalog, - left_on: str = None, - right_on: str = None, + left_on: str, + right_on: str, suffixes: Tuple[str, str] | None = None, output_catalog_name: str | None = None ) -> Catalog: if suffixes is None: - suffixes = ("", "") + suffixes = (f"_{self.name}", f"_{other.name}") + + if len(suffixes) != 2: + raise ValueError("`suffixes` must be a tuple with two strings") + + if left_on not in self._ddf.columns: + raise ValueError("left_on must be a column in the left catalog") + + if right_on not in other._ddf.columns: + raise ValueError("right_on must be a column in the right catalog") ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes=suffixes) if output_catalog_name is None: output_catalog_name = self.hc_structure.catalog_info.catalog_name + new_catalog_info = dataclasses.replace( self.hc_structure.catalog_info, catalog_name=output_catalog_name, diff --git a/tests/conftest.py b/tests/conftest.py index 07c8c270..8898952f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,6 +5,7 @@ import pytest import lsdb +from lsdb import Catalog DATA_DIR_NAME = "data" SMALL_SKY_DIR_NAME = "small_sky" @@ -45,7 +46,7 @@ def small_sky_hipscat_catalog(small_sky_dir): @pytest.fixture def small_sky_catalog(small_sky_dir): - return lsdb.read_hipscat(small_sky_dir) + return lsdb.read_hipscat(small_sky_dir, catalog_type=Catalog) @pytest.fixture diff --git a/tests/lsdb/catalog/test_join.py b/tests/lsdb/catalog/test_join.py new file mode 100644 index 00000000..54d0da9d --- /dev/null +++ b/tests/lsdb/catalog/test_join.py @@ -0,0 +1,31 @@ +import pytest + + +def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_catalog): + suffixes = ("_a", "_b") + joined = small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="id", suffixes=suffixes) + for col_name, dtype in small_sky_catalog.dtypes.items(): + assert (col_name + suffixes[0], dtype) in joined.dtypes.items() + for col_name, dtype in small_sky_order1_catalog.dtypes.items(): + assert (col_name + suffixes[1], dtype) in joined.dtypes.items() + joined_compute = joined.compute() + small_sky_compute = small_sky_catalog.compute() + small_sky_order1_compute = small_sky_order1_catalog.compute() + assert len(joined_compute) == len(small_sky_compute) + assert len(joined_compute) == len(small_sky_order1_compute) + for index, row in small_sky_compute.iterrows(): + joined_row = joined_compute.query(f"id{suffixes[0]} == {row['id']}") + assert joined_row.index.values[0] == index + assert joined_row[f"id{suffixes[1]}"].values[0] == row["id"] + + +def test_join_wrong_columns(small_sky_catalog, small_sky_order1_catalog): + with pytest.raises(ValueError): + small_sky_catalog.join(small_sky_order1_catalog, left_on="bad", right_on="id") + with pytest.raises(ValueError): + small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="bad") + + +def test_join_wrong_suffixes(small_sky_catalog, small_sky_order1_catalog): + with pytest.raises(ValueError): + small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="id", suffixes=("wrong",)) From 6c1f1684c12f0e6986e34e89fcb98a90683b2230 Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 20 Nov 2023 19:44:44 -0500 Subject: [PATCH 5/7] add docstrings --- src/lsdb/catalog/catalog.py | 17 +++++++++ src/lsdb/dask/crossmatch_catalog_data.py | 45 ++++++++++++++++++---- src/lsdb/dask/join_catalog_data.py | 48 ++++++++++++++++++++---- 3 files changed, 96 insertions(+), 14 deletions(-) diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index cd287594..31ffcc0c 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -329,6 +329,23 @@ def join( suffixes: Tuple[str, str] | None = None, output_catalog_name: str | None = None ) -> Catalog: + """Perform a spatial join to another catalog + + Joins two catalogs together on a shared column value, merging rows where they match. The operation + only joins data from matching partitions, and does not join rows that have a matching column value but + are in separate partitions in the sky. For a more general join, see the `merge` function. + + Args: + other (Catalog): the right catalog to join to + left_on (str): the name of the column in the left catalog to join on + right_on (str): the name of the column in the right catalog to join on + suffixes (Tuple[str,str]): suffixes to apply to the columns of each table + output_catalog_name (str): The name of the resulting catalog to be stored in metadata + + Returns: + A new catalog with the columns from each of the input catalogs with their respective suffixes + added, and the rows merged on the specified columns. + """ if suffixes is None: suffixes = (f"_{self.name}", f"_{other.name}") diff --git a/src/lsdb/dask/crossmatch_catalog_data.py b/src/lsdb/dask/crossmatch_catalog_data.py index b491ca91..991b7f26 100644 --- a/src/lsdb/dask/crossmatch_catalog_data.py +++ b/src/lsdb/dask/crossmatch_catalog_data.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TYPE_CHECKING, List, Tuple, Type, cast, Dict, Sequence +from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple, Type, cast import dask import dask.dataframe as dd @@ -58,6 +58,16 @@ def perform_crossmatch( def filter_by_hipscat_index_to_pixel(dataframe: pd.DataFrame, order: int, pixel: int) -> pd.DataFrame: + """Filters a catalog dataframe to the points within a specified HEALPix pixel using the hipscat index + + Args: + dataframe (pd.DataFrame): The dataframe to filter + order (int): The order of the HEALPix pixel to filter to + pixel (int): The pixel number in NESTED numbering of the HEALPix pixel to filter to + + Returns: + The filtered dataframe with only the rows that are within the specified HEALPix pixel + """ lower_bound = healpix_to_hipscat_id(order, pixel) upper_bound = healpix_to_hipscat_id(order, pixel + 1) filtered_df = dataframe[(dataframe.index >= lower_bound) & (dataframe.index < upper_bound)] @@ -162,15 +172,28 @@ def crossmatch_catalog_data( def generate_meta_df_for_joined_tables( - tables: Sequence[Catalog], + catalogs: Sequence[Catalog], suffixes: Sequence[str], extra_columns: Dict[str, pd.Series] | None = None, index_name: str = HIPSCAT_ID_COLUMN, -): - if len(tables) != len(suffixes): - raise ValueError("tables and suffixes must have the same length") +) -> pd.DataFrame: + """Generates a Dask meta DataFrame that would result from joining two catalogs + + Creates an empty dataframe with the columns of each catalog appended with a suffix. Allows specifying + extra columns that should also be added, and the name of the index of the resulting dataframe. + + Args: + catalogs (Sequence[Catalog]): The catalogs to merge together + suffixes (Sequence[Str]): The column suffixes to apply each catalog + extra_columns (Dict[str, pd.Series]): Any additional columns to the merged catalogs + index_name: The name of the index in the resulting DataFrame + + Returns: + An empty dataframe with the columns of each catalog with their respective suffix, and any extra columns + specified, with the index name set. + """ meta = {} - for table, suffix in zip(tables, suffixes): + for table, suffix in zip(catalogs, suffixes): for name, col_type in table.dtypes.items(): meta[name + suffix] = pd.Series(dtype=col_type) if extra_columns is not None: @@ -180,7 +203,15 @@ def generate_meta_df_for_joined_tables( return meta_df -def get_partition_map_from_alignment_pixels(join_pixels): +def get_partition_map_from_alignment_pixels(join_pixels: pd.DataFrame) -> DaskDFPixelMap: + """Gets a dictionary mapping HEALPix pixel to index of pixel in the pixel_mapping of a `PixelAlignment` + + Args: + join_pixels (pd.DataFrame): The pixel_mapping from a `PixelAlignment` object + + Returns: + A dictionary mapping HEALPix pixel to the index that the pixel occurs in the pixel_mapping table + """ partition_map = {} for i, (_, row) in enumerate(join_pixels.iterrows()): pixel = HealpixPixel( diff --git a/src/lsdb/dask/join_catalog_data.py b/src/lsdb/dask/join_catalog_data.py index b256f34b..4b81bce0 100644 --- a/src/lsdb/dask/join_catalog_data.py +++ b/src/lsdb/dask/join_catalog_data.py @@ -1,3 +1,5 @@ +# pylint: disable=duplicate-code + from __future__ import annotations from typing import TYPE_CHECKING, Tuple, cast @@ -6,11 +8,15 @@ import dask.dataframe as dd import pandas as pd from hipscat.pixel_math import HealpixPixel -from hipscat.pixel_math.hipscat_id import healpix_to_hipscat_id, HIPSCAT_ID_COLUMN -from hipscat.pixel_tree import PixelAlignmentType, PixelAlignment, align_trees +from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN +from hipscat.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees -from lsdb.dask.crossmatch_catalog_data import align_catalog_to_partitions, filter_by_hipscat_index_to_pixel, \ - get_partition_map_from_alignment_pixels, generate_meta_df_for_joined_tables +from lsdb.dask.crossmatch_catalog_data import ( + align_catalog_to_partitions, + filter_by_hipscat_index_to_pixel, + generate_meta_df_for_joined_tables, + get_partition_map_from_alignment_pixels, +) if TYPE_CHECKING: from lsdb.catalog.catalog import Catalog, DaskDFPixelMap @@ -26,6 +32,20 @@ def perform_join_on( right_pixel: HealpixPixel, suffixes: Tuple[str, str] ): + """Performs a join on two catalog partitions + + Args: + left (pd.DataFrame): the left partition to merge + right (pd.DataFrame): the right partition to merge + left_on (str): the column to join on from the left partition + right_on (str): the column to join on from the right partition + left_pixel (HealpixPixel): the HEALPix pixel of the left partition + right_pixel (HealpixPixel): the HEALPix pixel of the right partition + suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + + Returns: + A dataframe with the result of merging the left and right partitions on the specified columns + """ if right_pixel.order > left_pixel.order: left = filter_by_hipscat_index_to_pixel(left, right_pixel.order, right_pixel.pixel) left_columns_renamed = {name: name + suffixes[0] for name in left.columns} @@ -40,10 +60,24 @@ def perform_join_on( def join_catalog_data_on( left: Catalog, right: Catalog, - left_on: str = None, - right_on: str = None, - suffixes: Tuple[str, str] | None = None + left_on: str, + right_on: str, + suffixes: Tuple[str, str] ) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]: + """Joins two catalogs spatially on a specified column + + Args: + left (Catalog): the left catalog to join + right (Catalog): the right catalog to join + left_on (str): the column to join on from the left partition + right_on (str): the column to join on from the right partition + suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names + + Returns: + A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix + pixel to partition index within the dataframe, and the PixelAlignment of the two input + catalogs. + """ alignment = align_trees( left.hc_structure.pixel_tree, right.hc_structure.pixel_tree, From da65ccd0892982c1ae919bcd4ce3dafd8218673d Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 20 Nov 2023 19:54:51 -0500 Subject: [PATCH 6/7] fix import --- src/lsdb/catalog/catalog.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index 31ffcc0c..c5b588f0 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -15,7 +15,6 @@ from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data from lsdb.dask.join_catalog_data import join_catalog_data_on -DaskDFPixelMap = Dict[HealpixPixel, int] from lsdb.types import DaskDFPixelMap From 5de2bc07feeb164738bd99a7c0e808c390348dcf Mon Sep 17 00:00:00 2001 From: Sean McGuire Date: Mon, 20 Nov 2023 19:57:55 -0500 Subject: [PATCH 7/7] fix isort --- src/lsdb/catalog/catalog.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lsdb/catalog/catalog.py b/src/lsdb/catalog/catalog.py index c5b588f0..7946f00f 100644 --- a/src/lsdb/catalog/catalog.py +++ b/src/lsdb/catalog/catalog.py @@ -14,7 +14,6 @@ from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data from lsdb.dask.join_catalog_data import join_catalog_data_on - from lsdb.types import DaskDFPixelMap