from __future__ import annotations import warnings from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Literal from typing import Sequence from narwhals._compliant import LazyExpr from narwhals._compliant.expr import DepthTrackingExpr from narwhals._dask.expr_dt import DaskExprDateTimeNamespace from narwhals._dask.expr_str import DaskExprStringNamespace from narwhals._dask.utils import add_row_index from narwhals._dask.utils import maybe_evaluate_expr from narwhals._dask.utils import narwhals_to_native_dtype from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import evaluate_output_names_and_aliases from narwhals._pandas_like.utils import native_to_narwhals_dtype from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import InvalidOperationError from narwhals.utils import Implementation from narwhals.utils import generate_temporary_column_name from narwhals.utils import not_implemented if TYPE_CHECKING: import dask.dataframe.dask_expr as dx from typing_extensions import Self from narwhals._compliant.typing import AliasNames from narwhals._compliant.typing import EvalNames from narwhals._compliant.typing import EvalSeries from narwhals._dask.dataframe import DaskLazyFrame from narwhals._dask.namespace import DaskNamespace from narwhals._expression_parsing import ExprKind from narwhals._expression_parsing import ExprMetadata from narwhals.dtypes import DType from narwhals.typing import FillNullStrategy from narwhals.typing import NonNestedLiteral from narwhals.typing import NumericLiteral from narwhals.typing import RollingInterpolationMethod from narwhals.typing import TemporalLiteral from narwhals.utils import Version from narwhals.utils import _FullContext class DaskExpr( LazyExpr["DaskLazyFrame", "dx.Series"], DepthTrackingExpr["DaskLazyFrame", "dx.Series"], ): _implementation: Implementation = Implementation.DASK def __init__( self, call: EvalSeries[DaskLazyFrame, dx.Series], *, depth: int, function_name: str, evaluate_output_names: EvalNames[DaskLazyFrame], alias_output_names: AliasNames | None, backend_version: tuple[int, ...], version: Version, # Kwargs with metadata which we may need in group-by agg # (e.g. `ddof` for `std` and `var`). call_kwargs: dict[str, Any] | None = None, ) -> None: self._call = call self._depth = depth self._function_name = function_name self._evaluate_output_names = evaluate_output_names self._alias_output_names = alias_output_names self._backend_version = backend_version self._version = version self._call_kwargs = call_kwargs or {} self._metadata: ExprMetadata | None = None def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]: return self._call(df) def __narwhals_expr__(self) -> None: ... def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover # Unused, just for compatibility with PandasLikeExpr from narwhals._dask.namespace import DaskNamespace return DaskNamespace(backend_version=self._backend_version, version=self._version) def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [result[0] for result in self(df)] return self.__class__( func, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, call_kwargs=self._call_kwargs, ) @classmethod def from_column_names( cls: type[Self], evaluate_column_names: EvalNames[DaskLazyFrame], /, *, context: _FullContext, function_name: str = "", ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: try: return [ df._native_frame[column_name] for column_name in evaluate_column_names(df) ] except KeyError as e: missing_columns = [ x for x in evaluate_column_names(df) if x not in df.columns ] raise ColumnNotFoundError.from_missing_and_available_column_names( missing_columns=missing_columns, available_columns=df.columns, ) from e return cls( func, depth=0, function_name=function_name, evaluate_output_names=evaluate_column_names, alias_output_names=None, backend_version=context._backend_version, version=context._version, ) @classmethod def from_column_indices( cls: type[Self], *column_indices: int, context: _FullContext ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: return [ df._native_frame.iloc[:, column_index] for column_index in column_indices ] return cls( func, depth=0, function_name="nth", evaluate_output_names=lambda df: [df.columns[i] for i in column_indices], alias_output_names=None, backend_version=context._backend_version, version=context._version, ) def _with_callable( self, # First argument to `call` should be `dx.Series` call: Callable[..., dx.Series], /, expr_name: str = "", call_kwargs: dict[str, Any] | None = None, **expressifiable_args: Self | Any, ) -> Self: def func(df: DaskLazyFrame) -> list[dx.Series]: native_results: list[dx.Series] = [] native_series_list = self._call(df) other_native_series = { key: maybe_evaluate_expr(df, value) for key, value in expressifiable_args.items() } for native_series in native_series_list: result_native = call(native_series, **other_native_series) native_results.append(result_native) return native_results return self.__class__( func, depth=self._depth + 1, function_name=f"{self._function_name}->{expr_name}", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, call_kwargs=call_kwargs, ) def _with_alias_output_names(self, func: AliasNames | None, /) -> Self: return type(self)( call=self._call, depth=self._depth, function_name=self._function_name, evaluate_output_names=self._evaluate_output_names, alias_output_names=func, backend_version=self._backend_version, version=self._version, call_kwargs=self._call_kwargs, ) def __add__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__add__(other), "__add__", other=other ) def __sub__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__sub__(other), "__sub__", other=other ) def __rsub__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: other - _input, "__rsub__", other=other ).alias("literal") def __mul__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__mul__(other), "__mul__", other=other ) def __truediv__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__truediv__(other), "__truediv__", other=other ) def __rtruediv__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: other / _input, "__rtruediv__", other=other ).alias("literal") def __floordiv__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__floordiv__(other), "__floordiv__", other=other ) def __rfloordiv__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: other // _input, "__rfloordiv__", other=other ).alias("literal") def __pow__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__pow__(other), "__pow__", other=other ) def __rpow__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: other**_input, "__rpow__", other=other ).alias("literal") def __mod__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: _input.__mod__(other), "__mod__", other=other ) def __rmod__(self, other: Any) -> Self: return self._with_callable( lambda _input, other: other % _input, "__rmod__", other=other ).alias("literal") def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__eq__(other), "__eq__", other=other ) def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override] return self._with_callable( lambda _input, other: _input.__ne__(other), "__ne__", other=other ) def __ge__(self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda _input, other: _input.__ge__(other), "__ge__", other=other ) def __gt__(self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__gt__(other), "__gt__", other=other ) def __le__(self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__le__(other), "__le__", other=other ) def __lt__(self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__lt__(other), "__lt__", other=other ) def __and__(self, other: DaskExpr | Any) -> Self: return self._with_callable( lambda _input, other: _input.__and__(other), "__and__", other=other ) def __or__(self, other: DaskExpr) -> Self: return self._with_callable( lambda _input, other: _input.__or__(other), "__or__", other=other ) def __invert__(self) -> Self: return self._with_callable(lambda _input: _input.__invert__(), "__invert__") def mean(self) -> Self: return self._with_callable(lambda _input: _input.mean().to_series(), "mean") def median(self) -> Self: from narwhals.exceptions import InvalidOperationError def func(s: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK) if not dtype.is_numeric(): msg = "`median` operation not supported for non-numeric input type." raise InvalidOperationError(msg) return s.median_approximate().to_series() return self._with_callable(func, "median") def min(self) -> Self: return self._with_callable(lambda _input: _input.min().to_series(), "min") def max(self) -> Self: return self._with_callable(lambda _input: _input.max().to_series(), "max") def std(self, ddof: int) -> Self: return self._with_callable( lambda _input: _input.std(ddof=ddof).to_series(), "std", call_kwargs={"ddof": ddof}, ) def var(self, ddof: int) -> Self: return self._with_callable( lambda _input: _input.var(ddof=ddof).to_series(), "var", call_kwargs={"ddof": ddof}, ) def skew(self) -> Self: return self._with_callable(lambda _input: _input.skew().to_series(), "skew") def shift(self, n: int) -> Self: return self._with_callable(lambda _input: _input.shift(n), "shift") def cum_sum(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover # https://github.com/dask/dask/issues/11802 msg = "`cum_sum(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cumsum(), "cum_sum") def cum_count(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_count(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable( lambda _input: (~_input.isna()).astype(int).cumsum(), "cum_count" ) def cum_min(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_min(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cummin(), "cum_min") def cum_max(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_max(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cummax(), "cum_max") def cum_prod(self, *, reverse: bool) -> Self: if reverse: # pragma: no cover msg = "`cum_prod(reverse=True)` is not supported with Dask backend" raise NotImplementedError(msg) return self._with_callable(lambda _input: _input.cumprod(), "cum_prod") def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).sum(), "rolling_sum", ) def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).mean(), "rolling_mean", ) def rolling_var( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).var(), "rolling_var", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_var`" raise NotImplementedError(msg) def rolling_std( self, window_size: int, *, min_samples: int, center: bool, ddof: int ) -> Self: if ddof == 1: return self._with_callable( lambda _input: _input.rolling( window=window_size, min_periods=min_samples, center=center ).std(), "rolling_std", ) else: msg = "Dask backend only supports `ddof=1` for `rolling_std`" raise NotImplementedError(msg) def sum(self) -> Self: return self._with_callable(lambda _input: _input.sum().to_series(), "sum") def count(self) -> Self: return self._with_callable(lambda _input: _input.count().to_series(), "count") def round(self, decimals: int) -> Self: return self._with_callable(lambda _input: _input.round(decimals), "round") def unique(self) -> Self: return self._with_callable(lambda _input: _input.unique(), "unique") def drop_nulls(self) -> Self: return self._with_callable(lambda _input: _input.dropna(), "drop_nulls") def abs(self) -> Self: return self._with_callable(lambda _input: _input.abs(), "abs") def all(self) -> Self: return self._with_callable( lambda _input: _input.all( axis=None, skipna=True, split_every=False, out=None ).to_series(), "all", ) def any(self) -> Self: return self._with_callable( lambda _input: _input.any(axis=0, skipna=True, split_every=False).to_series(), "any", ) def fill_null( self, value: Self | NonNestedLiteral, strategy: FillNullStrategy | None, limit: int | None, ) -> Self: def func(_input: dx.Series) -> dx.Series: if value is not None: res_ser = _input.fillna(value) else: res_ser = ( _input.ffill(limit=limit) if strategy == "forward" else _input.bfill(limit=limit) ) return res_ser return self._with_callable(func, "fillna") def clip( self, lower_bound: Self | NumericLiteral | TemporalLiteral | None, upper_bound: Self | NumericLiteral | TemporalLiteral | None, ) -> Self: return self._with_callable( lambda _input, lower_bound, upper_bound: _input.clip( lower=lower_bound, upper=upper_bound ), "clip", lower_bound=lower_bound, upper_bound=upper_bound, ) def diff(self) -> Self: return self._with_callable(lambda _input: _input.diff(), "diff") def n_unique(self) -> Self: return self._with_callable( lambda _input: _input.nunique(dropna=False).to_series(), "n_unique" ) def is_null(self) -> Self: return self._with_callable(lambda _input: _input.isna(), "is_null") def is_nan(self) -> Self: def func(_input: dx.Series) -> dx.Series: dtype = native_to_narwhals_dtype( _input.dtype, self._version, self._implementation ) if dtype.is_numeric(): return _input != _input # pyright: ignore[reportReturnType] # noqa: PLR0124 msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?" raise InvalidOperationError(msg) return self._with_callable(func, "is_null") def len(self) -> Self: return self._with_callable(lambda _input: _input.size.to_series(), "len") def quantile( self, quantile: float, interpolation: RollingInterpolationMethod ) -> Self: if interpolation == "linear": def func(_input: dx.Series, quantile: float) -> dx.Series: if _input.npartitions > 1: msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions." raise NotImplementedError(msg) return _input.quantile( q=quantile, method="dask" ).to_series() # pragma: no cover return self._with_callable(func, "quantile", quantile=quantile) else: msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead." raise NotImplementedError(msg) def is_first_distinct(self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index( _input.to_frame(), col_token, self._backend_version, self._implementation ) first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token] return frame[col_token].isin(first_distinct_index) return self._with_callable(func, "is_first_distinct") def is_last_distinct(self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name col_token = generate_temporary_column_name(n_bytes=8, columns=[_name]) frame = add_row_index( _input.to_frame(), col_token, self._backend_version, self._implementation ) last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token] return frame[col_token].isin(last_distinct_index) return self._with_callable(func, "is_last_distinct") def is_unique(self) -> Self: def func(_input: dx.Series) -> dx.Series: _name = _input.name return ( _input.to_frame() .groupby(_name, dropna=False) .transform("size", meta=(_name, int)) == 1 ) return self._with_callable(func, "is_unique") def is_in(self, other: Any) -> Self: return self._with_callable(lambda _input: _input.isin(other), "is_in") def null_count(self) -> Self: return self._with_callable( lambda _input: _input.isna().sum().to_series(), "null_count" ) def over( self, partition_by: Sequence[str], order_by: Sequence[str] | None, ) -> Self: # pandas is a required dependency of dask so it's safe to import this from narwhals._pandas_like.group_by import PandasLikeGroupBy if not partition_by: assert order_by is not None # help type checkers # noqa: S101 # This is something like `nw.col('a').cum_sum().order_by(key)` # which we can always easily support, as it doesn't require grouping. def func(df: DaskLazyFrame) -> Sequence[dx.Series]: return self(df.sort(*order_by, descending=False, nulls_last=False)) elif not self._is_elementary(): # pragma: no cover msg = ( "Only elementary expressions are supported for `.over` in dask.\n\n" "Please see: " "https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/" ) raise NotImplementedError(msg) elif order_by: # Wrong results https://github.com/dask/dask/issues/11806. msg = "`over` with `order_by` is not yet supported in Dask." raise NotImplementedError(msg) else: function_name = PandasLikeGroupBy._leaf_name(self) try: dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name] except KeyError: # window functions are unsupported: https://github.com/dask/dask/issues/11806 msg = ( f"Unsupported function: {function_name} in `over` context.\n\n" f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n" ) raise NotImplementedError(msg) from None def func(df: DaskLazyFrame) -> Sequence[dx.Series]: output_names, aliases = evaluate_output_names_and_aliases(self, df, []) with warnings.catch_warnings(): # https://github.com/dask/dask/issues/11804 warnings.filterwarnings( "ignore", message=".*`meta` is not specified", category=UserWarning, ) grouped = df.native.groupby(partition_by) if dask_function_name == "size": if len(output_names) != 1: # pragma: no cover msg = "Safety check failed, please report a bug." raise AssertionError(msg) res_native = grouped.transform( dask_function_name, **self._call_kwargs ).to_frame(output_names[0]) else: res_native = grouped[list(output_names)].transform( dask_function_name, **self._call_kwargs ) result_frame = df._with_native( res_native.rename(columns=dict(zip(output_names, aliases))) ).native return [result_frame[name] for name in aliases] return self.__class__( func, depth=self._depth + 1, function_name=self._function_name + "->over", evaluate_output_names=self._evaluate_output_names, alias_output_names=self._alias_output_names, backend_version=self._backend_version, version=self._version, ) def cast(self, dtype: DType | type[DType]) -> Self: def func(_input: dx.Series) -> dx.Series: native_dtype = narwhals_to_native_dtype(dtype, self._version) return _input.astype(native_dtype) return self._with_callable(func, "cast") def is_finite(self) -> Self: import dask.array as da return self._with_callable(da.isfinite, "is_finite") @property def str(self) -> DaskExprStringNamespace: return DaskExprStringNamespace(self) @property def dt(self) -> DaskExprDateTimeNamespace: return DaskExprDateTimeNamespace(self) list = not_implemented() # pyright: ignore[reportAssignmentType] struct = not_implemented() # pyright: ignore[reportAssignmentType] rank = not_implemented() # pyright: ignore[reportAssignmentType]