Files
Buffteks-Website/venv/lib/python3.12/site-packages/narwhals/_dask/expr.py
2025-05-08 21:10:14 -05:00

677 lines
25 KiB
Python

from __future__ import annotations
import warnings
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Literal
from typing import Sequence
from narwhals._compliant import LazyExpr
from narwhals._compliant.expr import DepthTrackingExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
from narwhals._dask.utils import add_row_index
from narwhals._dask.utils import maybe_evaluate_expr
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._pandas_like.utils import native_to_narwhals_dtype
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import InvalidOperationError
from narwhals.utils import Implementation
from narwhals.utils import generate_temporary_column_name
from narwhals.utils import not_implemented
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames
from narwhals._compliant.typing import EvalNames
from narwhals._compliant.typing import EvalSeries
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import ExprKind
from narwhals._expression_parsing import ExprMetadata
from narwhals.dtypes import DType
from narwhals.typing import FillNullStrategy
from narwhals.typing import NonNestedLiteral
from narwhals.typing import NumericLiteral
from narwhals.typing import RollingInterpolationMethod
from narwhals.typing import TemporalLiteral
from narwhals.utils import Version
from narwhals.utils import _FullContext
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"],
DepthTrackingExpr["DaskLazyFrame", "dx.Series"],
):
_implementation: Implementation = Implementation.DASK
def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[DaskLazyFrame],
alias_output_names: AliasNames | None,
backend_version: tuple[int, ...],
version: Version,
# Kwargs with metadata which we may need in group-by agg
# (e.g. `ddof` for `std` and `var`).
call_kwargs: dict[str, Any] | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._backend_version = backend_version
self._version = version
self._call_kwargs = call_kwargs or {}
self._metadata: ExprMetadata | None = None
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)
def __narwhals_expr__(self) -> None: ...
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
# Unused, just for compatibility with PandasLikeExpr
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(backend_version=self._backend_version, version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [result[0] for result in self(df)]
return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
call_kwargs=self._call_kwargs,
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[DaskLazyFrame],
/,
*,
context: _FullContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
missing_columns = [
x for x in evaluate_column_names(df) if x not in df.columns
]
raise ColumnNotFoundError.from_missing_and_available_column_names(
missing_columns=missing_columns,
available_columns=df.columns,
) from e
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
)
@classmethod
def from_column_indices(
cls: type[Self], *column_indices: int, context: _FullContext
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [
df._native_frame.iloc[:, column_index] for column_index in column_indices
]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
alias_output_names=None,
backend_version=context._backend_version,
version=context._version,
)
def _with_callable(
self,
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
/,
expr_name: str = "",
call_kwargs: dict[str, Any] | None = None,
**expressifiable_args: Self | Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
native_results: list[dx.Series] = []
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate_expr(df, value)
for key, value in expressifiable_args.items()
}
for native_series in native_series_list:
result_native = call(native_series, **other_native_series)
native_results.append(result_native)
return native_results
return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
call_kwargs=call_kwargs,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
return type(self)(
call=self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=func,
backend_version=self._backend_version,
version=self._version,
call_kwargs=self._call_kwargs,
)
def __add__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__add__(other), "__add__", other=other
)
def __sub__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__sub__(other), "__sub__", other=other
)
def __rsub__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: other - _input, "__rsub__", other=other
).alias("literal")
def __mul__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__mul__(other), "__mul__", other=other
)
def __truediv__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__truediv__(other), "__truediv__", other=other
)
def __rtruediv__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: other / _input, "__rtruediv__", other=other
).alias("literal")
def __floordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__floordiv__(other), "__floordiv__", other=other
)
def __rfloordiv__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: other // _input, "__rfloordiv__", other=other
).alias("literal")
def __pow__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__pow__(other), "__pow__", other=other
)
def __rpow__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: other**_input, "__rpow__", other=other
).alias("literal")
def __mod__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__mod__(other), "__mod__", other=other
)
def __rmod__(self, other: Any) -> Self:
return self._with_callable(
lambda _input, other: other % _input, "__rmod__", other=other
).alias("literal")
def __eq__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda _input, other: _input.__eq__(other), "__eq__", other=other
)
def __ne__(self, other: DaskExpr) -> Self: # type: ignore[override]
return self._with_callable(
lambda _input, other: _input.__ne__(other), "__ne__", other=other
)
def __ge__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__ge__(other), "__ge__", other=other
)
def __gt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda _input, other: _input.__gt__(other), "__gt__", other=other
)
def __le__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda _input, other: _input.__le__(other), "__le__", other=other
)
def __lt__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda _input, other: _input.__lt__(other), "__lt__", other=other
)
def __and__(self, other: DaskExpr | Any) -> Self:
return self._with_callable(
lambda _input, other: _input.__and__(other), "__and__", other=other
)
def __or__(self, other: DaskExpr) -> Self:
return self._with_callable(
lambda _input, other: _input.__or__(other), "__or__", other=other
)
def __invert__(self) -> Self:
return self._with_callable(lambda _input: _input.__invert__(), "__invert__")
def mean(self) -> Self:
return self._with_callable(lambda _input: _input.mean().to_series(), "mean")
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return s.median_approximate().to_series()
return self._with_callable(func, "median")
def min(self) -> Self:
return self._with_callable(lambda _input: _input.min().to_series(), "min")
def max(self) -> Self:
return self._with_callable(lambda _input: _input.max().to_series(), "max")
def std(self, ddof: int) -> Self:
return self._with_callable(
lambda _input: _input.std(ddof=ddof).to_series(),
"std",
call_kwargs={"ddof": ddof},
)
def var(self, ddof: int) -> Self:
return self._with_callable(
lambda _input: _input.var(ddof=ddof).to_series(),
"var",
call_kwargs={"ddof": ddof},
)
def skew(self) -> Self:
return self._with_callable(lambda _input: _input.skew().to_series(), "skew")
def shift(self, n: int) -> Self:
return self._with_callable(lambda _input: _input.shift(n), "shift")
def cum_sum(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
# https://github.com/dask/dask/issues/11802
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda _input: _input.cumsum(), "cum_sum")
def cum_count(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(
lambda _input: (~_input.isna()).astype(int).cumsum(), "cum_count"
)
def cum_min(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda _input: _input.cummin(), "cum_min")
def cum_max(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda _input: _input.cummax(), "cum_max")
def cum_prod(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda _input: _input.cumprod(), "cum_prod")
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda _input: _input.rolling(
window=window_size, min_periods=min_samples, center=center
).sum(),
"rolling_sum",
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda _input: _input.rolling(
window=window_size, min_periods=min_samples, center=center
).mean(),
"rolling_mean",
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda _input: _input.rolling(
window=window_size, min_periods=min_samples, center=center
).var(),
"rolling_var",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
raise NotImplementedError(msg)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda _input: _input.rolling(
window=window_size, min_periods=min_samples, center=center
).std(),
"rolling_std",
)
else:
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
raise NotImplementedError(msg)
def sum(self) -> Self:
return self._with_callable(lambda _input: _input.sum().to_series(), "sum")
def count(self) -> Self:
return self._with_callable(lambda _input: _input.count().to_series(), "count")
def round(self, decimals: int) -> Self:
return self._with_callable(lambda _input: _input.round(decimals), "round")
def unique(self) -> Self:
return self._with_callable(lambda _input: _input.unique(), "unique")
def drop_nulls(self) -> Self:
return self._with_callable(lambda _input: _input.dropna(), "drop_nulls")
def abs(self) -> Self:
return self._with_callable(lambda _input: _input.abs(), "abs")
def all(self) -> Self:
return self._with_callable(
lambda _input: _input.all(
axis=None, skipna=True, split_every=False, out=None
).to_series(),
"all",
)
def any(self) -> Self:
return self._with_callable(
lambda _input: _input.any(axis=0, skipna=True, split_every=False).to_series(),
"any",
)
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
def func(_input: dx.Series) -> dx.Series:
if value is not None:
res_ser = _input.fillna(value)
else:
res_ser = (
_input.ffill(limit=limit)
if strategy == "forward"
else _input.bfill(limit=limit)
)
return res_ser
return self._with_callable(func, "fillna")
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
return self._with_callable(
lambda _input, lower_bound, upper_bound: _input.clip(
lower=lower_bound, upper=upper_bound
),
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
)
def diff(self) -> Self:
return self._with_callable(lambda _input: _input.diff(), "diff")
def n_unique(self) -> Self:
return self._with_callable(
lambda _input: _input.nunique(dropna=False).to_series(), "n_unique"
)
def is_null(self) -> Self:
return self._with_callable(lambda _input: _input.isna(), "is_null")
def is_nan(self) -> Self:
def func(_input: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(
_input.dtype, self._version, self._implementation
)
if dtype.is_numeric():
return _input != _input # pyright: ignore[reportReturnType] # noqa: PLR0124
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
raise InvalidOperationError(msg)
return self._with_callable(func, "is_null")
def len(self) -> Self:
return self._with_callable(lambda _input: _input.size.to_series(), "len")
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation == "linear":
def func(_input: dx.Series, quantile: float) -> dx.Series:
if _input.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return _input.quantile(
q=quantile, method="dask"
).to_series() # pragma: no cover
return self._with_callable(func, "quantile", quantile=quantile)
else:
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
raise NotImplementedError(msg)
def is_first_distinct(self) -> Self:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(
_input.to_frame(), col_token, self._backend_version, self._implementation
)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
return self._with_callable(func, "is_first_distinct")
def is_last_distinct(self) -> Self:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(
_input.to_frame(), col_token, self._backend_version, self._implementation
)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
return self._with_callable(func, "is_last_distinct")
def is_unique(self) -> Self:
def func(_input: dx.Series) -> dx.Series:
_name = _input.name
return (
_input.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)
return self._with_callable(func, "is_unique")
def is_in(self, other: Any) -> Self:
return self._with_callable(lambda _input: _input.isin(other), "is_in")
def null_count(self) -> Self:
return self._with_callable(
lambda _input: _input.isna().sum().to_series(), "null_count"
)
def over(
self,
partition_by: Sequence[str],
order_by: Sequence[str] | None,
) -> Self:
# pandas is a required dependency of dask so it's safe to import this
from narwhals._pandas_like.group_by import PandasLikeGroupBy
if not partition_by:
assert order_by is not None # help type checkers # noqa: S101
# This is something like `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/pandas_like_concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
elif order_by:
# Wrong results https://github.com/dask/dask/issues/11806.
msg = "`over` with `order_by` is not yet supported in Dask."
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
try:
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
except KeyError:
# window functions are unsupported: https://github.com/dask/dask/issues/11806
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
)
raise NotImplementedError(msg) from None
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
with warnings.catch_warnings():
# https://github.com/dask/dask/issues/11804
warnings.filterwarnings(
"ignore",
message=".*`meta` is not specified",
category=UserWarning,
)
grouped = df.native.groupby(partition_by)
if dask_function_name == "size":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform(
dask_function_name, **self._call_kwargs
).to_frame(output_names[0])
else:
res_native = grouped[list(output_names)].transform(
dask_function_name, **self._call_kwargs
)
result_frame = df._with_native(
res_native.rename(columns=dict(zip(output_names, aliases)))
).native
return [result_frame[name] for name in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
backend_version=self._backend_version,
version=self._version,
)
def cast(self, dtype: DType | type[DType]) -> Self:
def func(_input: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return _input.astype(native_dtype)
return self._with_callable(func, "cast")
def is_finite(self) -> Self:
import dask.array as da
return self._with_callable(da.isfinite, "is_finite")
@property
def str(self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
@property
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)
list = not_implemented() # pyright: ignore[reportAssignmentType]
struct = not_implemented() # pyright: ignore[reportAssignmentType]
rank = not_implemented() # pyright: ignore[reportAssignmentType]