Skip to content

Commit

Permalink
refactor pyarrow (#931)
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Sep 9, 2024
1 parent 767fbfb commit 06f7b87
Showing 1 changed file with 20 additions and 20 deletions.
40 changes: 20 additions & 20 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,26 @@
from narwhals._arrow.typing import IntoArrowExpr

POLARS_TO_ARROW_AGGREGATIONS = {
"len": "count",
"n_unique": "count_distinct",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
}


def get_function_name_option(function_name: str) -> Any | None:
"""Map specific pyarrow compute function to respective option to match polars behaviour."""
import pyarrow.compute as pc # ignore-banned-import

function_name_to_options = {
"count": pc.CountOptions(mode="all"),
"count_distinct": pc.CountOptions(mode="all"),
"stddev": pc.VarianceOptions(ddof=1),
"variance": pc.VarianceOptions(ddof=1),
}
return function_name_to_options.get(function_name)


class ArrowGroupBy:
def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None:
import pyarrow as pa # ignore-banned-import()
Expand Down Expand Up @@ -119,27 +133,13 @@ def agg_arrow(

function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)

option = get_function_name_option(function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
if function_name == "len":
simple_aggregations[output_name] = (
(root_name, "count", pc.CountOptions(mode="all")),
f"{root_name}_count",
)
elif function_name == "count_distinct":
simple_aggregations[output_name] = (
(root_name, "count_distinct", pc.CountOptions(mode="all")),
f"{root_name}_count_distinct",
)
elif function_name == "stddev":
simple_aggregations[output_name] = (
(root_name, "stddev", pc.VarianceOptions(ddof=1)),
f"{root_name}_stddev",
)
else:
simple_aggregations[output_name] = (
(root_name, function_name),
f"{root_name}_{function_name}",
)
simple_aggregations[output_name] = (
(root_name, function_name, option),
f"{root_name}_{function_name}",
)

aggs: list[Any] = []
name_mapping = {}
Expand Down

0 comments on commit 06f7b87

Please sign in to comment.