From 5bad66101779c020bce2b712626c0d5ada86d1b3 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 20 Mar 2024 18:29:38 +0000 Subject: [PATCH] never reset index --- narwhals/pandas_like/dataframe.py | 4 ++-- narwhals/pandas_like/utils.py | 9 +++------ tests/test_common.py | 4 ++-- 3 files changed, 7 insertions(+), 10 deletions(-) diff --git a/narwhals/pandas_like/dataframe.py b/narwhals/pandas_like/dataframe.py index 1724a5395..c2a2c2796 100644 --- a/narwhals/pandas_like/dataframe.py +++ b/narwhals/pandas_like/dataframe.py @@ -8,9 +8,9 @@ from narwhals.pandas_like.utils import evaluate_into_exprs from narwhals.pandas_like.utils import horizontal_concat -from narwhals.pandas_like.utils import maybe_reset_indices from narwhals.pandas_like.utils import translate_dtype from narwhals.pandas_like.utils import validate_dataframe_comparand +from narwhals.pandas_like.utils import validate_indices from narwhals.utils import flatten_str if TYPE_CHECKING: @@ -87,7 +87,7 @@ def select( **named_exprs: IntoPandasExpr, ) -> Self: new_series = evaluate_into_exprs(self, *exprs, **named_exprs) - new_series = maybe_reset_indices(new_series) + new_series = validate_indices(new_series) df = horizontal_concat( [series._series for series in new_series], implementation=self._implementation, diff --git a/narwhals/pandas_like/utils.py b/narwhals/pandas_like/utils.py index 6c870506d..6c2c507b7 100644 --- a/narwhals/pandas_like/utils.py +++ b/narwhals/pandas_like/utils.py @@ -383,13 +383,10 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any: raise TypeError(msg) -def maybe_reset_indices(series: list[PandasSeries]) -> list[PandasSeries]: +def validate_indices(series: list[PandasSeries]) -> list[PandasSeries]: idx = series[0]._series.index - found_non_matching_index = False for s in series[1:]: if s._series.index is not idx and not (s._series.index == idx).all(): - found_non_matching_index = True - break - if found_non_matching_index: - return [s._from_series(s._series.reset_index(drop=True)) for s in series] + msg = "Found implicit index alignment, which is not allowed by Narwhals." + raise RuntimeError(msg) return series diff --git a/tests/test_common.py b/tests/test_common.py index 011afaf11..e0f072888 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -328,8 +328,8 @@ def test_expr_min_max(df_raw: Any) -> None: @pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd, df_lazy]) def test_expr_sample(df_raw: Any) -> None: df = nw.LazyFrame(df_raw) - result_shape = nw.to_native(df.select(nw.col("a", "b").sample(n=2)).collect()).shape - expected = (2, 2) + result_shape = nw.to_native(df.select(nw.col("a").sample(n=2)).collect()).shape + expected = (2, 1) assert result_shape == expected