Skip to content

Commit

Permalink
Merge pull request #70 from astronomy-commons/sean/catalog-join
Browse files Browse the repository at this point in the history
Add Catalog Joining on a column
  • Loading branch information
smcguire-cmu authored Nov 27, 2023
2 parents ed5bfe9 + 5de2bc0 commit cba6aa5
Show file tree
Hide file tree
Showing 5 changed files with 286 additions and 20 deletions.
52 changes: 52 additions & 0 deletions src/lsdb/catalog/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from lsdb.core.crossmatch.abstract_crossmatch_algorithm import AbstractCrossmatchAlgorithm
from lsdb.core.crossmatch.crossmatch_algorithms import BuiltInCrossmatchAlgorithm
from lsdb.dask.crossmatch_catalog_data import crossmatch_catalog_data
from lsdb.dask.join_catalog_data import join_catalog_data_on
from lsdb.types import DaskDFPixelMap


Expand Down Expand Up @@ -317,3 +318,54 @@ def to_hipscat(
**kwargs: Arguments to pass to the parquet write operations
"""
io.to_hipscat(self, base_catalog_path, catalog_name, storage_options, **kwargs)

def join(
self,
other: Catalog,
left_on: str,
right_on: str,
suffixes: Tuple[str, str] | None = None,
output_catalog_name: str | None = None
) -> Catalog:
"""Perform a spatial join to another catalog
Joins two catalogs together on a shared column value, merging rows where they match. The operation
only joins data from matching partitions, and does not join rows that have a matching column value but
are in separate partitions in the sky. For a more general join, see the `merge` function.
Args:
other (Catalog): the right catalog to join to
left_on (str): the name of the column in the left catalog to join on
right_on (str): the name of the column in the right catalog to join on
suffixes (Tuple[str,str]): suffixes to apply to the columns of each table
output_catalog_name (str): The name of the resulting catalog to be stored in metadata
Returns:
A new catalog with the columns from each of the input catalogs with their respective suffixes
added, and the rows merged on the specified columns.
"""
if suffixes is None:
suffixes = (f"_{self.name}", f"_{other.name}")

if len(suffixes) != 2:
raise ValueError("`suffixes` must be a tuple with two strings")

if left_on not in self._ddf.columns:
raise ValueError("left_on must be a column in the left catalog")

if right_on not in other._ddf.columns:
raise ValueError("right_on must be a column in the right catalog")

ddf, ddf_map, alignment = join_catalog_data_on(self, other, left_on, right_on, suffixes=suffixes)

if output_catalog_name is None:
output_catalog_name = self.hc_structure.catalog_info.catalog_name

new_catalog_info = dataclasses.replace(
self.hc_structure.catalog_info,
catalog_name=output_catalog_name,
ra_column=self.hc_structure.catalog_info.ra_column + suffixes[0],
dec_column=self.hc_structure.catalog_info.dec_column + suffixes[0],
)
hc_catalog = hc.catalog.Catalog(new_catalog_info, alignment.pixel_tree)
return Catalog(ddf, ddf_map, hc_catalog)
94 changes: 75 additions & 19 deletions src/lsdb/dask/crossmatch_catalog_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import TYPE_CHECKING, List, Tuple, Type, cast
from typing import TYPE_CHECKING, Dict, List, Sequence, Tuple, Type, cast

import dask
import dask.dataframe as dd
Expand Down Expand Up @@ -43,9 +43,7 @@ def perform_crossmatch(
the result.
"""
if right_order > left_order:
lower_bound = healpix_to_hipscat_id(right_order, right_pixel)
upper_bound = healpix_to_hipscat_id(right_order, right_pixel + 1)
left_df = left_df[(left_df.index >= lower_bound) & (left_df.index < upper_bound)]
left_df = filter_by_hipscat_index_to_pixel(left_df, right_order, right_pixel)
return algorithm(
left_df,
right_df,
Expand All @@ -59,6 +57,23 @@ def perform_crossmatch(
).crossmatch(**kwargs)


def filter_by_hipscat_index_to_pixel(dataframe: pd.DataFrame, order: int, pixel: int) -> pd.DataFrame:
"""Filters a catalog dataframe to the points within a specified HEALPix pixel using the hipscat index
Args:
dataframe (pd.DataFrame): The dataframe to filter
order (int): The order of the HEALPix pixel to filter to
pixel (int): The pixel number in NESTED numbering of the HEALPix pixel to filter to
Returns:
The filtered dataframe with only the rows that are within the specified HEALPix pixel
"""
lower_bound = healpix_to_hipscat_id(order, pixel)
upper_bound = healpix_to_hipscat_id(order, pixel + 1)
filtered_df = dataframe[(dataframe.index >= lower_bound) & (dataframe.index < upper_bound)]
return filtered_df


# pylint: disable=too-many-locals
def crossmatch_catalog_data(
left: Catalog,
Expand Down Expand Up @@ -141,23 +156,13 @@ def crossmatch_catalog_data(
]

# generate dask df partition map from alignment
partition_map = {}
for i, (_, row) in enumerate(join_pixels.iterrows()):
pixel = HealpixPixel(
order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME],
pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME],
)
partition_map[pixel] = i
partition_map = get_partition_map_from_alignment_pixels(join_pixels)

# generate meta table structure for dask df
meta = {}
for name, col_type in left.dtypes.items():
meta[name + suffixes[0]] = pd.Series(dtype=col_type)
for name, col_type in right.dtypes.items():
meta[name + suffixes[1]] = pd.Series(dtype=col_type)
meta[crossmatch_algorithm.DISTANCE_COLUMN_NAME] = pd.Series(dtype=np.dtype("float64"))
meta_df = pd.DataFrame(meta)
meta_df.index.name = HIPSCAT_ID_COLUMN
extra_columns = {
crossmatch_algorithm.DISTANCE_COLUMN_NAME: pd.Series(dtype=np.dtype("float64"))
}
meta_df = generate_meta_df_for_joined_tables([left, right], suffixes, extra_columns=extra_columns)

# create dask df from delayed partitions
ddf = dd.from_delayed(joined_partitions, meta=meta_df)
Expand All @@ -166,6 +171,57 @@ def crossmatch_catalog_data(
return ddf, partition_map, alignment


def generate_meta_df_for_joined_tables(
catalogs: Sequence[Catalog],
suffixes: Sequence[str],
extra_columns: Dict[str, pd.Series] | None = None,
index_name: str = HIPSCAT_ID_COLUMN,
) -> pd.DataFrame:
"""Generates a Dask meta DataFrame that would result from joining two catalogs
Creates an empty dataframe with the columns of each catalog appended with a suffix. Allows specifying
extra columns that should also be added, and the name of the index of the resulting dataframe.
Args:
catalogs (Sequence[Catalog]): The catalogs to merge together
suffixes (Sequence[Str]): The column suffixes to apply each catalog
extra_columns (Dict[str, pd.Series]): Any additional columns to the merged catalogs
index_name: The name of the index in the resulting DataFrame
Returns:
An empty dataframe with the columns of each catalog with their respective suffix, and any extra columns
specified, with the index name set.
"""
meta = {}
for table, suffix in zip(catalogs, suffixes):
for name, col_type in table.dtypes.items():
meta[name + suffix] = pd.Series(dtype=col_type)
if extra_columns is not None:
meta.update(extra_columns)
meta_df = pd.DataFrame(meta)
meta_df.index.name = index_name
return meta_df


def get_partition_map_from_alignment_pixels(join_pixels: pd.DataFrame) -> DaskDFPixelMap:
"""Gets a dictionary mapping HEALPix pixel to index of pixel in the pixel_mapping of a `PixelAlignment`
Args:
join_pixels (pd.DataFrame): The pixel_mapping from a `PixelAlignment` object
Returns:
A dictionary mapping HEALPix pixel to the index that the pixel occurs in the pixel_mapping table
"""
partition_map = {}
for i, (_, row) in enumerate(join_pixels.iterrows()):
pixel = HealpixPixel(
order=row[PixelAlignment.ALIGNED_ORDER_COLUMN_NAME],
pixel=row[PixelAlignment.ALIGNED_PIXEL_COLUMN_NAME],
)
partition_map[pixel] = i
return partition_map


def get_crossmatch_algorithm(
algorithm: Type[AbstractCrossmatchAlgorithm] | BuiltInCrossmatchAlgorithm,
) -> Type[AbstractCrossmatchAlgorithm]:
Expand Down
126 changes: 126 additions & 0 deletions src/lsdb/dask/join_catalog_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
# pylint: disable=duplicate-code

from __future__ import annotations

from typing import TYPE_CHECKING, Tuple, cast

import dask
import dask.dataframe as dd
import pandas as pd
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN
from hipscat.pixel_tree import PixelAlignment, PixelAlignmentType, align_trees

from lsdb.dask.crossmatch_catalog_data import (
align_catalog_to_partitions,
filter_by_hipscat_index_to_pixel,
generate_meta_df_for_joined_tables,
get_partition_map_from_alignment_pixels,
)

if TYPE_CHECKING:
from lsdb.catalog.catalog import Catalog, DaskDFPixelMap


@dask.delayed
def perform_join_on(
left: pd.DataFrame,
right: pd.DataFrame,
left_on: str,
right_on: str,
left_pixel: HealpixPixel,
right_pixel: HealpixPixel,
suffixes: Tuple[str, str]
):
"""Performs a join on two catalog partitions
Args:
left (pd.DataFrame): the left partition to merge
right (pd.DataFrame): the right partition to merge
left_on (str): the column to join on from the left partition
right_on (str): the column to join on from the right partition
left_pixel (HealpixPixel): the HEALPix pixel of the left partition
right_pixel (HealpixPixel): the HEALPix pixel of the right partition
suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names
Returns:
A dataframe with the result of merging the left and right partitions on the specified columns
"""
if right_pixel.order > left_pixel.order:
left = filter_by_hipscat_index_to_pixel(left, right_pixel.order, right_pixel.pixel)
left_columns_renamed = {name: name + suffixes[0] for name in left.columns}
left = left.rename(columns=left_columns_renamed)
right_columns_renamed = {name: name + suffixes[1] for name in right.columns}
right = right.rename(columns=right_columns_renamed)
merged = left.reset_index().merge(right, left_on=left_on + suffixes[0], right_on=right_on + suffixes[1])
merged.set_index(HIPSCAT_ID_COLUMN, inplace=True)
return merged


def join_catalog_data_on(
left: Catalog,
right: Catalog,
left_on: str,
right_on: str,
suffixes: Tuple[str, str]
) -> Tuple[dd.core.DataFrame, DaskDFPixelMap, PixelAlignment]:
"""Joins two catalogs spatially on a specified column
Args:
left (Catalog): the left catalog to join
right (Catalog): the right catalog to join
left_on (str): the column to join on from the left partition
right_on (str): the column to join on from the right partition
suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names
Returns:
A tuple of the dask dataframe with the result of the join, the pixel map from HEALPix
pixel to partition index within the dataframe, and the PixelAlignment of the two input
catalogs.
"""
alignment = align_trees(
left.hc_structure.pixel_tree,
right.hc_structure.pixel_tree,
alignment_type=PixelAlignmentType.INNER
)
join_pixels = alignment.pixel_mapping
left_aligned_to_join_partitions = align_catalog_to_partitions(
left,
join_pixels,
order_col=PixelAlignment.PRIMARY_ORDER_COLUMN_NAME,
pixel_col=PixelAlignment.PRIMARY_PIXEL_COLUMN_NAME,
)
right_aligned_to_join_partitions = align_catalog_to_partitions(
right,
join_pixels,
order_col=PixelAlignment.JOIN_ORDER_COLUMN_NAME,
pixel_col=PixelAlignment.JOIN_PIXEL_COLUMN_NAME,
)

left_pixels = [
HealpixPixel(
row[PixelAlignment.PRIMARY_ORDER_COLUMN_NAME],
row[PixelAlignment.PRIMARY_PIXEL_COLUMN_NAME]
)
for _, row in join_pixels.iterrows()
]

right_pixels = [
HealpixPixel(
row[PixelAlignment.JOIN_ORDER_COLUMN_NAME],
row[PixelAlignment.JOIN_PIXEL_COLUMN_NAME]
)
for _, row in join_pixels.iterrows()
]

joined_partitions = [
perform_join_on(left_df, right_df, left_on, right_on, left_pixel, right_pixel, suffixes)
for left_df, right_df, left_pixel, right_pixel
in zip(left_aligned_to_join_partitions, right_aligned_to_join_partitions, left_pixels, right_pixels)
]

partition_map = get_partition_map_from_alignment_pixels(join_pixels)
meta_df = generate_meta_df_for_joined_tables([left, right], suffixes)
ddf = dd.from_delayed(joined_partitions, meta=meta_df)
ddf = cast(dd.DataFrame, ddf)
return ddf, partition_map, alignment
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest

import lsdb
from lsdb import Catalog

DATA_DIR_NAME = "data"
SMALL_SKY_DIR_NAME = "small_sky"
Expand Down Expand Up @@ -45,7 +46,7 @@ def small_sky_hipscat_catalog(small_sky_dir):

@pytest.fixture
def small_sky_catalog(small_sky_dir):
return lsdb.read_hipscat(small_sky_dir)
return lsdb.read_hipscat(small_sky_dir, catalog_type=Catalog)


@pytest.fixture
Expand Down
31 changes: 31 additions & 0 deletions tests/lsdb/catalog/test_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import pytest


def test_small_sky_join_small_sky_order1(small_sky_catalog, small_sky_order1_catalog):
suffixes = ("_a", "_b")
joined = small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="id", suffixes=suffixes)
for col_name, dtype in small_sky_catalog.dtypes.items():
assert (col_name + suffixes[0], dtype) in joined.dtypes.items()
for col_name, dtype in small_sky_order1_catalog.dtypes.items():
assert (col_name + suffixes[1], dtype) in joined.dtypes.items()
joined_compute = joined.compute()
small_sky_compute = small_sky_catalog.compute()
small_sky_order1_compute = small_sky_order1_catalog.compute()
assert len(joined_compute) == len(small_sky_compute)
assert len(joined_compute) == len(small_sky_order1_compute)
for index, row in small_sky_compute.iterrows():
joined_row = joined_compute.query(f"id{suffixes[0]} == {row['id']}")
assert joined_row.index.values[0] == index
assert joined_row[f"id{suffixes[1]}"].values[0] == row["id"]


def test_join_wrong_columns(small_sky_catalog, small_sky_order1_catalog):
with pytest.raises(ValueError):
small_sky_catalog.join(small_sky_order1_catalog, left_on="bad", right_on="id")
with pytest.raises(ValueError):
small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="bad")


def test_join_wrong_suffixes(small_sky_catalog, small_sky_order1_catalog):
with pytest.raises(ValueError):
small_sky_catalog.join(small_sky_order1_catalog, left_on="id", right_on="id", suffixes=("wrong",))

0 comments on commit cba6aa5

Please sign in to comment.