732 lines
28 KiB
Python
732 lines
28 KiB
Python
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import operator
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Callable
|
|
from typing import Literal
|
|
from typing import Sequence
|
|
from typing import cast
|
|
|
|
from duckdb import CoalesceOperator
|
|
from duckdb import FunctionExpression
|
|
from duckdb.typing import DuckDBPyType
|
|
|
|
from narwhals._compliant import LazyExpr
|
|
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
|
|
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
|
|
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
|
|
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
|
|
from narwhals._duckdb.utils import WindowInputs
|
|
from narwhals._duckdb.utils import col
|
|
from narwhals._duckdb.utils import generate_order_by_sql
|
|
from narwhals._duckdb.utils import generate_partition_by_sql
|
|
from narwhals._duckdb.utils import lit
|
|
from narwhals._duckdb.utils import narwhals_to_native_dtype
|
|
from narwhals._duckdb.utils import when
|
|
from narwhals._expression_parsing import ExprKind
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import not_implemented
|
|
|
|
if TYPE_CHECKING:
|
|
import duckdb
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._compliant.typing import AliasNames
|
|
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
|
from narwhals._duckdb.namespace import DuckDBNamespace
|
|
from narwhals._duckdb.typing import WindowFunction
|
|
from narwhals._expression_parsing import ExprMetadata
|
|
from narwhals.dtypes import DType
|
|
from narwhals.utils import Version
|
|
from narwhals.utils import _FullContext
|
|
|
|
with contextlib.suppress(ImportError): # requires duckdb>=1.3.0
|
|
from duckdb import SQLExpression # type: ignore[attr-defined, unused-ignore]
|
|
|
|
|
|
class DuckDBExpr(LazyExpr["DuckDBLazyFrame", "duckdb.Expression"]):
|
|
_implementation = Implementation.DUCKDB
|
|
|
|
def __init__(
|
|
self: Self,
|
|
call: Callable[[DuckDBLazyFrame], Sequence[duckdb.Expression]],
|
|
*,
|
|
evaluate_output_names: Callable[[DuckDBLazyFrame], Sequence[str]],
|
|
alias_output_names: Callable[[Sequence[str]], Sequence[str]] | None,
|
|
backend_version: tuple[int, ...],
|
|
version: Version,
|
|
) -> None:
|
|
self._call = call
|
|
self._evaluate_output_names = evaluate_output_names
|
|
self._alias_output_names = alias_output_names
|
|
self._backend_version = backend_version
|
|
self._version = version
|
|
self._window_function: WindowFunction | None = None
|
|
self._metadata: ExprMetadata | None = None
|
|
|
|
def __call__(self: Self, df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
|
|
return self._call(df)
|
|
|
|
def __narwhals_expr__(self) -> None: ...
|
|
|
|
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
|
|
# Unused, just for compatibility with PandasLikeExpr
|
|
from narwhals._duckdb.namespace import DuckDBNamespace
|
|
|
|
return DuckDBNamespace(
|
|
backend_version=self._backend_version, version=self._version
|
|
)
|
|
|
|
def _cum_window_func(
|
|
self,
|
|
*,
|
|
reverse: bool,
|
|
func_name: Literal["sum", "max", "min", "count", "product"],
|
|
) -> WindowFunction:
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(
|
|
*window_inputs.order_by, ascending=not reverse
|
|
)
|
|
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
|
|
sql = (
|
|
f"{func_name} ({window_inputs.expr}) over ({partition_by_sql} {order_by_sql} "
|
|
"rows between unbounded preceding and current row)"
|
|
)
|
|
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return func
|
|
|
|
def _rolling_window_func(
|
|
self,
|
|
*,
|
|
func_name: Literal["sum", "mean", "std", "var"],
|
|
center: bool,
|
|
window_size: int,
|
|
min_samples: int,
|
|
ddof: int | None = None,
|
|
) -> WindowFunction:
|
|
supported_funcs = ["sum", "mean", "std", "var"]
|
|
if center:
|
|
half = (window_size - 1) // 2
|
|
remainder = (window_size - 1) % 2
|
|
start = f"{half + remainder} preceding"
|
|
end = f"{half} following"
|
|
else:
|
|
start = f"{window_size - 1} preceding"
|
|
end = "current row"
|
|
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
|
|
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
|
|
window = f"({partition_by_sql} {order_by_sql} rows between {start} and {end})"
|
|
if func_name in {"sum", "mean"}:
|
|
func_: str = func_name
|
|
elif func_name == "var" and ddof == 0:
|
|
func_ = "var_pop"
|
|
elif func_name in "var" and ddof == 1:
|
|
func_ = "var_samp"
|
|
elif func_name == "std" and ddof == 0:
|
|
func_ = "stddev_pop"
|
|
elif func_name == "std" and ddof == 1:
|
|
func_ = "stddev_samp"
|
|
elif func_name in {"var", "std"}: # pragma: no cover
|
|
msg = f"Only ddof=0 and ddof=1 are currently supported for rolling_{func_name}."
|
|
raise ValueError(msg)
|
|
else: # pragma: no cover
|
|
msg = f"Only the following functions are supported: {supported_funcs}.\nGot: {func_name}."
|
|
raise ValueError(msg)
|
|
sql = (
|
|
f"case when count({window_inputs.expr}) over {window} >= {min_samples}"
|
|
f"then {func_}({window_inputs.expr}) over {window} end"
|
|
)
|
|
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return func
|
|
|
|
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
|
if kind is ExprKind.LITERAL:
|
|
return self
|
|
if self._backend_version < (1, 3):
|
|
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
|
|
raise NotImplementedError(msg)
|
|
|
|
template = "{expr} over ()"
|
|
|
|
def func(df: DuckDBLazyFrame) -> Sequence[duckdb.Expression]:
|
|
return [SQLExpression(template.format(expr=expr)) for expr in self(df)]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_names(
|
|
cls: type[Self],
|
|
evaluate_column_names: Callable[[DuckDBLazyFrame], Sequence[str]],
|
|
/,
|
|
*,
|
|
context: _FullContext,
|
|
) -> Self:
|
|
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
|
return [col(name) for name in evaluate_column_names(df)]
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=evaluate_column_names,
|
|
alias_output_names=None,
|
|
backend_version=context._backend_version,
|
|
version=context._version,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_indices(
|
|
cls: type[Self], *column_indices: int, context: _FullContext
|
|
) -> Self:
|
|
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
|
columns = df.columns
|
|
return [col(columns[i]) for i in column_indices]
|
|
|
|
return cls(
|
|
func,
|
|
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
|
|
alias_output_names=None,
|
|
backend_version=context._backend_version,
|
|
version=context._version,
|
|
)
|
|
|
|
def _with_callable(
|
|
self: Self,
|
|
call: Callable[..., duckdb.Expression],
|
|
/,
|
|
**expressifiable_args: Self | Any,
|
|
) -> Self:
|
|
"""Create expression from callable.
|
|
|
|
Arguments:
|
|
call: Callable from compliant DataFrame to native Expression
|
|
expr_name: Expression name
|
|
expressifiable_args: arguments pass to expression which should be parsed
|
|
as expressions (e.g. in `nw.col('a').is_between('b', 'c')`)
|
|
"""
|
|
|
|
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
|
native_series_list = self(df)
|
|
other_native_series = {
|
|
key: df._evaluate_expr(value) if self._is_expr(value) else lit(value)
|
|
for key, value in expressifiable_args.items()
|
|
}
|
|
return [
|
|
call(native_series, **other_native_series)
|
|
for native_series in native_series_list
|
|
]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
|
return type(self)(
|
|
call=self._call,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=func,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
def _with_window_function(
|
|
self: Self,
|
|
window_function: WindowFunction,
|
|
) -> Self:
|
|
result = self.__class__(
|
|
self._call,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
result._window_function = window_function
|
|
return result
|
|
|
|
def __and__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input & other, other=other)
|
|
|
|
def __or__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input | other, other=other)
|
|
|
|
def __add__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input + other, other=other)
|
|
|
|
def __truediv__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input / other, other=other)
|
|
|
|
def __rtruediv__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: other.__truediv__(_input), other=other
|
|
).alias("literal")
|
|
|
|
def __floordiv__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: _input.__floordiv__(other), other=other
|
|
)
|
|
|
|
def __rfloordiv__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: other.__floordiv__(_input), other=other
|
|
).alias("literal")
|
|
|
|
def __mod__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: _input.__mod__(other), other=other
|
|
)
|
|
|
|
def __rmod__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: other.__mod__(_input), other=other
|
|
).alias("literal")
|
|
|
|
def __sub__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input - other, other=other)
|
|
|
|
def __rsub__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: other.__sub__(_input), other=other
|
|
).alias("literal")
|
|
|
|
def __mul__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input * other, other=other)
|
|
|
|
def __pow__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input**other, other=other)
|
|
|
|
def __rpow__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(
|
|
lambda _input, other: other.__pow__(_input), other=other
|
|
).alias("literal")
|
|
|
|
def __lt__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input < other, other=other)
|
|
|
|
def __gt__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input > other, other=other)
|
|
|
|
def __le__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input <= other, other=other)
|
|
|
|
def __ge__(self: Self, other: DuckDBExpr) -> Self:
|
|
return self._with_callable(lambda _input, other: _input >= other, other=other)
|
|
|
|
def __eq__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override]
|
|
return self._with_callable(lambda _input, other: _input == other, other=other)
|
|
|
|
def __ne__(self: Self, other: DuckDBExpr) -> Self: # type: ignore[override]
|
|
return self._with_callable(lambda _input, other: _input != other, other=other)
|
|
|
|
def __invert__(self: Self) -> Self:
|
|
invert = cast("Callable[..., duckdb.Expression]", operator.invert)
|
|
return self._with_callable(invert)
|
|
|
|
def alias(self: Self, name: str) -> Self:
|
|
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
|
|
if len(names) != 1:
|
|
msg = f"Expected function with single output, found output names: {names}"
|
|
raise ValueError(msg)
|
|
return [name]
|
|
|
|
return self.__class__(
|
|
self._call,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
def abs(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("abs", _input))
|
|
|
|
def mean(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("mean", _input))
|
|
|
|
def skew(self: Self) -> Self:
|
|
def func(_input: duckdb.Expression) -> duckdb.Expression:
|
|
count = FunctionExpression("count", _input)
|
|
# Adjust population skewness by correction factor to get sample skewness
|
|
sample_skewness = (
|
|
FunctionExpression("skewness", _input)
|
|
* (count - lit(2))
|
|
/ FunctionExpression("sqrt", count * (count - lit(1)))
|
|
)
|
|
return when(count == lit(0), lit(None)).otherwise(
|
|
when(count == lit(1), lit(float("nan"))).otherwise(
|
|
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
|
|
)
|
|
)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def median(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("median", _input))
|
|
|
|
def all(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("bool_and", _input))
|
|
|
|
def any(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("bool_or", _input))
|
|
|
|
def quantile(
|
|
self: Self,
|
|
quantile: float,
|
|
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
|
|
) -> Self:
|
|
def func(_input: duckdb.Expression) -> duckdb.Expression:
|
|
if interpolation == "linear":
|
|
return FunctionExpression("quantile_cont", _input, lit(quantile))
|
|
msg = "Only linear interpolation methods are supported for DuckDB quantile."
|
|
raise NotImplementedError(msg)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def clip(self: Self, lower_bound: Any, upper_bound: Any) -> Self:
|
|
def _clip_lower(_input: duckdb.Expression, lower_bound: Any) -> duckdb.Expression:
|
|
return FunctionExpression("greatest", _input, lower_bound)
|
|
|
|
def _clip_upper(_input: duckdb.Expression, upper_bound: Any) -> duckdb.Expression:
|
|
return FunctionExpression("least", _input, upper_bound)
|
|
|
|
def _clip_both(
|
|
_input: duckdb.Expression, lower_bound: Any, upper_bound: Any
|
|
) -> duckdb.Expression:
|
|
return FunctionExpression(
|
|
"greatest", FunctionExpression("least", _input, upper_bound), lower_bound
|
|
)
|
|
|
|
if lower_bound is None:
|
|
return self._with_callable(_clip_upper, upper_bound=upper_bound)
|
|
if upper_bound is None:
|
|
return self._with_callable(_clip_lower, lower_bound=lower_bound)
|
|
return self._with_callable(
|
|
_clip_both, lower_bound=lower_bound, upper_bound=upper_bound
|
|
)
|
|
|
|
def sum(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("sum", _input))
|
|
|
|
def n_unique(self: Self) -> Self:
|
|
def func(_input: duckdb.Expression) -> duckdb.Expression:
|
|
# https://stackoverflow.com/a/79338887/4451315
|
|
return FunctionExpression(
|
|
"array_unique", FunctionExpression("array_agg", _input)
|
|
) + FunctionExpression(
|
|
"max", when(_input.isnotnull(), lit(0)).otherwise(lit(1))
|
|
)
|
|
|
|
return self._with_callable(func)
|
|
|
|
def count(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("count", _input))
|
|
|
|
def len(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("count"))
|
|
|
|
def std(self: Self, ddof: int) -> Self:
|
|
if ddof == 0:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("stddev_pop", _input)
|
|
)
|
|
if ddof == 1:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("stddev_samp", _input)
|
|
)
|
|
|
|
def _std(_input: duckdb.Expression) -> duckdb.Expression:
|
|
n_samples = FunctionExpression("count", _input)
|
|
return (
|
|
FunctionExpression("stddev_pop", _input)
|
|
* FunctionExpression("sqrt", n_samples)
|
|
/ (FunctionExpression("sqrt", (n_samples - lit(ddof))))
|
|
)
|
|
|
|
return self._with_callable(_std)
|
|
|
|
def var(self: Self, ddof: int) -> Self:
|
|
if ddof == 0:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("var_pop", _input)
|
|
)
|
|
if ddof == 1:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("var_samp", _input)
|
|
)
|
|
|
|
def _var(_input: duckdb.Expression) -> duckdb.Expression:
|
|
n_samples = FunctionExpression("count", _input)
|
|
return (
|
|
FunctionExpression("var_pop", _input)
|
|
* n_samples
|
|
/ (n_samples - lit(ddof))
|
|
)
|
|
|
|
return self._with_callable(_var)
|
|
|
|
def max(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("max", _input))
|
|
|
|
def min(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("min", _input))
|
|
|
|
def null_count(self: Self) -> Self:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("sum", _input.isnull().cast("int")),
|
|
)
|
|
|
|
def over(
|
|
self: Self,
|
|
partition_by: Sequence[str],
|
|
order_by: Sequence[str] | None,
|
|
) -> Self:
|
|
if self._backend_version < (1, 3):
|
|
msg = "At least version 1.3 of DuckDB is required for `over` operation."
|
|
raise NotImplementedError(msg)
|
|
if (window_function := self._window_function) is not None:
|
|
assert order_by is not None # noqa: S101
|
|
|
|
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
|
return [
|
|
window_function(WindowInputs(expr, partition_by, order_by))
|
|
for expr in self._call(df)
|
|
]
|
|
else:
|
|
partition_by_sql = generate_partition_by_sql(*partition_by)
|
|
template = f"{{expr}} over ({partition_by_sql})"
|
|
|
|
def func(df: DuckDBLazyFrame) -> list[duckdb.Expression]:
|
|
return [
|
|
SQLExpression(template.format(expr=expr)) for expr in self._call(df)
|
|
]
|
|
|
|
return self.__class__(
|
|
func,
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
def is_null(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: _input.isnull())
|
|
|
|
def is_nan(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("isnan", _input))
|
|
|
|
def is_finite(self: Self) -> Self:
|
|
return self._with_callable(lambda _input: FunctionExpression("isfinite", _input))
|
|
|
|
def is_in(self: Self, other: Sequence[Any]) -> Self:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("contains", lit(other), _input)
|
|
)
|
|
|
|
def round(self: Self, decimals: int) -> Self:
|
|
return self._with_callable(
|
|
lambda _input: FunctionExpression("round", _input, lit(decimals))
|
|
)
|
|
|
|
def shift(self, n: int) -> Self:
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
|
|
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
|
|
sql = (
|
|
f"lag({window_inputs.expr}, {n}) over ({partition_by_sql} {order_by_sql})"
|
|
)
|
|
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return self._with_window_function(func)
|
|
|
|
def is_first_distinct(self) -> Self:
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
|
|
if window_inputs.partition_by:
|
|
partition_by_sql = (
|
|
generate_partition_by_sql(*window_inputs.partition_by)
|
|
+ f", {window_inputs.expr}"
|
|
)
|
|
else:
|
|
partition_by_sql = f"partition by {window_inputs.expr}"
|
|
sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1"
|
|
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return self._with_window_function(func)
|
|
|
|
def is_last_distinct(self) -> Self:
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=False)
|
|
if window_inputs.partition_by:
|
|
partition_by_sql = (
|
|
generate_partition_by_sql(*window_inputs.partition_by)
|
|
+ f", {window_inputs.expr}"
|
|
)
|
|
else:
|
|
partition_by_sql = f"partition by {window_inputs.expr}"
|
|
sql = f"row_number() over({partition_by_sql} {order_by_sql}) == 1"
|
|
return SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return self._with_window_function(func)
|
|
|
|
def diff(self) -> Self:
|
|
def func(window_inputs: WindowInputs) -> duckdb.Expression:
|
|
order_by_sql = generate_order_by_sql(*window_inputs.order_by, ascending=True)
|
|
partition_by_sql = generate_partition_by_sql(*window_inputs.partition_by)
|
|
sql = f"lag({window_inputs.expr}) over ({partition_by_sql} {order_by_sql})"
|
|
return window_inputs.expr - SQLExpression(sql) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return self._with_window_function(func)
|
|
|
|
def cum_sum(self, *, reverse: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._cum_window_func(reverse=reverse, func_name="sum")
|
|
)
|
|
|
|
def cum_max(self, *, reverse: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._cum_window_func(reverse=reverse, func_name="max")
|
|
)
|
|
|
|
def cum_min(self, *, reverse: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._cum_window_func(reverse=reverse, func_name="min")
|
|
)
|
|
|
|
def cum_count(self, *, reverse: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._cum_window_func(reverse=reverse, func_name="count")
|
|
)
|
|
|
|
def cum_prod(self, *, reverse: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._cum_window_func(reverse=reverse, func_name="product")
|
|
)
|
|
|
|
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._rolling_window_func(
|
|
func_name="sum",
|
|
center=center,
|
|
window_size=window_size,
|
|
min_samples=min_samples,
|
|
)
|
|
)
|
|
|
|
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
|
return self._with_window_function(
|
|
self._rolling_window_func(
|
|
func_name="mean",
|
|
center=center,
|
|
window_size=window_size,
|
|
min_samples=min_samples,
|
|
)
|
|
)
|
|
|
|
def rolling_var(
|
|
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
|
) -> Self:
|
|
return self._with_window_function(
|
|
self._rolling_window_func(
|
|
func_name="var",
|
|
center=center,
|
|
window_size=window_size,
|
|
min_samples=min_samples,
|
|
ddof=ddof,
|
|
)
|
|
)
|
|
|
|
def rolling_std(
|
|
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
|
) -> Self:
|
|
return self._with_window_function(
|
|
self._rolling_window_func(
|
|
func_name="std",
|
|
center=center,
|
|
window_size=window_size,
|
|
min_samples=min_samples,
|
|
ddof=ddof,
|
|
)
|
|
)
|
|
|
|
def fill_null(
|
|
self: Self, value: Self | Any, strategy: Any, limit: int | None
|
|
) -> Self:
|
|
if strategy is not None:
|
|
msg = "todo"
|
|
raise NotImplementedError(msg)
|
|
|
|
def func(_input: duckdb.Expression, value: Any) -> duckdb.Expression:
|
|
return CoalesceOperator(_input, value)
|
|
|
|
return self._with_callable(func, value=value)
|
|
|
|
def cast(self: Self, dtype: DType | type[DType]) -> Self:
|
|
def func(_input: duckdb.Expression) -> duckdb.Expression:
|
|
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
|
return _input.cast(DuckDBPyType(native_dtype))
|
|
|
|
return self._with_callable(func)
|
|
|
|
def is_unique(self: Self) -> Self:
|
|
def func(_input: duckdb.Expression) -> duckdb.Expression:
|
|
sql = f"count(*) over (partition by {_input})"
|
|
return SQLExpression(sql) == lit(1) # type: ignore[no-any-return, unused-ignore]
|
|
|
|
return self._with_callable(func)
|
|
|
|
def rank(
|
|
self,
|
|
method: Literal["average", "min", "max", "dense", "ordinal"],
|
|
*,
|
|
descending: bool,
|
|
) -> Self:
|
|
if method == "min":
|
|
func_name = "rank"
|
|
elif method == "dense":
|
|
func_name = "dense_rank"
|
|
else: # pragma: no cover
|
|
msg = f"Method {method} is not yet implemented."
|
|
raise NotImplementedError(msg)
|
|
|
|
def _rank(_input: duckdb.Expression) -> duckdb.Expression:
|
|
if descending:
|
|
by_sql = f"{_input} desc nulls last"
|
|
else:
|
|
by_sql = f"{_input} asc nulls last"
|
|
sql = f"{func_name}() OVER (order by {by_sql})"
|
|
return when(_input.isnotnull(), SQLExpression(sql))
|
|
|
|
return self._with_callable(_rank)
|
|
|
|
@property
|
|
def str(self: Self) -> DuckDBExprStringNamespace:
|
|
return DuckDBExprStringNamespace(self)
|
|
|
|
@property
|
|
def dt(self: Self) -> DuckDBExprDateTimeNamespace:
|
|
return DuckDBExprDateTimeNamespace(self)
|
|
|
|
@property
|
|
def list(self: Self) -> DuckDBExprListNamespace:
|
|
return DuckDBExprListNamespace(self)
|
|
|
|
@property
|
|
def struct(self: Self) -> DuckDBExprStructNamespace:
|
|
return DuckDBExprStructNamespace(self)
|
|
|
|
drop_nulls = not_implemented()
|
|
unique = not_implemented()
|