from __future__ import annotations from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Iterator from typing import Literal from typing import Sequence from typing import overload from narwhals._arrow.utils import broadcast_series from narwhals._arrow.utils import convert_str_slice_to_int_slice from narwhals._arrow.utils import native_to_narwhals_dtype from narwhals._arrow.utils import select_rows from narwhals._arrow.utils import validate_dataframe_comparand from narwhals._expression_parsing import evaluate_into_exprs from narwhals.dependencies import is_numpy_array from narwhals.utils import Implementation from narwhals.utils import flatten from narwhals.utils import generate_temporary_column_name from narwhals.utils import is_sequence_but_not_str from narwhals.utils import parse_columns_to_drop if TYPE_CHECKING: from types import ModuleType import numpy as np import pyarrow as pa from typing_extensions import Self from narwhals._arrow.group_by import ArrowGroupBy from narwhals._arrow.namespace import ArrowNamespace from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import IntoArrowExpr from narwhals.dtypes import DType from narwhals.typing import DTypes class ArrowDataFrame: # --- not in the spec --- def __init__( self, native_dataframe: pa.Table, *, backend_version: tuple[int, ...], dtypes: DTypes, ) -> None: self._native_frame = native_dataframe self._implementation = Implementation.PYARROW self._backend_version = backend_version self._dtypes = dtypes def __narwhals_namespace__(self) -> ArrowNamespace: from narwhals._arrow.namespace import ArrowNamespace return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes) def __native_namespace__(self: Self) -> ModuleType: if self._implementation is Implementation.PYARROW: return self._implementation.to_native_namespace() msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover raise AssertionError(msg) def __narwhals_dataframe__(self) -> Self: return self def __narwhals_lazyframe__(self) -> Self: return self def _from_native_frame(self, df: Any) -> Self: return self.__class__( df, backend_version=self._backend_version, dtypes=self._dtypes ) @property def shape(self) -> tuple[int, int]: return self._native_frame.shape # type: ignore[no-any-return] def __len__(self) -> int: return len(self._native_frame) def row(self, index: int) -> tuple[Any, ...]: return tuple(col[index] for col in self._native_frame) @overload def rows( self, *, named: Literal[True], ) -> list[dict[str, Any]]: ... @overload def rows( self, *, named: Literal[False] = False, ) -> list[tuple[Any, ...]]: ... @overload def rows( self, *, named: bool, ) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ... def rows( self, *, named: bool = False ) -> list[tuple[Any, ...]] | list[dict[str, Any]]: if not named: return list(self.iter_rows(named=False)) # type: ignore[return-value] return self._native_frame.to_pylist() # type: ignore[no-any-return] def iter_rows( self, *, named: bool = False, buffer_size: int = 512, ) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]: df = self._native_frame num_rows = df.num_rows if not named: for i in range(0, num_rows, buffer_size): rows = df[i : i + buffer_size].to_pydict().values() yield from zip(*rows) else: for i in range(0, num_rows, buffer_size): yield from df[i : i + buffer_size].to_pylist() def get_column(self, name: str) -> ArrowSeries: from narwhals._arrow.series import ArrowSeries if not isinstance(name, str): msg = f"Expected str, got: {type(name)}" raise TypeError(msg) return ArrowSeries( self._native_frame[name], name=name, backend_version=self._backend_version, dtypes=self._dtypes, ) def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray: return self._native_frame.__array__(dtype, copy=copy) @overload def __getitem__(self, item: tuple[Sequence[int], str | int]) -> ArrowSeries: ... # type: ignore[overload-overlap] @overload def __getitem__(self, item: Sequence[int]) -> ArrowDataFrame: ... @overload def __getitem__(self, item: str) -> ArrowSeries: ... @overload def __getitem__(self, item: slice) -> ArrowDataFrame: ... @overload def __getitem__(self, item: tuple[slice, slice]) -> ArrowDataFrame: ... def __getitem__( self, item: ( str | slice | Sequence[int] | Sequence[str] | tuple[Sequence[int], str | int] | tuple[slice, str | int] | tuple[slice, slice] ), ) -> ArrowSeries | ArrowDataFrame: if isinstance(item, tuple): item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # type: ignore[assignment] if isinstance(item, str): from narwhals._arrow.series import ArrowSeries return ArrowSeries( self._native_frame[item], name=item, backend_version=self._backend_version, dtypes=self._dtypes, ) elif ( isinstance(item, tuple) and len(item) == 2 and is_sequence_but_not_str(item[1]) ): if len(item[1]) == 0: # Return empty dataframe return self._from_native_frame(self._native_frame.slice(0, 0).select([])) selected_rows = select_rows(self._native_frame, item[0]) return self._from_native_frame(selected_rows.select(item[1])) elif isinstance(item, tuple) and len(item) == 2: if isinstance(item[1], slice): columns = self.columns if item[1] == slice(None): if isinstance(item[0], Sequence) and len(item[0]) == 0: return self._from_native_frame(self._native_frame.slice(0, 0)) return self._from_native_frame(self._native_frame.take(item[0])) if isinstance(item[1].start, str) or isinstance(item[1].stop, str): start, stop, step = convert_str_slice_to_int_slice(item[1], columns) return self._from_native_frame( self._native_frame.take(item[0]).select(columns[start:stop:step]) ) if isinstance(item[1].start, int) or isinstance(item[1].stop, int): return self._from_native_frame( self._native_frame.take(item[0]).select( columns[item[1].start : item[1].stop : item[1].step] ) ) msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover raise TypeError(msg) # pragma: no cover from narwhals._arrow.series import ArrowSeries # PyArrow columns are always strings col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]] if isinstance(item[0], str): # pragma: no cover msg = "Can not slice with tuple with the first element as a str" raise TypeError(msg) if (isinstance(item[0], slice)) and (item[0] == slice(None)): return ArrowSeries( self._native_frame[col_name], name=col_name, backend_version=self._backend_version, dtypes=self._dtypes, ) selected_rows = select_rows(self._native_frame, item[0]) return ArrowSeries( selected_rows[col_name], name=col_name, backend_version=self._backend_version, dtypes=self._dtypes, ) elif isinstance(item, slice): if item.step is not None and item.step != 1: msg = "Slicing with step is not supported on PyArrow tables" raise NotImplementedError(msg) columns = self.columns if isinstance(item.start, str) or isinstance(item.stop, str): start, stop, step = convert_str_slice_to_int_slice(item, columns) return self._from_native_frame( self._native_frame.select(columns[start:stop:step]) ) start = item.start or 0 stop = item.stop if item.stop is not None else len(self._native_frame) return self._from_native_frame( self._native_frame.slice(start, stop - start), ) elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1): if ( isinstance(item, Sequence) and all(isinstance(x, str) for x in item) and len(item) > 0 ): return self._from_native_frame(self._native_frame.select(item)) if isinstance(item, Sequence) and len(item) == 0: return self._from_native_frame(self._native_frame.slice(0, 0)) return self._from_native_frame(self._native_frame.take(item)) else: # pragma: no cover msg = f"Expected str or slice, got: {type(item)}" raise TypeError(msg) @property def schema(self) -> dict[str, DType]: schema = self._native_frame.schema return { name: native_to_narwhals_dtype(dtype, self._dtypes) for name, dtype in zip(schema.names, schema.types) } def collect_schema(self) -> dict[str, DType]: return self.schema @property def columns(self) -> list[str]: return self._native_frame.schema.names # type: ignore[no-any-return] def select( self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr, ) -> Self: import pyarrow as pa # ignore-banned-import() new_series = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame(self._native_frame.__class__.from_arrays([])) names = [s.name for s in new_series] df = pa.Table.from_arrays( broadcast_series(new_series), names=names, ) return self._from_native_frame(df) def with_columns( self, *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr, ) -> Self: new_columns = evaluate_into_exprs(self, *exprs, **named_exprs) new_column_name_to_new_column_map = {s.name: s for s in new_columns} to_concat = [] output_names = [] # Make sure to preserve column order length = len(self) for name in self.columns: if name in new_column_name_to_new_column_map: to_concat.append( validate_dataframe_comparand( length=length, other=new_column_name_to_new_column_map.pop(name), backend_version=self._backend_version, ) ) else: to_concat.append(self._native_frame[name]) output_names.append(name) for s in new_column_name_to_new_column_map: to_concat.append( validate_dataframe_comparand( length=length, other=new_column_name_to_new_column_map[s], backend_version=self._backend_version, ) ) output_names.append(s) df = self._native_frame.__class__.from_arrays(to_concat, names=output_names) return self._from_native_frame(df) def group_by(self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy: from narwhals._arrow.group_by import ArrowGroupBy return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys) def join( self, other: Self, *, how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner", left_on: str | list[str] | None, right_on: str | list[str] | None, suffix: str, ) -> Self: how_to_join_map = { "anti": "left anti", "semi": "left semi", "inner": "inner", "left": "left outer", } if how == "cross": plx = self.__narwhals_namespace__() key_token = generate_temporary_column_name( n_bytes=8, columns=[*self.columns, *other.columns] ) return self._from_native_frame( self.with_columns(**{key_token: plx.lit(0, None)}) ._native_frame.join( other.with_columns(**{key_token: plx.lit(0, None)})._native_frame, keys=key_token, right_keys=key_token, join_type="inner", right_suffix=suffix, ) .drop([key_token]), ) return self._from_native_frame( self._native_frame.join( other._native_frame, keys=left_on, right_keys=right_on, join_type=how_to_join_map[how], right_suffix=suffix, ), ) def join_asof( self, other: Self, *, left_on: str | None = None, right_on: str | None = None, on: str | None = None, by_left: str | list[str] | None = None, by_right: str | list[str] | None = None, by: str | list[str] | None = None, strategy: Literal["backward", "forward", "nearest"] = "backward", ) -> Self: msg = "join_asof is not yet supported on PyArrow tables" raise NotImplementedError(msg) def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001 to_drop = parse_columns_to_drop( compliant_frame=self, columns=columns, strict=strict ) return self._from_native_frame(self._native_frame.drop(to_drop)) def drop_nulls(self: Self, subset: str | list[str] | None) -> Self: if subset is None: return self._from_native_frame(self._native_frame.drop_null()) subset = [subset] if isinstance(subset, str) else subset plx = self.__narwhals_namespace__() return self.filter(~plx.any_horizontal(plx.col(*subset).is_null())) def sort( self, by: str | Iterable[str], *more_by: str, descending: bool | Sequence[bool], nulls_last: bool, ) -> Self: flat_keys = flatten([*flatten([by]), *more_by]) df = self._native_frame if isinstance(descending, bool): order = "descending" if descending else "ascending" sorting = [(key, order) for key in flat_keys] else: sorting = [ (key, "descending" if is_descending else "ascending") for key, is_descending in zip(flat_keys, descending) ] null_placement = "at_end" if nulls_last else "at_start" return self._from_native_frame(df.sort_by(sorting, null_placement=null_placement)) def to_pandas(self) -> Any: return self._native_frame.to_pandas() def to_numpy(self) -> Any: import numpy as np # ignore-banned-import return np.column_stack([col.to_numpy() for col in self._native_frame.columns]) def to_dict(self, *, as_series: bool) -> Any: df = self._native_frame names_and_values = zip(df.column_names, df.columns) if as_series: from narwhals._arrow.series import ArrowSeries return { name: ArrowSeries( col, name=name, backend_version=self._backend_version, dtypes=self._dtypes, ) for name, col in names_and_values } else: return {name: col.to_pylist() for name, col in names_and_values} def with_row_index(self, name: str) -> Self: import pyarrow as pa # ignore-banned-import() df = self._native_frame row_indices = pa.array(range(df.num_rows)) return self._from_native_frame(df.append_column(name, row_indices)) def filter( self, *predicates: IntoArrowExpr, ) -> Self: if ( len(predicates) == 1 and isinstance(predicates[0], list) and all(isinstance(x, bool) for x in predicates[0]) ): mask = predicates[0] else: plx = self.__narwhals_namespace__() expr = plx.all_horizontal(*predicates) # Safety: all_horizontal's expression only returns a single column. mask = expr._call(self)[0]._native_series return self._from_native_frame(self._native_frame.filter(mask)) def null_count(self) -> Self: import pyarrow as pa # ignore-banned-import() df = self._native_frame names_and_values = zip(df.column_names, df.columns) return self._from_native_frame( pa.table({name: [col.null_count] for name, col in names_and_values}) ) def head(self, n: int) -> Self: df = self._native_frame if n >= 0: return self._from_native_frame(df.slice(0, n)) else: num_rows = df.num_rows return self._from_native_frame(df.slice(0, max(0, num_rows + n))) def tail(self, n: int) -> Self: df = self._native_frame if n >= 0: num_rows = df.num_rows return self._from_native_frame(df.slice(max(0, num_rows - n))) else: return self._from_native_frame(df.slice(abs(n))) def lazy(self) -> Self: return self def collect(self) -> ArrowDataFrame: return ArrowDataFrame( self._native_frame, backend_version=self._backend_version, dtypes=self._dtypes, ) def clone(self) -> Self: msg = "clone is not yet supported on PyArrow tables" raise NotImplementedError(msg) def is_empty(self: Self) -> bool: return self.shape[0] == 0 def item(self: Self, row: int | None = None, column: int | str | None = None) -> Any: if row is None and column is None: if self.shape != (1, 1): msg = ( "can only call `.item()` if the dataframe is of shape (1, 1)," " or if explicit row/col values are provided;" f" frame has shape {self.shape!r}" ) raise ValueError(msg) return self._native_frame[0][0] elif row is None or column is None: msg = "cannot call `.item()` with only one of `row` or `column`" raise ValueError(msg) _col = self.columns.index(column) if isinstance(column, str) else column return self._native_frame[_col][row] def rename(self, mapping: dict[str, str]) -> Self: df = self._native_frame new_cols = [mapping.get(c, c) for c in df.column_names] return self._from_native_frame(df.rename_columns(new_cols)) def write_parquet(self, file: Any) -> Any: import pyarrow.parquet as pp # ignore-banned-import pp.write_table(self._native_frame, file) def write_csv(self, file: Any) -> Any: import pyarrow as pa # ignore-banned-import import pyarrow.csv as pa_csv # ignore-banned-import pa_table = self._native_frame if file is None: csv_buffer = pa.BufferOutputStream() pa_csv.write_csv(pa_table, csv_buffer) return csv_buffer.getvalue().to_pybytes().decode() return pa_csv.write_csv(pa_table, file) def is_duplicated(self: Self) -> ArrowSeries: import numpy as np # ignore-banned-import import pyarrow as pa # ignore-banned-import() import pyarrow.compute as pc # ignore-banned-import() from narwhals._arrow.series import ArrowSeries df = self._native_frame columns = self.columns col_token = generate_temporary_column_name(n_bytes=8, columns=columns) row_count = ( df.append_column(col_token, pa.array(np.arange(len(self)))) .group_by(columns) .aggregate([(col_token, "count")]) ) is_duplicated = pc.greater( df.join( row_count, keys=columns, right_keys=columns, join_type="inner" ).column(f"{col_token}_count"), 1, ) return ArrowSeries( is_duplicated, name="", backend_version=self._backend_version, dtypes=self._dtypes, ) def is_unique(self: Self) -> ArrowSeries: import pyarrow.compute as pc # ignore-banned-import() from narwhals._arrow.series import ArrowSeries is_duplicated = self.is_duplicated()._native_series return ArrowSeries( pc.invert(is_duplicated), name="", backend_version=self._backend_version, dtypes=self._dtypes, ) def unique( self: Self, subset: str | list[str] | None, *, keep: Literal["any", "first", "last", "none"] = "any", maintain_order: bool = False, ) -> Self: """ NOTE: The param `maintain_order` is only here for compatibility with the polars API and has no effect on the output. """ import numpy as np # ignore-banned-import import pyarrow as pa # ignore-banned-import() import pyarrow.compute as pc # ignore-banned-import() df = self._native_frame if isinstance(subset, str): subset = [subset] subset = subset or self.columns if keep in {"any", "first", "last"}: agg_func_map = {"any": "min", "first": "min", "last": "max"} agg_func = agg_func_map[keep] col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns) keep_idx = ( df.append_column(col_token, pa.array(np.arange(len(self)))) .group_by(subset) .aggregate([(col_token, agg_func)]) .column(f"{col_token}_{agg_func}") ) return self._from_native_frame(pc.take(df, keep_idx)) keep_idx = self.select(*subset).is_unique() return self.filter(keep_idx) def gather_every(self: Self, n: int, offset: int = 0) -> Self: return self._from_native_frame(self._native_frame[offset::n]) def to_arrow(self: Self) -> Any: return self._native_frame def sample( self: Self, n: int | None = None, *, fraction: float | None = None, with_replacement: bool = False, seed: int | None = None, ) -> Self: import numpy as np # ignore-banned-import import pyarrow.compute as pc # ignore-banned-import() frame = self._native_frame num_rows = len(self) if n is None and fraction is not None: n = int(num_rows * fraction) rng = np.random.default_rng(seed=seed) idx = np.arange(0, num_rows) mask = rng.choice(idx, size=n, replace=with_replacement) return self._from_native_frame(pc.take(frame, mask)) def unpivot( self: Self, on: str | list[str] | None, index: str | list[str] | None, variable_name: str | None, value_name: str | None, ) -> Self: import pyarrow as pa # ignore-banned-import native_frame = self._native_frame variable_name = variable_name if variable_name is not None else "variable" value_name = value_name if value_name is not None else "value" index_: list[str] = ( [] if index is None else [index] if isinstance(index, str) else index ) on_: list[str] = ( [c for c in self.columns if c not in index_] if on is None else [on] if isinstance(on, str) else on ) n_rows = len(self) promote_kwargs = ( {"promote_options": "permissive"} if self._backend_version >= (14, 0, 0) else {} ) return self._from_native_frame( pa.concat_tables( [ pa.Table.from_arrays( [ *[native_frame.column(idx_col) for idx_col in index_], pa.array([on_col] * n_rows, pa.string()), native_frame.column(on_col), ], names=[*index_, variable_name, value_name], ) for on_col in on_ ], **promote_kwargs, ) ) # TODO(Unassigned): Even with promote_options="permissive", pyarrow does not # upcast numeric to non-numeric (e.g. string) datatypes