Skip to content

Commit

Permalink
allow slicing by sequence (#380)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jul 2, 2024
1 parent 7b9eff6 commit f17530b
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 8 deletions.
13 changes: 10 additions & 3 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
from typing import overload

from narwhals._arrow.utils import translate_dtype
from narwhals._pandas_like.utils import evaluate_into_exprs
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pyarrow

if TYPE_CHECKING:
Expand Down Expand Up @@ -65,17 +67,22 @@ def __getitem__(self, item: str | slice) -> ArrowSeries | ArrowDataFrame:
return ArrowSeries(self._dataframe[item], name=item)

elif isinstance(item, slice):
from narwhals._arrow.dataframe import ArrowDataFrame

if item.step is not None and item.step != 1:
msg = "Slicing with step is not supported on PyArrow tables"
raise NotImplementedError(msg)
start = item.start or 0
stop = item.stop or len(self._dataframe)
return ArrowDataFrame(
return self._from_dataframe(
self._dataframe.slice(item.start, stop - start),
)

elif isinstance(item, Sequence) or (
(np := get_numpy()) is not None
and isinstance(item, np.ndarray)
and item.ndim == 1
):
return self._from_dataframe(self._dataframe.take(item))

else: # pragma: no cover
msg = f"Expected str or slice, got: {type(item)}"
raise TypeError(msg)
Expand Down
12 changes: 9 additions & 3 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Iterable
from typing import Iterator
from typing import Literal
from typing import Sequence
from typing import overload

from narwhals._pandas_like.expr import PandasExpr
Expand All @@ -17,13 +18,12 @@
from narwhals._pandas_like.utils import validate_indices
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

from narwhals._pandas_like.group_by import PandasGroupBy
Expand Down Expand Up @@ -101,12 +101,18 @@ def __getitem__(self, item: str | slice) -> PandasSeries | PandasDataFrame:
implementation=self._implementation,
)

elif isinstance(item, slice):
elif isinstance(item, (slice, Sequence)):
from narwhals._pandas_like.dataframe import PandasDataFrame

return PandasDataFrame(
self._dataframe.iloc[item], implementation=self._implementation
)
elif (
(np := get_numpy()) is not None
and isinstance(item, np.ndarray)
and item.ndim == 1
):
return self._from_dataframe(self._dataframe.iloc[item])

else: # pragma: no cover
msg = f"Expected str or slice, got: {type(item)}"
Expand Down
12 changes: 10 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from narwhals._pandas_like.dataframe import PandasDataFrame
from narwhals.dependencies import get_cudf
from narwhals.dependencies import get_modin
from narwhals.dependencies import get_numpy
from narwhals.dependencies import get_pandas
from narwhals.dependencies import get_polars
from narwhals.dependencies import get_pyarrow
Expand Down Expand Up @@ -388,19 +389,26 @@ def shape(self) -> tuple[int, int]:
"""
return self._dataframe.shape # type: ignore[no-any-return]

@overload
def __getitem__(self, item: Sequence[int]) -> Series: ...

@overload
def __getitem__(self, item: str) -> Series: ...

@overload
def __getitem__(self, item: slice) -> DataFrame: ...

def __getitem__(self, item: str | slice) -> Series | DataFrame:
def __getitem__(self, item: str | slice | Sequence[int]) -> Series | DataFrame:
if isinstance(item, str):
from narwhals.series import Series

return Series(self._dataframe[item])

elif isinstance(item, (range, slice)):
elif isinstance(item, (Sequence, slice)) or (
(np := get_numpy()) is not None
and isinstance(item, np.ndarray)
and item.ndim == 1
):
return self._from_dataframe(self._dataframe[item])

else:
Expand Down
22 changes: 22 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from typing import Any

import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest
Expand Down Expand Up @@ -47,3 +49,23 @@ def test_slice_lazy_fails() -> None:
def test_slice_int_fails(constructor_with_pyarrow: Any) -> None:
with pytest.raises(TypeError, match="Expected str or slice, got: <class 'int'>"):
_ = nw.from_native(constructor_with_pyarrow(data))[1] # type: ignore[call-overload,index]


def test_gather(constructor_with_pyarrow: Any) -> None:
df = nw.from_native(constructor_with_pyarrow(data), eager_only=True)
result = df[[0, 3, 1]]
expected = {
"a": [1.0, 4.0, 2.0],
"b": [11, 14, 12],
}
compare_dicts(result, expected)
result = df[np.array([0, 3, 1])]
compare_dicts(result, expected)


def test_gather_pandas_index() -> None:
# check that we're slicing positionally, and not on the pandas index
df = pd.DataFrame({"a": [4, 1, 2], "b": [1, 4, 2]}, index=[2, 1, 3])
result = nw.from_native(df, eager_only=True)[[1, 2]]
expected = {"a": [1, 2], "b": [4, 2]}
compare_dicts(result, expected)

0 comments on commit f17530b

Please sign in to comment.