215 lines
8.2 KiB
Python
215 lines
8.2 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Sequence
|
|
|
|
import pyarrow.compute as pc
|
|
|
|
from narwhals._arrow.series import ArrowSeries
|
|
from narwhals._compliant import EagerExpr
|
|
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
|
from narwhals._expression_parsing import is_scalar_like
|
|
from narwhals.exceptions import ColumnNotFoundError
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import generate_temporary_column_name
|
|
from narwhals.utils import not_implemented
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
from narwhals._compliant.typing import AliasNames
|
|
from narwhals._compliant.typing import EvalNames
|
|
from narwhals._compliant.typing import EvalSeries
|
|
from narwhals._expression_parsing import ExprMetadata
|
|
from narwhals.typing import RankMethod
|
|
from narwhals.utils import Version
|
|
from narwhals.utils import _FullContext
|
|
|
|
|
|
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
|
|
_implementation: Implementation = Implementation.PYARROW
|
|
|
|
def __init__(
|
|
self,
|
|
call: EvalSeries[ArrowDataFrame, ArrowSeries],
|
|
*,
|
|
depth: int,
|
|
function_name: str,
|
|
evaluate_output_names: EvalNames[ArrowDataFrame],
|
|
alias_output_names: AliasNames | None,
|
|
backend_version: tuple[int, ...],
|
|
version: Version,
|
|
call_kwargs: dict[str, Any] | None = None,
|
|
implementation: Implementation | None = None,
|
|
) -> None:
|
|
self._call = call
|
|
self._depth = depth
|
|
self._function_name = function_name
|
|
self._depth = depth
|
|
self._evaluate_output_names = evaluate_output_names
|
|
self._alias_output_names = alias_output_names
|
|
self._backend_version = backend_version
|
|
self._version = version
|
|
self._call_kwargs = call_kwargs or {}
|
|
self._metadata: ExprMetadata | None = None
|
|
|
|
@classmethod
|
|
def from_column_names(
|
|
cls: type[Self],
|
|
evaluate_column_names: EvalNames[ArrowDataFrame],
|
|
/,
|
|
*,
|
|
context: _FullContext,
|
|
function_name: str = "",
|
|
) -> Self:
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
try:
|
|
return [
|
|
ArrowSeries(
|
|
df.native[column_name],
|
|
name=column_name,
|
|
backend_version=df._backend_version,
|
|
version=df._version,
|
|
)
|
|
for column_name in evaluate_column_names(df)
|
|
]
|
|
except KeyError as e:
|
|
missing_columns = [
|
|
x for x in evaluate_column_names(df) if x not in df.columns
|
|
]
|
|
raise ColumnNotFoundError.from_missing_and_available_column_names(
|
|
missing_columns=missing_columns, available_columns=df.columns
|
|
) from e
|
|
|
|
return cls(
|
|
func,
|
|
depth=0,
|
|
function_name=function_name,
|
|
evaluate_output_names=evaluate_column_names,
|
|
alias_output_names=None,
|
|
backend_version=context._backend_version,
|
|
version=context._version,
|
|
)
|
|
|
|
@classmethod
|
|
def from_column_indices(
|
|
cls: type[Self], *column_indices: int, context: _FullContext
|
|
) -> Self:
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
return [
|
|
ArrowSeries(
|
|
df.native[column_index],
|
|
name=df.native.column_names[column_index],
|
|
backend_version=df._backend_version,
|
|
version=df._version,
|
|
)
|
|
for column_index in column_indices
|
|
]
|
|
|
|
return cls(
|
|
func,
|
|
depth=0,
|
|
function_name="nth",
|
|
evaluate_output_names=lambda df: [df.columns[i] for i in column_indices],
|
|
alias_output_names=None,
|
|
backend_version=context._backend_version,
|
|
version=context._version,
|
|
)
|
|
|
|
def __narwhals_namespace__(self) -> ArrowNamespace:
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
|
|
return ArrowNamespace(
|
|
backend_version=self._backend_version, version=self._version
|
|
)
|
|
|
|
def __narwhals_expr__(self) -> None: ...
|
|
|
|
def _reuse_series_extra_kwargs(
|
|
self, *, returns_scalar: bool = False
|
|
) -> dict[str, Any]:
|
|
return {"_return_py_scalar": False} if returns_scalar else {}
|
|
|
|
def cum_sum(self, *, reverse: bool) -> Self:
|
|
return self._reuse_series("cum_sum", reverse=reverse)
|
|
|
|
def shift(self, n: int) -> Self:
|
|
return self._reuse_series("shift", n=n)
|
|
|
|
def over(self, partition_by: Sequence[str], order_by: Sequence[str] | None) -> Self:
|
|
assert self._metadata is not None # noqa: S101
|
|
if partition_by and not is_scalar_like(self._metadata.kind):
|
|
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
|
|
raise NotImplementedError(msg)
|
|
|
|
if not partition_by:
|
|
# e.g. `nw.col('a').cum_sum().order_by(key)`
|
|
# which we can always easily support, as it doesn't require grouping.
|
|
assert order_by is not None # help type checkers # noqa: S101
|
|
|
|
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
|
token = generate_temporary_column_name(8, df.columns)
|
|
df = df.with_row_index(token).sort(
|
|
*order_by, descending=False, nulls_last=False
|
|
)
|
|
result = self(df.drop([token], strict=True))
|
|
# TODO(marco): is there a way to do this efficiently without
|
|
# doing 2 sorts? Here we're sorting the dataframe and then
|
|
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
|
|
sorting_indices = pc.sort_indices(df.get_column(token).native)
|
|
return [s._with_native(s.native.take(sorting_indices)) for s in result]
|
|
else:
|
|
|
|
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
|
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
|
if overlap := set(output_names).intersection(partition_by):
|
|
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
|
|
# we just don't support it yet.
|
|
msg = (
|
|
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
|
|
"This is not yet supported."
|
|
)
|
|
raise NotImplementedError(msg)
|
|
|
|
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
|
|
tmp = df.simple_select(*partition_by).join(
|
|
tmp,
|
|
how="left",
|
|
left_on=partition_by,
|
|
right_on=partition_by,
|
|
suffix="_right",
|
|
)
|
|
return [tmp.get_column(alias) for alias in aliases]
|
|
|
|
return self.__class__(
|
|
func,
|
|
depth=self._depth + 1,
|
|
function_name=self._function_name + "->over",
|
|
evaluate_output_names=self._evaluate_output_names,
|
|
alias_output_names=self._alias_output_names,
|
|
backend_version=self._backend_version,
|
|
version=self._version,
|
|
)
|
|
|
|
def cum_count(self, *, reverse: bool) -> Self:
|
|
return self._reuse_series("cum_count", reverse=reverse)
|
|
|
|
def cum_min(self, *, reverse: bool) -> Self:
|
|
return self._reuse_series("cum_min", reverse=reverse)
|
|
|
|
def cum_max(self, *, reverse: bool) -> Self:
|
|
return self._reuse_series("cum_max", reverse=reverse)
|
|
|
|
def cum_prod(self, *, reverse: bool) -> Self:
|
|
return self._reuse_series("cum_prod", reverse=reverse)
|
|
|
|
def rank(self, method: RankMethod, *, descending: bool) -> Self:
|
|
return self._reuse_series("rank", method=method, descending=descending)
|
|
|
|
ewm_mean = not_implemented()
|