Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_arrow/namespace.py
2025-05-08 21:10:14 -05:00

295 lines
11 KiB
Python

from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Literal
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.selectors import ArrowSelectorNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import align_series_full_broadcast
from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._arrow.utils import diagonal_concat
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import vertical_concat
from narwhals._compliant import CompliantThen
from narwhals._compliant import EagerNamespace
from narwhals._compliant import EagerWhen
from narwhals._expression_parsing import combine_alias_output_names
from narwhals._expression_parsing import combine_evaluate_output_names
from narwhals.utils import Implementation
from narwhals.utils import import_dtypes_module
if TYPE_CHECKING:
from typing_extensions import Self
from narwhals._arrow.typing import ArrowChunkedArray
from narwhals._arrow.typing import Incomplete
from narwhals.dtypes import DType
from narwhals.utils import Version
class ArrowNamespace(EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr]):
@property
def _dataframe(self) -> type[ArrowDataFrame]:
return ArrowDataFrame
@property
def _expr(self) -> type[ArrowExpr]:
return ArrowExpr
@property
def _series(self) -> type[ArrowSeries]:
return ArrowSeries
# --- not in spec ---
def __init__(
self: Self, *, backend_version: tuple[int, ...], version: Version
) -> None:
self._backend_version = backend_version
self._implementation = Implementation.PYARROW
self._version = version
def len(self: Self) -> ArrowExpr:
# coverage bug? this is definitely hit
return self._expr( # pragma: no cover
lambda df: [
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
],
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
def lit(self: Self, value: Any, dtype: DType | type[DType] | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries.from_iterable(
data=[value], name="literal", context=self
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series
return self._expr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
backend_version=self._backend_version,
version=self._version,
)
def all_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = chain.from_iterable(expr(df) for expr in exprs)
return [reduce(operator.and_, align_series_full_broadcast(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def any_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series = chain.from_iterable(expr(df) for expr in exprs)
return [reduce(operator.or_, align_series_full_broadcast(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def sum_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
it = chain.from_iterable(expr(df) for expr in exprs)
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
return [reduce(operator.add, align_series_full_broadcast(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def mean_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
dtypes = import_dtypes_module(self._version)
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
series = align_series_full_broadcast(
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
)
non_na = align_series_full_broadcast(
*(1 - s.is_null().cast(dtypes.Int64()) for s in expr_results)
)
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def min_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align_series_full_broadcast(init_series, *series)
native_series = reduce(
pc.min_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(
native_series,
name=init_series.name,
backend_version=self._backend_version,
version=self._version,
)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def max_horizontal(self: Self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align_series_full_broadcast(init_series, *series)
native_series = reduce(
pc.max_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(
native_series,
name=init_series.name,
backend_version=self._backend_version,
version=self._version,
)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def concat(
self: Self,
items: Iterable[ArrowDataFrame],
*,
how: Literal["horizontal", "vertical", "diagonal"],
) -> ArrowDataFrame:
dfs = [item.native for item in items]
if not dfs:
msg = "No dataframes to concatenate" # pragma: no cover
raise AssertionError(msg)
if how == "horizontal":
result_table = horizontal_concat(dfs)
elif how == "vertical":
result_table = vertical_concat(dfs)
elif how == "diagonal":
result_table = diagonal_concat(dfs, self._backend_version)
else:
raise NotImplementedError
return ArrowDataFrame(
result_table,
backend_version=self._backend_version,
version=self._version,
validate_column_names=True,
)
@property
def selectors(self: Self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace(self)
def when(self: Self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen.from_expr(predicate, context=self)
def concat_str(
self: Self,
*exprs: ArrowExpr,
separator: str,
ignore_nulls: bool,
) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
compliant_series_list = align_series_full_broadcast(
*(chain.from_iterable(expr(df) for expr in exprs))
)
name = compliant_series_list[0].name
null_handling: Literal["skip", "emit_null"] = (
"skip" if ignore_nulls else "emit_null"
)
it, separator_scalar = cast_to_comparable_string_types(
*(s.native for s in compliant_series_list), separator=separator
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
compliant = self._series(
concat_str(*it, separator_scalar, null_handling=null_handling),
name=name,
backend_version=self._backend_version,
version=self._version,
)
return [compliant]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ArrowChunkedArray"]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen
def _if_then_else(
self, when: ArrowChunkedArray, then: ArrowChunkedArray, otherwise: Any, /
) -> ArrowChunkedArray:
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
return pc.if_else(when, then, otherwise)
class ArrowThen(CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr], ArrowExpr): ...