Skip to content

Commit

Permalink
only ever do a single groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Mar 27, 2024
1 parent 66594aa commit 95661c2
Showing 1 changed file with 42 additions and 55 deletions.
97 changes: 42 additions & 55 deletions narwhals/pandas_like/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _from_dataframe(self, df: PandasDataFrame) -> PandasDataFrame:
)


def agg_pandas( # noqa: PLR0913,PLR0915
def agg_pandas( # noqa: PLR0913
grouped: Any,
exprs: list[PandasExpr],
keys: list[str],
Expand All @@ -89,32 +89,31 @@ def agg_pandas( # noqa: PLR0913,PLR0915

from narwhals.pandas_like.namespace import PandasNamespace

simple_aggs = []
complex_aggs = []
all_simple_aggs = True
for expr in exprs:
if is_simple_aggregation(expr):
simple_aggs.append(expr)
else:
complex_aggs.append(expr)
simple_aggregations = {}
for expr in simple_aggs:
if expr._depth == 0:
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = (
keys[0],
expr._function_name.replace("len", "size"),
)
continue
if not is_simple_aggregation(expr):
all_simple_aggs = False
break

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
name = remove_prefix(expr._function_name, "col->")
simple_aggregations[output_name] = (root_name, name)
if all_simple_aggs:
simple_aggregations: dict[str, tuple[str, str]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(pl.len())
assert expr._output_names is not None
for output_name in expr._output_names:
simple_aggregations[output_name] = (
keys[0],
expr._function_name.replace("len", "size"),
)
continue

assert expr._root_names is not None
assert expr._output_names is not None
for root_name, output_name in zip(expr._root_names, expr._output_names):
name = remove_prefix(expr._function_name, "col->")
simple_aggregations[output_name] = (root_name, name)

if simple_aggregations:
aggs = collections.defaultdict(list)
name_mapping = {}
for output_name, named_agg in simple_aggregations.items():
Expand All @@ -128,49 +127,37 @@ def agg_pandas( # noqa: PLR0913,PLR0915
) from exc
result_simple.columns = [f"{a}_{b}" for a, b in result_simple.columns]
result_simple = result_simple.rename(columns=name_mapping).reset_index()
else:
result_simple = None
return from_dataframe(result_simple.loc[:, output_names])

warnings.warn(
"Found complex group-by expression, which can't be expressed efficiently with the "
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...).",
UserWarning,
stacklevel=2,
)

plx = PandasNamespace(implementation=implementation)

def func(df: Any) -> Any:
out_group = []
out_names = []
for expr in complex_aggs:
for expr in exprs:
results_keys = expr._call(from_dataframe(df))
for result_keys in results_keys:
out_group.append(item(result_keys._series))
out_names.append(result_keys.name)
return plx.make_native_series(name="", data=out_group, index=out_names)

if complex_aggs:
warnings.warn(
"Found complex group-by expression, which can't be expressed efficiently with the "
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...).",
UserWarning,
stacklevel=2,
)
if implementation == "pandas":
import pandas as pd

if parse_version(pd.__version__) < parse_version("2.2.0"): # pragma: no cover
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
else: # pragma: no cover
if implementation == "pandas":
import pandas as pd

if parse_version(pd.__version__) < parse_version("2.2.0"): # pragma: no cover
result_complex = grouped.apply(func)
else:
result_complex = grouped.apply(func, include_groups=False)
else: # pragma: no cover
result_complex = grouped.apply(func)

if result_simple is not None and not complex_aggs:
result = result_simple
elif result_simple is not None and complex_aggs:
result = pd.concat(
[result_simple, result_complex.reset_index(drop=True)],
axis=1,
copy=False,
)
elif complex_aggs:
result = result_complex.reset_index()
else:
raise AssertionError("At least one aggregation should have been passed")
result = result_complex.reset_index()
return from_dataframe(result.loc[:, output_names])

0 comments on commit 95661c2

Please sign in to comment.