from __future__ import annotations from functools import partial from typing import TYPE_CHECKING from typing import Any from typing import Collection from typing import Iterator from typing import Literal from typing import Mapping from typing import Sequence from typing import cast from typing import overload import pyarrow as pa import pyarrow.compute as pc from narwhals._arrow.series import ArrowSeries from narwhals._arrow.utils import align_series_full_broadcast from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._compliant import EagerDataFrame from narwhals._expression_parsing import ExprKind from narwhals.dependencies import is_numpy_array_1d from narwhals.exceptions import ShapeError from narwhals.utils import Implementation from narwhals.utils import Version from narwhals.utils import check_column_exists from narwhals.utils import check_column_names_are_unique from narwhals.utils import convert_str_slice_to_int_slice from narwhals.utils import generate_temporary_column_name from narwhals.utils import not_implemented from narwhals.utils import parse_columns_to_drop from narwhals.utils import parse_version from narwhals.utils import scale_bytes from narwhals.utils import supports_arrow_c_stream from narwhals.utils import validate_backend_version if TYPE_CHECKING: from io import BytesIO from pathlib import Path from types import ModuleType import pandas as pd import polars as pl from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import TypeIs from narwhals._arrow.expr import ArrowExpr from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.typing import ChunkedArrayAny from narwhals._arrow.typing import Mask # type: ignore[attr-defined] from narwhals._arrow.typing import Order # type: ignore[attr-defined] from narwhals._compliant.typing import CompliantDataFrameAny from narwhals._compliant.typing import CompliantLazyFrameAny from narwhals._translate import IntoArrowTable from narwhals.dtypes import DType from narwhals.schema import Schema from narwhals.typing import JoinStrategy from narwhals.typing import SizedMultiIndexSelector from narwhals.typing import SizedMultiNameSelector from narwhals.typing import SizeUnit from narwhals.typing import UniqueKeepStrategy from narwhals.typing import _1DArray from narwhals.typing import _2DArray from narwhals.typing import _SliceIndex from narwhals.typing import _SliceName from narwhals.utils import Version from narwhals.utils import _FullContext JoinType: TypeAlias = Literal[ "left semi", "right semi", "left anti", "right anti", "inner", "left outer", "right outer", "full outer", ] PromoteOptions: TypeAlias = Literal["none", "default", "permissive"] class ArrowDataFrame( EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"] ): def __init__( self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...], version: Version, validate_column_names: bool, ) -> None: if validate_column_names: check_column_names_are_unique(native_dataframe.column_names) self._native_frame = native_dataframe self._implementation = Implementation.PYARROW self._backend_version = backend_version self._version = version validate_backend_version(self._implementation, self._backend_version) @classmethod def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self: backend_version = context._backend_version if cls._is_native(data): native = data elif backend_version >= (14,) or isinstance(data, Collection): native = pa.table(data) elif supports_arrow_c_stream(data): # pragma: no cover msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}." raise ModuleNotFoundError(msg) else: # pragma: no cover msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}." raise TypeError(msg) return cls.from_native(native, context=context) @classmethod def from_dict( cls, data: Mapping[str, Any], /, *, context: _FullContext, schema: Mapping[str, DType] | Schema | None, ) -> Self: from narwhals.schema import Schema pa_schema = Schema(schema).to_arrow() if schema is not None else schema native = pa.Table.from_pydict(data, schema=pa_schema) return cls.from_native(native, context=context) @staticmethod def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]: return isinstance(obj, pa.Table) @classmethod def from_native(cls, data: pa.Table, /, *, context: _FullContext) -> Self: return cls( data, backend_version=context._backend_version, version=context._version, validate_column_names=True, ) @classmethod def from_numpy( cls, data: _2DArray, /, *, context: _FullContext, schema: Mapping[str, DType] | Schema | Sequence[str] | None, ) -> Self: from narwhals.schema import Schema arrays = [pa.array(val) for val in data.T] if isinstance(schema, (Mapping, Schema)): native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow()) else: native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema)) return cls.from_native(native, context=context) def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace return ArrowNamespace( backend_version=self._backend_version, version=self._version ) def __native_namespace__(self) -> ModuleType: if self._implementation is Implementation.PYARROW: return self._implementation.to_native_namespace() msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __narwhals_dataframe__(self) -> Self: return self def __narwhals_lazyframe__(self) -> Self: return self def _with_version(self, version: Version) -> Self: return self.__class__( self.native, backend_version=self._backend_version, version=version, validate_column_names=False, ) def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self: return self.__class__( df, backend_version=self._backend_version, version=self._version, validate_column_names=validate_column_names, ) @property def shape(self) -> tuple[int, int]: return self.native.shape def __len__(self) -> int: return len(self.native) def row(self, index: int) -> tuple[Any, ...]: return tuple(col[index] for col in self.native.itercolumns()) @overload def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ... @overload def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ... @overload def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ... def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: if not named: return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value] return self.native.to_pylist() def iter_columns(self) -> Iterator[ArrowSeries]: for name, series in zip(self.columns, self.native.itercolumns()): yield ArrowSeries.from_native(series, context=self, name=name) _iter_columns = iter_columns def iter_rows( self, *, named: bool, buffer_size: int ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]: df = self.native num_rows = df.num_rows if not named: for i in range(0, num_rows, buffer_size): rows = df[i : i + buffer_size].to_pydict().values() yield from zip(*rows) else: for i in range(0, num_rows, buffer_size): yield from df[i : i + buffer_size].to_pylist() def get_column(self, name: str) -> ArrowSeries: if not isinstance(name, str): msg = f"Expected str, got: {type(name)}" raise TypeError(msg) return ArrowSeries.from_native(self.native[name], context=self, name=name) def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: return self.native.__array__(dtype, copy=copy) def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self: if len(rows) == 0: return self._with_native(self.native.slice(0, 0)) if self._backend_version < (18,) and isinstance(rows, tuple): rows = list(rows) return self._with_native(self.native.take(rows)) def _gather_slice(self, rows: _SliceIndex | range) -> Self: start = rows.start or 0 stop = rows.stop if rows.stop is not None else len(self.native) if start < 0: start = len(self.native) + start if stop < 0: stop = len(self.native) + stop if rows.step is not None and rows.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) return self._with_native(self.native.slice(start, stop - start)) def _select_slice_name(self, columns: _SliceName) -> Self: start, stop, step = convert_str_slice_to_int_slice(columns, self.columns) return self._with_native(self.native.select(self.columns[start:stop:step])) def _select_slice_index(self, columns: _SliceIndex | range) -> Self: return self._with_native( self.native.select(self.columns[columns.start : columns.stop : columns.step]) ) def _select_multi_index( self, columns: SizedMultiIndexSelector[ChunkedArrayAny] ) -> Self: selector: Sequence[int] if isinstance(columns, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` selector = cast("Sequence[int]", columns.to_pylist()) # TODO @dangotbanned: Fix upstream, it is actually much narrower # **Doesn't accept `ndarray`** elif is_numpy_array_1d(columns): selector = columns.tolist() else: selector = columns return self._with_native(self.native.select(selector)) def _select_multi_name( self, columns: SizedMultiNameSelector[ChunkedArrayAny] ) -> Self: selector: Sequence[str] | _1DArray if isinstance(columns, pa.ChunkedArray): # TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:` selector = cast("Sequence[str]", columns.to_pylist()) else: selector = columns # NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221 return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType] @property def schema(self) -> dict[str, DType]: schema = self.native.schema return { name: native_to_narwhals_dtype(dtype, self._version) for name, dtype in zip(schema.names, schema.types) } def collect_schema(self) -> dict[str, DType]: return self.schema def estimated_size(self, unit: SizeUnit) -> int | float: sz = self.native.nbytes return scale_bytes(sz, unit) explode = not_implemented() @property def columns(self) -> list[str]: return self.native.column_names def simple_select(self, *column_names: str) -> Self: return self._with_native( self.native.select(list(column_names)), validate_column_names=False ) def select(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: new_series = self._evaluate_into_exprs(*exprs) if not new_series: # return empty dataframe, like Polars does return self._with_native( self.native.__class__.from_arrays([]), validate_column_names=False ) names = [s.name for s in new_series] reshaped = align_series_full_broadcast(*new_series) df = pa.Table.from_arrays([s.native for s in reshaped], names=names) return self._with_native(df, validate_column_names=True) def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny: length = len(self) if not other._broadcast: if (len_other := len(other)) != length: msg = f"Expected object of length {length}, got: {len_other}." raise ShapeError(msg) return other.native import numpy as np # ignore-banned-import value = other.native[0] if self._backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() return pa.chunked_array([np.full(shape=length, fill_value=value)]) def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame: # NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame) # All `pyarrow` data is immutable, so this is fine native_frame = self.native new_columns = self._evaluate_into_exprs(*exprs) columns = self.columns for col_value in new_columns: col_name = col_value.name column = self._extract_comparand(col_value) native_frame = ( native_frame.set_column(columns.index(col_name), col_name, column=column) if col_name in columns else native_frame.append_column(col_name, column=column) ) return self._with_native(native_frame, validate_column_names=False) def group_by( self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool ) -> ArrowGroupBy: from narwhals._arrow.group_by import ArrowGroupBy return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys) def join( self, other: Self, *, how: JoinStrategy, left_on: Sequence[str] | None, right_on: Sequence[str] | None, suffix: str, ) -> Self: how_to_join_map: dict[str, JoinType] = { "anti": "left anti", "semi": "left semi", "inner": "inner", "left": "left outer", "full": "full outer", } if how == "cross": plx = self.__narwhals_namespace__() key_token = generate_temporary_column_name( n_bytes=8, columns=[*self.columns, *other.columns] ) return self._with_native( self.with_columns( plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) ) .native.join( other.with_columns( plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL) ).native, keys=key_token, right_keys=key_token, join_type="inner", right_suffix=suffix, ) .drop([key_token]) ) coalesce_keys = how != "full" # polars full join does not coalesce keys return self._with_native( self.native.join( other.native, keys=left_on or [], # type: ignore[arg-type] right_keys=right_on, # type: ignore[arg-type] join_type=how_to_join_map[how], right_suffix=suffix, coalesce_keys=coalesce_keys, ), ) join_asof = not_implemented() def drop(self, columns: Sequence[str], *, strict: bool) -> Self: to_drop = parse_columns_to_drop( compliant_frame=self, columns=columns, strict=strict ) return self._with_native(self.native.drop(to_drop), validate_column_names=False) def drop_nulls(self: ArrowDataFrame, subset: Sequence[str] | None) -> ArrowDataFrame: if subset is None: return self._with_native(self.native.drop_null(), validate_column_names=False) plx = self.__narwhals_namespace__() return self.filter(~plx.any_horizontal(plx.col(*subset).is_null())) def sort( self, *by: str, descending: bool | Sequence[bool], nulls_last: bool, ) -> Self: if isinstance(descending, bool): order: Order = "descending" if descending else "ascending" sorting: list[tuple[str, Order]] = [(key, order) for key in by] else: sorting = [ (key, "descending" if is_descending else "ascending") for key, is_descending in zip(by, descending) ] null_placement = "at_end" if nulls_last else "at_start" return self._with_native( self.native.sort_by(sorting, null_placement=null_placement), validate_column_names=False, ) def to_pandas(self) -> pd.DataFrame: return self.native.to_pandas() def to_polars(self) -> pl.DataFrame: import polars as pl # ignore-banned-import return pl.from_arrow(self.native) # type: ignore[return-value] def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray: import numpy as np # ignore-banned-import arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns]) return arr @overload def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ... @overload def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ... def to_dict( self, *, as_series: bool ) -> dict[str, ArrowSeries] | dict[str, list[Any]]: it = self.iter_columns() if as_series: return {ser.name: ser for ser in it} return {ser.name: ser.to_list() for ser in it} def with_row_index(self, name: str) -> Self: df = self.native cols = self.columns row_indices = pa.array(range(df.num_rows)) return self._with_native( df.append_column(name, row_indices).select([name, *cols]) ) def filter( self: ArrowDataFrame, predicate: ArrowExpr | list[bool | None] ) -> ArrowDataFrame: if isinstance(predicate, list): mask_native: Mask | ChunkedArrayAny = predicate else: # `[0]` is safe as the predicate's expression only returns a single column mask_native = self._evaluate_into_exprs(predicate)[0].native return self._with_native( self.native.filter(mask_native), validate_column_names=False ) def head(self, n: int) -> Self: df = self.native if n >= 0: return self._with_native(df.slice(0, n), validate_column_names=False) else: num_rows = df.num_rows return self._with_native( df.slice(0, max(0, num_rows + n)), validate_column_names=False ) def tail(self, n: int) -> Self: df = self.native if n >= 0: num_rows = df.num_rows return self._with_native( df.slice(max(0, num_rows - n)), validate_column_names=False ) else: return self._with_native(df.slice(abs(n)), validate_column_names=False) def lazy(self, *, backend: Implementation | None = None) -> CompliantLazyFrameAny: if backend is None: return self elif backend is Implementation.DUCKDB: import duckdb # ignore-banned-import from narwhals._duckdb.dataframe import DuckDBLazyFrame df = self.native # noqa: F841 return DuckDBLazyFrame( duckdb.table("df"), backend_version=parse_version(duckdb), version=self._version, ) elif backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsLazyFrame return PolarsLazyFrame( cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(), backend_version=parse_version(pl), version=self._version, ) elif backend is Implementation.DASK: import dask # ignore-banned-import import dask.dataframe as dd # ignore-banned-import from narwhals._dask.dataframe import DaskLazyFrame return DaskLazyFrame( dd.from_pandas(self.native.to_pandas()), backend_version=parse_version(dask), version=self._version, ) raise AssertionError # pragma: no cover def collect( self, backend: Implementation | None, **kwargs: Any ) -> CompliantDataFrameAny: if backend is Implementation.PYARROW or backend is None: from narwhals._arrow.dataframe import ArrowDataFrame return ArrowDataFrame( self.native, backend_version=self._backend_version, version=self._version, validate_column_names=False, ) if backend is Implementation.PANDAS: import pandas as pd # ignore-banned-import from narwhals._pandas_like.dataframe import PandasLikeDataFrame return PandasLikeDataFrame( self.native.to_pandas(), implementation=Implementation.PANDAS, backend_version=parse_version(pd), version=self._version, validate_column_names=False, ) if backend is Implementation.POLARS: import polars as pl # ignore-banned-import from narwhals._polars.dataframe import PolarsDataFrame return PolarsDataFrame( cast("pl.DataFrame", pl.from_arrow(self.native)), backend_version=parse_version(pl), version=self._version, ) msg = f"Unsupported `backend` value: {backend}" # pragma: no cover raise AssertionError(msg) # pragma: no cover def clone(self) -> Self: return self._with_native(self.native, validate_column_names=False) def item(self, row: int | None, column: int | str | None) -> Any: from narwhals._arrow.series import maybe_extract_py_scalar if row is None and column is None: if self.shape != (1, 1): msg = ( "can only call `.item()` if the dataframe is of shape (1, 1)," " or if explicit row/col values are provided;" f" frame has shape {self.shape!r}" ) raise ValueError(msg) return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True) elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) _col = self.columns.index(column) if isinstance(column, str) else column return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True) def rename(self, mapping: Mapping[str, str]) -> Self: names: dict[str, str] | list[str] if self._backend_version >= (17,): names = cast("dict[str, str]", mapping) else: # pragma: no cover names = [mapping.get(c, c) for c in self.columns] return self._with_native(self.native.rename_columns(names)) def write_parquet(self, file: str | Path | BytesIO) -> None: import pyarrow.parquet as pp pp.write_table(self.native, file) @overload def write_csv(self, file: None) -> str: ... @overload def write_csv(self, file: str | Path | BytesIO) -> None: ... def write_csv(self, file: str | Path | BytesIO | None) -> str | None: import pyarrow.csv as pa_csv if file is None: csv_buffer = pa.BufferOutputStream() pa_csv.write_csv(self.native, csv_buffer) return csv_buffer.getvalue().to_pybytes().decode() pa_csv.write_csv(self.native, file) return None def is_unique(self) -> ArrowSeries: col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) row_index = pa.array(range(len(self))) keep_idx = ( self.native.append_column(col_token, row_index) .group_by(self.columns) .aggregate([(col_token, "min"), (col_token, "max")]) ) native = pa.chunked_array( pc.and_( pc.is_in(row_index, keep_idx[f"{col_token}_min"]), pc.is_in(row_index, keep_idx[f"{col_token}_max"]), ) ) return ArrowSeries.from_native(native, context=self) def unique( self: ArrowDataFrame, subset: Sequence[str] | None, *, keep: UniqueKeepStrategy, maintain_order: bool | None = None, ) -> ArrowDataFrame: # The param `maintain_order` is only here for compatibility with the Polars API # and has no effect on the output. import numpy as np # ignore-banned-import check_column_exists(self.columns, subset) subset = list(subset or self.columns) if keep in {"any", "first", "last"}: from narwhals._arrow.group_by import ArrowGroupBy agg_func = ArrowGroupBy._REMAP_UNIQUE[keep] col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) keep_idx_native = ( self.native.append_column(col_token, pa.array(np.arange(len(self)))) .group_by(subset) .aggregate([(col_token, agg_func)]) .column(f"{col_token}_{agg_func}") ) return self._with_native( self.native.take(keep_idx_native), validate_column_names=False ) keep_idx = self.simple_select(*subset).is_unique() plx = self.__narwhals_namespace__() return self.filter(plx._expr._from_series(keep_idx)) def gather_every(self, n: int, offset: int) -> Self: return self._with_native(self.native[offset::n], validate_column_names=False) def to_arrow(self) -> pa.Table: return self.native def sample( self, n: int | None, *, fraction: float | None, with_replacement: bool, seed: int | None, ) -> Self: import numpy as np # ignore-banned-import num_rows = len(self) if n is None and fraction is not None: n = int(num_rows * fraction) rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) mask = rng.choice(idx, size=n, replace=with_replacement) return self._with_native(self.native.take(mask), validate_column_names=False) def unpivot( self, on: Sequence[str] | None, index: Sequence[str] | None, variable_name: str, value_name: str, ) -> Self: n_rows = len(self) index_ = [] if index is None else index on_ = [c for c in self.columns if c not in index_] if on is None else on concat = ( partial(pa.concat_tables, promote_options="permissive") if self._backend_version >= (14, 0, 0) else pa.concat_tables ) names = [*index_, variable_name, value_name] return self._with_native( concat( [ pa.Table.from_arrays( [ *(self.native.column(idx_col) for idx_col in index_), cast( "ChunkedArrayAny", pa.array([on_col] * n_rows, pa.string()), ), self.native.column(on_col), ], names=names, ) for on_col in on_ ] ) ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes pivot = not_implemented()