173 lines
6.0 KiB
Python
173 lines
6.0 KiB
Python
from __future__ import annotations
|
|
|
|
import sys
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import Sequence
|
|
from typing import TypeVar
|
|
from typing import cast
|
|
|
|
from narwhals._compliant.expr import CompliantExpr
|
|
from narwhals._compliant.typing import CompliantExprAny
|
|
from narwhals._compliant.typing import CompliantFrameAny
|
|
from narwhals._compliant.typing import CompliantLazyFrameT
|
|
from narwhals._compliant.typing import CompliantSeriesOrNativeExprAny
|
|
from narwhals._compliant.typing import EagerDataFrameT
|
|
from narwhals._compliant.typing import EagerExprT
|
|
from narwhals._compliant.typing import EagerSeriesT
|
|
from narwhals._compliant.typing import LazyExprAny
|
|
from narwhals._compliant.typing import NativeExprT
|
|
from narwhals._compliant.typing import NativeSeriesT
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import Self
|
|
from typing_extensions import TypeAlias
|
|
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import Version
|
|
from narwhals.utils import _FullContext
|
|
|
|
if not TYPE_CHECKING: # pragma: no cover
|
|
if sys.version_info >= (3, 9):
|
|
from typing import Protocol as Protocol38
|
|
else:
|
|
from typing import Generic as Protocol38
|
|
else: # pragma: no cover
|
|
# TODO @dangotbanned: Remove after dropping `3.8` (#2084)
|
|
# - https://github.com/narwhals-dev/narwhals/pull/2064#discussion_r1965921386
|
|
from typing import Protocol as Protocol38
|
|
|
|
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen", "LazyWhen"]
|
|
|
|
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
|
|
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
|
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
|
|
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
|
|
|
|
Scalar: TypeAlias = Any
|
|
"""A native or python literal value."""
|
|
|
|
IntoExpr: TypeAlias = "SeriesT | ExprT | Scalar"
|
|
"""Anything that is convertible into a `CompliantExpr`."""
|
|
|
|
|
|
class CompliantWhen(Protocol38[FrameT, SeriesT, ExprT]):
|
|
_condition: ExprT
|
|
_then_value: IntoExpr[SeriesT, ExprT]
|
|
_otherwise_value: IntoExpr[SeriesT, ExprT]
|
|
_implementation: Implementation
|
|
_backend_version: tuple[int, ...]
|
|
_version: Version
|
|
|
|
@property
|
|
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT]]: ...
|
|
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
|
|
|
|
def then(
|
|
self, value: IntoExpr[SeriesT, ExprT], /
|
|
) -> CompliantThen[FrameT, SeriesT, ExprT]:
|
|
return self._then.from_when(self, value)
|
|
|
|
@classmethod
|
|
def from_expr(cls, condition: ExprT, /, *, context: _FullContext) -> Self:
|
|
obj = cls.__new__(cls)
|
|
obj._condition = condition
|
|
obj._then_value = None
|
|
obj._otherwise_value = None
|
|
obj._implementation = context._implementation
|
|
obj._backend_version = context._backend_version
|
|
obj._version = context._version
|
|
return obj
|
|
|
|
|
|
class CompliantThen(CompliantExpr[FrameT, SeriesT], Protocol38[FrameT, SeriesT, ExprT]):
|
|
_call: Callable[[FrameT], Sequence[SeriesT]]
|
|
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
|
|
_function_name: str
|
|
_depth: int
|
|
_implementation: Implementation
|
|
_backend_version: tuple[int, ...]
|
|
_version: Version
|
|
_call_kwargs: dict[str, Any]
|
|
|
|
@classmethod
|
|
def from_when(
|
|
cls,
|
|
when: CompliantWhen[FrameT, SeriesT, ExprT],
|
|
then: IntoExpr[SeriesT, ExprT],
|
|
/,
|
|
) -> Self:
|
|
when._then_value = then
|
|
obj = cls.__new__(cls)
|
|
obj._call = when
|
|
obj._when_value = when
|
|
obj._depth = 0
|
|
obj._function_name = "whenthen"
|
|
obj._evaluate_output_names = getattr(
|
|
then, "_evaluate_output_names", lambda _df: ["literal"]
|
|
)
|
|
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
|
obj._implementation = when._implementation
|
|
obj._backend_version = when._backend_version
|
|
obj._version = when._version
|
|
obj._call_kwargs = {}
|
|
return obj
|
|
|
|
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
|
|
self._when_value._otherwise_value = otherwise
|
|
self._function_name = "whenotherwise"
|
|
return cast("ExprT", self)
|
|
|
|
|
|
class EagerWhen(
|
|
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
|
|
Protocol38[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
|
|
):
|
|
def _if_then_else(
|
|
self,
|
|
when: NativeSeriesT,
|
|
then: NativeSeriesT,
|
|
otherwise: NativeSeriesT | Scalar | None,
|
|
/,
|
|
) -> NativeSeriesT: ...
|
|
|
|
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
|
|
is_expr = self._condition._is_expr
|
|
when: EagerSeriesT = self._condition(df)[0]
|
|
then: EagerSeriesT
|
|
if is_expr(self._then_value):
|
|
then = self._then_value(df)[0]
|
|
else:
|
|
then = when.alias("literal")._from_scalar(self._then_value)
|
|
then._broadcast = True
|
|
if is_expr(self._otherwise_value):
|
|
otherwise = df._extract_comparand(self._otherwise_value(df)[0])
|
|
else:
|
|
otherwise = self._otherwise_value
|
|
result = self._if_then_else(when.native, df._extract_comparand(then), otherwise)
|
|
return [then._with_native(result)]
|
|
|
|
|
|
class LazyWhen(
|
|
CompliantWhen[CompliantLazyFrameT, NativeExprT, LazyExprT],
|
|
Protocol38[CompliantLazyFrameT, NativeExprT, LazyExprT],
|
|
):
|
|
when: Callable[..., NativeExprT]
|
|
lit: Callable[..., NativeExprT]
|
|
|
|
def __call__(self: Self, df: CompliantLazyFrameT) -> Sequence[NativeExprT]:
|
|
is_expr = self._condition._is_expr
|
|
when = self.when
|
|
lit = self.lit
|
|
condition = df._evaluate_expr(self._condition)
|
|
then_ = self._then_value
|
|
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
|
|
other_ = self._otherwise_value
|
|
if other_ is None:
|
|
result = when(condition, then)
|
|
else:
|
|
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
|
|
result = when(condition, then).otherwise(otherwise) # type: ignore # noqa: PGH003
|
|
return [result]
|