Files
Buffteks-Website/venv/lib/python3.12/site-packages/narwhals/_compliant/dataframe.py
2025-05-08 21:10:14 -05:00

478 lines
17 KiB
Python

from __future__ import annotations
from itertools import chain
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
from typing import Literal
from typing import Mapping
from typing import Protocol
from typing import Sequence
from typing import Sized
from typing import TypeVar
from typing import overload
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._compliant.typing import CompliantExprT_contra
from narwhals._compliant.typing import CompliantLazyFrameAny
from narwhals._compliant.typing import CompliantSeriesT
from narwhals._compliant.typing import EagerExprT
from narwhals._compliant.typing import EagerSeriesT
from narwhals._compliant.typing import NativeFrameT
from narwhals._compliant.typing import NativeSeriesT
from narwhals._translate import ArrowConvertible
from narwhals._translate import DictConvertible
from narwhals._translate import FromNative
from narwhals._translate import NumpyConvertible
from narwhals._translate import ToNarwhals
from narwhals._translate import ToNarwhalsT_co
from narwhals.utils import Version
from narwhals.utils import _StoresNative
from narwhals.utils import deprecated
from narwhals.utils import is_compliant_series
from narwhals.utils import is_index_selector
from narwhals.utils import is_range
from narwhals.utils import is_sequence_like
from narwhals.utils import is_sized_multi_index_selector
from narwhals.utils import is_slice_index
from narwhals.utils import is_slice_none
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self
from typing_extensions import TypeAlias
from narwhals._compliant.group_by import CompliantGroupBy
from narwhals._compliant.group_by import DataFrameGroupBy
from narwhals._compliant.namespace import EagerNamespace
from narwhals._translate import IntoArrowTable
from narwhals.dataframe import DataFrame
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import AsofJoinStrategy
from narwhals.typing import JoinStrategy
from narwhals.typing import LazyUniqueKeepStrategy
from narwhals.typing import MultiColSelector
from narwhals.typing import MultiIndexSelector
from narwhals.typing import PivotAgg
from narwhals.typing import SingleIndexSelector
from narwhals.typing import SizedMultiIndexSelector
from narwhals.typing import SizedMultiNameSelector
from narwhals.typing import SizeUnit
from narwhals.typing import UniqueKeepStrategy
from narwhals.typing import _2DArray
from narwhals.typing import _SliceIndex
from narwhals.typing import _SliceName
from narwhals.utils import Implementation
from narwhals.utils import _FullContext
Incomplete: TypeAlias = Any
__all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"]
T = TypeVar("T")
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Sized,
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
def __narwhals_dataframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: ...
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _FullContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _FullContext,
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
def __getitem__(
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
MultiColSelector[CompliantSeriesT],
],
) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
# NOTE: Ignore is to avoid an intermittent false positive
return self.select(*exprs) # pyright: ignore[reportArgumentType]
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
@property
def shape(self) -> tuple[int, int]: ...
def clone(self) -> Self: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def estimated_size(self, unit: SizeUnit) -> int | float: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def head(self, n: int) -> Self: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
def is_unique(self) -> CompliantSeriesT: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str | None,
right_on: str | None,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ...
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
def rows(
self, *, named: bool
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Table: ...
def to_pandas(self) -> pd.DataFrame: ...
def to_polars(self) -> pl.DataFrame: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
class CompliantLazyFrame(
_StoresNative[NativeFrameT],
FromNative[NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co],
):
_native_frame: NativeFrameT
_implementation: Implementation
_backend_version: tuple[int, ...]
_version: Version
def __narwhals_lazyframe__(self) -> Self: ...
def __narwhals_namespace__(self) -> Any: ...
@classmethod
def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
...
def _with_version(self, version: Version) -> Self: ...
@property
def native(self) -> NativeFrameT:
return self._native_frame
@property
def columns(self) -> Sequence[str]: ...
@property
def schema(self) -> Mapping[str, DType]: ...
def _iter_columns(self) -> Iterator[Any]: ...
def collect(
self, backend: Implementation | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
@deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
def gather_every(self, n: int, offset: int) -> Self: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
other: Self,
*,
how: Literal["left", "inner", "cross", "anti", "semi"],
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str | None,
right_on: str | None,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
@deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.")
def tail(self, n: int) -> Self: ...
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str) -> Self: ...
def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any:
result = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]:
it = (expr._evaluate_aliases(self) for expr in exprs)
return list(chain.from_iterable(it))
class EagerDataFrame(
CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"],
Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
def __narwhals_namespace__(
self,
) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ...
def to_narwhals(self) -> DataFrame[NativeFrameT]:
return self._version.dataframe(self, level="full")
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
result: Sequence[EagerSeriesT] = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
# NOTE: Ignore is to avoid an intermittent false positive
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
"""Return list of raw columns.
For eager backends we alias operations at each step.
As a safety precaution, here we can check that the expected result names match those
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
): # pragma: no cover
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
@staticmethod
def _numpy_column_names(
data: _2DArray, columns: Sequence[str] | None, /
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
MultiColSelector[EagerSeriesT],
],
) -> Self:
rows, columns = item
compliant = self
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return compliant.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
compliant = compliant._select_slice_index(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_index(columns.native)
else:
compliant = compliant._select_multi_index(columns)
elif isinstance(columns, slice):
compliant = compliant._select_slice_name(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_name(columns.native)
elif is_sequence_like(columns):
compliant = self._select_multi_name(columns)
else: # pragma: no cover
msg = f"Unreachable code, got unexpected type: {type(columns)}"
raise AssertionError(msg)
if not is_slice_none(rows):
if isinstance(rows, int):
compliant = compliant._gather([rows])
elif isinstance(rows, (slice, range)):
compliant = compliant._gather_slice(rows)
elif is_compliant_series(rows):
compliant = compliant._gather(rows.native)
elif is_sized_multi_index_selector(rows):
compliant = compliant._gather(rows)
else: # pragma: no cover
msg = f"Unreachable code, got unexpected type: {type(rows)}"
raise AssertionError(msg)
return compliant