Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/_arrow/group_by.py
2025-01-10 21:40:35 +00:00

170 lines
6.0 KiB
Python
Executable File

from __future__ import annotations
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterator
from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.utils import remove_prefix
if TYPE_CHECKING:
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import IntoArrowExpr
POLARS_TO_ARROW_AGGREGATIONS = {
"len": "count",
"n_unique": "count_distinct",
"std": "stddev",
"var": "variance", # currently unused, we don't have `var` yet
}
def get_function_name_option(function_name: str) -> Any | None:
"""Map specific pyarrow compute function to respective option to match polars behaviour."""
import pyarrow.compute as pc # ignore-banned-import
function_name_to_options = {
"count": pc.CountOptions(mode="all"),
"count_distinct": pc.CountOptions(mode="all"),
"stddev": pc.VarianceOptions(ddof=1),
"variance": pc.VarianceOptions(ddof=1),
}
return function_name_to_options.get(function_name)
class ArrowGroupBy:
def __init__(
self, df: ArrowDataFrame, keys: list[str], *, drop_null_keys: bool
) -> None:
import pyarrow as pa # ignore-banned-import()
if drop_null_keys:
self._df = df.drop_nulls(keys)
else:
self._df = df
self._keys = list(keys)
self._grouped = pa.TableGroupBy(self._df._native_frame, list(self._keys))
def agg(
self,
*aggs: IntoArrowExpr,
**named_aggs: IntoArrowExpr,
) -> ArrowDataFrame:
exprs = parse_into_exprs(
*aggs,
namespace=self._df.__narwhals_namespace__(),
**named_aggs,
)
output_names: list[str] = copy(self._keys)
for expr in exprs:
if expr._output_names is None:
msg = (
"Anonymous expressions are not supported in group_by.agg.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)
output_names.extend(expr._output_names)
return agg_arrow(
self._grouped,
exprs,
self._keys,
output_names,
self._df._from_native_frame,
)
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
key_values = self._df.select(*self._keys).unique(subset=self._keys, keep="first")
nw_namespace = self._df.__narwhals_namespace__()
yield from (
(
key_value,
self._df.filter(
*[nw_namespace.col(k) == v for k, v in zip(self._keys, key_value)]
),
)
for key_value in key_values.iter_rows()
)
def agg_arrow(
grouped: Any,
exprs: list[ArrowExpr],
keys: list[str],
output_names: list[str],
from_dataframe: Callable[[Any], ArrowDataFrame],
) -> ArrowDataFrame:
import pyarrow.compute as pc # ignore-banned-import()
all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
all_simple_aggs = False
break
if all_simple_aggs:
# Mapping from output name to
# (aggregation_args, pyarrow_output_name) # noqa: ERA001
simple_aggregations: dict[str, tuple[tuple[Any, ...], str]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
if (
expr._output_names is None or expr._function_name != "len"
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
simple_aggregations[expr._output_names[0]] = (
(keys[0], "count", pc.CountOptions(mode="all")),
f"{keys[0]}_count",
)
continue
# e.g. agg(nw.mean('a')) # noqa: ERA001
if (
expr._depth != 1 or expr._root_names is None or expr._output_names is None
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_ARROW_AGGREGATIONS.get(function_name, function_name)
option = get_function_name_option(function_name)
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (
(root_name, function_name, option),
f"{root_name}_{function_name}",
)
aggs: list[Any] = []
name_mapping = {}
for output_name, (
aggregation_args,
pyarrow_output_name,
) in simple_aggregations.items():
aggs.append(aggregation_args)
name_mapping[pyarrow_output_name] = output_name
result_simple = grouped.aggregate(aggs)
result_simple = result_simple.rename_columns(
[name_mapping.get(col, col) for col in result_simple.column_names]
).select(output_names)
return from_dataframe(result_simple)
msg = (
"Non-trivial complex found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"pyarrow table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)