Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_dask/group_by.py
2025-05-08 21:10:14 -05:00

120 lines
4.0 KiB
Python

from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import ClassVar
from typing import Mapping
from typing import Sequence
import dask.dataframe as dd
from narwhals._compliant import DepthTrackingGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
if TYPE_CHECKING:
import pandas as pd
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
from typing_extensions import Self
from typing_extensions import TypeAlias
from narwhals._compliant.group_by import NarwhalsAggregation
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
_AggFn: TypeAlias = Callable[..., Any]
from dask_expr._groupby import GroupBy as _DaskGroupBy
else:
_DaskGroupBy = dx._groupby.GroupBy
Aggregation: TypeAlias = "str | _AggFn"
"""The name of an aggregation function, or the function itself."""
def n_unique() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.nunique(dropna=False)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.sum()
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
def var(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.var, ddof=ddof)
def std(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.std, ddof=ddof)
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": std,
"var": var,
"len": "size",
"n_unique": n_unique,
"count": "count",
}
def __init__(
self: Self, df: DaskLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool
) -> None:
self._compliant_frame = df
self._keys: list[str] = list(keys)
self._grouped = self.compliant.native.groupby(
list(self._keys), dropna=drop_null_keys, observed=True
)
def agg(self: Self, *exprs: DaskExpr) -> DaskLazyFrame:
from narwhals._dask.dataframe import DaskLazyFrame
if not exprs:
# No aggregation provided
return self.compliant.simple_select(*self._keys).unique(
self._keys, keep="any"
)
self._ensure_all_simple(exprs)
# This should be the fastpath, but cuDF is too far behind to use it.
# - https://github.com/rapidsai/cudf/issues/15118
# - https://github.com/rapidsai/cudf/issues/15084
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, self._keys
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
column = self._keys[0]
agg_fn = self._remap_expr_name(expr._function_name)
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
continue
# e.g. `agg(nw.mean('a'))`
agg_fn = self._remap_expr_name(self._leaf_name(expr))
# deal with n_unique case in a "lazy" mode to not depend on dask globally
agg_fn = agg_fn(**expr._call_kwargs) if callable(agg_fn) else agg_fn
simple_aggregations.update(
(alias, (output_name, agg_fn))
for alias, output_name in zip(aliases, output_names)
)
return DaskLazyFrame(
self._grouped.agg(**simple_aggregations).reset_index(),
backend_version=self.compliant._backend_version,
version=self.compliant._version,
)