Skip to content

Commit

Permalink
Merge pull request #317 from astronomy-commons/sean/delayed-metadata
Browse files Browse the repository at this point in the history
Only pass catalog_info to delayed tasks instead of full hipscat catalog metadata
  • Loading branch information
smcguire-cmu authored May 13, 2024
2 parents 6b35c87 + b94c911 commit 5ded7d1
Show file tree
Hide file tree
Showing 11 changed files with 83 additions and 79 deletions.
2 changes: 1 addition & 1 deletion src/lsdb/catalog/dataset/healpix_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def _perform_search(
partitions = self._ddf.to_delayed()
targeted_partitions = [partitions[self._ddf_pixel_map[pixel]] for pixel in filtered_pixels]
filtered_partitions = (
[search.search_points(partition, metadata) for partition in targeted_partitions]
[search.search_points(partition, metadata.catalog_info) for partition in targeted_partitions]
if fine
else targeted_partitions
)
Expand Down
39 changes: 20 additions & 19 deletions src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from abc import ABC, abstractmethod
from typing import Tuple

import hipscat as hc
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.catalog.margin_cache import MarginCacheCatalogInfo
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN


Expand All @@ -23,9 +24,9 @@ def __init__(
left_pixel: int,
right_order: int,
right_pixel: int,
left_metadata: hc.catalog.Catalog,
right_metadata: hc.catalog.Catalog,
right_margin_hc_structure: hc.catalog.MarginCatalog | None,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
right_margin_catalog_info: MarginCacheCatalogInfo | None,
suffixes: Tuple[str, str],
):
"""Initializes a crossmatch algorithm
Expand All @@ -37,12 +38,12 @@ def __init__(
left_pixel (int): The HEALPix pixel number in NESTED ordering of the left pixel
right_order (int): The HEALPix order of the right pixel
right_pixel (int): The HEALPix pixel number in NESTED ordering of the right pixel
left_metadata (hipscat.Catalog): The hipscat Catalog object with the metadata of the
left_catalog_info (hipscat.CatalogInfo): The hipscat CatalogInfo object with the metadata of the
left catalog
right_metadata (hipscat.Catalog): The hipscat Catalog object with the metadata of the
right_catalog_info (hipscat.CatalogInfo): The hipscat CatalogInfo object with the metadata of the
right catalog
right_margin_hc_structure (hipscat.MarginCatalog): The hipscat MarginCatalog objects
with the metadata of the right **margin** catalog
right_margin_catalog_info (hipscat.MarginCacheCatalogInfo): The hipscat MarginCacheCatalogInfo
objects with the metadata of the right **margin** catalog
suffixes (Tuple[str,str]): A pair of suffixes to be appended to the end of each column
name, with the first appended to the left columns and the second to the right
columns
Expand All @@ -53,9 +54,9 @@ def __init__(
self.left_pixel = left_pixel
self.right_order = right_order
self.right_pixel = right_pixel
self.left_metadata = left_metadata
self.right_metadata = right_metadata
self.right_margin_hc_structure = right_margin_hc_structure
self.left_catalog_info = left_catalog_info
self.right_catalog_info = right_catalog_info
self.right_margin_catalog_info = right_margin_catalog_info
self.suffixes = suffixes

@abstractmethod
Expand All @@ -76,16 +77,16 @@ def validate(self):
if self.right.index.name != HIPSCAT_ID_COLUMN:
raise ValueError(f"index of right table must be {HIPSCAT_ID_COLUMN}")
column_names = self.left.columns
if self.left_metadata.catalog_info.ra_column not in column_names:
raise ValueError(f"left table must have column {self.left_metadata.catalog_info.ra_column}")
if self.left_metadata.catalog_info.dec_column not in column_names:
raise ValueError(f"left table must have column {self.left_metadata.catalog_info.dec_column}")
if self.left_catalog_info.ra_column not in column_names:
raise ValueError(f"left table must have column {self.left_catalog_info.ra_column}")
if self.left_catalog_info.dec_column not in column_names:
raise ValueError(f"left table must have column {self.left_catalog_info.dec_column}")

column_names = self.right.columns
if self.right_metadata.catalog_info.ra_column not in column_names:
raise ValueError(f"right table must have column {self.right_metadata.catalog_info.ra_column}")
if self.right_metadata.catalog_info.dec_column not in column_names:
raise ValueError(f"right table must have column {self.right_metadata.catalog_info.dec_column}")
if self.right_catalog_info.ra_column not in column_names:
raise ValueError(f"right table must have column {self.right_catalog_info.ra_column}")
if self.right_catalog_info.dec_column not in column_names:
raise ValueError(f"right table must have column {self.right_catalog_info.dec_column}")

@staticmethod
def _rename_columns_with_suffix(dataframe, suffix):
Expand Down
12 changes: 6 additions & 6 deletions src/lsdb/core/crossmatch/kdtree_match.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ def validate(
if n_neighbors < 1:
raise ValueError("n_neighbors must be greater than 1")
# Check that the margin exists and has a compatible radius.
if self.right_margin_hc_structure is None:
if self.right_margin_catalog_info is None:
if require_right_margin:
raise ValueError("Right catalog margin cache is required for cross-match.")
else:
if self.right_margin_hc_structure.catalog_info.margin_threshold < radius_arcsec:
if self.right_margin_catalog_info.margin_threshold < radius_arcsec:
raise ValueError("Cross match radius is greater than margin threshold")

def crossmatch(
Expand Down Expand Up @@ -73,12 +73,12 @@ def crossmatch(

def _get_point_coordinates(self) -> Tuple[npt.NDArray[np.float64], npt.NDArray[np.float64]]:
left_xyz = _lon_lat_to_xyz(
lon=self.left[self.left_metadata.catalog_info.ra_column].to_numpy(),
lat=self.left[self.left_metadata.catalog_info.dec_column].to_numpy(),
lon=self.left[self.left_catalog_info.ra_column].to_numpy(),
lat=self.left[self.left_catalog_info.dec_column].to_numpy(),
)
right_xyz = _lon_lat_to_xyz(
lon=self.right[self.right_metadata.catalog_info.ra_column].to_numpy(),
lat=self.right[self.right_metadata.catalog_info.dec_column].to_numpy(),
lon=self.right[self.right_catalog_info.ra_column].to_numpy(),
lat=self.right[self.right_catalog_info.dec_column].to_numpy(),
)
return left_xyz, right_xyz

Expand Down
4 changes: 2 additions & 2 deletions src/lsdb/core/search/abstract_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from abc import ABC, abstractmethod
from typing import List

import hipscat as hc
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.pixel_math import HealpixPixel


Expand All @@ -25,5 +25,5 @@ def search_partitions(self, pixels: List[HealpixPixel]) -> List[HealpixPixel]:
"""Determine the target partitions for further filtering."""

@abstractmethod
def search_points(self, frame: pd.DataFrame, metadata: hc.catalog.Catalog) -> pd.DataFrame:
def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFrame:
"""Determine the search results within a data frame"""
10 changes: 5 additions & 5 deletions src/lsdb/core/search/box_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import List, Tuple

import dask
import hipscat as hc
import numpy as np
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.box_filter import filter_pixels_by_box, wrap_ra_angles
from hipscat.pixel_math.validators import validate_box_search
Expand All @@ -32,7 +32,7 @@ def search_partitions(self, pixels: List[HealpixPixel]) -> List[HealpixPixel]:
pixel_tree = PixelTree.from_healpix(pixels)
return filter_pixels_by_box(pixel_tree, self.ra, self.dec)

def search_points(self, frame: pd.DataFrame, metadata: hc.catalog.Catalog) -> pd.DataFrame:
def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFrame:
"""Determine the search results within a data frame"""
return box_filter(frame, self.ra, self.dec, metadata)

Expand All @@ -42,7 +42,7 @@ def box_filter(
data_frame: pd.DataFrame,
ra: Tuple[float, float] | None,
dec: Tuple[float, float] | None,
metadata: hc.catalog.Catalog,
metadata: CatalogInfo,
):
"""Filters a dataframe to only include points within the specified box region.
Expand All @@ -57,12 +57,12 @@ def box_filter(
"""
mask = np.ones(len(data_frame), dtype=bool)
if ra is not None:
ra_values = data_frame[metadata.catalog_info.ra_column]
ra_values = data_frame[metadata.ra_column]
wrapped_ra = np.asarray(wrap_ra_angles(ra_values))
mask_ra = _create_ra_mask(ra, wrapped_ra)
mask = np.logical_and(mask, mask_ra)
if dec is not None:
dec_values = data_frame[metadata.catalog_info.dec_column].to_numpy()
dec_values = data_frame[metadata.dec_column].to_numpy()
mask_dec = np.logical_and(dec[0] <= dec_values, dec_values <= dec[1])
mask = np.logical_and(mask, mask_dec)
data_frame = data_frame.iloc[mask]
Expand Down
12 changes: 6 additions & 6 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List

import dask
import hipscat as hc
import pandas as pd
from astropy.coordinates import SkyCoord
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.cone_filter import filter_pixels_by_cone
from hipscat.pixel_math.validators import validate_declination_values, validate_radius
Expand Down Expand Up @@ -32,27 +32,27 @@ def search_partitions(self, pixels: List[HealpixPixel]) -> List[HealpixPixel]:
pixel_tree = PixelTree.from_healpix(pixels)
return filter_pixels_by_cone(pixel_tree, self.ra, self.dec, self.radius_arcsec)

def search_points(self, frame: pd.DataFrame, metadata: hc.catalog.Catalog) -> pd.DataFrame:
def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFrame:
"""Determine the search results within a data frame"""
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata)


@dask.delayed
def cone_filter(data_frame: pd.DataFrame, ra, dec, radius_arcsec, metadata: hc.catalog.Catalog):
def cone_filter(data_frame: pd.DataFrame, ra, dec, radius_arcsec, metadata: CatalogInfo):
"""Filters a dataframe to only include points within the specified cone
Args:
data_frame (pd.DataFrame): DataFrame containing points in the sky
ra (float): Right Ascension of the center of the cone in degrees
dec (float): Declination of the center of the cone in degrees
radius_arcsec (float): Radius of the cone in arcseconds
metadata (hc.catalog.Catalog): hipscat `Catalog` with catalog_info that matches `data_frame`
metadata (hc.CatalogInfo): hipscat `CatalogInfo` with metadata that matches `data_frame`
Returns:
A new DataFrame with the rows from `data_frame` filtered to only the points inside the cone
"""
df_ras = data_frame[metadata.catalog_info.ra_column].to_numpy()
df_decs = data_frame[metadata.catalog_info.dec_column].to_numpy()
df_ras = data_frame[metadata.ra_column].to_numpy()
df_decs = data_frame[metadata.dec_column].to_numpy()
df_coords = SkyCoord(df_ras, df_decs, unit="deg")
center_coord = SkyCoord(ra, dec, unit="deg")
df_separations_deg = df_coords.separation(center_coord).value
Expand Down
10 changes: 5 additions & 5 deletions src/lsdb/core/search/polygon_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import dask
import healpy as hp
import hipscat as hc
import numpy as np
import pandas as pd
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.polygon_filter import (
CartesianCoordinates,
Expand Down Expand Up @@ -35,13 +35,13 @@ def search_partitions(self, pixels: List[HealpixPixel]) -> List[HealpixPixel]:
pixel_tree = PixelTree.from_healpix(pixels)
return filter_pixels_by_polygon(pixel_tree, self.vertices_xyz)

def search_points(self, frame: pd.DataFrame, metadata: hc.catalog.Catalog) -> pd.DataFrame:
def search_points(self, frame: pd.DataFrame, metadata: CatalogInfo) -> pd.DataFrame:
"""Determine the search results within a data frame"""
return polygon_filter(frame, self.polygon, metadata)


@dask.delayed
def polygon_filter(data_frame: pd.DataFrame, polygon: ConvexPolygon, metadata: hc.catalog.Catalog):
def polygon_filter(data_frame: pd.DataFrame, polygon: ConvexPolygon, metadata: CatalogInfo):
"""Filters a dataframe to only include points within the specified polygon.
Args:
Expand All @@ -52,8 +52,8 @@ def polygon_filter(data_frame: pd.DataFrame, polygon: ConvexPolygon, metadata: h
Returns:
A new DataFrame with the rows from `dataframe` filtered to only the pixels inside the polygon.
"""
ra_values = np.radians(data_frame[metadata.catalog_info.ra_column].to_numpy())
dec_values = np.radians(data_frame[metadata.catalog_info.dec_column].to_numpy())
ra_values = np.radians(data_frame[metadata.ra_column].to_numpy())
dec_values = np.radians(data_frame[metadata.dec_column].to_numpy())
inside_polygon = polygon.contains(ra_values, dec_values)
data_frame = data_frame.iloc[inside_polygon]
return data_frame
Expand Down
18 changes: 9 additions & 9 deletions src/lsdb/dask/crossmatch_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ def perform_crossmatch(
left_pix,
right_pix,
right_margin_pix,
left_hc_structure,
right_hc_structure,
right_margin_hc_structure,
left_catalog_info,
right_catalog_info,
right_margin_catalog_info,
algorithm,
suffixes,
right_columns,
Expand All @@ -65,9 +65,9 @@ def perform_crossmatch(
left_pix.pixel,
right_pix.order,
right_pix.pixel,
left_hc_structure,
right_hc_structure,
right_margin_hc_structure,
left_catalog_info,
right_catalog_info,
right_margin_catalog_info,
suffixes,
).crossmatch(**kwargs)

Expand Down Expand Up @@ -110,9 +110,9 @@ def crossmatch_catalog_data(
0,
0,
0,
left.hc_structure,
right.hc_structure,
right.margin.hc_structure if right.margin is not None else None,
left.hc_structure.catalog_info,
right.hc_structure.catalog_info,
right.margin.hc_structure.catalog_info if right.margin is not None else None,
suffixes,
)
meta_df_crossmatch.validate(**kwargs)
Expand Down
41 changes: 21 additions & 20 deletions src/lsdb/dask/join_catalog_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@

import dask
import dask.dataframe as dd
import hipscat as hc
import pandas as pd
from hipscat.catalog.association_catalog import AssociationCatalogInfo
from hipscat.catalog.catalog_info import CatalogInfo
from hipscat.catalog.margin_cache import MarginCacheCatalogInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.hipscat_id import HIPSCAT_ID_COLUMN
from hipscat.pixel_tree import PixelAlignment
Expand Down Expand Up @@ -59,9 +61,9 @@ def perform_join_on(
left_pixel: HealpixPixel,
right_pixel: HealpixPixel,
right_margin_pixel: HealpixPixel,
left_structure: hc.catalog.Catalog,
right_structure: hc.catalog.Catalog,
right_margin_structure: hc.catalog.Catalog,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
right_margin_catalog_info: MarginCacheCatalogInfo,
left_on: str,
right_on: str,
suffixes: Tuple[str, str],
Expand All @@ -76,9 +78,9 @@ def perform_join_on(
left_pixel (HealpixPixel): the HEALPix pixel of the left partition
right_pixel (HealpixPixel): the HEALPix pixel of the right partition
right_margin_pixel (HealpixPixel): the HEALPix pixel of the right margin partition
left_structure (hc.Catalog): the hipscat structure of the left catalog
right_structure (hc.Catalog): the hipscat structure of the right catalog
right_margin_structure (hc.Catalog): the hipscat structure of the right margin catalog
left_catalog_info (hc.CatalogInfo): the catalog info of the left catalog
right_catalog_info (hc.CatalogInfo): the catalog info of the right catalog
right_margin_catalog_info (hc.MarginCacheCatalogInfo): the catalog info of the right margin catalog
left_on (str): the column to join on from the left partition
right_on (str): the column to join on from the right partition
suffixes (Tuple[str,str]): the suffixes to apply to each partition's column names
Expand Down Expand Up @@ -111,10 +113,10 @@ def perform_join_through(
right_pixel: HealpixPixel,
right_margin_pixel: HealpixPixel,
through_pixel: HealpixPixel,
left_catalog: hc.catalog.Catalog,
right_catalog: hc.catalog.Catalog,
right_margin_catalog: hc.catalog.Catalog,
assoc_catalog: hc.catalog.AssociationCatalog,
left_catalog_info: CatalogInfo,
right_catalog_info: CatalogInfo,
right_margin_catalog_info: MarginCacheCatalogInfo,
assoc_catalog_info: AssociationCatalogInfo,
suffixes: Tuple[str, str],
right_columns: List[str],
):
Expand All @@ -139,8 +141,7 @@ def perform_join_through(
Returns:
A dataframe with the result of merging the left and right partitions on the specified columns
"""
catalog_info = assoc_catalog.catalog_info
if catalog_info.primary_column is None or catalog_info.join_column is None:
if assoc_catalog_info.primary_column is None or assoc_catalog_info.join_column is None:
raise ValueError("Invalid catalog_info")
if right_pixel.order > left_pixel.order:
left = filter_by_hipscat_index_to_pixel(left, right_pixel.order, right_pixel.pixel)
Expand All @@ -149,23 +150,23 @@ def perform_join_through(

left, right_joined_df = rename_columns_with_suffixes(left, right_joined_df, suffixes)

join_columns = [catalog_info.primary_column_association]
if catalog_info.join_column_association != catalog_info.primary_column_association:
join_columns.append(catalog_info.join_column_association)
join_columns = [assoc_catalog_info.primary_column_association]
if assoc_catalog_info.join_column_association != assoc_catalog_info.primary_column_association:
join_columns.append(assoc_catalog_info.join_column_association)

through = through.drop(NON_JOINING_ASSOCIATION_COLUMNS, axis=1)

merged = (
left.reset_index()
.merge(
through,
left_on=catalog_info.primary_column + suffixes[0],
right_on=catalog_info.primary_column_association,
left_on=assoc_catalog_info.primary_column + suffixes[0],
right_on=assoc_catalog_info.primary_column_association,
)
.merge(
right_joined_df,
left_on=catalog_info.join_column_association,
right_on=catalog_info.join_column + suffixes[1],
left_on=assoc_catalog_info.join_column_association,
right_on=assoc_catalog_info.join_column + suffixes[1],
)
)

Expand Down
Loading

0 comments on commit 5ded7d1

Please sign in to comment.