165 lines
5.5 KiB
Python
Executable File
165 lines
5.5 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from copy import copy
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Callable
|
|
|
|
from narwhals._expression_parsing import is_simple_aggregation
|
|
from narwhals._expression_parsing import parse_into_exprs
|
|
from narwhals.utils import remove_prefix
|
|
|
|
if TYPE_CHECKING:
|
|
import dask.dataframe as dd
|
|
import pandas as pd
|
|
|
|
from narwhals._dask.dataframe import DaskLazyFrame
|
|
from narwhals._dask.expr import DaskExpr
|
|
from narwhals._dask.typing import IntoDaskExpr
|
|
|
|
|
|
def n_unique() -> dd.Aggregation:
|
|
import dask.dataframe as dd # ignore-banned-import
|
|
|
|
def chunk(s: pd.core.groupby.generic.SeriesGroupBy) -> int:
|
|
return s.nunique(dropna=False) # type: ignore[no-any-return]
|
|
|
|
def agg(s0: pd.core.groupby.generic.SeriesGroupBy) -> int:
|
|
return s0.sum() # type: ignore[no-any-return]
|
|
|
|
return dd.Aggregation(
|
|
name="nunique",
|
|
chunk=chunk,
|
|
agg=agg,
|
|
)
|
|
|
|
|
|
POLARS_TO_DASK_AGGREGATIONS = {
|
|
"len": "size",
|
|
"n_unique": n_unique,
|
|
}
|
|
|
|
|
|
class DaskLazyGroupBy:
|
|
def __init__(
|
|
self, df: DaskLazyFrame, keys: list[str], *, drop_null_keys: bool
|
|
) -> None:
|
|
self._df = df
|
|
self._keys = keys
|
|
self._grouped = self._df._native_frame.groupby(
|
|
list(self._keys),
|
|
dropna=drop_null_keys,
|
|
observed=True,
|
|
)
|
|
|
|
def agg(
|
|
self,
|
|
*aggs: IntoDaskExpr,
|
|
**named_aggs: IntoDaskExpr,
|
|
) -> DaskLazyFrame:
|
|
exprs = parse_into_exprs(
|
|
*aggs,
|
|
namespace=self._df.__narwhals_namespace__(),
|
|
**named_aggs,
|
|
)
|
|
output_names: list[str] = copy(self._keys)
|
|
for expr in exprs:
|
|
if expr._output_names is None:
|
|
msg = (
|
|
"Anonymous expressions are not supported in group_by.agg.\n"
|
|
"Instead of `nw.all()`, try using a named expression, such as "
|
|
"`nw.col('a', 'b')`\n"
|
|
)
|
|
raise ValueError(msg)
|
|
|
|
output_names.extend(expr._output_names)
|
|
|
|
return agg_dask(
|
|
self._df,
|
|
self._grouped,
|
|
exprs,
|
|
self._keys,
|
|
self._from_native_frame,
|
|
)
|
|
|
|
def _from_native_frame(self, df: DaskLazyFrame) -> DaskLazyFrame:
|
|
from narwhals._dask.dataframe import DaskLazyFrame
|
|
|
|
return DaskLazyFrame(
|
|
df, backend_version=self._df._backend_version, dtypes=self._df._dtypes
|
|
)
|
|
|
|
|
|
def agg_dask(
|
|
df: DaskLazyFrame,
|
|
grouped: Any,
|
|
exprs: list[DaskExpr],
|
|
keys: list[str],
|
|
from_dataframe: Callable[[Any], DaskLazyFrame],
|
|
) -> DaskLazyFrame:
|
|
"""
|
|
This should be the fastpath, but cuDF is too far behind to use it.
|
|
|
|
- https://github.com/rapidsai/cudf/issues/15118
|
|
- https://github.com/rapidsai/cudf/issues/15084
|
|
"""
|
|
if not exprs:
|
|
# No aggregation provided
|
|
return df.select(*keys).unique(subset=keys)
|
|
|
|
all_simple_aggs = True
|
|
for expr in exprs:
|
|
if not is_simple_aggregation(expr):
|
|
all_simple_aggs = False
|
|
break
|
|
|
|
if all_simple_aggs:
|
|
simple_aggregations: dict[str, tuple[str, str | dd.Aggregation]] = {}
|
|
for expr in exprs:
|
|
if expr._depth == 0:
|
|
# e.g. agg(nw.len()) # noqa: ERA001
|
|
if expr._output_names is None: # pragma: no cover
|
|
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
|
raise AssertionError(msg)
|
|
|
|
function_name = POLARS_TO_DASK_AGGREGATIONS.get(
|
|
expr._function_name, expr._function_name
|
|
)
|
|
for output_name in expr._output_names:
|
|
simple_aggregations[output_name] = (keys[0], function_name)
|
|
continue
|
|
|
|
# e.g. agg(nw.mean('a')) # noqa: ERA001
|
|
if (
|
|
expr._depth != 1 or expr._root_names is None or expr._output_names is None
|
|
): # pragma: no cover
|
|
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
|
raise AssertionError(msg)
|
|
|
|
function_name = remove_prefix(expr._function_name, "col->")
|
|
function_name = POLARS_TO_DASK_AGGREGATIONS.get(function_name, function_name)
|
|
|
|
# deal with n_unique case in a "lazy" mode to not depend on dask globally
|
|
function_name = function_name() if callable(function_name) else function_name
|
|
|
|
for root_name, output_name in zip(expr._root_names, expr._output_names):
|
|
simple_aggregations[output_name] = (root_name, function_name)
|
|
try:
|
|
result_simple = grouped.agg(**simple_aggregations)
|
|
except ValueError as exc:
|
|
msg = "Failed to aggregated - does your aggregation function return a scalar?"
|
|
raise RuntimeError(msg) from exc
|
|
return from_dataframe(result_simple.reset_index())
|
|
|
|
msg = (
|
|
"Non-trivial complex found.\n\n"
|
|
"Hint: you were probably trying to apply a non-elementary aggregation with a "
|
|
"dask dataframe.\n"
|
|
"Please rewrite your query such that group-by aggregations "
|
|
"are elementary. For example, instead of:\n\n"
|
|
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
|
|
"use:\n\n"
|
|
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
|
|
)
|
|
raise ValueError(msg)
|