516 lines
18 KiB
Python
Executable File
516 lines
18 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from functools import reduce
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Iterable
|
|
from typing import Literal
|
|
from typing import cast
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
from narwhals._arrow.selectors import ArrowSelectorNamespace
|
|
from narwhals._arrow.series import ArrowSeries
|
|
from narwhals._arrow.utils import broadcast_series
|
|
from narwhals._arrow.utils import horizontal_concat
|
|
from narwhals._arrow.utils import vertical_concat
|
|
from narwhals._expression_parsing import combine_root_names
|
|
from narwhals._expression_parsing import parse_into_exprs
|
|
from narwhals._expression_parsing import reduce_output_names
|
|
from narwhals.utils import Implementation
|
|
|
|
if TYPE_CHECKING:
|
|
from typing import Callable
|
|
|
|
from narwhals._arrow.typing import IntoArrowExpr
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import DTypes
|
|
|
|
|
|
class ArrowNamespace:
|
|
def _create_expr_from_callable(
|
|
self,
|
|
func: Callable[[ArrowDataFrame], list[ArrowSeries]],
|
|
*,
|
|
depth: int,
|
|
function_name: str,
|
|
root_names: list[str] | None,
|
|
output_names: list[str] | None,
|
|
) -> ArrowExpr:
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
|
|
return ArrowExpr(
|
|
func,
|
|
depth=depth,
|
|
function_name=function_name,
|
|
root_names=root_names,
|
|
output_names=output_names,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def _create_expr_from_series(self, series: ArrowSeries) -> ArrowExpr:
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
|
|
return ArrowExpr(
|
|
lambda _df: [series],
|
|
depth=0,
|
|
function_name="series",
|
|
root_names=None,
|
|
output_names=None,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSeries:
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
if self._backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover
|
|
value = value.as_py()
|
|
return ArrowSeries._from_iterable(
|
|
[value],
|
|
name=series.name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def _create_compliant_series(self, value: Any) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
return ArrowSeries(
|
|
native_series=pa.chunked_array([value]),
|
|
name="",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
# --- not in spec ---
|
|
def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None:
|
|
self._backend_version = backend_version
|
|
self._implementation = Implementation.PYARROW
|
|
self._dtypes = dtypes
|
|
|
|
# --- selection ---
|
|
def col(self, *column_names: str) -> ArrowExpr:
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
|
|
return ArrowExpr.from_column_names(
|
|
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def nth(self, *column_indices: int) -> ArrowExpr:
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
|
|
return ArrowExpr.from_column_indices(
|
|
*column_indices, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def len(self) -> ArrowExpr:
|
|
# coverage bug? this is definitely hit
|
|
return ArrowExpr( # pragma: no cover
|
|
lambda df: [
|
|
ArrowSeries._from_iterable(
|
|
[len(df._native_frame)],
|
|
name="len",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
],
|
|
depth=0,
|
|
function_name="len",
|
|
root_names=None,
|
|
output_names=["len"],
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def all(self) -> ArrowExpr:
|
|
from narwhals._arrow.expr import ArrowExpr
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
return ArrowExpr(
|
|
lambda df: [
|
|
ArrowSeries(
|
|
df._native_frame[column_name],
|
|
name=column_name,
|
|
backend_version=df._backend_version,
|
|
dtypes=df._dtypes,
|
|
)
|
|
for column_name in df.columns
|
|
],
|
|
depth=0,
|
|
function_name="all",
|
|
root_names=None,
|
|
output_names=None,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def lit(self, value: Any, dtype: DType | None) -> ArrowExpr:
|
|
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
|
|
arrow_series = ArrowSeries._from_iterable(
|
|
data=[value],
|
|
name="lit",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
if dtype:
|
|
return arrow_series.cast(dtype)
|
|
return arrow_series
|
|
|
|
return ArrowExpr(
|
|
lambda df: [_lit_arrow_series(df)],
|
|
depth=0,
|
|
function_name="lit",
|
|
root_names=None,
|
|
output_names=["lit"],
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def all_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
series = (s for _expr in parsed_exprs for s in _expr._call(df))
|
|
return [reduce(lambda x, y: x & y, series)]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="all_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def any_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
series = (s for _expr in parsed_exprs for s in _expr._call(df))
|
|
return [reduce(lambda x, y: x | y, series)]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="any_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def sum_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr._call(df))
|
|
return [reduce(lambda x, y: x + y, series)]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="sum_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def mean_horizontal(self, *exprs: IntoArrowExpr) -> IntoArrowExpr:
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
series = (s.fill_null(0) for _expr in parsed_exprs for s in _expr._call(df))
|
|
non_na = (
|
|
1 - s.is_null().cast(self._dtypes.Int64())
|
|
for _expr in parsed_exprs
|
|
for s in _expr._call(df)
|
|
)
|
|
return [
|
|
reduce(lambda x, y: x + y, series) / reduce(lambda x, y: x + y, non_na)
|
|
]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="mean_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def min_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)]
|
|
return [
|
|
ArrowSeries(
|
|
native_series=reduce(
|
|
lambda x, y: pc.min_element_wise(x, y),
|
|
[s._native_series for s in series],
|
|
init_series._native_series,
|
|
),
|
|
name=init_series.name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="min_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def max_horizontal(self, *exprs: IntoArrowExpr) -> ArrowExpr:
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
parsed_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
init_series, *series = [s for _expr in parsed_exprs for s in _expr._call(df)]
|
|
return [
|
|
ArrowSeries(
|
|
native_series=reduce(
|
|
lambda x, y: pc.max_element_wise(x, y),
|
|
[s._native_series for s in series],
|
|
init_series._native_series,
|
|
),
|
|
name=init_series.name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="max_horizontal",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
def concat(
|
|
self,
|
|
items: Iterable[ArrowDataFrame],
|
|
*,
|
|
how: Literal["horizontal", "vertical"],
|
|
) -> ArrowDataFrame:
|
|
dfs: list[Any] = [item._native_frame for item in items]
|
|
|
|
if how == "horizontal":
|
|
return ArrowDataFrame(
|
|
horizontal_concat(dfs),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
if how == "vertical":
|
|
return ArrowDataFrame(
|
|
vertical_concat(dfs),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
raise NotImplementedError
|
|
|
|
def sum(self, *column_names: str) -> ArrowExpr:
|
|
return ArrowExpr.from_column_names(
|
|
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
|
|
).sum()
|
|
|
|
def mean(self, *column_names: str) -> ArrowExpr:
|
|
return ArrowExpr.from_column_names(
|
|
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
|
|
).mean()
|
|
|
|
def max(self, *column_names: str) -> ArrowExpr:
|
|
return ArrowExpr.from_column_names(
|
|
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
|
|
).max()
|
|
|
|
def min(self, *column_names: str) -> ArrowExpr:
|
|
return ArrowExpr.from_column_names(
|
|
*column_names, backend_version=self._backend_version, dtypes=self._dtypes
|
|
).min()
|
|
|
|
@property
|
|
def selectors(self) -> ArrowSelectorNamespace:
|
|
return ArrowSelectorNamespace(
|
|
backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def when(
|
|
self,
|
|
*predicates: IntoArrowExpr,
|
|
) -> ArrowWhen:
|
|
plx = self.__class__(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
if predicates:
|
|
condition = plx.all_horizontal(*predicates)
|
|
else:
|
|
msg = "at least one predicate needs to be provided"
|
|
raise TypeError(msg)
|
|
|
|
return ArrowWhen(condition, self._backend_version, dtypes=self._dtypes)
|
|
|
|
def concat_str(
|
|
self,
|
|
exprs: Iterable[IntoArrowExpr],
|
|
*more_exprs: IntoArrowExpr,
|
|
separator: str = "",
|
|
ignore_nulls: bool = False,
|
|
) -> ArrowExpr:
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
parsed_exprs: list[ArrowExpr] = [
|
|
*parse_into_exprs(*exprs, namespace=self),
|
|
*parse_into_exprs(*more_exprs, namespace=self),
|
|
]
|
|
|
|
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
series = (
|
|
s._native_series
|
|
for _expr in parsed_exprs
|
|
for s in _expr.cast(self._dtypes.String())._call(df)
|
|
)
|
|
null_handling = "skip" if ignore_nulls else "emit_null"
|
|
result_series = pc.binary_join_element_wise(
|
|
*series, separator, null_handling=null_handling
|
|
)
|
|
return [
|
|
ArrowSeries(
|
|
native_series=result_series,
|
|
name="",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
]
|
|
|
|
return self._create_expr_from_callable(
|
|
func=func,
|
|
depth=max(x._depth for x in parsed_exprs) + 1,
|
|
function_name="concat_str",
|
|
root_names=combine_root_names(parsed_exprs),
|
|
output_names=reduce_output_names(parsed_exprs),
|
|
)
|
|
|
|
|
|
class ArrowWhen:
|
|
def __init__(
|
|
self,
|
|
condition: ArrowExpr,
|
|
backend_version: tuple[int, ...],
|
|
then_value: Any = None,
|
|
otherwise_value: Any = None,
|
|
*,
|
|
dtypes: DTypes,
|
|
) -> None:
|
|
self._backend_version = backend_version
|
|
self._condition = condition
|
|
self._then_value = then_value
|
|
self._otherwise_value = otherwise_value
|
|
self._dtypes = dtypes
|
|
|
|
def __call__(self, df: ArrowDataFrame) -> list[ArrowSeries]:
|
|
import pyarrow as pa # ignore-banned-import
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
from narwhals._expression_parsing import parse_into_expr
|
|
|
|
plx = ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
|
|
condition = parse_into_expr(self._condition, namespace=plx)._call(df)[0] # type: ignore[arg-type]
|
|
try:
|
|
value_series = parse_into_expr(self._then_value, namespace=plx)._call(df)[0] # type: ignore[arg-type]
|
|
except TypeError:
|
|
# `self._otherwise_value` is a scalar and can't be converted to an expression
|
|
value_series = condition.__class__._from_iterable( # type: ignore[call-arg]
|
|
[self._then_value] * len(condition),
|
|
name="literal",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
value_series = cast(ArrowSeries, value_series)
|
|
|
|
value_series_native = value_series._native_series
|
|
condition_native = condition._native_series.combine_chunks()
|
|
|
|
if self._otherwise_value is None:
|
|
otherwise_native = pa.array(
|
|
[None] * len(condition_native), type=value_series_native.type
|
|
)
|
|
return [
|
|
value_series._from_native_series(
|
|
pc.if_else(condition_native, value_series_native, otherwise_native)
|
|
)
|
|
]
|
|
try:
|
|
otherwise_series = parse_into_expr(
|
|
self._otherwise_value, namespace=plx
|
|
)._call(df)[0] # type: ignore[arg-type]
|
|
except TypeError:
|
|
# `self._otherwise_value` is a scalar and can't be converted to an expression.
|
|
# Remark that string values _are_ converted into expressions!
|
|
return [
|
|
value_series._from_native_series(
|
|
pc.if_else(
|
|
condition_native, value_series_native, self._otherwise_value
|
|
)
|
|
)
|
|
]
|
|
else:
|
|
otherwise_series = cast(ArrowSeries, otherwise_series)
|
|
condition = cast(ArrowSeries, condition)
|
|
condition_native, otherwise_native = broadcast_series(
|
|
[condition, otherwise_series]
|
|
)
|
|
return [
|
|
value_series._from_native_series(
|
|
pc.if_else(condition_native, value_series_native, otherwise_native)
|
|
)
|
|
]
|
|
|
|
def then(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowThen:
|
|
self._then_value = value
|
|
|
|
return ArrowThen(
|
|
self,
|
|
depth=0,
|
|
function_name="whenthen",
|
|
root_names=None,
|
|
output_names=None,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
|
|
class ArrowThen(ArrowExpr):
|
|
def __init__(
|
|
self,
|
|
call: ArrowWhen,
|
|
*,
|
|
depth: int,
|
|
function_name: str,
|
|
root_names: list[str] | None,
|
|
output_names: list[str] | None,
|
|
backend_version: tuple[int, ...],
|
|
dtypes: DTypes,
|
|
) -> None:
|
|
self._backend_version = backend_version
|
|
self._dtypes = dtypes
|
|
self._call = call
|
|
self._depth = depth
|
|
self._function_name = function_name
|
|
self._root_names = root_names
|
|
self._output_names = output_names
|
|
|
|
def otherwise(self, value: ArrowExpr | ArrowSeries | Any) -> ArrowExpr:
|
|
# type ignore because we are setting the `_call` attribute to a
|
|
# callable object of type `PandasWhen`, base class has the attribute as
|
|
# only a `Callable`
|
|
self._call._otherwise_value = value # type: ignore[attr-defined]
|
|
self._function_name = "whenotherwise"
|
|
return self
|