64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Iterator
|
|
from typing import Sequence
|
|
from typing import cast
|
|
|
|
from narwhals._polars.utils import extract_native
|
|
|
|
if TYPE_CHECKING:
|
|
from polars.dataframe.group_by import GroupBy as NativeGroupBy
|
|
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._polars.dataframe import PolarsDataFrame
|
|
from narwhals._polars.dataframe import PolarsLazyFrame
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
|
|
class PolarsGroupBy:
|
|
_compliant_frame: PolarsDataFrame
|
|
_keys: Sequence[str]
|
|
|
|
@property
|
|
def compliant(self) -> PolarsDataFrame:
|
|
return self._compliant_frame
|
|
|
|
def __init__(
|
|
self, df: PolarsDataFrame, keys: Sequence[str], /, *, drop_null_keys: bool
|
|
) -> None:
|
|
self._compliant_frame = df
|
|
self._keys = list(keys)
|
|
df = df.drop_nulls(keys) if drop_null_keys else df
|
|
self._grouped: NativeGroupBy = df._native_frame.group_by(keys)
|
|
|
|
def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame:
|
|
from_native = self.compliant._with_native
|
|
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))
|
|
|
|
def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
|
|
for key, df in self._grouped:
|
|
yield tuple(cast("str", key)), self.compliant._with_native(df)
|
|
|
|
|
|
class PolarsLazyGroupBy:
|
|
_compliant_frame: PolarsLazyFrame
|
|
_keys: Sequence[str]
|
|
|
|
@property
|
|
def compliant(self) -> PolarsLazyFrame:
|
|
return self._compliant_frame
|
|
|
|
def __init__(
|
|
self, df: PolarsLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool
|
|
) -> None:
|
|
self._compliant_frame = df
|
|
self._keys = list(keys)
|
|
df = df.drop_nulls(keys) if drop_null_keys else df
|
|
self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys)
|
|
|
|
def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame:
|
|
from_native = self.compliant._with_native
|
|
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))
|