1138 lines
40 KiB
Python
Executable File
1138 lines
40 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Iterable
|
|
from typing import Iterator
|
|
from typing import Literal
|
|
from typing import Sequence
|
|
from typing import overload
|
|
|
|
from narwhals._arrow.utils import cast_for_truediv
|
|
from narwhals._arrow.utils import floordiv_compat
|
|
from narwhals._arrow.utils import narwhals_to_native_dtype
|
|
from narwhals._arrow.utils import native_to_narwhals_dtype
|
|
from narwhals._arrow.utils import parse_datetime_format
|
|
from narwhals._arrow.utils import validate_column_comparand
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import generate_temporary_column_name
|
|
|
|
if TYPE_CHECKING:
|
|
from types import ModuleType
|
|
|
|
import pyarrow as pa
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import DTypes
|
|
|
|
|
|
class ArrowSeries:
|
|
def __init__(
|
|
self,
|
|
native_series: pa.ChunkedArray,
|
|
*,
|
|
name: str,
|
|
backend_version: tuple[int, ...],
|
|
dtypes: DTypes,
|
|
) -> None:
|
|
self._name = name
|
|
self._native_series = native_series
|
|
self._implementation = Implementation.PYARROW
|
|
self._backend_version = backend_version
|
|
self._dtypes = dtypes
|
|
|
|
def _from_native_series(self, series: Any) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
if isinstance(series, pa.Array):
|
|
series = pa.chunked_array([series])
|
|
return self.__class__(
|
|
series,
|
|
name=self._name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
@classmethod
|
|
def _from_iterable(
|
|
cls: type[Self],
|
|
data: Iterable[Any],
|
|
name: str,
|
|
*,
|
|
backend_version: tuple[int, ...],
|
|
dtypes: DTypes,
|
|
) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
return cls(
|
|
pa.chunked_array([data]),
|
|
name=name,
|
|
backend_version=backend_version,
|
|
dtypes=dtypes,
|
|
)
|
|
|
|
def __narwhals_namespace__(self) -> ArrowNamespace:
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
|
|
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._native_series)
|
|
|
|
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
|
import pyarrow.compute as pc
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.equal(ser, other))
|
|
|
|
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.not_equal(ser, other))
|
|
|
|
def __ge__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.greater_equal(ser, other))
|
|
|
|
def __gt__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.greater(ser, other))
|
|
|
|
def __le__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.less_equal(ser, other))
|
|
|
|
def __lt__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.less(ser, other))
|
|
|
|
def __and__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.and_kleene(ser, other))
|
|
|
|
def __rand__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.and_kleene(other, ser))
|
|
|
|
def __or__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.or_kleene(ser, other))
|
|
|
|
def __ror__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.or_kleene(other, ser))
|
|
|
|
def __add__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.add(self._native_series, other))
|
|
|
|
def __radd__(self, other: Any) -> Self:
|
|
return self + other # type: ignore[no-any-return]
|
|
|
|
def __sub__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.subtract(self._native_series, other))
|
|
|
|
def __rsub__(self, other: Any) -> Self:
|
|
return (self - other) * (-1) # type: ignore[no-any-return]
|
|
|
|
def __mul__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.multiply(self._native_series, other))
|
|
|
|
def __rmul__(self, other: Any) -> Self:
|
|
return self * other # type: ignore[no-any-return]
|
|
|
|
def __pow__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.power(ser, other))
|
|
|
|
def __rpow__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(pc.power(other, ser))
|
|
|
|
def __floordiv__(self, other: Any) -> Self:
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(floordiv_compat(ser, other))
|
|
|
|
def __rfloordiv__(self, other: Any) -> Self:
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(floordiv_compat(other, ser))
|
|
|
|
def __truediv__(self, other: Any) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
|
|
# scalar
|
|
other = pa.scalar(other)
|
|
return self._from_native_series(pc.divide(*cast_for_truediv(ser, other)))
|
|
|
|
def __rtruediv__(self, other: Any) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
if not isinstance(other, (pa.Array, pa.ChunkedArray)):
|
|
# scalar
|
|
other = pa.scalar(other)
|
|
return self._from_native_series(pc.divide(*cast_for_truediv(other, ser)))
|
|
|
|
def __mod__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
floor_div = (self // other)._native_series
|
|
res = pc.subtract(ser, pc.multiply(floor_div, other))
|
|
return self._from_native_series(res)
|
|
|
|
def __rmod__(self, other: Any) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
other = validate_column_comparand(other)
|
|
floor_div = (other // self)._native_series
|
|
res = pc.subtract(other, pc.multiply(floor_div, ser))
|
|
return self._from_native_series(res)
|
|
|
|
def __invert__(self) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(pc.invert(self._native_series))
|
|
|
|
def len(self) -> int:
|
|
return len(self._native_series)
|
|
|
|
def filter(self, other: Any) -> Self:
|
|
if not (isinstance(other, list) and all(isinstance(x, bool) for x in other)):
|
|
other = validate_column_comparand(other)
|
|
return self._from_native_series(self._native_series.filter(other))
|
|
|
|
def mean(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.mean(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def min(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.min(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def max(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.max(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def sum(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.sum(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def drop_nulls(self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(pc.drop_null(self._native_series))
|
|
|
|
def shift(self, n: int) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
ca = self._native_series
|
|
|
|
if n > 0:
|
|
result = pa.concat_arrays([pa.nulls(n, ca.type), *ca[:-n].chunks])
|
|
elif n < 0:
|
|
result = pa.concat_arrays([*ca[-n:].chunks, pa.nulls(-n, ca.type)])
|
|
else:
|
|
result = ca
|
|
return self._from_native_series(result)
|
|
|
|
def std(self, ddof: int = 1) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return]
|
|
|
|
def count(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.count(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def n_unique(self) -> int:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
unique_values = pc.unique(self._native_series)
|
|
return pc.count(unique_values, mode="all") # type: ignore[no-any-return]
|
|
|
|
def __native_namespace__(self: Self) -> ModuleType:
|
|
if self._implementation is Implementation.PYARROW:
|
|
return self._implementation.to_native_namespace()
|
|
|
|
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
|
|
raise AssertionError(msg)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return self._name
|
|
|
|
def __narwhals_series__(self) -> Self:
|
|
return self
|
|
|
|
@overload
|
|
def __getitem__(self, idx: int) -> Any: ...
|
|
|
|
@overload
|
|
def __getitem__(self, idx: slice | Sequence[int]) -> Self: ...
|
|
|
|
def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self:
|
|
if isinstance(idx, int):
|
|
return self._native_series[idx]
|
|
if isinstance(idx, Sequence):
|
|
return self._from_native_series(self._native_series.take(idx))
|
|
return self._from_native_series(self._native_series[idx])
|
|
|
|
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
ca = self._native_series
|
|
mask = np.zeros(len(ca), dtype=bool)
|
|
mask[indices] = True
|
|
if isinstance(values, self.__class__):
|
|
values = validate_column_comparand(values)
|
|
if isinstance(values, pa.ChunkedArray):
|
|
values = values.combine_chunks()
|
|
if not isinstance(values, pa.Array):
|
|
values = pa.array(values)
|
|
result = pc.replace_with_mask(ca, mask, values.take(indices))
|
|
return self._from_native_series(result)
|
|
|
|
def to_list(self) -> Any:
|
|
return self._native_series.to_pylist()
|
|
|
|
def __array__(self, dtype: Any = None, copy: bool | None = None) -> Any:
|
|
return self._native_series.__array__(dtype=dtype, copy=copy)
|
|
|
|
def to_numpy(self) -> Any:
|
|
return self._native_series.to_numpy()
|
|
|
|
def alias(self, name: str) -> Self:
|
|
return self.__class__(
|
|
self._native_series,
|
|
name=name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
@property
|
|
def dtype(self: Self) -> DType:
|
|
return native_to_narwhals_dtype(self._native_series.type, self._dtypes)
|
|
|
|
def abs(self) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(pc.abs(self._native_series))
|
|
|
|
def cum_sum(self) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(pc.cumulative_sum(self._native_series))
|
|
|
|
def round(self, decimals: int) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(
|
|
pc.round(self._native_series, decimals, round_mode="half_towards_infinity")
|
|
)
|
|
|
|
def diff(self) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(
|
|
pc.pairwise_diff(self._native_series.combine_chunks())
|
|
)
|
|
|
|
def any(self) -> bool:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.any(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def all(self) -> bool:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.all(self._native_series) # type: ignore[no-any-return]
|
|
|
|
def is_between(
|
|
self, lower_bound: Any, upper_bound: Any, closed: str = "both"
|
|
) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
if closed == "left":
|
|
ge = pc.greater_equal(ser, lower_bound)
|
|
lt = pc.less(ser, upper_bound)
|
|
res = pc.and_kleene(ge, lt)
|
|
elif closed == "right":
|
|
gt = pc.greater(ser, lower_bound)
|
|
le = pc.less_equal(ser, upper_bound)
|
|
res = pc.and_kleene(gt, le)
|
|
elif closed == "none":
|
|
gt = pc.greater(ser, lower_bound)
|
|
lt = pc.less(ser, upper_bound)
|
|
res = pc.and_kleene(gt, lt)
|
|
elif closed == "both":
|
|
ge = pc.greater_equal(ser, lower_bound)
|
|
le = pc.less_equal(ser, upper_bound)
|
|
res = pc.and_kleene(ge, le)
|
|
else: # pragma: no cover
|
|
raise AssertionError
|
|
return self._from_native_series(res)
|
|
|
|
def is_empty(self) -> bool:
|
|
return len(self) == 0
|
|
|
|
def is_null(self) -> Self:
|
|
ser = self._native_series
|
|
return self._from_native_series(ser.is_null())
|
|
|
|
def cast(self, dtype: DType) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
dtype = narwhals_to_native_dtype(dtype, self._dtypes)
|
|
return self._from_native_series(pc.cast(ser, dtype))
|
|
|
|
def null_count(self: Self) -> int:
|
|
return self._native_series.null_count # type: ignore[no-any-return]
|
|
|
|
def head(self, n: int) -> Self:
|
|
ser = self._native_series
|
|
if n >= 0:
|
|
return self._from_native_series(ser.slice(0, n))
|
|
else:
|
|
num_rows = len(ser)
|
|
return self._from_native_series(ser.slice(0, max(0, num_rows + n)))
|
|
|
|
def tail(self, n: int) -> Self:
|
|
ser = self._native_series
|
|
if n >= 0:
|
|
num_rows = len(ser)
|
|
return self._from_native_series(ser.slice(max(0, num_rows - n)))
|
|
else:
|
|
return self._from_native_series(ser.slice(abs(n)))
|
|
|
|
def is_in(self, other: Any) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
value_set = pa.array(other)
|
|
ser = self._native_series
|
|
return self._from_native_series(pc.is_in(ser, value_set=value_set))
|
|
|
|
def arg_true(self) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
|
|
ser = self._native_series
|
|
res = np.flatnonzero(ser)
|
|
return self._from_iterable(
|
|
res,
|
|
name=self.name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def item(self: Self, index: int | None = None) -> Any:
|
|
if index is None:
|
|
if len(self) != 1:
|
|
msg = (
|
|
"can only call '.item()' if the Series is of length 1,"
|
|
f" or an explicit index is provided (Series is of length {len(self)})"
|
|
)
|
|
raise ValueError(msg)
|
|
return self._native_series[0]
|
|
return self._native_series[index]
|
|
|
|
def value_counts(
|
|
self: Self,
|
|
*,
|
|
sort: bool = False,
|
|
parallel: bool = False,
|
|
name: str | None = None,
|
|
normalize: bool = False,
|
|
) -> ArrowDataFrame:
|
|
"""Parallel is unused, exists for compatibility"""
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
|
|
index_name_ = "index" if self._name is None else self._name
|
|
value_name_ = name or ("proportion" if normalize else "count")
|
|
|
|
val_count = pc.value_counts(self._native_series)
|
|
values = val_count.field("values")
|
|
counts = val_count.field("counts")
|
|
|
|
if normalize:
|
|
counts = pc.divide(*cast_for_truediv(counts, pc.sum(counts)))
|
|
|
|
val_count = pa.Table.from_arrays(
|
|
[values, counts], names=[index_name_, value_name_]
|
|
)
|
|
|
|
if sort:
|
|
val_count = val_count.sort_by([(value_name_, "descending")])
|
|
|
|
return ArrowDataFrame(
|
|
val_count, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def zip_with(self: Self, mask: Self, other: Self) -> Self:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
mask = mask._native_series.combine_chunks()
|
|
return self._from_native_series(
|
|
pc.if_else(
|
|
mask,
|
|
self._native_series,
|
|
other._native_series,
|
|
)
|
|
)
|
|
|
|
def sample(
|
|
self: Self,
|
|
n: int | None = None,
|
|
*,
|
|
fraction: float | None = None,
|
|
with_replacement: bool = False,
|
|
seed: int | None = None,
|
|
) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
num_rows = len(self)
|
|
|
|
if n is None and fraction is not None:
|
|
n = int(num_rows * fraction)
|
|
|
|
rng = np.random.default_rng(seed=seed)
|
|
idx = np.arange(0, num_rows)
|
|
mask = rng.choice(idx, size=n, replace=with_replacement)
|
|
|
|
return self._from_native_series(pc.take(ser, mask))
|
|
|
|
def fill_null(self: Self, value: Any) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
dtype = ser.type
|
|
|
|
return self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype)))
|
|
|
|
def to_frame(self: Self) -> ArrowDataFrame:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
|
|
df = pa.Table.from_arrays([self._native_series], names=[self.name])
|
|
return ArrowDataFrame(
|
|
df, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def to_pandas(self: Self) -> Any:
|
|
import pandas as pd # ignore-banned-import()
|
|
|
|
return pd.Series(self._native_series, name=self.name)
|
|
|
|
def is_duplicated(self: Self) -> ArrowSeries:
|
|
return self.to_frame().is_duplicated().alias(self.name)
|
|
|
|
def is_unique(self: Self) -> ArrowSeries:
|
|
return self.to_frame().is_unique().alias(self.name)
|
|
|
|
def is_first_distinct(self: Self) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
row_number = pa.array(np.arange(len(self)))
|
|
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
|
|
first_distinct_index = (
|
|
pa.Table.from_arrays([self._native_series], names=[self.name])
|
|
.append_column(col_token, row_number)
|
|
.group_by(self.name)
|
|
.aggregate([(col_token, "min")])
|
|
.column(f"{col_token}_min")
|
|
)
|
|
|
|
return self._from_native_series(pc.is_in(row_number, first_distinct_index))
|
|
|
|
def is_last_distinct(self: Self) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
row_number = pa.array(np.arange(len(self)))
|
|
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
|
|
last_distinct_index = (
|
|
pa.Table.from_arrays([self._native_series], names=[self.name])
|
|
.append_column(col_token, row_number)
|
|
.group_by(self.name)
|
|
.aggregate([(col_token, "max")])
|
|
.column(f"{col_token}_max")
|
|
)
|
|
|
|
return self._from_native_series(pc.is_in(row_number, last_distinct_index))
|
|
|
|
def is_sorted(self: Self, *, descending: bool = False) -> bool:
|
|
if not isinstance(descending, bool):
|
|
msg = f"argument 'descending' should be boolean, found {type(descending)}"
|
|
raise TypeError(msg)
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
ser = self._native_series
|
|
if descending:
|
|
return pc.all(pc.greater_equal(ser[:-1], ser[1:])) # type: ignore[no-any-return]
|
|
else:
|
|
return pc.all(pc.less_equal(ser[:-1], ser[1:])) # type: ignore[no-any-return]
|
|
|
|
def unique(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._from_native_series(pc.unique(self._native_series))
|
|
|
|
def sort(
|
|
self: Self, *, descending: bool = False, nulls_last: bool = False
|
|
) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
series = self._native_series
|
|
order = "descending" if descending else "ascending"
|
|
null_placement = "at_end" if nulls_last else "at_start"
|
|
sorted_indices = pc.array_sort_indices(
|
|
series, order=order, null_placement=null_placement
|
|
)
|
|
|
|
return self._from_native_series(pc.take(series, sorted_indices))
|
|
|
|
def to_dummies(
|
|
self: Self, *, separator: str = "_", drop_first: bool = False
|
|
) -> ArrowDataFrame:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
from narwhals._arrow.dataframe import ArrowDataFrame
|
|
|
|
series = self._native_series
|
|
da = series.dictionary_encode().combine_chunks()
|
|
|
|
columns = np.zeros((len(da.dictionary), len(da)), np.uint8)
|
|
columns[da.indices, np.arange(len(da))] = 1
|
|
names = [f"{self._name}{separator}{v}" for v in da.dictionary]
|
|
|
|
return ArrowDataFrame(
|
|
pa.Table.from_arrays(columns, names=names),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
).select(*sorted(names)[int(drop_first) :])
|
|
|
|
def quantile(
|
|
self: Self,
|
|
quantile: float,
|
|
interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"],
|
|
) -> Any:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[
|
|
0
|
|
]
|
|
|
|
def gather_every(self: Self, n: int, offset: int = 0) -> Self:
|
|
return self._from_native_series(self._native_series[offset::n])
|
|
|
|
def clip(
|
|
self: Self, lower_bound: Any | None = None, upper_bound: Any | None = None
|
|
) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._native_series
|
|
arr = pc.max_element_wise(arr, pa.scalar(lower_bound, type=arr.type))
|
|
arr = pc.min_element_wise(arr, pa.scalar(upper_bound, type=arr.type))
|
|
|
|
return self._from_native_series(arr)
|
|
|
|
def to_arrow(self: Self) -> pa.Array:
|
|
return self._native_series.combine_chunks()
|
|
|
|
def mode(self: Self) -> ArrowSeries:
|
|
plx = self.__narwhals_namespace__()
|
|
col_token = generate_temporary_column_name(n_bytes=8, columns=[self.name])
|
|
return self.value_counts(name=col_token, normalize=False).filter(
|
|
plx.col(col_token) == plx.col(col_token).max()
|
|
)[self.name]
|
|
|
|
def __iter__(self: Self) -> Iterator[Any]:
|
|
yield from self._native_series.__iter__()
|
|
|
|
@property
|
|
def shape(self) -> tuple[int]:
|
|
return (len(self._native_series),)
|
|
|
|
@property
|
|
def dt(self) -> ArrowSeriesDateTimeNamespace:
|
|
return ArrowSeriesDateTimeNamespace(self)
|
|
|
|
@property
|
|
def cat(self) -> ArrowSeriesCatNamespace:
|
|
return ArrowSeriesCatNamespace(self)
|
|
|
|
@property
|
|
def str(self) -> ArrowSeriesStringNamespace:
|
|
return ArrowSeriesStringNamespace(self)
|
|
|
|
|
|
class ArrowSeriesDateTimeNamespace:
|
|
def __init__(self: Self, series: ArrowSeries) -> None:
|
|
self._arrow_series = series
|
|
|
|
def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
# PyArrow differs from other libraries in that %S also prints out
|
|
# the fractional part of the second...:'(
|
|
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
|
|
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
|
|
return self._arrow_series._from_native_series(
|
|
pc.strftime(self._arrow_series._native_series, format)
|
|
)
|
|
|
|
def replace_time_zone(self: Self, time_zone: str | None) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
if time_zone is not None:
|
|
result = pc.assume_timezone(
|
|
pc.local_timestamp(self._arrow_series._native_series), time_zone
|
|
)
|
|
else:
|
|
result = pc.local_timestamp(self._arrow_series._native_series)
|
|
return self._arrow_series._from_native_series(result)
|
|
|
|
def convert_time_zone(self: Self, time_zone: str) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import
|
|
|
|
if self._arrow_series.dtype.time_zone is None: # type: ignore[attr-defined]
|
|
result = self.replace_time_zone("UTC")._native_series.cast(
|
|
pa.timestamp(self._arrow_series._native_series.type.unit, time_zone)
|
|
)
|
|
else:
|
|
result = self._arrow_series._native_series.cast(
|
|
pa.timestamp(self._arrow_series._native_series.type.unit, time_zone)
|
|
)
|
|
|
|
return self._arrow_series._from_native_series(result)
|
|
|
|
def timestamp(self: Self, time_unit: Literal["ns", "us", "ms"] = "us") -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import
|
|
import pyarrow.compute as pc # ignore-banned-import
|
|
|
|
s = self._arrow_series._native_series
|
|
dtype = self._arrow_series.dtype
|
|
if dtype == self._arrow_series._dtypes.Datetime:
|
|
unit = dtype.time_unit # type: ignore[attr-defined]
|
|
s_cast = s.cast(pa.int64())
|
|
if unit == "ns":
|
|
if time_unit == "ns":
|
|
result = s_cast
|
|
elif time_unit == "us":
|
|
result = floordiv_compat(s_cast, 1_000)
|
|
else:
|
|
result = floordiv_compat(s_cast, 1_000_000)
|
|
elif unit == "us":
|
|
if time_unit == "ns":
|
|
result = pc.multiply(s_cast, 1_000)
|
|
elif time_unit == "us":
|
|
result = s_cast
|
|
else:
|
|
result = floordiv_compat(s_cast, 1_000)
|
|
elif unit == "ms":
|
|
if time_unit == "ns":
|
|
result = pc.multiply(s_cast, 1_000_000)
|
|
elif time_unit == "us":
|
|
result = pc.multiply(s_cast, 1_000)
|
|
else:
|
|
result = s_cast
|
|
elif unit == "s":
|
|
if time_unit == "ns":
|
|
result = pc.multiply(s_cast, 1_000_000_000)
|
|
elif time_unit == "us":
|
|
result = pc.multiply(s_cast, 1_000_000)
|
|
else:
|
|
result = pc.multiply(s_cast, 1_000)
|
|
else: # pragma: no cover
|
|
msg = f"unexpected time unit {unit}, please report an issue at https://github.com/narwhals-dev/narwhals"
|
|
raise AssertionError(msg)
|
|
elif dtype == self._arrow_series._dtypes.Date:
|
|
time_s = pc.multiply(s.cast(pa.int32()), 86400)
|
|
if time_unit == "ns":
|
|
result = pc.multiply(time_s, 1_000_000_000)
|
|
elif time_unit == "us":
|
|
result = pc.multiply(time_s, 1_000_000)
|
|
else:
|
|
result = pc.multiply(time_s, 1_000)
|
|
else:
|
|
msg = "Input should be either of Date or Datetime type"
|
|
raise TypeError(msg)
|
|
return self._arrow_series._from_native_series(result)
|
|
|
|
def date(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
self._arrow_series._native_series.cast(pa.date32())
|
|
)
|
|
|
|
def year(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.year(self._arrow_series._native_series)
|
|
)
|
|
|
|
def month(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.month(self._arrow_series._native_series)
|
|
)
|
|
|
|
def day(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.day(self._arrow_series._native_series)
|
|
)
|
|
|
|
def hour(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.hour(self._arrow_series._native_series)
|
|
)
|
|
|
|
def minute(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.minute(self._arrow_series._native_series)
|
|
)
|
|
|
|
def second(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.second(self._arrow_series._native_series)
|
|
)
|
|
|
|
def millisecond(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.millisecond(self._arrow_series._native_series)
|
|
)
|
|
|
|
def microsecond(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
result = pc.add(pc.multiply(pc.millisecond(arr), 1000), pc.microsecond(arr))
|
|
|
|
return self._arrow_series._from_native_series(result)
|
|
|
|
def nanosecond(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
result = pc.add(
|
|
pc.multiply(self.microsecond()._native_series, 1000), pc.nanosecond(arr)
|
|
)
|
|
return self._arrow_series._from_native_series(result)
|
|
|
|
def ordinal_day(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.day_of_year(self._arrow_series._native_series)
|
|
)
|
|
|
|
def total_minutes(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
unit = arr.type.unit
|
|
|
|
unit_to_minutes_factor = {
|
|
"s": 60, # seconds
|
|
"ms": 60 * 1e3, # milli
|
|
"us": 60 * 1e6, # micro
|
|
"ns": 60 * 1e9, # nano
|
|
}
|
|
|
|
factor = pa.scalar(unit_to_minutes_factor[unit], type=pa.int64())
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.divide(arr, factor), pa.int64())
|
|
)
|
|
|
|
def total_seconds(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
unit = arr.type.unit
|
|
|
|
unit_to_seconds_factor = {
|
|
"s": 1, # seconds
|
|
"ms": 1e3, # milli
|
|
"us": 1e6, # micro
|
|
"ns": 1e9, # nano
|
|
}
|
|
factor = pa.scalar(unit_to_seconds_factor[unit], type=pa.int64())
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.divide(arr, factor), pa.int64())
|
|
)
|
|
|
|
def total_milliseconds(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
unit = arr.type.unit
|
|
|
|
unit_to_milli_factor = {
|
|
"s": 1e3, # seconds
|
|
"ms": 1, # milli
|
|
"us": 1e3, # micro
|
|
"ns": 1e6, # nano
|
|
}
|
|
|
|
factor = pa.scalar(unit_to_milli_factor[unit], type=pa.int64())
|
|
|
|
if unit == "s":
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.multiply(arr, factor), pa.int64())
|
|
)
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.divide(arr, factor), pa.int64())
|
|
)
|
|
|
|
def total_microseconds(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
unit = arr.type.unit
|
|
|
|
unit_to_micro_factor = {
|
|
"s": 1e6, # seconds
|
|
"ms": 1e3, # milli
|
|
"us": 1, # micro
|
|
"ns": 1e3, # nano
|
|
}
|
|
|
|
factor = pa.scalar(unit_to_micro_factor[unit], type=pa.int64())
|
|
|
|
if unit in {"s", "ms"}:
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.multiply(arr, factor), pa.int64())
|
|
)
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.divide(arr, factor), pa.int64())
|
|
)
|
|
|
|
def total_nanoseconds(self: Self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
arr = self._arrow_series._native_series
|
|
unit = arr.type.unit
|
|
|
|
unit_to_nano_factor = {
|
|
"s": 1e9, # seconds
|
|
"ms": 1e6, # milli
|
|
"us": 1e3, # micro
|
|
"ns": 1, # nano
|
|
}
|
|
|
|
factor = pa.scalar(unit_to_nano_factor[unit], type=pa.int64())
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.cast(pc.multiply(arr, factor), pa.int64())
|
|
)
|
|
|
|
|
|
class ArrowSeriesCatNamespace:
|
|
def __init__(self, series: ArrowSeries) -> None:
|
|
self._arrow_series = series
|
|
|
|
def get_categories(self) -> ArrowSeries:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
ca = self._arrow_series._native_series
|
|
# TODO(Unassigned): this looks potentially expensive - is there no better way?
|
|
# https://github.com/narwhals-dev/narwhals/issues/464
|
|
out = pa.chunked_array(
|
|
[pa.concat_arrays([x.dictionary for x in ca.chunks]).unique()]
|
|
)
|
|
return self._arrow_series._from_native_series(out)
|
|
|
|
|
|
class ArrowSeriesStringNamespace:
|
|
def __init__(self: Self, series: ArrowSeries) -> None:
|
|
self._arrow_series = series
|
|
|
|
def len_chars(self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.utf8_length(self._arrow_series._native_series)
|
|
)
|
|
|
|
def replace(
|
|
self, pattern: str, value: str, *, literal: bool = False, n: int = 1
|
|
) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
method = "replace_substring" if literal else "replace_substring_regex"
|
|
return self._arrow_series._from_native_series(
|
|
getattr(pc, method)(
|
|
self._arrow_series._native_series,
|
|
pattern=pattern,
|
|
replacement=value,
|
|
max_replacements=n,
|
|
)
|
|
)
|
|
|
|
def replace_all(
|
|
self, pattern: str, value: str, *, literal: bool = False
|
|
) -> ArrowSeries:
|
|
return self.replace(pattern, value, literal=literal, n=-1)
|
|
|
|
def strip_chars(self: Self, characters: str | None = None) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
whitespace = " \t\n\r\v\f"
|
|
return self._arrow_series._from_native_series(
|
|
pc.utf8_trim(
|
|
self._arrow_series._native_series,
|
|
characters or whitespace,
|
|
)
|
|
)
|
|
|
|
def starts_with(self: Self, prefix: str) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.equal(self.slice(0, len(prefix))._native_series, prefix)
|
|
)
|
|
|
|
def ends_with(self: Self, suffix: str) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.equal(self.slice(-len(suffix))._native_series, suffix)
|
|
)
|
|
|
|
def contains(self: Self, pattern: str, *, literal: bool = False) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
check_func = pc.match_substring if literal else pc.match_substring_regex
|
|
return self._arrow_series._from_native_series(
|
|
check_func(self._arrow_series._native_series, pattern)
|
|
)
|
|
|
|
def slice(self: Self, offset: int, length: int | None = None) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
stop = offset + length if length else None
|
|
return self._arrow_series._from_native_series(
|
|
pc.utf8_slice_codeunits(
|
|
self._arrow_series._native_series, start=offset, stop=stop
|
|
),
|
|
)
|
|
|
|
def to_datetime(self: Self, format: str | None) -> ArrowSeries: # noqa: A002
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
if format is None:
|
|
format = parse_datetime_format(self._arrow_series._native_series)
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.strptime(self._arrow_series._native_series, format=format, unit="us")
|
|
)
|
|
|
|
def to_uppercase(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.utf8_upper(self._arrow_series._native_series),
|
|
)
|
|
|
|
def to_lowercase(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
return self._arrow_series._from_native_series(
|
|
pc.utf8_lower(self._arrow_series._native_series),
|
|
)
|