Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/_dask/group_by.py
2025-01-10 21:40:35 +00:00

165 lines
5.5 KiB
Python
Executable File

from __future__ import annotations
from copy import copy
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.utils import remove_prefix
if TYPE_CHECKING:
import dask.dataframe as dd
import pandas as pd
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.typing import IntoDaskExpr
def n_unique() -> dd.Aggregation:
import dask.dataframe as dd # ignore-banned-import
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s.nunique(dropna=False) # type: ignore[no-any-return]
def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
return s0.sum() # type: ignore[no-any-return]
return dd.Aggregation(
name="nunique",
chunk=chunk,
agg=agg,
)
POLARS_TO_DASK_AGGREGATIONS = {
"len": "size",
"n_unique": n_unique,
}
class DaskLazyGroupBy:
def __init__(
self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool
) -> None:
self._df = df
self._keys = keys
self._grouped = self._df._native_frame.groupby(
list(self._keys),
dropna=drop_null_keys,
observed=True,
)
def agg(
self,
*aggs: IntoDaskExpr,
**named_aggs: IntoDaskExpr,
) -> DaskLazyFrame:
exprs = parse_into_exprs(
*aggs,
namespace=self._df.__narwhals_namespace__(),
**named_aggs,
)
output_names: list[str] = copy(self._keys)
for expr in exprs:
if expr._output_names is None:
msg = (
"Anonymous expressions are not supported in group_by.agg.\n"
"Instead of `nw.all()`, try using a named expression, such as "
"`nw.col('a', 'b')`\n"
)
raise ValueError(msg)
output_names.extend(expr._output_names)
return agg_dask(
self._df,
self._grouped,
exprs,
self._keys,
self._from_native_frame,
)
def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
df, backend_version=self._df._backend_version, dtypes=self._df._dtypes
)
def agg_dask(
df: DaskLazyFrame,
grouped: Any,
exprs: list[DaskExpr],
keys: list[str],
from_dataframe: Callable[[Any], DaskLazyFrame],
) -> DaskLazyFrame:
"""
This should be the fastpath, but cuDF is too far behind to use it.
- https://github.com/rapidsai/cudf/issues/15118
- https://github.com/rapidsai/cudf/issues/15084
"""
if not exprs:
# No aggregation provided
return df.select(*keys).unique(subset=keys)
all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
all_simple_aggs = False
break
if all_simple_aggs:
simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {}
for expr in exprs:
if expr._depth == 0:
# e.g. agg(nw.len()) # noqa: ERA001
if expr._output_names is None: # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
expr._function_name, expr._function_name
)
for output_name in expr._output_names:
simple_aggregations[output_name] = (keys[0], function_name)
continue
# e.g. agg(nw.mean('a')) # noqa: ERA001
if (
expr._depth != 1 or expr._root_names is None or expr._output_names is None
): # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
function_name = remove_prefix(expr._function_name, "col->")
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
# deal with n_unique case in a "lazy" mode to not depend on dask globally
function_name = function_name() if callable(function_name) else function_name
for root_name, output_name in zip(expr._root_names, expr._output_names):
simple_aggregations[output_name] = (root_name, function_name)
try:
result_simple = grouped.agg(**simple_aggregations)
except ValueError as exc:
msg = "Failed to aggregated - does your aggregation function return a scalar?"
raise RuntimeError(msg) from exc
return from_dataframe(result_simple.reset_index())
msg = (
"Non-trivial complex found.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"dask dataframe.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)