Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/group_by.py
2025-01-10 21:40:35 +00:00

138 lines
4.6 KiB
Python
Executable File

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import Generic
from typing import Iterable
from typing import Iterator
from typing import TypeVar
from typing import cast
from narwhals.dataframe import DataFrame
from narwhals.dataframe import LazyFrame
from narwhals.utils import tupleify
if TYPE_CHECKING:
from narwhals.typing import IntoExpr
DataFrameT = TypeVar("DataFrameT")
LazyFrameT = TypeVar("LazyFrameT")
class GroupBy(Generic[DataFrameT]):
def __init__(self, df: DataFrameT, *keys: str, drop_null_keys: bool) -> None:
self._df = cast(DataFrame[Any], df)
self._keys = keys
self._grouped = self._df._compliant_frame.group_by(
*self._keys, drop_null_keys=drop_null_keys
)
def agg(
self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr
) -> DataFrameT:
"""
Compute aggregations for each group of a group by operation.
Arguments:
aggs: Aggregations to compute for each group of the group by operation,
specified as positional arguments.
named_aggs: Additional aggregations, specified as keyword arguments.
Examples:
Group by one column or by multiple columns and call `agg` to compute
the grouped sum of another column.
>>> import pandas as pd
>>> import polars as pl
>>> import narwhals as nw
>>> df_pd = pd.DataFrame(
... {
... "a": ["a", "b", "a", "b", "c"],
... "b": [1, 2, 1, 3, 3],
... "c": [5, 4, 3, 2, 1],
... }
... )
>>> df_pl = pl.DataFrame(
... {
... "a": ["a", "b", "a", "b", "c"],
... "b": [1, 2, 1, 3, 3],
... "c": [5, 4, 3, 2, 1],
... }
... )
We define library agnostic functions:
>>> @nw.narwhalify
... def func(df):
... return df.group_by("a").agg(nw.col("b").sum()).sort("a")
>>> @nw.narwhalify
... def func_mult_col(df):
... return df.group_by("a", "b").agg(nw.sum("c")).sort("a", "b")
We can then pass either pandas or Polars to `func` and `func_mult_col`:
>>> func(df_pd)
a b
0 a 2
1 b 5
2 c 3
>>> func(df_pl)
shape: (3, 2)
┌─────┬─────┐
│ a ┆ b │
│ --- ┆ --- │
│ str ┆ i64 │
╞═════╪═════╡
│ a ┆ 2 │
│ b ┆ 5 │
│ c ┆ 3 │
└─────┴─────┘
>>> func_mult_col(df_pd)
a b c
0 a 1 8
1 b 2 4
2 b 3 2
3 c 3 1
>>> func_mult_col(df_pl)
shape: (4, 3)
┌─────┬─────┬─────┐
│ a ┆ b ┆ c │
│ --- ┆ --- ┆ --- │
│ str ┆ i64 ┆ i64 │
╞═════╪═════╪═════╡
│ a ┆ 1 ┆ 8 │
│ b ┆ 2 ┆ 4 │
│ b ┆ 3 ┆ 2 │
│ c ┆ 3 ┆ 1 │
└─────┴─────┴─────┘
"""
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
return self._df._from_compliant_dataframe( # type: ignore[return-value]
self._grouped.agg(*aggs, **named_aggs),
)
def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
yield from ( # type: ignore[misc]
(tupleify(key), self._df._from_compliant_dataframe(df))
for (key, df) in self._grouped.__iter__()
)
class LazyGroupBy(Generic[LazyFrameT]):
def __init__(self, df: LazyFrameT, *keys: str, drop_null_keys: bool) -> None:
self._df = cast(LazyFrame[Any], df)
self._keys = keys
self._grouped = self._df._compliant_frame.group_by(
*self._keys, drop_null_keys=drop_null_keys
)
def agg(
self, *aggs: IntoExpr | Iterable[IntoExpr], **named_aggs: IntoExpr
) -> LazyFrameT:
aggs, named_aggs = self._df._flatten_and_extract(*aggs, **named_aggs)
return self._df._from_compliant_dataframe( # type: ignore[return-value]
self._grouped.agg(*aggs, **named_aggs),
)