Skip to content

Commit

Permalink
Add plotting to search filters (#524)
Browse files Browse the repository at this point in the history
* add search plot to cone search

* unit test

* lint
  • Loading branch information
smcguire-cmu authored Dec 9, 2024
1 parent 892e631 commit a65912e
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
69 changes: 68 additions & 1 deletion src/lsdb/core/search/abstract_search.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Tuple, Type

import astropy
import nested_pandas as npd
from astropy.coordinates import SkyCoord
from astropy.units import Quantity
from astropy.visualization.wcsaxes import WCSAxes
from astropy.visualization.wcsaxes.frame import BaseFrame
from hats.catalog import TableProperties
from hats.inspection.visualize_catalog import initialize_wcs_axes
from matplotlib import pyplot as plt
from matplotlib.figure import Figure
from mocpy import MOC

if TYPE_CHECKING:
Expand Down Expand Up @@ -32,3 +40,62 @@ def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> MOC:
@abstractmethod
def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame:
"""Determine the search results within a data frame"""

def plot(
self,
projection: str = "MOL",
title: str = "",
fov: Quantity | Tuple[Quantity, Quantity] | None = None,
center: SkyCoord | None = None,
wcs: astropy.wcs.WCS | None = None,
frame_class: Type[BaseFrame] | None = None,
ax: WCSAxes | None = None,
fig: Figure | None = None,
**kwargs,
):
"""Plot the search region
Args:
projection (str): The projection to use in the WCS. Available projections listed at
https://docs.astropy.org/en/stable/wcs/supported_projections.html
title (str): The title of the plot
fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an
astropy Quantity with an angular unit, or a tuple of quantities for different longitude and
latitude FOVs (Default covers the full sky)
center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0))
wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters
are ignored and the parameters from the WCS object is used.
frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized
with. if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for
full sky projection. If FOV is set, RectangularFrame is used)
ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be
used. If specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set
with the WCS object used in the axes. (Default: None)
fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created,
unless ax is specified (Default: None)
**kwargs: Additional kwargs to pass to creating the matplotlib patch object for the search region
Returns:
Tuple[Figure, WCSAxes] - The figure and axes used for the plot
"""
fig, ax, wcs = initialize_wcs_axes(
projection=projection,
fov=fov,
center=center,
wcs=wcs,
frame_class=frame_class,
ax=ax,
fig=fig,
figsize=(9, 5),
)
self._perform_plot(ax, **kwargs)

plt.grid()
plt.ylabel("Dec")
plt.xlabel("RA")
plt.title(title)
return fig, ax

def _perform_plot(self, ax: WCSAxes, **kwargs):
"""Perform the plot of the search region on an initialized WCSAxes"""
raise NotImplementedError("Plotting has not been implemented for this search")
11 changes: 11 additions & 0 deletions src/lsdb/core/search/cone_search.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import astropy.units as u
import nested_pandas as npd
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes import SphericalCircle, WCSAxes
from hats.catalog import TableProperties
from hats.pixel_math.validators import validate_declination_values, validate_radius
from mocpy import MOC
Expand Down Expand Up @@ -31,6 +33,15 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np
"""Determine the search results within a data frame"""
return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata)

def _perform_plot(self, ax: WCSAxes, **kwargs):
circle = SphericalCircle(
(self.ra * u.deg, self.dec * u.deg),
self.radius_arcsec * u.arcsec,
transform=ax.get_transform("icrs"),
**kwargs,
)
ax.add_patch(circle)


def cone_filter(data_frame: npd.NestedFrame, ra, dec, radius_arcsec, metadata: TableProperties):
"""Filters a dataframe to only include points within the specified cone
Expand Down
13 changes: 13 additions & 0 deletions tests/lsdb/catalog/test_cone_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import pandas as pd
import pytest
from astropy.coordinates import SkyCoord
from astropy.visualization.wcsaxes import SphericalCircle
from hats.pixel_math.validators import ValidatorsErrors

from lsdb import ConeSearch


def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct):
ra = 0
Expand Down Expand Up @@ -124,3 +127,13 @@ def test_empty_cone_search_with_margin(small_sky_order1_source_with_margin):
cone = small_sky_order1_source_with_margin.cone_search(ra, dec, radius, fine=False)
assert len(cone._ddf_pixel_map) == 0
assert len(cone.margin._ddf_pixel_map) == 0


def test_cone_search_plot():
ra = 100
dec = 80
radius = 60
search = ConeSearch(ra, dec, radius)
_, ax = search.plot()
assert len(ax.patches) == 1
assert isinstance(ax.patches[0], SphericalCircle)

0 comments on commit a65912e

Please sign in to comment.