diff --git a/narwhals/pandas_like/group_by.py b/narwhals/pandas_like/group_by.py index b2bbe37bc..0170f5cde 100644 --- a/narwhals/pandas_like/group_by.py +++ b/narwhals/pandas_like/group_by.py @@ -71,7 +71,7 @@ def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame: ) -def agg_pandas( # noqa: PLR0913,PLR0915 +def agg_pandas( # noqa: PLR0913 grouped: Any, exprs: list[PandasExpr], keys: list[str], @@ -89,32 +89,31 @@ def agg_pandas( # noqa: PLR0913,PLR0915 from narwhals.pandas_like.namespace import PandasNamespace - simple_aggs = [] - complex_aggs = [] + all_simple_aggs = True for expr in exprs: - if is_simple_aggregation(expr): - simple_aggs.append(expr) - else: - complex_aggs.append(expr) - simple_aggregations = {} - for expr in simple_aggs: - if expr._depth == 0: - # e.g. agg(pl.len()) - assert expr._output_names is not None - for output_name in expr._output_names: - simple_aggregations[output_name] = ( - keys[0], - expr._function_name.replace("len", "size"), - ) - continue + if not is_simple_aggregation(expr): + all_simple_aggs = False + break - assert expr._root_names is not None - assert expr._output_names is not None - for root_name, output_name in zip(expr._root_names, expr._output_names): - name = remove_prefix(expr._function_name, "col->") - simple_aggregations[output_name] = (root_name, name) + if all_simple_aggs: + simple_aggregations: dict[str, tuple[str, str]] = {} + for expr in exprs: + if expr._depth == 0: + # e.g. agg(pl.len()) + assert expr._output_names is not None + for output_name in expr._output_names: + simple_aggregations[output_name] = ( + keys[0], + expr._function_name.replace("len", "size"), + ) + continue + + assert expr._root_names is not None + assert expr._output_names is not None + for root_name, output_name in zip(expr._root_names, expr._output_names): + name = remove_prefix(expr._function_name, "col->") + simple_aggregations[output_name] = (root_name, name) - if simple_aggregations: aggs = collections.defaultdict(list) name_mapping = {} for output_name, named_agg in simple_aggregations.items(): @@ -128,49 +127,37 @@ def agg_pandas( # noqa: PLR0913,PLR0915 ) from exc result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns] result_simple = result_simple.rename(columns=name_mapping).reset_index() - else: - result_simple = None + return from_dataframe(result_simple.loc[:, output_names]) + + warnings.warn( + "Found complex group-by expression, which can't be expressed efficiently with the " + "pandas API. If you can, please rewrite your query such that group-by aggregations " + "are simple (e.g. mean, std, min, max, ...).", + UserWarning, + stacklevel=2, + ) plx = PandasNamespace(implementation=implementation) def func(df: Any) -> Any: out_group = [] out_names = [] - for expr in complex_aggs: + for expr in exprs: results_keys = expr._call(from_dataframe(df)) for result_keys in results_keys: out_group.append(item(result_keys._series)) out_names.append(result_keys.name) return plx.make_native_series(name="", data=out_group, index=out_names) - if complex_aggs: - warnings.warn( - "Found complex group-by expression, which can't be expressed efficiently with the " - "pandas API. If you can, please rewrite your query such that group-by aggregations " - "are simple (e.g. mean, std, min, max, ...).", - UserWarning, - stacklevel=2, - ) - if implementation == "pandas": - import pandas as pd - - if parse_version(pd.__version__) < parse_version("2.2.0"): # pragma: no cover - result_complex = grouped.apply(func) - else: - result_complex = grouped.apply(func, include_groups=False) - else: # pragma: no cover + if implementation == "pandas": + import pandas as pd + + if parse_version(pd.__version__) < parse_version("2.2.0"): # pragma: no cover result_complex = grouped.apply(func) + else: + result_complex = grouped.apply(func, include_groups=False) + else: # pragma: no cover + result_complex = grouped.apply(func) - if result_simple is not None and not complex_aggs: - result = result_simple - elif result_simple is not None and complex_aggs: - result = pd.concat( - [result_simple, result_complex.reset_index(drop=True)], - axis=1, - copy=False, - ) - elif complex_aggs: - result = result_complex.reset_index() - else: - raise AssertionError("At least one aggregation should have been passed") + result = result_complex.reset_index() return from_dataframe(result.loc[:, output_names])