# Utilities for expression parsing # Useful for backends which don't have any concept of expressions, such # and pandas or PyArrow. from __future__ import annotations from enum import Enum from enum import auto from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Literal from typing import Sequence from typing import TypeVar from typing import cast from narwhals.dependencies import is_narwhals_series from narwhals.dependencies import is_numpy_array from narwhals.exceptions import LengthChangingExprError from narwhals.exceptions import MultiOutputExpressionError from narwhals.exceptions import ShapeError from narwhals.utils import is_compliant_expr if TYPE_CHECKING: from typing_extensions import Never from typing_extensions import TypeIs from narwhals._compliant import CompliantExpr from narwhals._compliant import CompliantFrameT from narwhals._compliant.typing import AliasNames from narwhals._compliant.typing import CompliantExprAny from narwhals._compliant.typing import CompliantFrameAny from narwhals._compliant.typing import CompliantNamespaceAny from narwhals._compliant.typing import EagerNamespaceAny from narwhals._compliant.typing import EvalNames from narwhals.expr import Expr from narwhals.typing import IntoExpr from narwhals.typing import NonNestedLiteral from narwhals.typing import _1DArray T = TypeVar("T") def is_expr(obj: Any) -> TypeIs[Expr]: """Check whether `obj` is a Narwhals Expr.""" from narwhals.expr import Expr return isinstance(obj, Expr) def combine_evaluate_output_names( *exprs: CompliantExpr[CompliantFrameT, Any], ) -> EvalNames[CompliantFrameT]: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the # first name of `expr1`. if not is_compliant_expr(exprs[0]): # pragma: no cover msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug." raise AssertionError(msg) def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]: return exprs[0]._evaluate_output_names(df)[:1] return evaluate_output_names def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None: # Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the # aliasing function of `expr1` and apply it to the first output name of `expr1`. if exprs[0]._alias_output_names is None: return None def alias_output_names(names: Sequence[str]) -> Sequence[str]: return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc] return alias_output_names def extract_compliant( plx: CompliantNamespaceAny, other: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool, ) -> CompliantExprAny | NonNestedLiteral: if is_expr(other): return other._to_compliant_expr(plx) if isinstance(other, str) and not str_as_lit: return plx.col(other) if is_narwhals_series(other): return other._compliant_series._to_expr() if is_numpy_array(other): ns = cast("EagerNamespaceAny", plx) return ns._series.from_numpy(other, context=ns)._to_expr() return other def evaluate_output_names_and_aliases( expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str] ) -> tuple[Sequence[str], Sequence[str]]: output_names = expr._evaluate_output_names(df) aliases = ( output_names if expr._alias_output_names is None else expr._alias_output_names(output_names) ) if exclude: assert expr._metadata is not None # noqa: S101 if expr._metadata.expansion_kind.is_multi_unnamed(): output_names, aliases = zip( *[ (x, alias) for x, alias in zip(output_names, aliases) if x not in exclude ] ) return output_names, aliases class ExprKind(Enum): """Describe which kind of expression we are dealing with. Commutative composition rules are: - LITERAL vs LITERAL -> LITERAL - FILTRATION vs (LITERAL | AGGREGATION) -> FILTRATION - FILTRATION vs (FILTRATION | TRANSFORM | WINDOW) -> raise - (TRANSFORM | WINDOW) vs (...) -> TRANSFORM - AGGREGATION vs (LITERAL | AGGREGATION) -> AGGREGATION """ LITERAL = auto() """e.g. `nw.lit(1)`""" AGGREGATION = auto() """e.g. `nw.col('a').mean()`""" TRANSFORM = auto() """preserves length, e.g. `nw.col('a').round()`""" WINDOW = auto() """transform in which last node is order-dependent examples: - `nw.col('a').cum_sum()` - `(nw.col('a')+1).cum_sum()` non-examples: - `nw.col('a').cum_sum()+1` - `nw.col('a').cum_sum().mean()` """ FILTRATION = auto() """e.g. `nw.col('a').drop_nulls()`""" def preserves_length(self) -> bool: return self in {ExprKind.TRANSFORM, ExprKind.WINDOW} def is_window(self) -> bool: return self is ExprKind.WINDOW def is_filtration(self) -> bool: return self is ExprKind.FILTRATION def is_scalar_like(self) -> bool: return is_scalar_like(self) @classmethod def from_into_expr( cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool ) -> ExprKind: if is_expr(obj): return obj._metadata.kind if ( is_narwhals_series(obj) or is_numpy_array(obj) or (isinstance(obj, str) and not str_as_lit) ): return ExprKind.TRANSFORM return ExprKind.LITERAL def is_scalar_like( kind: ExprKind, ) -> TypeIs[Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]]: # Like ExprKind.is_scalar_like, but uses TypeIs for better type checking. return kind in {ExprKind.AGGREGATION, ExprKind.LITERAL} class ExpansionKind(Enum): """Describe what kind of expansion the expression performs.""" SINGLE = auto() """e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`""" MULTI_NAMED = auto() """e.g. `nw.col('a', 'b')`""" MULTI_UNNAMED = auto() """e.g. `nw.all()`, nw.nth(0, 1)""" def is_multi_unnamed(self) -> bool: return self is ExpansionKind.MULTI_UNNAMED def is_multi_output(self) -> bool: return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED} def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]: if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED: # e.g. nw.selectors.all() - nw.selectors.numeric(). return ExpansionKind.MULTI_UNNAMED # Don't attempt anything more complex, keep it simple and raise in the face of ambiguity. msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover raise AssertionError(msg) # pragma: no cover class WindowKind(Enum): """Describe what kind of window the expression contains.""" NONE = auto() """e.g. `nw.col('a').abs()`, no windows.""" CLOSEABLE = auto() """e.g. `nw.col('a').cum_sum()` - can be closed if immediately followed by `over(order_by=...)`.""" UNCLOSEABLE = auto() """e.g. `nw.col('a').cum_sum().abs()` - the window function (`cum_sum`) wasn't immediately followed by `over(order_by=...)`, and so the window is uncloseable. Uncloseable windows can be used freely in `nw.DataFrame`, but not in `nw.LazyFrame` where row-order is undefined.""" CLOSED = auto() """e.g. `nw.col('a').cum_sum().over(order_by='i')`.""" def is_open(self) -> bool: return self in {WindowKind.UNCLOSEABLE, WindowKind.CLOSEABLE} def is_closed(self) -> bool: return self is WindowKind.CLOSED def is_uncloseable(self) -> bool: return self is WindowKind.UNCLOSEABLE class ExprMetadata: __slots__ = ("_expansion_kind", "_kind", "_window_kind") def __init__( self, kind: ExprKind, /, *, window_kind: WindowKind, expansion_kind: ExpansionKind, ) -> None: self._kind: ExprKind = kind self._window_kind = window_kind self._expansion_kind = expansion_kind def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover msg = f"Cannot subclass {cls.__name__!r}" raise TypeError(msg) def __repr__(self) -> str: return f"ExprMetadata(kind: {self._kind}, window_kind: {self._window_kind}, expansion_kind: {self._expansion_kind})" @property def kind(self) -> ExprKind: return self._kind @property def window_kind(self) -> WindowKind: return self._window_kind @property def expansion_kind(self) -> ExpansionKind: return self._expansion_kind def with_kind(self, kind: ExprKind, /) -> ExprMetadata: """Change metadata kind, leaving all other attributes the same.""" return ExprMetadata( kind, window_kind=self._window_kind, expansion_kind=self._expansion_kind, ) def with_uncloseable_window(self) -> ExprMetadata: """Add uncloseable window, leaving other attributes the same.""" if self._window_kind is WindowKind.CLOSED: # pragma: no cover msg = "Unreachable code, please report a bug." raise AssertionError(msg) return ExprMetadata( self.kind, window_kind=WindowKind.UNCLOSEABLE, expansion_kind=self._expansion_kind, ) def with_kind_and_closeable_window(self, kind: ExprKind, /) -> ExprMetadata: """Change metadata kind and add closeable window. If we already have an uncloseable window, the window stays uncloseable. """ if self._window_kind is WindowKind.NONE: window_kind = WindowKind.CLOSEABLE elif self._window_kind is WindowKind.CLOSED: # pragma: no cover msg = "Unreachable code, please report a bug." raise AssertionError(msg) else: window_kind = WindowKind.UNCLOSEABLE return ExprMetadata( kind, window_kind=window_kind, expansion_kind=self._expansion_kind, ) def with_kind_and_uncloseable_window(self, kind: ExprKind, /) -> ExprMetadata: """Change metadata kind and set window kind to uncloseable.""" return ExprMetadata( kind, window_kind=WindowKind.UNCLOSEABLE, expansion_kind=self._expansion_kind, ) @staticmethod def selector_single() -> ExprMetadata: # e.g. `nw.col('a')`, `nw.nth(0)` return ExprMetadata( ExprKind.TRANSFORM, window_kind=WindowKind.NONE, expansion_kind=ExpansionKind.SINGLE, ) @staticmethod def selector_multi_named() -> ExprMetadata: # e.g. `nw.col('a', 'b')` return ExprMetadata( ExprKind.TRANSFORM, window_kind=WindowKind.NONE, expansion_kind=ExpansionKind.MULTI_NAMED, ) @staticmethod def selector_multi_unnamed() -> ExprMetadata: # e.g. `nw.all()` return ExprMetadata( ExprKind.TRANSFORM, window_kind=WindowKind.NONE, expansion_kind=ExpansionKind.MULTI_UNNAMED, ) @classmethod def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata: # We may be able to allow multi-output rhs in the future: # https://github.com/narwhals-dev/narwhals/issues/2244. return combine_metadata( lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False ) @classmethod def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata: return combine_metadata( *exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True ) def combine_metadata( # noqa: C901, PLR0912, PLR0915 *args: IntoExpr | object | None, str_as_lit: bool, allow_multi_output: bool, to_single_output: bool, ) -> ExprMetadata: """Combine metadata from `args`. Arguments: args: Arguments, maybe expressions, literals, or Series. str_as_lit: Whether to interpret strings as literals or as column names. allow_multi_output: Whether to allow multi-output inputs. to_single_output: Whether the result is always single-output, regardless of the inputs (e.g. `nw.sum_horizontal`). """ n_filtrations = 0 has_transforms_or_windows = False has_aggregations = False has_literals = False result_expansion_kind = ExpansionKind.SINGLE has_closeable_windows = False has_uncloseable_windows = False has_closed_windows = False for i, arg in enumerate(args): # noqa: PLR1702 if isinstance(arg, str) and not str_as_lit: has_transforms_or_windows = True elif is_expr(arg): metadata = arg._metadata if metadata.expansion_kind.is_multi_output(): expansion_kind = metadata.expansion_kind if i > 0 and not allow_multi_output: # Left-most argument is always allowed to be multi-output. msg = ( "Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) " "are not supported in this context." ) raise MultiOutputExpressionError(msg) if not to_single_output: if i == 0: result_expansion_kind = expansion_kind else: result_expansion_kind = result_expansion_kind & expansion_kind kind = metadata.kind if kind is ExprKind.AGGREGATION: has_aggregations = True elif kind is ExprKind.LITERAL: has_literals = True elif kind is ExprKind.FILTRATION: n_filtrations += 1 elif kind.preserves_length(): has_transforms_or_windows = True else: # pragma: no cover msg = "unreachable code" raise AssertionError(msg) window_kind = metadata.window_kind if window_kind is WindowKind.UNCLOSEABLE: has_uncloseable_windows = True elif window_kind is WindowKind.CLOSEABLE: has_closeable_windows = True elif window_kind is WindowKind.CLOSED: has_closed_windows = True if ( has_literals and not has_aggregations and not has_transforms_or_windows and not n_filtrations ): result_kind = ExprKind.LITERAL elif n_filtrations > 1: msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation" raise LengthChangingExprError(msg) elif n_filtrations and has_transforms_or_windows: msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations" raise ShapeError(msg) elif n_filtrations: result_kind = ExprKind.FILTRATION elif has_transforms_or_windows: result_kind = ExprKind.TRANSFORM else: result_kind = ExprKind.AGGREGATION if has_uncloseable_windows or has_closeable_windows: result_window_kind = WindowKind.UNCLOSEABLE elif has_closed_windows: result_window_kind = WindowKind.CLOSED else: result_window_kind = WindowKind.NONE return ExprMetadata( result_kind, window_kind=result_window_kind, expansion_kind=result_expansion_kind ) def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None: # Raise if any argument in `args` isn't length-preserving. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. from narwhals.series import Series if not all( (is_expr(x) and x._metadata.kind.preserves_length()) or isinstance(x, (str, Series)) for x in args ): msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'." raise ShapeError(msg) def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool: # Raise if any argument in `args` isn't an aggregation or literal. # For Series input, we don't raise (yet), we let such checks happen later, # as this function works lazily and so can't evaluate lengths. exprs = chain(args, kwargs.values()) return all(is_expr(x) and x._metadata.kind.is_scalar_like() for x in exprs) def apply_n_ary_operation( plx: CompliantNamespaceAny, function: Any, *comparands: IntoExpr | NonNestedLiteral | _1DArray, str_as_lit: bool, ) -> CompliantExprAny: compliant_exprs = ( extract_compliant(plx, comparand, str_as_lit=str_as_lit) for comparand in comparands ) kinds = [ ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit) for comparand in comparands ] broadcast = any(not kind.is_scalar_like() for kind in kinds) compliant_exprs = ( compliant_expr.broadcast(kind) if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind) else compliant_expr for compliant_expr, kind in zip(compliant_exprs, kinds) ) return function(*compliant_exprs)