Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_polars/group_by.py
2025-05-08 21:10:14 -05:00

64 lines
2.1 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Iterator
from typing import Sequence
from typing import cast
from narwhals._polars.utils import extract_native
if TYPE_CHECKING:
from polars.dataframe.group_by import GroupBy as NativeGroupBy
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
from typing_extensions import Self
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
class PolarsGroupBy:
_compliant_frame: PolarsDataFrame
_keys: Sequence[str]
@property
def compliant(self) -> PolarsDataFrame:
return self._compliant_frame
def __init__(
self, df: PolarsDataFrame, keys: Sequence[str], /, *, drop_null_keys: bool
) -> None:
self._compliant_frame = df
self._keys = list(keys)
df = df.drop_nulls(keys) if drop_null_keys else df
self._grouped: NativeGroupBy = df._native_frame.group_by(keys)
def agg(self: Self, *aggs: PolarsExpr) -> PolarsDataFrame:
from_native = self.compliant._with_native
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))
def __iter__(self: Self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
for key, df in self._grouped:
yield tuple(cast("str", key)), self.compliant._with_native(df)
class PolarsLazyGroupBy:
_compliant_frame: PolarsLazyFrame
_keys: Sequence[str]
@property
def compliant(self) -> PolarsLazyFrame:
return self._compliant_frame
def __init__(
self, df: PolarsLazyFrame, keys: Sequence[str], /, *, drop_null_keys: bool
) -> None:
self._compliant_frame = df
self._keys = list(keys)
df = df.drop_nulls(keys) if drop_null_keys else df
self._grouped: NativeLazyGroupBy = df._native_frame.group_by(keys)
def agg(self: Self, *aggs: PolarsExpr) -> PolarsLazyFrame:
from_native = self.compliant._with_native
return from_native(self._grouped.agg(extract_native(arg) for arg in aggs))