from __future__ import annotations import warnings from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Literal from typing import Sequence from narwhals._compliant import LazyExpr from narwhals._compliant.expr import DepthTrackingExpr from narwhals._dask.expr_dt import DaskExprDateTimeNamespace from narwhals._dask.expr_str import DaskExprStringNamespace from narwhals._dask.utils import add_row_index from narwhals._dask.utils import maybe_evaluate_expr from narwhals._dask.utils import narwhals_to_native_dtype from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name from narwhals.utils import not_implemented if TYPE_CHECKING: from narwhals._compliant.typing import AliasNames from narwhals._expression_parsing import ExprKind try: import dask.dataframe.dask_expr as dx except ModuleNotFoundError: import dask_expr as dx from typing_extensions import Self from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import ExprMetadata from narwhals.dtypes import DType from narwhals.utils import Version from narwhals.utils import _FullContext class DaskExpr( LazyExpr["DaskLazyFrame", "dx.Series"], DepthTrackingExpr["DaskLazyFrame", "dx.Series"], ): _implementation: Implementation = Implementation.DASK def __init__( self: Self, call: Callable[[DaskLazyFrame], Sequence[dx.Series]], *, depth: int, function_name: str, evaluate_output_names: Callable[[DaskLazyFrame], Sequence[str]], alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None, backend_version: tuple[int, ...], version: Version, # Kwargs with metadata which we may need in group-by agg # (e.g. `ddof` for `std` and `var`). call_kwargs: dict[str, Any] | None = None, ) -> None: self._call = call self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version self._call_kwargs = call_kwargs or {} self._metadata: ExprMetadata | None = None def __call__(self: Self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover # Unused, just for compatibility with PandasLikeExpr from narwhals._dask.namespace import DaskNamespace return DaskNamespace(backend_version=self._backend_version, version=self._version) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [result[0] for result in self(df)] return self.__class__( func, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, call_kwargs=self._call_kwargs, ) @classmethod def from_column_names( cls: type[Self], evaluate_column_names: Callable[[DaskLazyFrame], Sequence[str]], /, *, context: _FullContext, function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: return [ df._native_frame[column_name] for column_name in evaluate_column_names(df) ] except KeyError as e: missing_columns = [ x for x in evaluate_column_names(df) if x not in df.columns ] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns=missing_columns, available_columns=df.columns, ) from e return cls( func, depth=0, function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=context._backend_version, version=context._version, ) @classmethod def from_column_indices( cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [ df._native_frame.iloc[:, column_index] for column_index in column_indices ] return cls( func, depth=0, function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, backend_version=context._backend_version, version=context._version, ) def _with_callable( self: Self, # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, expr_name: str = "", call_kwargs: dict[str, Any] | None = None, **expressifiable_args: Self | Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { key: maybe_evaluate_expr(df, value) for key, value in expressifiable_args.items() } for native_series in native_series_list: result_native = call(native_series, **other_native_series) native_results.append(result_native) return native_results return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, call_kwargs=call_kwargs, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: return type(self)( call=self._call, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=func, backend_version=self._backend_version, version=self._version, call_kwargs=self._call_kwargs, ) def alias(self: Self, name: str) -> Self: def alias_output_names(names: Sequence[str]) -> Sequence[str]: if len(names) != 1: msg = f"Expected function with single output, found output names: {names}" raise ValueError(msg) return [name] return self.__class__( self._call, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=alias_output_names, backend_version=self._backend_version, version=self._version, call_kwargs=self._call_kwargs, ) def __add__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__add__(other), "__add__", other=other ) def __sub__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__sub__(other), "__sub__", other=other ) def __rsub__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: other - _input, "__rsub__", other=other ).alias("literal") def __mul__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__mul__(other), "__mul__", other=other ) def __truediv__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__truediv__(other), "__truediv__", other=other ) def __rtruediv__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: other / _input, "__rtruediv__", other=other ).alias("literal") def __floordiv__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__floordiv__(other), "__floordiv__", other=other ) def __rfloordiv__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: other // _input, "__rfloordiv__", other=other ).alias("literal") def __pow__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__pow__(other), "__pow__", other=other ) def __rpow__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: other**_input, "__rpow__", other=other ).alias("literal") def __mod__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__mod__(other), "__mod__", other=other ) def __rmod__(self: Self, other: Any) -> Self: return self._with_callable( lambda _input, other: other % _input, "__rmod__", other=other ).alias("literal") def __eq__(self: Self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__eq__(other), "__eq__", other=other ) def __ne__(self: Self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__ne__(other), "__ne__", other=other ) def __ge__(self: Self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda _input, other: _input.__ge__(other), "__ge__", other=other ) def __gt__(self: Self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__gt__(other), "__gt__", other=other ) def __le__(self: Self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__le__(other), "__le__", other=other ) def __lt__(self: Self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__lt__(other), "__lt__", other=other ) def __and__(self: Self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda _input, other: _input.__and__(other), "__and__", other=other ) def __or__(self: Self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__or__(other), "__or__", other=other ) def __invert__(self: Self) -> Self: return self._with_callable(lambda _input: _input.__invert__(), "__invert__") def mean(self: Self) -> Self: return self._with_callable(lambda _input: _input.mean().to_series(), "mean") def median(self: Self) -> Self: from narwhals.exceptions import InvalidOperationError def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK) if not dtype.is_numeric(): msg = "`median` operation not supported for non-numeric input type." raise InvalidOperationError(msg) return s.median_approximate().to_series() return self._with_callable(func, "median") def min(self: Self) -> Self: return self._with_callable(lambda _input: _input.min().to_series(), "min") def max(self: Self) -> Self: return self._with_callable(lambda _input: _input.max().to_series(), "max") def std(self: Self, ddof: int) -> Self: return self._with_callable( lambda _input: _input.std(ddof=ddof).to_series(), "std", call_kwargs={"ddof": ddof}, ) def var(self: Self, ddof: int) -> Self: return self._with_callable( lambda _input: _input.var(ddof=ddof).to_series(), "var", call_kwargs={"ddof": ddof}, ) def skew(self: Self) -> Self: return self._with_callable(lambda _input: _input.skew().to_series(), "skew") def shift(self: Self, n: int) -> Self: return self._with_callable(lambda _input: _input.shift(n), "shift") def cum_sum(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover # https://github.com/dask/dask/issues/11802 msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cumsum(), "cum_sum") def cum_count(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_count(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable( lambda _input: (~_input.isna()).astype(int).cumsum(), "cum_count" ) def cum_min(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_min(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cummin(), "cum_min") def cum_max(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_max(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cummax(), "cum_max") def cum_prod(self: Self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_prod(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cumprod(), "cum_prod") def rolling_sum( self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).sum(), "rolling_sum", ) def rolling_mean( self: Self, window_size: int, *, min_samples: int, center: bool ) -> Self: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).mean(), "rolling_mean", ) def rolling_var( self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).var(), "rolling_var", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_var`" raise NotImplementedError(msg) def rolling_std( self: Self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).std(), "rolling_std", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_std`" raise NotImplementedError(msg) def sum(self: Self) -> Self: return self._with_callable(lambda _input: _input.sum().to_series(), "sum") def count(self: Self) -> Self: return self._with_callable(lambda _input: _input.count().to_series(), "count") def round(self: Self, decimals: int) -> Self: return self._with_callable(lambda _input: _input.round(decimals), "round") def unique(self: Self) -> Self: return self._with_callable(lambda _input: _input.unique(), "unique") def drop_nulls(self: Self) -> Self: return self._with_callable(lambda _input: _input.dropna(), "drop_nulls") def abs(self: Self) -> Self: return self._with_callable(lambda _input: _input.abs(), "abs") def all(self: Self) -> Self: return self._with_callable( lambda _input: _input.all( axis=None, skipna=True, split_every=False, out=None ).to_series(), "all", ) def any(self: Self) -> Self: return self._with_callable( lambda _input: _input.any(axis=0, skipna=True, split_every=False).to_series(), "any", ) def fill_null( self: Self, value: Self | Any | None, strategy: Literal["forward", "backward"] | None, limit: int | None, ) -> DaskExpr: def func(_input: dx.Series) -> dx.Series: if value is not None: res_ser = _input.fillna(value) else: res_ser = ( _input.ffill(limit=limit) if strategy == "forward" else _input.bfill(limit=limit) ) return res_ser return self._with_callable(func, "fillna") def clip( self: Self, lower_bound: Self | Any | None, upper_bound: Self | Any | None, ) -> Self: return self._with_callable( lambda _input, lower_bound, upper_bound: _input.clip( lower=lower_bound, upper=upper_bound ), "clip", lower_bound=lower_bound, upper_bound=upper_bound, ) def diff(self: Self) -> Self: return self._with_callable(lambda _input: _input.diff(), "diff") def n_unique(self: Self) -> Self: return self._with_callable( lambda _input: _input.nunique(dropna=False).to_series(), "n_unique" ) def is_null(self: Self) -> Self: return self._with_callable(lambda _input: _input.isna(), "is_null") def is_nan(self: Self) -> Self: def func(_input: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( _input.dtype, self._version, self._implementation ) if dtype.is_numeric(): return _input != _input # pyright: ignore[reportReturnType] # noqa: PLR0124 msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) return self._with_callable(func, "is_null") def len(self: Self) -> Self: return self._with_callable(lambda _input: _input.size.to_series(), "len") def quantile( self: Self, quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Self: if interpolation == "linear": def func(_input: dx.Series, quantile: float) -> dx.Series: if _input.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) return _input.quantile( q=quantile, method="dask" ).to_series() # pragma: no cover return self._with_callable(func, "quantile", quantile=quantile) else: msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." raise NotImplementedError(msg) def is_first_distinct(self: Self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index( _input.to_frame(), col_token, self._backend_version, self._implementation ) first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] return frame[col_token].isin(first_distinct_index) return self._with_callable(func, "is_first_distinct") def is_last_distinct(self: Self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index( _input.to_frame(), col_token, self._backend_version, self._implementation ) last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] return frame[col_token].isin(last_distinct_index) return self._with_callable(func, "is_last_distinct") def is_unique(self: Self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name return ( _input.to_frame() .groupby(_name, dropna=False) .transform("size", meta=(_name, int)) == 1 ) return self._with_callable(func, "is_unique") def is_in(self: Self, other: Any) -> Self: return self._with_callable(lambda _input: _input.isin(other), "is_in") def null_count(self: Self) -> Self: return self._with_callable( lambda _input: _input.isna().sum().to_series(), "null_count" ) def over( self: Self, partition_by: Sequence[str], order_by: Sequence[str] | None, ) -> Self: # pandas is a required dependency of dask so it's safe to import this from narwhals._pandas_like.group_by import PandasLikeGroupBy if not partition_by: assert order_by is not None # help type checkers # noqa: S101 # This is something like `nw.col('a').cum_sum().order_by(key)` # which we can always easily support, as it doesn't require grouping. def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self(df.sort(*order_by, descending=False, nulls_last=False)) elif not self._is_elementary(): # pragma: no cover msg = ( "Only elementary expressions are supported for `.over` in dask.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) else: function_name = PandasLikeGroupBy._leaf_name(self) try: dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: # window functions are unsupported: https://github.com/dask/dask/issues/11806 msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" ) raise NotImplementedError(msg) from None def func(df: DaskLazyFrame) -> Sequence[dx.Series]: output_names, aliases = evaluate_output_names_and_aliases(self, df, []) with warnings.catch_warnings(): # https://github.com/dask/dask/issues/11804 warnings.filterwarnings( "ignore", message=".*`meta` is not specified", category=UserWarning, ) res_native = df.native.groupby(partition_by)[ list(output_names) ].transform(dask_function_name, **self._call_kwargs) result_frame = df._with_native( res_native.rename(columns=dict(zip(output_names, aliases))) ).native return [result_frame[name] for name in aliases] return self.__class__( func, depth=self._depth + 1, function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, ) def cast(self: Self, dtype: DType | type[DType]) -> Self: def func(_input: dx.Series) -> dx.Series: native_dtype = narwhals_to_native_dtype(dtype, self._version) return _input.astype(native_dtype) return self._with_callable(func, "cast") def is_finite(self: Self) -> Self: import dask.array as da return self._with_callable(da.isfinite, "is_finite") @property def str(self: Self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) @property def dt(self: Self) -> DaskExprDateTimeNamespace: return DaskExprDateTimeNamespace(self) list = not_implemented() # pyright: ignore[reportAssignmentType] struct = not_implemented() # pyright: ignore[reportAssignmentType] rank = not_implemented() # pyright: ignore[reportAssignmentType]