diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index 78b241c9b..6c7b20485 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -16,12 +16,26 @@ from narwhals._arrow.typing import IntoArrowExpr POLARS_TO_ARROW_AGGREGATIONS = { + "len": "count", "n_unique": "count_distinct", "std": "stddev", "var": "variance", # currently unused, we don't have `var` yet } +def get_function_name_option(function_name: str) -> Any | None: + """Map specific pyarrow compute function to respective option to match polars behaviour.""" + import pyarrow.compute as pc # ignore-banned-import + + function_name_to_options = { + "count": pc.CountOptions(mode="all"), + "count_distinct": pc.CountOptions(mode="all"), + "stddev": pc.VarianceOptions(ddof=1), + "variance": pc.VarianceOptions(ddof=1), + } + return function_name_to_options.get(function_name) + + class ArrowGroupBy: def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None: import pyarrow as pa # ignore-banned-import() @@ -119,27 +133,13 @@ def agg_arrow( function_name = remove_prefix(expr._function_name, "col->") function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name) + + option = get_function_name_option(function_name) for root_name, output_name in zip(expr._root_names, expr._output_names): - if function_name == "len": - simple_aggregations[output_name] = ( - (root_name, "count", pc.CountOptions(mode="all")), - f"{root_name}_count", - ) - elif function_name == "count_distinct": - simple_aggregations[output_name] = ( - (root_name, "count_distinct", pc.CountOptions(mode="all")), - f"{root_name}_count_distinct", - ) - elif function_name == "stddev": - simple_aggregations[output_name] = ( - (root_name, "stddev", pc.VarianceOptions(ddof=1)), - f"{root_name}_stddev", - ) - else: - simple_aggregations[output_name] = ( - (root_name, function_name), - f"{root_name}_{function_name}", - ) + simple_aggregations[output_name] = ( + (root_name, function_name, option), + f"{root_name}_{function_name}", + ) aggs: list[Any] = [] name_mapping = {}