Skip to content

Commit

Permalink
Merge pull request #216 from astronomy-commons/fix-margins
Browse files Browse the repository at this point in the history
Patch margin cache generation
  • Loading branch information
camposandro authored Mar 11, 2024
2 parents 49e53b5 + bf34c81 commit 9af85cf
Show file tree
Hide file tree
Showing 93 changed files with 377 additions and 78 deletions.
6 changes: 3 additions & 3 deletions docs/notebooks/import_catalogs.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,14 +293,14 @@
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 2
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython2",
"version": "2.7.6"
"pygments_lexer": "ipython3",
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
6 changes: 4 additions & 2 deletions docs/notebooks/ztf_bts-ngc.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@
" matched_df = matched.compute()\n",
"\n",
"# Let's output transient name, NGC name and angular distance between them\n",
"matched_df = matched_df[[\"IAUID_ztf\", \"Name_ngc\", \"_DIST\", \"RA_ztf\", \"Dec_ztf\"]].sort_values(by=[\"_DIST\"])\n",
"matched_df = matched_df[[\"IAUID_ztf\", \"Name_ngc\", \"_dist_arcsec\", \"RA_ztf\", \"Dec_ztf\"]].sort_values(\n",
" by=[\"_dist_arcsec\"]\n",
")\n",
"matched_df"
]
},
Expand Down Expand Up @@ -288,7 +290,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion src/lsdb/loaders/dataframe/from_dataframe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def _generate_dask_dataframe(
Returns:
The catalog's Dask Dataframe and its total number of rows.
"""
schema = pixel_dfs[0].iloc[:0, :].copy()
schema = pixel_dfs[0].iloc[:0, :].copy() if len(pixels) > 0 else []
divisions = get_pixels_divisions(pixels)
delayed_dfs = [delayed(df) for df in pixel_dfs]
ddf = dd.from_delayed(delayed_dfs, meta=schema, divisions=divisions)
Expand Down
160 changes: 100 additions & 60 deletions src/lsdb/loaders/dataframe/margin_catalog_generator.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from typing import List
from typing import Dict, List, Tuple

import dask.dataframe as dd
import healpy as hp
import hipscat as hc
import numpy as np
import pandas as pd
from hipscat import pixel_math
from hipscat.catalog import CatalogType
from hipscat.catalog.margin_cache import MarginCacheCatalogInfo
from hipscat.pixel_math import HealpixPixel
from hipscat.pixel_math.healpix_pixel_function import get_pixel_argsort

from lsdb import Catalog
from lsdb.catalog.margin_catalog import MarginCatalog
Expand All @@ -27,7 +30,7 @@ def __init__(
margin_order: int | None = -1,
margin_threshold: float = 5.0,
) -> None:
"""Initializes a MarginCatalogGenerator
"""Initialize a MarginCatalogGenerator
Args:
catalog (Catalog): The LSDB catalog to generate margins for
Expand All @@ -40,11 +43,19 @@ def __init__(
self.margin_order = self._set_margin_order(margin_order)

def _set_margin_order(self, margin_order: int | None) -> int:
"""Set the order of the margin cache to be generated.
If not provided, the margin will be of an order that
is higher than that of the original catalog by 1"""
highest_order = self.hc_structure.partition_info.get_highest_order()
margin_pixel_k = highest_order + 1
"""Calculate the order of the margin cache to be generated. If not provided
the margin will be greater than that of the original catalog by 1.
Args:
margin_order (int): The order to generate the margin cache with
Returns:
The validated order of the margin catalog.
Raises:
ValueError, if the provided margin order is lower than that of the catalog.
"""
margin_pixel_k = self.hc_structure.partition_info.get_highest_order() + 1
if margin_order is None or margin_order == -1:
margin_order = margin_pixel_k
elif margin_order < margin_pixel_k:
Expand All @@ -53,73 +64,57 @@ def _set_margin_order(self, margin_order: int | None) -> int:
)
return margin_order

def create_catalog(self) -> MarginCatalog:
def create_catalog(self) -> MarginCatalog | None:
"""Create a margin catalog for another pre-computed catalog
Returns:
Margin catalog object for the provided catalog
Margin catalog object, or None if the margin is empty.
"""
ddf, ddf_pixel_map, total_rows = self._generate_dask_df_and_map()
margin_catalog_info = self._create_catalog_info(total_rows)
margin_pixels = list(ddf_pixel_map.keys())
if total_rows == 0:
return None
margin_catalog_info = self._create_catalog_info(total_rows)
margin_structure = hc.catalog.MarginCatalog(margin_catalog_info, margin_pixels)
return MarginCatalog(ddf, ddf_pixel_map, margin_structure)

def _generate_dask_df_and_map(self):
def _generate_dask_df_and_map(self) -> Tuple[dd.DataFrame, Dict[HealpixPixel, int], int]:
"""Create the Dask Dataframe containing the data points in the margins
for the catalog, as well as the mapping of those HEALPix pixels to
HEALPix Dataframes.
for the catalog as well as the mapping of those HEALPix to Dataframes
Returns:
Tuple containing the Dask Dataframe, the mapping of HEALPix pixels
to the respective Pandas Dataframes and the total number of rows.
Tuple containing the Dask Dataframe, the mapping of margin HEALPix
to the respective partitions and the total number of rows.
"""
# Find the margin pairs of pixels for the catalog
healpix_pixels = self.hc_structure.get_healpix_pixels()
negative_pixels = self.hc_structure.generate_negative_tree_pixels()
combined_pixels = healpix_pixels + negative_pixels
margin_pairs_df = self._find_margin_pixel_pairs(combined_pixels)

# Find in which pixels the data is located in the margin catalog
self.dataframe["margin_pixel"] = hp.ang2pix(
2**self.margin_order,
self.dataframe[self.hc_structure.catalog_info.ra_column].values,
self.dataframe[self.hc_structure.catalog_info.dec_column].values,
lonlat=True,
nest=True,
)
constrained_data = self.dataframe.reset_index().merge(margin_pairs_df, on="margin_pixel")

pixel_dfs = []
ddf_pixel_map = {}

# For each partition, filter the data according to the threshold
partition_dfs = constrained_data.groupby(["partition_order", "partition_pixel"])
# Compute points for each margin pixels
margins_pixel_df = self._create_margins(margin_pairs_df)
pixels, partitions = list(margins_pixel_df.keys()), list(margins_pixel_df.values())

for i, (_, partition) in enumerate(partition_dfs):
order = partition["partition_order"].iloc[0]
pix = partition["partition_pixel"].iloc[0]
pixel = HealpixPixel(order, pix)
df = self._get_partition_data_in_margin(partition, pixel)
pixel_dfs.append(_format_margin_partition_dataframe(df))
ddf_pixel_map[pixel] = i
# Generate pixel map ordered by _hipscat_index
pixel_order = get_pixel_argsort(pixels)
ordered_pixels = np.asarray(pixels)[pixel_order]
ordered_partitions = [partitions[i] for i in pixel_order]
ddf_pixel_map = {pixel: index for index, pixel in enumerate(ordered_pixels)}

# Generate Dask Dataframe with original schema
pixel_list = list(ddf_pixel_map.keys())
ddf, total_rows = _generate_dask_dataframe(pixel_dfs, pixel_list)
# Generate the dask dataframe with the pixels and partitions
ddf, total_rows = _generate_dask_dataframe(ordered_partitions, ordered_pixels)
return ddf, ddf_pixel_map, total_rows

def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame:
"""Calculate the pairs of catalog pixels and their margin pixels
Args:
pixels (List[HealpixPixel]): The list of HEALPix pixels to
compute margin pixels for. These include the catalog
pixels as well as the negative pixels.
pixels (List[HealpixPixel]): The list of HEALPix to compute margin pixels for.
These include the catalog pixels as well as the negative pixels.
Returns:
A Pandas Dataframe with the many-to-many mapping between the
partitions and the respective margin pixels.
A Pandas Dataframe with the many-to-many mapping between each catalog HEALPix
and the respective margin pixels.
"""
n_orders = []
part_pix = []
Expand All @@ -140,40 +135,85 @@ def _find_margin_pixel_pairs(self, pixels: List[HealpixPixel]) -> pd.DataFrame:
columns=["partition_order", "partition_pixel", "margin_pixel"],
)

def _get_partition_data_in_margin(self, partition_df: pd.DataFrame, pixel: HealpixPixel) -> pd.DataFrame:
def _create_margins(self, margin_pairs_df: pd.DataFrame) -> Dict[HealpixPixel, pd.DataFrame]:
"""Compute the margins for all the pixels in the catalog
Args:
margin_pairs_df (pd.DataFrame): A DataFrame containing all the combinations
of catalog pixels and respective margin pixels
Returns:
A dictionary mapping each margin pixel to the respective DataFrame.
"""
margin_pixel_df_map: Dict[HealpixPixel, pd.DataFrame] = {}
self.dataframe["margin_pixel"] = hp.ang2pix(
2**self.margin_order,
self.dataframe[self.hc_structure.catalog_info.ra_column].values,
self.dataframe[self.hc_structure.catalog_info.dec_column].values,
lonlat=True,
nest=True,
)
constrained_data = self.dataframe.reset_index().merge(margin_pairs_df, on="margin_pixel")
if len(constrained_data):
constrained_data.groupby(["partition_order", "partition_pixel"]).apply(
self._append_margin_df, margin_pixel_df_map
)
return margin_pixel_df_map

def _append_margin_df(
self, partition_df: pd.DataFrame, margin_pixel_df_map: Dict[HealpixPixel, pd.DataFrame]
):
"""Filter margin data points and create the partition final Dataframe
Args:
partition_df (pd.DataFrame): Catalog data points for the margin pixel
margin_pixel_df_map (Dict[HealpixPixel, pd.DataFrame]): A dictionary mapping
each margin pixel to the respective DataFrame. This dictionary is updated
on each call to this method.
"""
partition_order = partition_df["partition_order"].iloc[0]
partition_pixel = partition_df["partition_pixel"].iloc[0]
margin_pixel = HealpixPixel(partition_order, partition_pixel)
df = self._get_data_in_margin(partition_df, margin_pixel)
if len(df):
df = _format_margin_partition_dataframe(df)
margin_pixel_df_map[margin_pixel] = df

def _get_data_in_margin(self, partition_df: pd.DataFrame, margin_pixel: HealpixPixel) -> pd.DataFrame:
"""Calculate the margin boundaries for the HEALPix and include the points
on the margins according to the specified threshold.
on the margin according to the specified threshold
Args:
partition_df (pd.DataFrame): The partition dataframe
pixel (HealpixPixel): The HEALPix pixel to get the margin points for
partition_df (pd.DataFrame): The margin pixel data
margin_pixel (HealpixPixel): The margin HEALPix
Returns:
A Pandas Dataframe with the points of the partition that
are within the specified margin.
A Pandas Dataframe with the points of the partition that are within
the specified threshold in the margin.
"""
margin_mask = pixel_math.check_margin_bounds(
partition_df[self.hc_structure.catalog_info.ra_column].values,
partition_df[self.hc_structure.catalog_info.dec_column].values,
pixel.order,
pixel.pixel,
margin_pixel.order,
margin_pixel.pixel,
self.margin_threshold,
)
return partition_df.loc[margin_mask]
return partition_df.iloc[margin_mask]

def _create_catalog_info(self, total_rows: int) -> MarginCacheCatalogInfo:
"""Creates the margin catalog info object
"""Create the margin catalog info object
Args:
total_rows: The number of elements in the margin catalog
total_rows (int): The number of elements in the margin catalog
Returns:
The margin catalog info object
The margin catalog info object.
"""
catalog_name = self.hc_structure.catalog_info.catalog_name
return MarginCacheCatalogInfo(
catalog_name=f"{self.hc_structure.catalog_info.catalog_name}_margin",
catalog_name=f"{catalog_name}_margin",
catalog_type=CatalogType.MARGIN,
total_rows=total_rows,
primary_catalog=self.hc_structure.catalog_info.catalog_name,
primary_catalog=catalog_name,
margin_threshold=self.margin_threshold,
)
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
SMALL_SKY_DIR_NAME = "small_sky"
SMALL_SKY_LEFT_XMATCH_NAME = "small_sky_left_xmatch"
SMALL_SKY_SOURCE_MARGIN_NAME = "small_sky_source_margin"
SMALL_SKY_ORDER3_SOURCE_MARGIN_NAME = "small_sky_order3_source_margin"
SMALL_SKY_XMATCH_NAME = "small_sky_xmatch"
SMALL_SKY_XMATCH_MARGIN_NAME = "small_sky_xmatch_margin"
SMALL_SKY_TO_XMATCH_NAME = "small_sky_to_xmatch"
Expand Down Expand Up @@ -181,6 +182,11 @@ def small_sky_source_margin_catalog(test_data_dir):
return lsdb.read_hipscat(os.path.join(test_data_dir, SMALL_SKY_SOURCE_MARGIN_NAME))


@pytest.fixture
def small_sky_order3_source_margin_catalog(test_data_dir):
return lsdb.read_hipscat(os.path.join(test_data_dir, SMALL_SKY_ORDER3_SOURCE_MARGIN_NAME))


@pytest.fixture
def xmatch_expected_dir(test_data_dir):
return os.path.join(test_data_dir, "raw", "xmatch_expected")
Expand Down
50 changes: 42 additions & 8 deletions tests/data/generate_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,47 @@
"runner.pipeline_with_client(args, client)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### small_sky_order3_source_margin\n",
"\n",
"This one is similar to the previous margin catalogs but it is generated from a source catalog of order 3."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"args = ImportArguments(\n",
" input_file_list=[\"raw/small_sky_source/small_sky_source.csv\"],\n",
" output_path=\".\",\n",
" file_reader=\"csv\",\n",
" ra_column=\"source_ra\",\n",
" dec_column=\"source_dec\",\n",
" catalog_type=\"source\",\n",
" output_artifact_name=\"small_sky_order3_source\",\n",
" constant_healpix_order=3,\n",
" overwrite=True,\n",
" tmp_dir=tmp_dir,\n",
")\n",
"runner.pipeline_with_client(args, client)\n",
"\n",
"args = MarginCacheArguments(\n",
" input_catalog_path=\"small_sky_order3_source\",\n",
" output_path=\".\",\n",
" output_artifact_name=\"small_sky_order3_source_margin\",\n",
" margin_threshold=300,\n",
" margin_order=7,\n",
" overwrite=True,\n",
" tmp_dir=tmp_dir,\n",
")\n",
"runner.pipeline_with_client(args, client)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down Expand Up @@ -434,13 +475,6 @@
"tmp_path.cleanup()\n",
"client.close()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
Expand All @@ -459,7 +493,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.13"
"version": "3.10.12"
}
},
"nbformat": 4,
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file added tests/data/small_sky_order3_source/_metadata
Binary file not shown.
8 changes: 8 additions & 0 deletions tests/data/small_sky_order3_source/catalog_info.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
{
"catalog_name": "small_sky_order3_source",
"catalog_type": "source",
"total_rows": 17161,
"epoch": "J2000",
"ra_column": "source_ra",
"dec_column": "source_dec"
}
Loading

0 comments on commit 9af85cf

Please sign in to comment.