from __future__ import annotations import operator from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Literal from typing import Sequence from typing import cast from narwhals._compliant import LazyExpr from narwhals._expression_parsing import ExprKind from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace from narwhals._spark_like.expr_list import SparkLikeExprListNamespace from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace from narwhals._spark_like.utils import WindowInputs from narwhals._spark_like.utils import import_functions from narwhals._spark_like.utils import import_native_dtypes from narwhals._spark_like.utils import import_window from narwhals._spark_like.utils import narwhals_to_native_dtype from narwhals.dependencies import get_pyspark from narwhals.utils import Implementation from narwhals.utils import not_implemented from narwhals.utils import parse_version if TYPE_CHECKING: from sqlframe.base.column import Column from sqlframe.base.window import Window from typing_extensions import Self from narwhals._compliant.typing import AliasNames from narwhals._expression_parsing import ExprMetadata from narwhals._spark_like.dataframe import SparkLikeLazyFrame from narwhals._spark_like.namespace import SparkLikeNamespace from narwhals._spark_like.typing import WindowFunction from narwhals.dtypes import DType from narwhals.utils import Version from narwhals.utils import _FullContext class SparkLikeExpr(LazyExpr["SparkLikeLazyFrame", "Column"]): def __init__( self: Self, call: Callable[[SparkLikeLazyFrame], Sequence[Column]], *, evaluate_output_names: Callable[[SparkLikeLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, implementation: Implementation, ) -> None: self._call = call self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version self._implementation = implementation self._window_function: WindowFunction | None = None self._metadata: ExprMetadata | None = None def __call__(self: Self, df: SparkLikeLazyFrame) -> Sequence[Column]: return self._call(df) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: if kind is ExprKind.LITERAL: return self def func(df: SparkLikeLazyFrame) -> Sequence[Column]: return [ result.over(df._Window().partitionBy(df._F.lit(1))) for result in self(df) ] return self.__class__( func, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) @property def _F(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202, N802 if TYPE_CHECKING: from sqlframe.base import functions return functions else: return import_functions(self._implementation) @property def _native_dtypes(self: Self): # type: ignore[no-untyped-def] # noqa: ANN202 if TYPE_CHECKING: from sqlframe.base import types return types else: return import_native_dtypes(self._implementation) @property def _Window(self: Self) -> type[Window]: # noqa: N802 if TYPE_CHECKING: from sqlframe.base.window import Window return Window else: return import_window(self._implementation) def __narwhals_expr__(self: Self) -> None: ... def __narwhals_namespace__(self: Self) -> SparkLikeNamespace: # pragma: no cover # Unused, just for compatibility with PandasLikeExpr from narwhals._spark_like.namespace import SparkLikeNamespace return SparkLikeNamespace( backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def _with_window_function( self: Self, window_function: WindowFunction, ) -> Self: result = self.__class__( self._call, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) result._window_function = window_function return result def _cum_window_func( self: Self, *, reverse: bool, func_name: Literal["sum", "max", "min", "count", "product"], ) -> WindowFunction: def func(window_inputs: WindowInputs) -> Column: if reverse: order_by_cols = [ self._F.col(x).desc_nulls_last() for x in window_inputs.order_by ] else: order_by_cols = [ self._F.col(x).asc_nulls_first() for x in window_inputs.order_by ] window = ( self._Window() .partitionBy(list(window_inputs.partition_by)) .orderBy(order_by_cols) .rowsBetween(self._Window().unboundedPreceding, 0) ) return getattr(self._F, func_name)(window_inputs.expr).over(window) return func def _rolling_window_func( self, *, func_name: Literal["sum", "mean", "std", "var"], center: bool, window_size: int, min_samples: int, ddof: int | None = None, ) -> WindowFunction: supported_funcs = ["sum", "mean", "std", "var"] if center: half = (window_size - 1) // 2 remainder = (window_size - 1) % 2 start = self._Window().currentRow - half - remainder end = self._Window().currentRow + half else: start = self._Window().currentRow - window_size + 1 end = self._Window().currentRow def func(window_inputs: WindowInputs) -> Column: window = ( self._Window() .partitionBy(list(window_inputs.partition_by)) .orderBy( [self._F.col(x).asc_nulls_first() for x in window_inputs.order_by] ) .rowsBetween(start, end) ) if func_name in {"sum", "mean"}: func_: str = func_name elif func_name == "var" and ddof == 0: func_ = "var_pop" elif func_name in "var" and ddof == 1: func_ = "var_samp" elif func_name == "std" and ddof == 0: func_ = "stddev_pop" elif func_name == "std" and ddof == 1: func_ = "stddev_samp" elif func_name in {"var", "std"}: # pragma: no cover msg = f"Only ddof=0 and ddof=1 are currently supported for rolling_{func_name}." raise ValueError(msg) else: # pragma: no cover msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}." raise ValueError(msg) return self._F.when( self._F.count(window_inputs.expr).over(window) >= min_samples, getattr(self._F, func_)(window_inputs.expr).over(window), ) return func @classmethod def from_column_names( cls: type[Self], evaluate_column_names: Callable[[SparkLikeLazyFrame], Sequence[str]], /, *, context: _FullContext, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: return [df._F.col(col_name) for col_name in evaluate_column_names(df)] return cls( func, evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=context._backend_version, version=context._version, implementation=context._implementation, ) @classmethod def from_column_indices( cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: columns = df.columns return [df._F.col(columns[i]) for i in column_indices] return cls( func, evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, backend_version=context._backend_version, version=context._version, implementation=context._implementation, ) def _with_callable( self: Self, call: Callable[..., Column], /, **expressifiable_args: Self | Any, ) -> Self: def func(df: SparkLikeLazyFrame) -> list[Column]: native_series_list = self(df) lit = df._F.lit other_native_series = { key: df._evaluate_expr(value) if self._is_expr(value) else lit(value) for key, value in expressifiable_args.items() } return [ call(native_series, **other_native_series) for native_series in native_series_list ] return self.__class__( func, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: return type(self)( call=self._call, evaluate_output_names=self._evaluate_output_names, alias_output_names=func, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def __eq__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__eq__(other), other=other ) def __ne__(self: Self, other: SparkLikeExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__ne__(other), other=other ) def __add__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__add__(other), other=other ) def __sub__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__sub__(other), other=other ) def __rsub__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: other.__sub__(_input), other=other ).alias("literal") def __mul__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__mul__(other), other=other ) def __truediv__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__truediv__(other), other=other ) def __rtruediv__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: other.__truediv__(_input), other=other ).alias("literal") def __floordiv__(self: Self, other: SparkLikeExpr) -> Self: def _floordiv(_input: Column, other: Column) -> Column: return self._F.floor(_input / other) return self._with_callable(_floordiv, other=other) def __rfloordiv__(self: Self, other: SparkLikeExpr) -> Self: def _rfloordiv(_input: Column, other: Column) -> Column: return self._F.floor(other / _input) return self._with_callable(_rfloordiv, other=other).alias("literal") def __pow__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__pow__(other), other=other ) def __rpow__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: other.__pow__(_input), other=other ).alias("literal") def __mod__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__mod__(other), other=other ) def __rmod__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: other.__mod__(_input), other=other ).alias("literal") def __ge__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__ge__(other), other=other ) def __gt__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable(lambda _input, other: _input > other, other=other) def __le__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__le__(other), other=other ) def __lt__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__lt__(other), other=other ) def __and__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__and__(other), other=other ) def __or__(self: Self, other: SparkLikeExpr) -> Self: return self._with_callable( lambda _input, other: _input.__or__(other), other=other ) def __invert__(self: Self) -> Self: invert = cast("Callable[..., Column]", operator.invert) return self._with_callable(invert) def abs(self: Self) -> Self: return self._with_callable(self._F.abs) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: if len(names) != 1: msg = f"Expected function with single output, found output names: {names}" raise ValueError(msg) return [name] return self.__class__( self._call, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def all(self: Self) -> Self: return self._with_callable(self._F.bool_and) def any(self: Self) -> Self: return self._with_callable(self._F.bool_or) def cast(self: Self, dtype: DType | type[DType]) -> Self: def _cast(_input: Column) -> Column: spark_dtype = narwhals_to_native_dtype( dtype, self._version, self._native_dtypes ) return _input.cast(spark_dtype) return self._with_callable(_cast) def count(self: Self) -> Self: return self._with_callable(self._F.count) def max(self: Self) -> Self: return self._with_callable(self._F.max) def mean(self: Self) -> Self: return self._with_callable(self._F.mean) def median(self: Self) -> Self: def _median(_input: Column) -> Column: if ( self._implementation.is_pyspark() and (pyspark := get_pyspark()) is not None and parse_version(pyspark) < (3, 4) ): # pragma: no cover # Use percentile_approx with default accuracy parameter (10000) return self._F.percentile_approx(_input.cast("double"), 0.5) return self._F.median(_input) return self._with_callable(_median) def min(self: Self) -> Self: return self._with_callable(self._F.min) def null_count(self: Self) -> Self: def _null_count(_input: Column) -> Column: return self._F.count_if(self._F.isnull(_input)) return self._with_callable(_null_count) def sum(self: Self) -> Self: return self._with_callable(self._F.sum) def std(self: Self, ddof: int) -> Self: from functools import partial import numpy as np # ignore-banned-import from narwhals._spark_like.utils import _std func = partial( _std, ddof=ddof, np_version=parse_version(np), functions=self._F, implementation=self._implementation, ) return self._with_callable(func) def var(self: Self, ddof: int) -> Self: from functools import partial import numpy as np # ignore-banned-import from narwhals._spark_like.utils import _var func = partial( _var, ddof=ddof, np_version=parse_version(np), functions=self._F, implementation=self._implementation, ) return self._with_callable(func) def clip( self: Self, lower_bound: Any | None = None, upper_bound: Any | None = None, ) -> Self: def _clip_lower(_input: Column, lower_bound: Column) -> Column: result = _input return self._F.when(result < lower_bound, lower_bound).otherwise(result) def _clip_upper(_input: Column, upper_bound: Column) -> Column: result = _input return self._F.when(result > upper_bound, upper_bound).otherwise(result) def _clip_both( _input: Column, lower_bound: Column, upper_bound: Column ) -> Column: return ( self._F.when(_input < lower_bound, lower_bound) .when(_input > upper_bound, upper_bound) .otherwise(_input) ) if lower_bound is None: return self._with_callable(_clip_upper, upper_bound=upper_bound) if upper_bound is None: return self._with_callable(_clip_lower, lower_bound=lower_bound) return self._with_callable( _clip_both, lower_bound=lower_bound, upper_bound=upper_bound ) def is_finite(self: Self) -> Self: def _is_finite(_input: Column) -> Column: # A value is finite if it's not NaN, and not infinite, while NULLs should be # preserved is_finite_condition = ( ~self._F.isnan(_input) & (_input != self._F.lit(float("inf"))) & (_input != self._F.lit(float("-inf"))) ) return self._F.when(~self._F.isnull(_input), is_finite_condition).otherwise( None ) return self._with_callable(_is_finite) def is_in(self: Self, values: Sequence[Any]) -> Self: def _is_in(_input: Column) -> Column: return _input.isin(values) if values else self._F.lit(False) # noqa: FBT003 return self._with_callable(_is_in) def is_unique(self: Self) -> Self: def _is_unique(_input: Column) -> Column: # Create a window spec that treats each value separately return self._F.count("*").over(self._Window.partitionBy(_input)) == 1 return self._with_callable(_is_unique) def len(self: Self) -> Self: def _len(_input: Column) -> Column: # Use count(*) to count all rows including nulls return self._F.count("*") return self._with_callable(_len) def round(self: Self, decimals: int) -> Self: def _round(_input: Column) -> Column: return self._F.round(_input, decimals) return self._with_callable(_round) def skew(self: Self) -> Self: return self._with_callable(self._F.skewness) def n_unique(self: Self) -> Self: def _n_unique(_input: Column) -> Column: return self._F.count_distinct(_input) + self._F.max( self._F.isnull(_input).cast(self._native_dtypes.IntegerType()) ) return self._with_callable(_n_unique) def over( self: Self, partition_by: Sequence[str], order_by: Sequence[str] | None, ) -> Self: if (window_function := self._window_function) is not None: assert order_by is not None # noqa: S101 def func(df: SparkLikeLazyFrame) -> list[Column]: return [ window_function(WindowInputs(expr, partition_by, order_by)) for expr in self._call(df) ] else: def func(df: SparkLikeLazyFrame) -> list[Column]: return [ expr.over(self._Window.partitionBy(*partition_by)) for expr in self._call(df) ] return self.__class__( func, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, implementation=self._implementation, ) def is_null(self: Self) -> Self: return self._with_callable(self._F.isnull) def is_nan(self: Self) -> Self: def _is_nan(_input: Column) -> Column: return self._F.when(self._F.isnull(_input), None).otherwise( self._F.isnan(_input) ) return self._with_callable(_is_nan) def shift(self, n: int) -> Self: def func(window_inputs: WindowInputs) -> Column: order_by_cols = [ self._F.col(x).asc_nulls_first() for x in window_inputs.order_by ] window = ( self._Window() .partitionBy(list(window_inputs.partition_by)) .orderBy(order_by_cols) ) return self._F.lag(window_inputs.expr, n).over(window) return self._with_window_function(func) def is_first_distinct(self) -> Self: def func(window_inputs: WindowInputs) -> Column: order_by_cols = [ self._F.col(x).asc_nulls_first() for x in window_inputs.order_by ] window = ( self._Window() .partitionBy([*window_inputs.partition_by, window_inputs.expr]) .orderBy(order_by_cols) ) return self._F.row_number().over(window) == 1 return self._with_window_function(func) def is_last_distinct(self) -> Self: def func(window_inputs: WindowInputs) -> Column: order_by_cols = [ self._F.col(x).desc_nulls_last() for x in window_inputs.order_by ] window = ( self._Window() .partitionBy([*window_inputs.partition_by, window_inputs.expr]) .orderBy(order_by_cols) ) return self._F.row_number().over(window) == 1 return self._with_window_function(func) def diff(self) -> Self: def func(window_inputs: WindowInputs) -> Column: order_by_cols = [ self._F.col(x).asc_nulls_first() for x in window_inputs.order_by ] window = ( self._Window() .partitionBy(list(window_inputs.partition_by)) .orderBy(order_by_cols) ) return window_inputs.expr - self._F.lag(window_inputs.expr).over(window) return self._with_window_function(func) def cum_sum(self, *, reverse: bool) -> Self: return self._with_window_function( self._cum_window_func(reverse=reverse, func_name="sum") ) def cum_max(self, *, reverse: bool) -> Self: return self._with_window_function( self._cum_window_func(reverse=reverse, func_name="max") ) def cum_min(self, *, reverse: bool) -> Self: return self._with_window_function( self._cum_window_func(reverse=reverse, func_name="min") ) def cum_count(self, *, reverse: bool) -> Self: return self._with_window_function( self._cum_window_func(reverse=reverse, func_name="count") ) def cum_prod(self, *, reverse: bool) -> Self: return self._with_window_function( self._cum_window_func(reverse=reverse, func_name="product") ) def fill_null( self, value: Any | None, strategy: Literal["forward", "backward"] | None, limit: int | None, ) -> Self: if strategy is not None: msg = "Support for strategies is not yet implemented." raise NotImplementedError(msg) def _fill_null(_input: Column, value: Column) -> Column: return self._F.ifnull(_input, value) return self._with_callable(_fill_null, value=value) def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_window_function( self._rolling_window_func( func_name="sum", center=center, window_size=window_size, min_samples=min_samples, ) ) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_window_function( self._rolling_window_func( func_name="mean", center=center, window_size=window_size, min_samples=min_samples, ) ) def rolling_var( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: return self._with_window_function( self._rolling_window_func( func_name="var", center=center, window_size=window_size, min_samples=min_samples, ddof=ddof, ) ) def rolling_std( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: return self._with_window_function( self._rolling_window_func( func_name="std", center=center, window_size=window_size, min_samples=min_samples, ddof=ddof, ) ) def rank( self, method: Literal["average", "min", "max", "dense", "ordinal"], *, descending: bool, ) -> Self: if method == "min": func_name = "rank" elif method == "dense": func_name = "dense_rank" else: # pragma: no cover msg = f"Method {method} is not yet implemented." raise NotImplementedError(msg) def _rank(_input: Column) -> Column: if descending: order_by_cols = [self._F.desc_nulls_last(_input)] else: order_by_cols = [self._F.asc_nulls_last(_input)] window = self._Window().orderBy(order_by_cols) return self._F.when( _input.isNotNull(), getattr(self._F, func_name)().over(window) ) return self._with_callable(_rank) @property def str(self: Self) -> SparkLikeExprStringNamespace: return SparkLikeExprStringNamespace(self) @property def dt(self: Self) -> SparkLikeExprDateTimeNamespace: return SparkLikeExprDateTimeNamespace(self) @property def list(self: Self) -> SparkLikeExprListNamespace: return SparkLikeExprListNamespace(self) @property def struct(self: Self) -> SparkLikeExprStructNamespace: return SparkLikeExprStructNamespace(self) drop_nulls = not_implemented() unique = not_implemented() quantile = not_implemented()