Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/_dask/namespace.py
2025-01-10 21:40:35 +00:00

486 lines
18 KiB
Python
Executable File

from __future__ import annotations
from functools import reduce
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable
from typing import Iterable
from typing import Literal
from typing import NoReturn
from typing import cast
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import name_preserving_div
from narwhals._dask.utils import name_preserving_sum
from narwhals._dask.utils import narwhals_to_native_dtype
from narwhals._dask.utils import validate_comparand
from narwhals._expression_parsing import combine_root_names
from narwhals._expression_parsing import parse_into_exprs
from narwhals._expression_parsing import reduce_output_names
if TYPE_CHECKING:
import dask_expr
from narwhals._dask.typing import IntoDaskExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes
class DaskNamespace:
@property
def selectors(self) -> DaskSelectorNamespace:
return DaskSelectorNamespace(
backend_version=self._backend_version, dtypes=self._dtypes
)
def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None:
self._backend_version = backend_version
self._dtypes = dtypes
def all(self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
return [df._native_frame.loc[:, column_name] for column_name in df.columns]
return DaskExpr(
func,
depth=0,
function_name="all",
root_names=None,
output_names=None,
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def col(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
)
def nth(self, *column_indices: int) -> DaskExpr:
return DaskExpr.from_column_indices(
*column_indices, backend_version=self._backend_version, dtypes=self._dtypes
)
def lit(self, value: Any, dtype: DType | None) -> DaskExpr:
def convert_if_dtype(
series: dask_expr.Series, dtype: DType | type[DType]
) -> dask_expr.Series:
return (
series.astype(narwhals_to_native_dtype(dtype, self._dtypes))
if dtype
else series
)
return DaskExpr(
lambda df: [
df._native_frame.assign(lit=value)
.loc[:, "lit"]
.pipe(convert_if_dtype, dtype)
],
depth=0,
function_name="lit",
root_names=None,
output_names=["lit"],
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def min(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
).min()
def max(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
).max()
def mean(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
).mean()
def sum(self, *column_names: str) -> DaskExpr:
return DaskExpr.from_column_names(
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
).sum()
def len(self) -> DaskExpr:
import dask.dataframe as dd # ignore-banned-import
import pandas as pd # ignore-banned-import
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
if not df.columns:
return [
dd.from_pandas(
pd.Series([0], name="len"),
npartitions=df._native_frame.npartitions,
)
]
return [df._native_frame.loc[:, df.columns[0]].size.to_series().rename("len")]
# coverage bug? this is definitely hit
return DaskExpr( # pragma: no cover
func,
depth=0,
function_name="len",
root_names=None,
output_names=["len"],
returns_scalar=True,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def all_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = [s for _expr in parsed_exprs for s in _expr._call(df)]
return [reduce(lambda x, y: x & y, series).rename(series[0].name)]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="all_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def any_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = [s for _expr in parsed_exprs for s in _expr._call(df)]
return [reduce(lambda x, y: x | y, series).rename(series[0].name)]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="any_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def sum_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = [s.fillna(0) for _expr in parsed_exprs for s in _expr._call(df)]
return [reduce(lambda x, y: x + y, series).rename(series[0].name)]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="sum_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def concat(
self,
items: Iterable[DaskLazyFrame],
*,
how: Literal["horizontal", "vertical"],
) -> DaskLazyFrame:
import dask.dataframe as dd # ignore-banned-import
if len(list(items)) == 0:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
native_frames = [i._native_frame for i in items]
if how == "vertical":
if not all(
tuple(i.columns) == tuple(native_frames[0].columns) for i in native_frames
): # pragma: no cover
msg = "unable to vstack with non-matching columns"
raise AssertionError(msg)
return DaskLazyFrame(
dd.concat(native_frames, axis=0, join="inner"),
backend_version=self._backend_version,
dtypes=self._dtypes,
)
if how == "horizontal":
all_column_names: list[str] = [
column for frame in native_frames for column in frame.columns
]
if len(all_column_names) != len(set(all_column_names)): # pragma: no cover
duplicates = [
i for i in all_column_names if all_column_names.count(i) > 1
]
msg = (
f"Columns with name(s): {', '.join(duplicates)} "
"have more than one occurrence"
)
raise AssertionError(msg)
return DaskLazyFrame(
dd.concat(native_frames, axis=1, join="outer"),
backend_version=self._backend_version,
dtypes=self._dtypes,
)
raise NotImplementedError
def mean_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = (s.fillna(0) for _expr in parsed_exprs for s in _expr._call(df))
non_na = (1 - s.isna() for _expr in parsed_exprs for s in _expr._call(df))
return [
name_preserving_div(
reduce(name_preserving_sum, series),
reduce(name_preserving_sum, non_na),
)
]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="mean_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def min_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
import dask.dataframe as dd # ignore-banned-import
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = [s for _expr in parsed_exprs for s in _expr._call(df)]
return [dd.concat(series, axis=1).min(axis=1).rename(series[0].name)]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="min_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def max_horizontal(self, *exprs: IntoDaskExpr) -> DaskExpr:
import dask.dataframe as dd # ignore-banned-import
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = [s for _expr in parsed_exprs for s in _expr._call(df)]
return [dd.concat(series, axis=1).max(axis=1).rename(series[0].name)]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="max_horizontal",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def _create_expr_from_series(self, _: Any) -> NoReturn:
msg = "`_create_expr_from_series` for DaskNamespace exists only for compatibility"
raise NotImplementedError(msg)
def _create_compliant_series(self, _: Any) -> NoReturn:
msg = "`_create_compliant_series` for DaskNamespace exists only for compatibility"
raise NotImplementedError(msg)
def _create_series_from_scalar(self, *_: Any) -> NoReturn:
msg = (
"`_create_series_from_scalar` for DaskNamespace exists only for compatibility"
)
raise NotImplementedError(msg)
def _create_expr_from_callable( # pragma: no cover
self,
func: Callable[[DaskLazyFrame], list[DaskExpr]],
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
) -> DaskExpr:
msg = (
"`_create_expr_from_callable` for DaskNamespace exists only for compatibility"
)
raise NotImplementedError(msg)
def when(
self,
*predicates: IntoDaskExpr,
) -> DaskWhen:
plx = self.__class__(backend_version=self._backend_version, dtypes=self._dtypes)
if predicates:
condition = plx.all_horizontal(*predicates)
else:
msg = "at least one predicate needs to be provided"
raise TypeError(msg)
return DaskWhen(
condition, self._backend_version, returns_scalar=False, dtypes=self._dtypes
)
def concat_str(
self,
exprs: Iterable[IntoDaskExpr],
*more_exprs: IntoDaskExpr,
separator: str = "",
ignore_nulls: bool = False,
) -> DaskExpr:
parsed_exprs: list[DaskExpr] = [
*parse_into_exprs(*exprs, namespace=self),
*parse_into_exprs(*more_exprs, namespace=self),
]
def func(df: DaskLazyFrame) -> list[dask_expr.Series]:
series = (s.astype(str) for _expr in parsed_exprs for s in _expr._call(df))
null_mask = [s for _expr in parsed_exprs for s in _expr.is_null()._call(df)]
if not ignore_nulls:
null_mask_result = reduce(lambda x, y: x | y, null_mask)
result = reduce(lambda x, y: x + separator + y, series).where(
~null_mask_result, None
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip(series, null_mask)
]
separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
lambda x, y: x + y,
(s + v for s, v in zip(separators, values)),
init_value,
)
return [result]
return DaskExpr(
call=func,
depth=max(x._depth for x in parsed_exprs) + 1,
function_name="concat_str",
root_names=combine_root_names(parsed_exprs),
output_names=reduce_output_names(parsed_exprs),
returns_scalar=False,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
class DaskWhen:
def __init__(
self,
condition: DaskExpr,
backend_version: tuple[int, ...],
then_value: Any = None,
otherwise_value: Any = None,
*,
returns_scalar: bool,
dtypes: DTypes,
) -> None:
self._backend_version = backend_version
self._condition = condition
self._then_value = then_value
self._otherwise_value = otherwise_value
self._returns_scalar = returns_scalar
self._dtypes = dtypes
def __call__(self, df: DaskLazyFrame) -> list[dask_expr.Series]:
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import parse_into_expr
plx = DaskNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
condition = cast("dask_expr.Series", condition)
try:
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
_df = condition.to_frame("a")
_df["tmp"] = self._then_value
value_series = _df["tmp"]
value_series = cast("dask_expr.Series", value_series)
validate_comparand(condition, value_series)
if self._otherwise_value is None:
return [value_series.where(condition)]
try:
otherwise_series = parse_into_expr(
self._otherwise_value, namespace=plx
)._call(df)[0] # type: ignore[arg-type]
except TypeError:
# `self._otherwise_value` is a scalar and can't be converted to an expression
return [value_series.where(condition, self._otherwise_value)]
validate_comparand(condition, otherwise_series)
return [value_series.where(condition, otherwise_series)]
def then(self, value: DaskExpr | Any) -> DaskThen:
self._then_value = value
return DaskThen(
self,
depth=0,
function_name="whenthen",
root_names=None,
output_names=None,
returns_scalar=self._returns_scalar,
backend_version=self._backend_version,
dtypes=self._dtypes,
)
class DaskThen(DaskExpr):
def __init__(
self,
call: DaskWhen,
*,
depth: int,
function_name: str,
root_names: list[str] | None,
output_names: list[str] | None,
returns_scalar: bool,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._backend_version = backend_version
self._dtypes = dtypes
self._call = call
self._depth = depth
self._function_name = function_name
self._root_names = root_names
self._output_names = output_names
self._returns_scalar = returns_scalar
def otherwise(self, value: DaskExpr | Any) -> DaskExpr:
# type ignore because we are setting the `_call` attribute to a
# callable object of type `DaskWhen`, base class has the attribute as
# only a `Callable`
self._call._otherwise_value = value # type: ignore[attr-defined]
self._function_name = "whenotherwise"
return self