From f17530bf307beb27de33572f71fe6dd62fa55df3 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Tue, 2 Jul 2024 10:35:36 +0100 Subject: [PATCH] allow slicing by sequence (#380) --- narwhals/_arrow/dataframe.py | 13 ++++++++++--- narwhals/_pandas_like/dataframe.py | 12 +++++++++--- narwhals/dataframe.py | 12 ++++++++++-- tests/frame/slice_test.py | 22 ++++++++++++++++++++++ 4 files changed, 51 insertions(+), 8 deletions(-) diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 9cf65e722..c3f2c0fbd 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -2,10 +2,12 @@ from typing import TYPE_CHECKING from typing import Any +from typing import Sequence from typing import overload from narwhals._arrow.utils import translate_dtype from narwhals._pandas_like.utils import evaluate_into_exprs +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pyarrow if TYPE_CHECKING: @@ -65,17 +67,22 @@ def __getitem__(self, item: str | slice) -> ArrowSeries | ArrowDataFrame: return ArrowSeries(self._dataframe[item], name=item) elif isinstance(item, slice): - from narwhals._arrow.dataframe import ArrowDataFrame - if item.step is not None and item.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) start = item.start or 0 stop = item.stop or len(self._dataframe) - return ArrowDataFrame( + return self._from_dataframe( self._dataframe.slice(item.start, stop - start), ) + elif isinstance(item, Sequence) or ( + (np := get_numpy()) is not None + and isinstance(item, np.ndarray) + and item.ndim == 1 + ): + return self._from_dataframe(self._dataframe.take(item)) + else: # pragma: no cover msg = f"Expected str or slice, got: {type(item)}" raise TypeError(msg) diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index c58f08ce4..dab477b07 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -6,6 +6,7 @@ from typing import Iterable from typing import Iterator from typing import Literal +from typing import Sequence from typing import overload from narwhals._pandas_like.expr import PandasExpr @@ -17,13 +18,12 @@ from narwhals._pandas_like.utils import validate_indices from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas from narwhals.utils import flatten from narwhals.utils import parse_version if TYPE_CHECKING: - from collections.abc import Sequence - from typing_extensions import Self from narwhals._pandas_like.group_by import PandasGroupBy @@ -101,12 +101,18 @@ def __getitem__(self, item: str | slice) -> PandasSeries | PandasDataFrame: implementation=self._implementation, ) - elif isinstance(item, slice): + elif isinstance(item, (slice, Sequence)): from narwhals._pandas_like.dataframe import PandasDataFrame return PandasDataFrame( self._dataframe.iloc[item], implementation=self._implementation ) + elif ( + (np := get_numpy()) is not None + and isinstance(item, np.ndarray) + and item.ndim == 1 + ): + return self._from_dataframe(self._dataframe.iloc[item]) else: # pragma: no cover msg = f"Expected str or slice, got: {type(item)}" diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 9e770774e..280c85e3f 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -13,6 +13,7 @@ from narwhals._pandas_like.dataframe import PandasDataFrame from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin +from narwhals.dependencies import get_numpy from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow @@ -388,19 +389,26 @@ def shape(self) -> tuple[int, int]: """ return self._dataframe.shape # type: ignore[no-any-return] + @overload + def __getitem__(self, item: Sequence[int]) -> Series: ... + @overload def __getitem__(self, item: str) -> Series: ... @overload def __getitem__(self, item: slice) -> DataFrame: ... - def __getitem__(self, item: str | slice) -> Series | DataFrame: + def __getitem__(self, item: str | slice | Sequence[int]) -> Series | DataFrame: if isinstance(item, str): from narwhals.series import Series return Series(self._dataframe[item]) - elif isinstance(item, (range, slice)): + elif isinstance(item, (Sequence, slice)) or ( + (np := get_numpy()) is not None + and isinstance(item, np.ndarray) + and item.ndim == 1 + ): return self._from_dataframe(self._dataframe[item]) else: diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index 5c47f2872..83ed31fe5 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -1,5 +1,7 @@ from typing import Any +import numpy as np +import pandas as pd import polars as pl import pyarrow as pa import pytest @@ -47,3 +49,23 @@ def test_slice_lazy_fails() -> None: def test_slice_int_fails(constructor_with_pyarrow: Any) -> None: with pytest.raises(TypeError, match="Expected str or slice, got: "): _ = nw.from_native(constructor_with_pyarrow(data))[1] # type: ignore[call-overload,index] + + +def test_gather(constructor_with_pyarrow: Any) -> None: + df = nw.from_native(constructor_with_pyarrow(data), eager_only=True) + result = df[[0, 3, 1]] + expected = { + "a": [1.0, 4.0, 2.0], + "b": [11, 14, 12], + } + compare_dicts(result, expected) + result = df[np.array([0, 3, 1])] + compare_dicts(result, expected) + + +def test_gather_pandas_index() -> None: + # check that we're slicing positionally, and not on the pandas index + df = pd.DataFrame({"a": [4, 1, 2], "b": [1, 4, 2]}, index=[2, 1, 3]) + result = nw.from_native(df, eager_only=True)[[1, 2]] + expected = {"a": [1, 2], "b": [4, 2]} + compare_dicts(result, expected)