diff --git a/src/lsdb/core/search/abstract_search.py b/src/lsdb/core/search/abstract_search.py index 0e57f8a0..4ed5269e 100644 --- a/src/lsdb/core/search/abstract_search.py +++ b/src/lsdb/core/search/abstract_search.py @@ -1,10 +1,18 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Tuple, Type +import astropy import nested_pandas as npd +from astropy.coordinates import SkyCoord +from astropy.units import Quantity +from astropy.visualization.wcsaxes import WCSAxes +from astropy.visualization.wcsaxes.frame import BaseFrame from hats.catalog import TableProperties +from hats.inspection.visualize_catalog import initialize_wcs_axes +from matplotlib import pyplot as plt +from matplotlib.figure import Figure from mocpy import MOC if TYPE_CHECKING: @@ -32,3 +40,62 @@ def filter_hc_catalog(self, hc_structure: HCCatalogTypeVar) -> MOC: @abstractmethod def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> npd.NestedFrame: """Determine the search results within a data frame""" + + def plot( + self, + projection: str = "MOL", + title: str = "", + fov: Quantity | Tuple[Quantity, Quantity] | None = None, + center: SkyCoord | None = None, + wcs: astropy.wcs.WCS | None = None, + frame_class: Type[BaseFrame] | None = None, + ax: WCSAxes | None = None, + fig: Figure | None = None, + **kwargs, + ): + """Plot the search region + + Args: + projection (str): The projection to use in the WCS. Available projections listed at + https://docs.astropy.org/en/stable/wcs/supported_projections.html + title (str): The title of the plot + fov (Quantity or Sequence[Quantity, Quantity] | None): The Field of View of the WCS. Must be an + astropy Quantity with an angular unit, or a tuple of quantities for different longitude and + latitude FOVs (Default covers the full sky) + center (SkyCoord | None): The center of the projection in the WCS (Default: SkyCoord(0, 0)) + wcs (WCS | None): The WCS to specify the projection of the plot. If used, all other WCS parameters + are ignored and the parameters from the WCS object is used. + frame_class (Type[BaseFrame] | None): The class of the frame for the WCSAxes to be initialized + with. if the `ax` kwarg is used, this value is ignored (By Default uses EllipticalFrame for + full sky projection. If FOV is set, RectangularFrame is used) + ax (WCSAxes | None): The matplotlib axes to plot onto. If None, an axes will be created to be + used. If specified, the axes must be an astropy WCSAxes, and the `wcs` parameter must be set + with the WCS object used in the axes. (Default: None) + fig (Figure | None): The matplotlib figure to add the axes to. If None, one will be created, + unless ax is specified (Default: None) + **kwargs: Additional kwargs to pass to creating the matplotlib patch object for the search region + + Returns: + Tuple[Figure, WCSAxes] - The figure and axes used for the plot + """ + fig, ax, wcs = initialize_wcs_axes( + projection=projection, + fov=fov, + center=center, + wcs=wcs, + frame_class=frame_class, + ax=ax, + fig=fig, + figsize=(9, 5), + ) + self._perform_plot(ax, **kwargs) + + plt.grid() + plt.ylabel("Dec") + plt.xlabel("RA") + plt.title(title) + return fig, ax + + def _perform_plot(self, ax: WCSAxes, **kwargs): + """Perform the plot of the search region on an initialized WCSAxes""" + raise NotImplementedError("Plotting has not been implemented for this search") diff --git a/src/lsdb/core/search/cone_search.py b/src/lsdb/core/search/cone_search.py index 46e5afc6..381e814a 100644 --- a/src/lsdb/core/search/cone_search.py +++ b/src/lsdb/core/search/cone_search.py @@ -1,5 +1,7 @@ +import astropy.units as u import nested_pandas as npd from astropy.coordinates import SkyCoord +from astropy.visualization.wcsaxes import SphericalCircle, WCSAxes from hats.catalog import TableProperties from hats.pixel_math.validators import validate_declination_values, validate_radius from mocpy import MOC @@ -31,6 +33,15 @@ def search_points(self, frame: npd.NestedFrame, metadata: TableProperties) -> np """Determine the search results within a data frame""" return cone_filter(frame, self.ra, self.dec, self.radius_arcsec, metadata) + def _perform_plot(self, ax: WCSAxes, **kwargs): + circle = SphericalCircle( + (self.ra * u.deg, self.dec * u.deg), + self.radius_arcsec * u.arcsec, + transform=ax.get_transform("icrs"), + **kwargs, + ) + ax.add_patch(circle) + def cone_filter(data_frame: npd.NestedFrame, ra, dec, radius_arcsec, metadata: TableProperties): """Filters a dataframe to only include points within the specified cone diff --git a/tests/lsdb/catalog/test_cone_search.py b/tests/lsdb/catalog/test_cone_search.py index 2be0e6ee..f43d601f 100644 --- a/tests/lsdb/catalog/test_cone_search.py +++ b/tests/lsdb/catalog/test_cone_search.py @@ -3,8 +3,11 @@ import pandas as pd import pytest from astropy.coordinates import SkyCoord +from astropy.visualization.wcsaxes import SphericalCircle from hats.pixel_math.validators import ValidatorsErrors +from lsdb import ConeSearch + def test_cone_search_filters_correct_points(small_sky_order1_catalog, assert_divisions_are_correct): ra = 0 @@ -124,3 +127,13 @@ def test_empty_cone_search_with_margin(small_sky_order1_source_with_margin): cone = small_sky_order1_source_with_margin.cone_search(ra, dec, radius, fine=False) assert len(cone._ddf_pixel_map) == 0 assert len(cone.margin._ddf_pixel_map) == 0 + + +def test_cone_search_plot(): + ra = 100 + dec = 80 + radius = 60 + search = ConeSearch(ra, dec, radius) + _, ax = search.plot() + assert len(ax.patches) == 1 + assert isinstance(ax.patches[0], SphericalCircle)