Skip to content

Commit

Permalink
Bug: __getitem__ check for edge cases such as subsetting 0 rows or …
Browse files Browse the repository at this point in the history
…0 columns (#994)
  • Loading branch information
raisadz authored Sep 20, 2024
1 parent a52d470 commit f0b31ee
Show file tree
Hide file tree
Showing 11 changed files with 201 additions and 77 deletions.
60 changes: 31 additions & 29 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from typing import overload

from narwhals._arrow.utils import broadcast_series
from narwhals._arrow.utils import convert_slice_to_nparray
from narwhals._arrow.utils import convert_str_slice_to_int_slice
from narwhals._arrow.utils import select_rows
from narwhals._arrow.utils import translate_dtype
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
Expand All @@ -18,6 +19,7 @@
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
Expand Down Expand Up @@ -121,14 +123,18 @@ def __getitem__(self, item: str) -> ArrowSeries: ...
@overload
def __getitem__(self, item: slice) -> ArrowDataFrame: ...

@overload
def __getitem__(self, item: tuple[slice, slice]) -> ArrowDataFrame: ...

def __getitem__(
self,
item: str
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int]
| tuple[slice, str | int],
| tuple[slice, str | int]
| tuple[slice, slice],
) -> ArrowSeries | ArrowDataFrame:
if isinstance(item, str):
from narwhals._arrow.series import ArrowSeries
Expand All @@ -141,33 +147,19 @@ def __getitem__(
elif (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[1], (list, tuple))
and is_sequence_but_not_str(item[1])
):
if item[0] == slice(None):
selected_rows = self._native_frame
else:
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
selected_rows = self._native_frame.take(range_)

if len(item[1]) == 0:
# Return empty dataframe
return self._from_native_frame(self._native_frame.slice(0, 0).select([]))
selected_rows = select_rows(self._native_frame, item[0])
return self._from_native_frame(selected_rows.select(item[1]))

elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start)
if item[1].start is not None
else None
)
stop = (
columns.index(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
return self._from_native_frame(
self._native_frame.take(item[0]).select(columns[start:stop:step])
)
Expand All @@ -192,11 +184,9 @@ def __getitem__(
name=col_name,
backend_version=self._backend_version,
)
range_ = convert_slice_to_nparray(
num_rows=len(self._native_frame), rows_slice=item[0]
)
selected_rows = select_rows(self._native_frame, item[0])
return ArrowSeries(
self._native_frame[col_name].take(range_),
selected_rows[col_name],
name=col_name,
backend_version=self._backend_version,
)
Expand All @@ -205,15 +195,27 @@ def __getitem__(
if item.step is not None and item.step != 1:
msg = "Slicing with step is not supported on PyArrow tables"
raise NotImplementedError(msg)
columns = self.columns
if isinstance(item.start, str) or isinstance(item.stop, str):
start, stop, step = convert_str_slice_to_int_slice(item, columns)
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step])
)
start = item.start or 0
stop = item.stop or len(self._native_frame)
stop = item.stop if item.stop is not None else len(self._native_frame)
return self._from_native_frame(
self._native_frame.slice(item.start, stop - start),
self._native_frame.slice(start, stop - start),
)

elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
if (
isinstance(item, Sequence)
and all(isinstance(x, str) for x in item)
and len(item) > 0
):
return self._from_native_frame(self._native_frame.select(item))
if isinstance(item, Sequence) and len(item) == 0:
return self._from_native_frame(self._native_frame.slice(0, 0))
return self._from_native_frame(self._native_frame.take(item))

else: # pragma: no cover
Expand Down
22 changes: 22 additions & 0 deletions narwhals/_arrow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from narwhals.utils import isinstance_or_issubclass

if TYPE_CHECKING:
import pyarrow as pa

from narwhals._arrow.series import ArrowSeries


Expand Down Expand Up @@ -286,3 +288,23 @@ def convert_slice_to_nparray(
return np.arange(num_rows)[rows_slice]
else:
return rows_slice


def select_rows(table: pa.Table, rows: Any) -> pa.Table:
if isinstance(rows, slice) and rows == slice(None):
selected_rows = table
elif isinstance(rows, Sequence) and not rows:
selected_rows = table.slice(0, 0)
else:
range_ = convert_slice_to_nparray(num_rows=len(table), rows_slice=rows)
selected_rows = table.take(range_)
return selected_rows


def convert_str_slice_to_int_slice(
str_slice: slice, columns: list[str]
) -> tuple[int | None, int | None, int | None]:
start = columns.index(str_slice.start) if str_slice.start is not None else None
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)
43 changes: 26 additions & 17 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.utils import broadcast_series
from narwhals._pandas_like.utils import convert_str_slice_to_int_slice
from narwhals._pandas_like.utils import create_native_series
from narwhals._pandas_like.utils import horizontal_concat
from narwhals._pandas_like.utils import translate_dtype
Expand All @@ -22,6 +23,7 @@
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import generate_unique_token
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
Expand Down Expand Up @@ -119,14 +121,18 @@ def __getitem__(self, item: Sequence[str]) -> PandasLikeDataFrame: ...
@overload
def __getitem__(self, item: slice) -> PandasLikeDataFrame: ...

@overload
def __getitem__(self, item: tuple[slice, slice]) -> Self: ...

