from __future__ import annotations from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Iterator from typing import Literal from typing import Mapping from typing import Protocol from typing import Sequence from typing import Sized from typing import TypeVar from typing import overload from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._compliant.typing import CompliantExprT_contra from narwhals._compliant.typing import CompliantLazyFrameAny from narwhals._compliant.typing import CompliantSeriesT from narwhals._compliant.typing import EagerExprT from narwhals._compliant.typing import EagerSeriesT from narwhals._compliant.typing import NativeFrameT from narwhals._compliant.typing import NativeSeriesT from narwhals._translate import ArrowConvertible from narwhals._translate import DictConvertible from narwhals._translate import FromNative from narwhals._translate import NumpyConvertible from narwhals._translate import ToNarwhals from narwhals._translate import ToNarwhalsT_co from narwhals.utils import Version from narwhals.utils import _StoresNative from narwhals.utils import deprecated from narwhals.utils import is_compliant_series from narwhals.utils import is_index_selector from narwhals.utils import is_range from narwhals.utils import is_sequence_like from narwhals.utils import is_sized_multi_index_selector from narwhals.utils import is_slice_index from narwhals.utils import is_slice_none if TYPE_CHECKING: from io import BytesIO from pathlib import Path import pandas as pd import polars as pl import pyarrow as pa from typing_extensions import Self from typing_extensions import TypeAlias from narwhals._compliant.group_by import CompliantGroupBy from narwhals._compliant.group_by import DataFrameGroupBy from narwhals._compliant.namespace import EagerNamespace from narwhals._translate import IntoArrowTable from narwhals.dataframe import DataFrame from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import AsofJoinStrategy from narwhals.typing import JoinStrategy from narwhals.typing import LazyUniqueKeepStrategy from narwhals.typing import MultiColSelector from narwhals.typing import MultiIndexSelector from narwhals.typing import PivotAgg from narwhals.typing import SingleIndexSelector from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _2DArray from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName from narwhals.utils import Implementation from narwhals.utils import _FullContext Incomplete: TypeAlias = Any __all__ = ["CompliantDataFrame", "CompliantLazyFrame", "EagerDataFrame"] T = TypeVar("T") _ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047 class CompliantDataFrame( NumpyConvertible["_2DArray", "_2DArray"], DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]], ArrowConvertible["pa.Table", "IntoArrowTable"], _StoresNative[NativeFrameT], FromNative[NativeFrameT], ToNarwhals[ToNarwhalsT_co], Sized, Protocol[CompliantSeriesT, CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co], ): _native_frame: NativeFrameT _implementation: Implementation _backend_version: tuple[int, ...] _version: Version def __narwhals_dataframe__(self) -> Self: ... def __narwhals_namespace__(self) -> Any: ... @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: ... @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, context: _FullContext, schema: Mapping[str, DType] | Schema | None, ) -> Self: ... @classmethod def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ... @classmethod def from_numpy( cls, data: _2DArray, /, *, context: _FullContext, schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: ... def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ... def __getitem__( self, item: tuple[ SingleIndexSelector | MultiIndexSelector[CompliantSeriesT], MultiColSelector[CompliantSeriesT], ], ) -> Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... def aggregate(self, *exprs: CompliantExprT_contra) -> Self: """`select` where all args are aggregations or literals. (so, no broadcasting is necessary). """ # NOTE: Ignore is to avoid an intermittent false positive return self.select(*exprs) # pyright: ignore[reportArgumentType] def _with_version(self, version: Version) -> Self: ... @property def native(self) -> NativeFrameT: return self._native_frame @property def columns(self) -> Sequence[str]: ... @property def schema(self) -> Mapping[str, DType]: ... @property def shape(self) -> tuple[int, int]: ... def clone(self) -> Self: ... def collect( self, backend: Implementation | None, **kwargs: Any ) -> CompliantDataFrameAny: ... def collect_schema(self) -> Mapping[str, DType]: ... def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... def estimated_size(self, unit: SizeUnit) -> int | float: ... def explode(self, columns: Sequence[str]) -> Self: ... def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... def gather_every(self, n: int, offset: int) -> Self: ... def get_column(self, name: str) -> CompliantSeriesT: ... def group_by( self, keys: Sequence[str] | Sequence[CompliantExprT_contra], *, drop_null_keys: bool, ) -> DataFrameGroupBy[Self, Any]: ... def head(self, n: int) -> Self: ... def item(self, row: int | None, column: int | str | None) -> Any: ... def iter_columns(self) -> Iterator[CompliantSeriesT]: ... def iter_rows( self, *, named: bool, buffer_size: int ) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ... def is_unique(self) -> CompliantSeriesT: ... def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: ... def join_asof( self, other: Self, *, left_on: str | None, right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, strategy: AsofJoinStrategy, suffix: str, ) -> Self: ... def lazy(self, *, backend: Implementation | None) -> CompliantLazyFrameAny: ... def pivot( self, on: Sequence[str], *, index: Sequence[str] | None, values: Sequence[str] | None, aggregate_function: PivotAgg | None, sort_columns: bool, separator: str, ) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... def row(self, index: int) -> tuple[Any, ...]: ... def rows( self, *, named: bool ) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ... def sample( self, n: int | None, *, fraction: float | None, with_replacement: bool, seed: int | None, ) -> Self: ... def select(self, *exprs: CompliantExprT_contra) -> Self: ... def sort( self, *by: str, descending: bool | Sequence[bool], nulls_last: bool ) -> Self: ... def tail(self, n: int) -> Self: ... def to_arrow(self) -> pa.Table: ... def to_pandas(self) -> pd.DataFrame: ... def to_polars(self) -> pl.DataFrame: ... @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ... @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ... def unique( self, subset: Sequence[str] | None, *, keep: UniqueKeepStrategy, maintain_order: bool | None = None, ) -> Self: ... def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: ... def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ... def with_row_index(self, name: str) -> Self: ... @overload def write_csv(self, file: None) -> str: ... @overload def write_csv(self, file: str | Path | BytesIO) -> None: ... def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ... def write_parquet(self, file: str | Path | BytesIO) -> None: ... def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: it = (expr._evaluate_aliases(self) for expr in exprs) return list(chain.from_iterable(it)) class CompliantLazyFrame( _StoresNative[NativeFrameT], FromNative[NativeFrameT], ToNarwhals[ToNarwhalsT_co], Protocol[CompliantExprT_contra, NativeFrameT, ToNarwhalsT_co], ): _native_frame: NativeFrameT _implementation: Implementation _backend_version: tuple[int, ...] _version: Version def __narwhals_lazyframe__(self) -> Self: ... def __narwhals_namespace__(self) -> Any: ... @classmethod def from_native(cls, data: NativeFrameT, /, *, context: _FullContext) -> Self: ... def simple_select(self, *column_names: str) -> Self: """`select` where all args are column names.""" ... def aggregate(self, *exprs: CompliantExprT_contra) -> Self: """`select` where all args are aggregations or literals. (so, no broadcasting is necessary). """ ... def _with_version(self, version: Version) -> Self: ... @property def native(self) -> NativeFrameT: return self._native_frame @property def columns(self) -> Sequence[str]: ... @property def schema(self) -> Mapping[str, DType]: ... def _iter_columns(self) -> Iterator[Any]: ... def collect( self, backend: Implementation | None, **kwargs: Any ) -> CompliantDataFrameAny: ... def collect_schema(self) -> Mapping[str, DType]: ... def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ... def drop_nulls(self, subset: Sequence[str] | None) -> Self: ... def explode(self, columns: Sequence[str]) -> Self: ... def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ... @deprecated( "`LazyFrame.gather_every` is deprecated and will be removed in a future version." ) def gather_every(self, n: int, offset: int) -> Self: ... def group_by( self, keys: Sequence[str] | Sequence[CompliantExprT_contra], *, drop_null_keys: bool, ) -> CompliantGroupBy[Self, CompliantExprT_contra]: ... def head(self, n: int) -> Self: ... def join( self, other: Self, *, how: Literal["left", "inner", "cross", "anti", "semi"], left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: ... def join_asof( self, other: Self, *, left_on: str | None, right_on: str | None, by_left: Sequence[str] | None, by_right: Sequence[str] | None, strategy: AsofJoinStrategy, suffix: str, ) -> Self: ... def rename(self, mapping: Mapping[str, str]) -> Self: ... def select(self, *exprs: CompliantExprT_contra) -> Self: ... def sort( self, *by: str, descending: bool | Sequence[bool], nulls_last: bool ) -> Self: ... @deprecated("`LazyFrame.tail` is deprecated and will be removed in a future version.") def tail(self, n: int) -> Self: ... def unique( self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy ) -> Self: ... def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: ... def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ... def with_row_index(self, name: str) -> Self: ... def _evaluate_expr(self, expr: CompliantExprT_contra, /) -> Any: result = expr(self) assert len(result) == 1 # debug assertion # noqa: S101 return result[0] def _evaluate_aliases(self, *exprs: CompliantExprT_contra) -> list[str]: it = (expr._evaluate_aliases(self) for expr in exprs) return list(chain.from_iterable(it)) class EagerDataFrame( CompliantDataFrame[EagerSeriesT, EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], CompliantLazyFrame[EagerExprT, NativeFrameT, "DataFrame[NativeFrameT]"], Protocol[EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT], ): def __narwhals_namespace__( self, ) -> EagerNamespace[Self, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT]: ... def to_narwhals(self) -> DataFrame[NativeFrameT]: return self._version.dataframe(self, level="full") def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT: """Evaluate `expr` and ensure it has a **single** output.""" result: Sequence[EagerSeriesT] = expr(self) assert len(result) == 1 # debug assertion # noqa: S101 return result[0] def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]: # NOTE: Ignore is to avoid an intermittent false positive return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType] def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]: """Return list of raw columns. For eager backends we alias operations at each step. As a safety precaution, here we can check that the expected result names match those we were expecting from the various `evaluate_output_names` / `alias_output_names` calls. Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want. """ aliases = expr._evaluate_aliases(self) result = expr(self) if list(aliases) != ( result_aliases := [s.name for s in result] ): # pragma: no cover msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}" raise AssertionError(msg) return result def _extract_comparand(self, other: EagerSeriesT, /) -> Any: """Extract native Series, broadcasting to `len(self)` if necessary.""" ... @staticmethod def _numpy_column_names( data: _2DArray, columns: Sequence[str] | None, / ) -> list[str]: return list(columns or (f"column_{x}" for x in range(data.shape[1]))) def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ... def _gather_slice(self, rows: _SliceIndex | range) -> Self: ... def _select_multi_index( self, columns: SizedMultiIndexSelector[NativeSeriesT] ) -> Self: ... def _select_multi_name( self, columns: SizedMultiNameSelector[NativeSeriesT] ) -> Self: ... def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ... def _select_slice_name(self, columns: _SliceName) -> Self: ... def __getitem__( # noqa: C901, PLR0912 self, item: tuple[ SingleIndexSelector | MultiIndexSelector[EagerSeriesT], MultiColSelector[EagerSeriesT], ], ) -> Self: rows, columns = item compliant = self if not is_slice_none(columns): if isinstance(columns, Sized) and len(columns) == 0: return compliant.select() if is_index_selector(columns): if is_slice_index(columns) or is_range(columns): compliant = compliant._select_slice_index(columns) elif is_compliant_series(columns): compliant = self._select_multi_index(columns.native) else: compliant = compliant._select_multi_index(columns) elif isinstance(columns, slice): compliant = compliant._select_slice_name(columns) elif is_compliant_series(columns): compliant = self._select_multi_name(columns.native) elif is_sequence_like(columns): compliant = self._select_multi_name(columns) else: # pragma: no cover msg = f"Unreachable code, got unexpected type: {type(columns)}" raise AssertionError(msg) if not is_slice_none(rows): if isinstance(rows, int): compliant = compliant._gather([rows]) elif isinstance(rows, (slice, range)): compliant = compliant._gather_slice(rows) elif is_compliant_series(rows): compliant = compliant._gather(rows.native) elif is_sized_multi_index_selector(rows): compliant = compliant._gather(rows) else: # pragma: no cover msg = f"Unreachable code, got unexpected type: {type(rows)}" raise AssertionError(msg) return compliant