Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_compliant/when_then.py
2025-05-08 21:10:14 -05:00

173 lines
6.0 KiB
Python

from __future__ import annotations
import sys
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Sequence
from typing import TypeVar
from typing import cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import CompliantExprAny
from narwhals._compliant.typing import CompliantFrameAny
from narwhals._compliant.typing import CompliantLazyFrameT
from narwhals._compliant.typing import CompliantSeriesOrNativeExprAny
from narwhals._compliant.typing import EagerDataFrameT
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import LazyExprAny
from narwhals._compliant.typing import NativeExprT
from narwhals._compliant.typing import NativeSeriesT
if TYPE_CHECKING:
from typing_extensions import Self
from typing_extensions import TypeAlias
from narwhals.utils import Implementation
from narwhals.utils import Version
from narwhals.utils import _FullContext
if not TYPE_CHECKING: # pragma: no cover
if sys.version_info >= (3, 9):
from typing import Protocol as Protocol38
else:
from typing import Generic as Protocol38
else: # pragma: no cover
# TODO @dangotbanned: Remove after dropping `3.8` (#2084)
# - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386
from typing import Protocol as Protocol38
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen", "LazyWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native or python literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT]
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _FullContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._backend_version = context._backend_version
obj._version = context._version
return obj
class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT, ExprT]):
_call: Callable[[FrameT], Sequence[SeriesT]]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_function_name: str
_depth: int
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
_call_kwargs: dict[str, Any]
@classmethod
def from_when(
cls,
when: CompliantWhen[FrameT, SeriesT, ExprT],
then: IntoExpr[SeriesT, ExprT],
/,
) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._depth = 0
obj._function_name = "whenthen"
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._backend_version = when._backend_version
obj._version = when._version
obj._call_kwargs = {}
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
self._function_name = "whenotherwise"
return cast("ExprT", self)
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | Scalar | None,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = df._extract_comparand(self._otherwise_value(df)[0])
else:
otherwise = self._otherwise_value
result = self._if_then_else(when.native, df._extract_comparand(then), otherwise)
return [then._with_native(result)]
class LazyWhen(
CompliantWhen[CompliantLazyFrameT, NativeExprT, LazyExprT],
Protocol38[CompliantLazyFrameT, NativeExprT, LazyExprT],
):
when: Callable[..., NativeExprT]
lit: Callable[..., NativeExprT]
def __call__(self: Self, df: CompliantLazyFrameT) -> Sequence[NativeExprT]:
is_expr = self._condition._is_expr
when = self.when
lit = self.lit
condition = df._evaluate_expr(self._condition)
then_ = self._then_value
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
other_ = self._otherwise_value
if other_ is None:
result = when(condition, then)
else:
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
result = when(condition, then).otherwise(otherwise) # type: ignore # noqa: PGH003
return [result]