def __getitem__(
self,
item: str
| int
| slice
| Sequence[int]
| Sequence[str]
| tuple[Sequence[int], str | int],
| tuple[Sequence[int], str | int]
| tuple[slice, slice],
) -> PandasLikeSeries | PandasLikeDataFrame:
if isinstance(item, str):
from narwhals._pandas_like.series import PandasLikeSeries
Expand All @@ -140,16 +146,19 @@ def __getitem__(
elif (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[1], (tuple, list))
and is_sequence_but_not_str(item[1])
):
if len(item[1]) == 0:
# Return empty dataframe
return self._from_native_frame(self._native_frame.__class__())
if all(isinstance(x, int) for x in item[1]):
return self._from_native_frame(self._native_frame.iloc[item])
if all(isinstance(x, str) for x in item[1]):
item = (
indexer = (
item[0],
self._native_frame.columns.get_indexer(item[1]),
)
return self._from_native_frame(self._native_frame.iloc[item])
return self._from_native_frame(self._native_frame.iloc[indexer])
msg = (
f"Expected sequence str or int, got: {type(item[1])}" # pragma: no cover
)
Expand All @@ -158,15 +167,7 @@ def __getitem__(
elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
columns = self._native_frame.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.get_loc(item[1].start) if item[1].start is not None else None
)
stop = (
columns.get_loc(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
return self._from_native_frame(
self._native_frame.iloc[item[0], slice(start, stop, step)]
)
Expand Down Expand Up @@ -197,13 +198,21 @@ def __getitem__(
backend_version=self._backend_version,
)

elif isinstance(item, (slice, Sequence)) or (
is_numpy_array(item) and item.ndim == 1
):
if isinstance(item, Sequence) and all(isinstance(x, str) for x in item):
elif is_sequence_but_not_str(item) or (is_numpy_array(item) and item.ndim == 1):
if all(isinstance(x, str) for x in item) and len(item) > 0:
return self._from_native_frame(self._native_frame.loc[:, item])
return self._from_native_frame(self._native_frame.iloc[item])

elif isinstance(item, slice):
if isinstance(item.start, str) or isinstance(item.stop, str):
start, stop, step = convert_str_slice_to_int_slice(
item, self._native_frame.columns
)
return self._from_native_frame(
self._native_frame.iloc[:, slice(start, stop, step)]
)
return self._from_native_frame(self._native_frame.iloc[item])

else: # pragma: no cover
msg = f"Expected str or slice, got: {type(item)}"
raise TypeError(msg)
Expand Down
10 changes: 10 additions & 0 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from narwhals.dtypes import DType

ExprT = TypeVar("ExprT", bound=PandasLikeExpr)
import pandas as pd


def validate_column_comparand(index: Any, other: Any) -> Any:
Expand Down Expand Up @@ -497,3 +498,12 @@ def int_dtype_mapper(dtype: Any) -> str:
if str(dtype).lower() != str(dtype): # pragma: no cover
return "Int64"
return "int64"


def convert_str_slice_to_int_slice(
str_slice: slice, columns: pd.Index
) -> tuple[int | None, int | None, int | None]:
start = columns.get_loc(str_slice.start) if str_slice.start is not None else None
stop = columns.get_loc(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)
68 changes: 42 additions & 26 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from typing import Any

from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import translate_dtype
from narwhals.dependencies import get_polars
from narwhals.utils import Implementation
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
Expand Down Expand Up @@ -84,38 +86,52 @@ def shape(self) -> tuple[int, int]:
return self._native_frame.shape # type: ignore[no-any-return]

def __getitem__(self, item: Any) -> Any:
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
if self._backend_version >= (0, 20, 30):
return self._from_native_object(self._native_frame.__getitem__(item))
else: # pragma: no cover
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start) if item[1].start is not None else None
)
stop = (
columns.index(item[1].stop) + 1 if item[1].stop is not None else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step]).__getitem__(
item[0]
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step]).__getitem__(
item[0]
)
)
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
pl = get_polars()
if (
isinstance(item, tuple)
and (len(item) == 2)
and is_sequence_but_not_str(item[1])
and (len(item[1]) == 0)
):
result = self._native_frame.select(item[1])
elif isinstance(item, slice) and (
isinstance(item.start, str) or isinstance(item.stop, str)
):
start, stop, step = convert_str_slice_to_int_slice(item, columns)
return self._from_native_frame(
self._native_frame.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
self._native_frame.select(columns[start:stop:step])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
pl = get_polars()
result = self._native_frame.__getitem__(item)
if isinstance(result, pl.Series):
from narwhals._polars.series import PolarsSeries

return PolarsSeries(result, backend_version=self._backend_version)
return self._from_native_object(result)
elif is_sequence_but_not_str(item) and (len(item) == 0):
result = self._native_frame.slice(0, 0)
else:
result = self._native_frame.__getitem__(item)
if isinstance(result, pl.Series):
from narwhals._polars.series import PolarsSeries

return PolarsSeries(result, backend_version=self._backend_version)
return self._from_native_object(result)

def get_column(self, name: str) -> Any:
from narwhals._polars.series import PolarsSeries
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_polars/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,12 @@ def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any:
if dtype == dtypes.Date:
return pl.Date()
return pl.Unknown() # pragma: no cover


def convert_str_slice_to_int_slice(
str_slice: slice, columns: list[str]
) -> tuple[int | None, int | None, int | None]: # pragma: no cover
start = columns.index(str_slice.start) if str_slice.start is not None else None
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)
Loading

0 comments on commit f0b31ee

Please sign in to comment.