Skip to content

Commit

Permalink
chore: factor out check_columns_exist (#1792)
Browse files Browse the repository at this point in the history
* adding check column exists

* add test

* remove str from _spark_like
  • Loading branch information
DeaMariaLeon authored Jan 11, 2025
1 parent 346724e commit 1f22a1d
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 20 deletions.
6 changes: 2 additions & 4 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import is_sequence_but_not_str
Expand Down Expand Up @@ -642,9 +642,7 @@ def unique(
import pyarrow.compute as pc

df = self._native_frame
if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)
check_column_exists(self.columns, subset)
subset = subset or self.columns

if keep in {"any", "first", "last"}:
Expand Down
6 changes: 2 additions & 4 deletions narwhals/_dask/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,9 @@
from narwhals._dask.utils import parse_exprs_and_named_exprs
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import CompliantLazyFrame
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import parse_columns_to_drop
Expand Down Expand Up @@ -198,9 +198,7 @@ def unique(
*,
keep: Literal["any", "none"] = "any",
) -> Self:
if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)
check_column_exists(self.columns, subset)
native_frame = self._native_frame
if keep == "none":
subset = subset or self.columns
Expand Down
6 changes: 2 additions & 4 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._pandas_like.utils import validate_dataframe_comparand
from narwhals.dependencies import is_numpy_array
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import flatten
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import import_dtypes_module
Expand Down Expand Up @@ -695,9 +695,7 @@ def unique(
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
mapped_keep = {"none": False, "any": "first"}.get(keep, keep)
if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)
check_column_exists(self.columns, subset)
return self._from_native_frame(
self._native_frame.drop_duplicates(subset=subset, keep=mapped_keep)
)
Expand Down
11 changes: 3 additions & 8 deletions narwhals/_spark_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

from narwhals._spark_like.utils import native_to_narwhals_dtype
from narwhals._spark_like.utils import parse_exprs_and_named_exprs
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import check_column_exists
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
Expand Down Expand Up @@ -200,19 +200,14 @@ def rename(self: Self, mapping: dict[str, str]) -> Self:

def unique(
self: Self,
subset: str | list[str] | None = None,
subset: list[str] | None = None,
*,
keep: Literal["any", "none"],
) -> Self:
if keep != "any":
msg = "`LazyFrame.unique` with PySpark backend only supports `keep='any'`."
raise ValueError(msg)

if subset is not None and any(x not in self.columns for x in subset):
msg = f"Column(s) {subset} not found in {self.columns}"
raise ColumnNotFoundError(msg)

subset = [subset] if isinstance(subset, str) else subset
check_column_exists(self.columns, subset)
return self._from_native_frame(self._native_frame.dropDuplicates(subset=subset))

def join(
Expand Down
6 changes: 6 additions & 0 deletions narwhals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1059,3 +1059,9 @@ def generate_repr(header: str, native_repr: str) -> str:
"| Use `.to_native` to see native output |\n└"
f"{'─' * 39}┘"
)


def check_column_exists(columns: list[str], subset: list[str] | None) -> None:
if subset is not None and any(x not in columns for x in subset):
msg = f"Column(s) {subset} not found in {columns}"
raise ColumnNotFoundError(msg)
13 changes: 13 additions & 0 deletions tests/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import re
import string
from typing import TYPE_CHECKING

Expand All @@ -13,6 +14,8 @@
from pandas.testing import assert_series_equal

import narwhals.stable.v1 as nw
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import check_column_exists
from narwhals.utils import parse_version
from tests.utils import PANDAS_VERSION
from tests.utils import get_module_version_as_tuple
Expand Down Expand Up @@ -284,3 +287,13 @@ def test_generate_temporary_column_name_raise() -> None:
)
def test_parse_version(version: str, expected: tuple[int, ...]) -> None:
assert parse_version(version) == expected


def test_check_column_exists() -> None:
columns = ["a", "b", "c"]
subset = ["a", "d"]
with pytest.raises(
ColumnNotFoundError,
match=re.escape("Column(s) ['a', 'd'] not found in ['a', 'b', 'c']"),
):
check_column_exists(columns, subset)

0 comments on commit 1f22a1d

Please sign in to comment.