diff --git a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py index adcdd56f..4119c46a 100644 --- a/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py +++ b/src/lsdb/core/crossmatch/abstract_crossmatch_algorithm.py @@ -63,7 +63,7 @@ def crossmatch(self) -> pd.DataFrame: """Perform a crossmatch""" # pylint: disable=unused-argument - def validate(self, *args, **kwargs): + def validate(self): """Validate the metadata and arguments. This method will be called **once**, after the algorithm object has diff --git a/src/lsdb/core/crossmatch/bounded_kdtree_match.py b/src/lsdb/core/crossmatch/bounded_kdtree_match.py index 4c24d701..0feee048 100644 --- a/src/lsdb/core/crossmatch/bounded_kdtree_match.py +++ b/src/lsdb/core/crossmatch/bounded_kdtree_match.py @@ -16,7 +16,6 @@ def validate( radius_arcsec: float = 1, require_right_margin: bool = True, min_radius_arcsec: float = 0, - **kwargs, ): super().validate(n_neighbors, radius_arcsec, require_right_margin) if min_radius_arcsec < 0: @@ -28,8 +27,9 @@ def crossmatch( self, n_neighbors: int = 1, radius_arcsec: float = 1, + # We need it here because the signature is shared with .validate() + require_right_margin: bool = True, # pylint: disable=unused-argument min_radius_arcsec: float = 0, - **kwargs, ) -> pd.DataFrame: """Perform a cross-match between the data from two HEALPix pixels diff --git a/src/lsdb/core/crossmatch/kdtree_match.py b/src/lsdb/core/crossmatch/kdtree_match.py index ddb4ee26..f32d7ea2 100644 --- a/src/lsdb/core/crossmatch/kdtree_match.py +++ b/src/lsdb/core/crossmatch/kdtree_match.py @@ -17,13 +17,11 @@ class KdTreeCrossmatch(AbstractCrossmatchAlgorithm): extra_columns = pd.DataFrame({"_dist_arcsec": pd.Series(dtype=np.dtype("float64"))}) - # pylint: disable=unused-argument,arguments-differ def validate( self, n_neighbors: int = 1, radius_arcsec: float = 1, require_right_margin=True, - **kwargs, ): super().validate() # Validate radius @@ -39,12 +37,12 @@ def validate( if self.right_margin_hc_structure.catalog_info.margin_threshold < radius_arcsec: raise ValueError("Cross match radius is greater than margin threshold") - # pylint: disable=unused-argument def crossmatch( self, n_neighbors: int = 1, radius_arcsec: float = 1, - **kwargs, + # We need it here because the signature is shared with .validate() + require_right_margin=True, # pylint: disable=unused-argument ) -> pd.DataFrame: """Perform a cross-match between the data from two HEALPix pixels diff --git a/tests/lsdb/catalog/test_crossmatch.py b/tests/lsdb/catalog/test_crossmatch.py index 246eaecc..ca60f7f1 100644 --- a/tests/lsdb/catalog/test_crossmatch.py +++ b/tests/lsdb/catalog/test_crossmatch.py @@ -226,6 +226,10 @@ class MockCrossmatchAlgorithm(AbstractCrossmatchAlgorithm): extra_columns = pd.DataFrame({"_DIST": pd.Series(dtype=np.dtype("float64"))}) + # We must have the same signature as the crossmatch method + def validate(self, mock_results: pd.DataFrame = None): # pylint: disable=unused-argument + super().validate() + def crossmatch(self, mock_results: pd.DataFrame = None): left_reset = self.left.reset_index(drop=True) right_reset = self.right.reset_index(drop=True) @@ -293,3 +297,8 @@ def test_append_extra_columns(small_sky_xmatch_catalog): algo.extra_columns = None algo._append_extra_columns(xmatch_df, pd.DataFrame(extra_columns)) assert "_DIST" not in xmatch_df.columns + + +def test_raise_for_unknown_kwargs(small_sky_catalog): + with pytest.raises(TypeError, match="unexpected keyword argument"): + small_sky_catalog.crossmatch(small_sky_catalog, unknown_kwarg="value")