Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_compliant/selectors.py
2025-05-08 21:10:14 -05:00

311 lines
11 KiB
Python

"""Almost entirely complete, generic `selectors` implementation."""
from __future__ import annotations
import re
from functools import partial
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Collection
from typing import Iterable
from typing import Iterator
from typing import Sequence
from typing import TypeVar
from typing import overload
from narwhals._compliant.expr import CompliantExpr
from narwhals.utils import _parse_time_unit_and_time_zone
from narwhals.utils import dtype_matches_time_unit_and_time_zone
from narwhals.utils import get_column_names
from narwhals.utils import import_dtypes_module
from narwhals.utils import is_compliant_dataframe
if not TYPE_CHECKING: # pragma: no cover
# TODO @dangotbanned: Remove after dropping `3.8` (#2084)
# - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386
import sys
if sys.version_info >= (3, 9):
from typing import Protocol
else:
from typing import Generic
Protocol = Generic
else: # pragma: no cover
from typing import Protocol
if TYPE_CHECKING:
from datetime import timezone
from typing_extensions import Self
from typing_extensions import TypeAlias
from typing_extensions import TypeIs
from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.dataframe import CompliantLazyFrame
from narwhals._compliant.expr import NativeExpr
from narwhals._compliant.series import CompliantSeries
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
from narwhals.utils import Implementation
from narwhals.utils import Version
__all__ = [
"CompliantSelector",
"CompliantSelectorNamespace",
"EagerSelectorNamespace",
"EvalNames",
"EvalSeries",
"LazySelectorNamespace",
]
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeries[Any] | NativeExpr")
SeriesT = TypeVar("SeriesT", bound="CompliantSeries[Any]")
ExprT = TypeVar("ExprT", bound="NativeExpr")
FrameT = TypeVar(
"FrameT", bound="CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any]"
)
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrame[Any, Any, Any]")
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrame[Any, Any]")
SelectorOrExpr: TypeAlias = (
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
)
EvalSeries: TypeAlias = Callable[[FrameT], Sequence[SeriesOrExprT]]
EvalNames: TypeAlias = Callable[[FrameT], Sequence[str]]
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
def _selector(
self,
call: EvalSeries[FrameT, SeriesOrExprT],
evaluate_output_names: EvalNames[FrameT],
/,
) -> CompliantSelector[FrameT, SeriesOrExprT]: ...
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
def _iter_columns_dtypes(
self, df: FrameT, /
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip(self._iter_columns(df), df.columns)
def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
return self._selector(series, names)
def by_dtype(
self: Self, dtypes: Collection[DType | type[DType]]
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
return self._selector(series, names)
def matches(self: Self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
p = re.compile(pattern)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
if is_compliant_dataframe(df) and not self._implementation.is_duckdb():
return [df.get_column(col) for col in df.columns if p.search(col)]
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
def names(df: FrameT) -> Sequence[str]:
return [col for col in df.columns if p.search(col)]
return self._selector(series, names)
def numeric(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
return self._selector(series, names)
def categorical(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).Categorical)
def string(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).String)
def boolean(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(import_dtypes_module(self._version).Boolean)
def all(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return list(self._iter_columns(df))
return self._selector(series, get_column_names)
def datetime(
self: Self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> CompliantSelector[FrameT, SeriesOrExprT]:
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=import_dtypes_module(version=self._version),
time_units=time_units,
time_zones=time_zones,
)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if matches(tp)]
return self._selector(series, names)
class EagerSelectorNamespace(
CompliantSelectorNamespace[DataFrameT, SeriesT],
Protocol[DataFrameT, SeriesT],
):
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
for ser in self._iter_columns(df):
yield ser.name, ser.dtype
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
yield from df.iter_columns()
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
for ser in self._iter_columns(df):
yield ser, ser.dtype
class LazySelectorNamespace(
CompliantSelectorNamespace[LazyFrameT, ExprT],
Protocol[LazyFrameT, ExprT],
):
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
yield from df.schema.items()
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip(self._iter_columns(df), df.schema.values())
class CompliantSelector(
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
):
@property
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
return self.__narwhals_namespace__().selectors
def _to_expr(self: Self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def _is_selector(
self: Self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
return isinstance(other, type(self))
@overload
def __sub__(self: Self, other: Self) -> Self: ...
@overload
def __sub__(
self: Self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __sub__(
self: Self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip(self(df), lhs_names) if name not in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x not in rhs_names]
return self.selectors._selector(series, names)
return self._to_expr() - other
@overload
def __or__(self: Self, other: Self) -> Self: ...
@overload
def __or__(
self: Self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __or__(
self: Self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def names(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(x for x, name in zip(self(df), lhs_names) if name not in rhs_names),
*other(df),
]
def series(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
return self.selectors._selector(names, series)
return self._to_expr() | other
@overload
def __and__(self: Self, other: Self) -> Self: ...
@overload
def __and__(
self: Self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __and__(
self: Self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x, name in zip(self(df), lhs_names) if name in rhs_names]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x in rhs_names]
return self.selectors._selector(series, names)
return self._to_expr() & other
def __invert__(self: Self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self # type: ignore[no-any-return]
def _eval_lhs_rhs(
df: CompliantDataFrame[Any, Any, Any] | CompliantLazyFrame[Any, Any],
lhs: CompliantExpr[Any, Any],
rhs: CompliantExpr[Any, Any],
) -> tuple[Sequence[str], Sequence[str]]:
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)