731 lines
25 KiB
Python
Executable File
731 lines
25 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Iterable
|
|
from typing import Iterator
|
|
from typing import Literal
|
|
from typing import Sequence
|
|
from typing import overload
|
|
|
|
from narwhals._arrow.utils import broadcast_series
|
|
from narwhals._arrow.utils import convert_str_slice_to_int_slice
|
|
from narwhals._arrow.utils import native_to_narwhals_dtype
|
|
from narwhals._arrow.utils import select_rows
|
|
from narwhals._arrow.utils import validate_dataframe_comparand
|
|
from narwhals._expression_parsing import evaluate_into_exprs
|
|
from narwhals.dependencies import is_numpy_array
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import flatten
|
|
from narwhals.utils import generate_temporary_column_name
|
|
from narwhals.utils import is_sequence_but_not_str
|
|
from narwhals.utils import parse_columns_to_drop
|
|
|
|
if TYPE_CHECKING:
|
|
from types import ModuleType
|
|
|
|
import numpy as np
|
|
import pyarrow as pa
|
|
from typing_extensions import Self
|
|
|
|
from narwhals._arrow.group_by import ArrowGroupBy
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
from narwhals._arrow.series import ArrowSeries
|
|
from narwhals._arrow.typing import IntoArrowExpr
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import DTypes
|
|
|
|
|
|
class ArrowDataFrame:
|
|
# --- not in the spec ---
|
|
def __init__(
|
|
self,
|
|
native_dataframe: pa.Table,
|
|
*,
|
|
backend_version: tuple[int, ...],
|
|
dtypes: DTypes,
|
|
) -> None:
|
|
self._native_frame = native_dataframe
|
|
self._implementation = Implementation.PYARROW
|
|
self._backend_version = backend_version
|
|
self._dtypes = dtypes
|
|
|
|
def __narwhals_namespace__(self) -> ArrowNamespace:
|
|
from narwhals._arrow.namespace import ArrowNamespace
|
|
|
|
return ArrowNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
|
|
def __native_namespace__(self: Self) -> ModuleType:
|
|
if self._implementation is Implementation.PYARROW:
|
|
return self._implementation.to_native_namespace()
|
|
|
|
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
|
|
raise AssertionError(msg)
|
|
|
|
def __narwhals_dataframe__(self) -> Self:
|
|
return self
|
|
|
|
def __narwhals_lazyframe__(self) -> Self:
|
|
return self
|
|
|
|
def _from_native_frame(self, df: Any) -> Self:
|
|
return self.__class__(
|
|
df, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
@property
|
|
def shape(self) -> tuple[int, int]:
|
|
return self._native_frame.shape # type: ignore[no-any-return]
|
|
|
|
def __len__(self) -> int:
|
|
return len(self._native_frame)
|
|
|
|
def row(self, index: int) -> tuple[Any, ...]:
|
|
return tuple(col[index] for col in self._native_frame)
|
|
|
|
@overload
|
|
def rows(
|
|
self,
|
|
*,
|
|
named: Literal[True],
|
|
) -> list[dict[str, Any]]: ...
|
|
|
|
@overload
|
|
def rows(
|
|
self,
|
|
*,
|
|
named: Literal[False] = False,
|
|
) -> list[tuple[Any, ...]]: ...
|
|
@overload
|
|
def rows(
|
|
self,
|
|
*,
|
|
named: bool,
|
|
) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
|
|
|
|
def rows(
|
|
self, *, named: bool = False
|
|
) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
|
|
if not named:
|
|
return list(self.iter_rows(named=False)) # type: ignore[return-value]
|
|
return self._native_frame.to_pylist() # type: ignore[no-any-return]
|
|
|
|
def iter_rows(
|
|
self,
|
|
*,
|
|
named: bool = False,
|
|
buffer_size: int = 512,
|
|
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
|
|
df = self._native_frame
|
|
num_rows = df.num_rows
|
|
|
|
if not named:
|
|
for i in range(0, num_rows, buffer_size):
|
|
rows = df[i : i + buffer_size].to_pydict().values()
|
|
yield from zip(*rows)
|
|
else:
|
|
for i in range(0, num_rows, buffer_size):
|
|
yield from df[i : i + buffer_size].to_pylist()
|
|
|
|
def get_column(self, name: str) -> ArrowSeries:
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
if not isinstance(name, str):
|
|
msg = f"Expected str, got: {type(name)}"
|
|
raise TypeError(msg)
|
|
|
|
return ArrowSeries(
|
|
self._native_frame[name],
|
|
name=name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
|
|
return self._native_frame.__array__(dtype, copy=copy)
|
|
|
|
@overload
|
|
def __getitem__(self, item: tuple[Sequence[int], str | int]) -> ArrowSeries: ... # type: ignore[overload-overlap]
|
|
|
|
@overload
|
|
def __getitem__(self, item: Sequence[int]) -> ArrowDataFrame: ...
|
|
|
|
@overload
|
|
def __getitem__(self, item: str) -> ArrowSeries: ...
|
|
|
|
@overload
|
|
def __getitem__(self, item: slice) -> ArrowDataFrame: ...
|
|
|
|
@overload
|
|
def __getitem__(self, item: tuple[slice, slice]) -> ArrowDataFrame: ...
|
|
|
|
def __getitem__(
|
|
self,
|
|
item: (
|
|
str
|
|
| slice
|
|
| Sequence[int]
|
|
| Sequence[str]
|
|
| tuple[Sequence[int], str | int]
|
|
| tuple[slice, str | int]
|
|
| tuple[slice, slice]
|
|
),
|
|
) -> ArrowSeries | ArrowDataFrame:
|
|
if isinstance(item, tuple):
|
|
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item) # type: ignore[assignment]
|
|
|
|
if isinstance(item, str):
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
return ArrowSeries(
|
|
self._native_frame[item],
|
|
name=item,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
elif (
|
|
isinstance(item, tuple)
|
|
and len(item) == 2
|
|
and is_sequence_but_not_str(item[1])
|
|
):
|
|
if len(item[1]) == 0:
|
|
# Return empty dataframe
|
|
return self._from_native_frame(self._native_frame.slice(0, 0).select([]))
|
|
selected_rows = select_rows(self._native_frame, item[0])
|
|
return self._from_native_frame(selected_rows.select(item[1]))
|
|
|
|
elif isinstance(item, tuple) and len(item) == 2:
|
|
if isinstance(item[1], slice):
|
|
columns = self.columns
|
|
if item[1] == slice(None):
|
|
if isinstance(item[0], Sequence) and len(item[0]) == 0:
|
|
return self._from_native_frame(self._native_frame.slice(0, 0))
|
|
return self._from_native_frame(self._native_frame.take(item[0]))
|
|
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
|
|
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
|
|
return self._from_native_frame(
|
|
self._native_frame.take(item[0]).select(columns[start:stop:step])
|
|
)
|
|
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
|
|
return self._from_native_frame(
|
|
self._native_frame.take(item[0]).select(
|
|
columns[item[1].start : item[1].stop : item[1].step]
|
|
)
|
|
)
|
|
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
|
|
raise TypeError(msg) # pragma: no cover
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
# PyArrow columns are always strings
|
|
col_name = item[1] if isinstance(item[1], str) else self.columns[item[1]]
|
|
if isinstance(item[0], str): # pragma: no cover
|
|
msg = "Can not slice with tuple with the first element as a str"
|
|
raise TypeError(msg)
|
|
if (isinstance(item[0], slice)) and (item[0] == slice(None)):
|
|
return ArrowSeries(
|
|
self._native_frame[col_name],
|
|
name=col_name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
selected_rows = select_rows(self._native_frame, item[0])
|
|
return ArrowSeries(
|
|
selected_rows[col_name],
|
|
name=col_name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
elif isinstance(item, slice):
|
|
if item.step is not None and item.step != 1:
|
|
msg = "Slicing with step is not supported on PyArrow tables"
|
|
raise NotImplementedError(msg)
|
|
columns = self.columns
|
|
if isinstance(item.start, str) or isinstance(item.stop, str):
|
|
start, stop, step = convert_str_slice_to_int_slice(item, columns)
|
|
return self._from_native_frame(
|
|
self._native_frame.select(columns[start:stop:step])
|
|
)
|
|
start = item.start or 0
|
|
stop = item.stop if item.stop is not None else len(self._native_frame)
|
|
return self._from_native_frame(
|
|
self._native_frame.slice(start, stop - start),
|
|
)
|
|
|
|
elif isinstance(item, Sequence) or (is_numpy_array(item) and item.ndim == 1):
|
|
if (
|
|
isinstance(item, Sequence)
|
|
and all(isinstance(x, str) for x in item)
|
|
and len(item) > 0
|
|
):
|
|
return self._from_native_frame(self._native_frame.select(item))
|
|
if isinstance(item, Sequence) and len(item) == 0:
|
|
return self._from_native_frame(self._native_frame.slice(0, 0))
|
|
return self._from_native_frame(self._native_frame.take(item))
|
|
|
|
else: # pragma: no cover
|
|
msg = f"Expected str or slice, got: {type(item)}"
|
|
raise TypeError(msg)
|
|
|
|
@property
|
|
def schema(self) -> dict[str, DType]:
|
|
schema = self._native_frame.schema
|
|
return {
|
|
name: native_to_narwhals_dtype(dtype, self._dtypes)
|
|
for name, dtype in zip(schema.names, schema.types)
|
|
}
|
|
|
|
def collect_schema(self) -> dict[str, DType]:
|
|
return self.schema
|
|
|
|
@property
|
|
def columns(self) -> list[str]:
|
|
return self._native_frame.schema.names # type: ignore[no-any-return]
|
|
|
|
def select(
|
|
self,
|
|
*exprs: IntoArrowExpr,
|
|
**named_exprs: IntoArrowExpr,
|
|
) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
|
|
if not new_series:
|
|
# return empty dataframe, like Polars does
|
|
return self._from_native_frame(self._native_frame.__class__.from_arrays([]))
|
|
names = [s.name for s in new_series]
|
|
df = pa.Table.from_arrays(
|
|
broadcast_series(new_series),
|
|
names=names,
|
|
)
|
|
return self._from_native_frame(df)
|
|
|
|
def with_columns(
|
|
self,
|
|
*exprs: IntoArrowExpr,
|
|
**named_exprs: IntoArrowExpr,
|
|
) -> Self:
|
|
new_columns = evaluate_into_exprs(self, *exprs, **named_exprs)
|
|
new_column_name_to_new_column_map = {s.name: s for s in new_columns}
|
|
to_concat = []
|
|
output_names = []
|
|
# Make sure to preserve column order
|
|
length = len(self)
|
|
for name in self.columns:
|
|
if name in new_column_name_to_new_column_map:
|
|
to_concat.append(
|
|
validate_dataframe_comparand(
|
|
length=length,
|
|
other=new_column_name_to_new_column_map.pop(name),
|
|
backend_version=self._backend_version,
|
|
)
|
|
)
|
|
else:
|
|
to_concat.append(self._native_frame[name])
|
|
output_names.append(name)
|
|
for s in new_column_name_to_new_column_map:
|
|
to_concat.append(
|
|
validate_dataframe_comparand(
|
|
length=length,
|
|
other=new_column_name_to_new_column_map[s],
|
|
backend_version=self._backend_version,
|
|
)
|
|
)
|
|
output_names.append(s)
|
|
df = self._native_frame.__class__.from_arrays(to_concat, names=output_names)
|
|
return self._from_native_frame(df)
|
|
|
|
def group_by(self, *keys: str, drop_null_keys: bool) -> ArrowGroupBy:
|
|
from narwhals._arrow.group_by import ArrowGroupBy
|
|
|
|
return ArrowGroupBy(self, list(keys), drop_null_keys=drop_null_keys)
|
|
|
|
def join(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
how: Literal["left", "inner", "outer", "cross", "anti", "semi"] = "inner",
|
|
left_on: str | list[str] | None,
|
|
right_on: str | list[str] | None,
|
|
suffix: str,
|
|
) -> Self:
|
|
how_to_join_map = {
|
|
"anti": "left anti",
|
|
"semi": "left semi",
|
|
"inner": "inner",
|
|
"left": "left outer",
|
|
}
|
|
|
|
if how == "cross":
|
|
plx = self.__narwhals_namespace__()
|
|
key_token = generate_temporary_column_name(
|
|
n_bytes=8, columns=[*self.columns, *other.columns]
|
|
)
|
|
|
|
return self._from_native_frame(
|
|
self.with_columns(**{key_token: plx.lit(0, None)})
|
|
._native_frame.join(
|
|
other.with_columns(**{key_token: plx.lit(0, None)})._native_frame,
|
|
keys=key_token,
|
|
right_keys=key_token,
|
|
join_type="inner",
|
|
right_suffix=suffix,
|
|
)
|
|
.drop([key_token]),
|
|
)
|
|
|
|
return self._from_native_frame(
|
|
self._native_frame.join(
|
|
other._native_frame,
|
|
keys=left_on,
|
|
right_keys=right_on,
|
|
join_type=how_to_join_map[how],
|
|
right_suffix=suffix,
|
|
),
|
|
)
|
|
|
|
def join_asof(
|
|
self,
|
|
other: Self,
|
|
*,
|
|
left_on: str | None = None,
|
|
right_on: str | None = None,
|
|
on: str | None = None,
|
|
by_left: str | list[str] | None = None,
|
|
by_right: str | list[str] | None = None,
|
|
by: str | list[str] | None = None,
|
|
strategy: Literal["backward", "forward", "nearest"] = "backward",
|
|
) -> Self:
|
|
msg = "join_asof is not yet supported on PyArrow tables"
|
|
raise NotImplementedError(msg)
|
|
|
|
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
|
|
to_drop = parse_columns_to_drop(
|
|
compliant_frame=self, columns=columns, strict=strict
|
|
)
|
|
return self._from_native_frame(self._native_frame.drop(to_drop))
|
|
|
|
def drop_nulls(self: Self, subset: str | list[str] | None) -> Self:
|
|
if subset is None:
|
|
return self._from_native_frame(self._native_frame.drop_null())
|
|
subset = [subset] if isinstance(subset, str) else subset
|
|
plx = self.__narwhals_namespace__()
|
|
return self.filter(~plx.any_horizontal(plx.col(*subset).is_null()))
|
|
|
|
def sort(
|
|
self,
|
|
by: str | Iterable[str],
|
|
*more_by: str,
|
|
descending: bool | Sequence[bool],
|
|
nulls_last: bool,
|
|
) -> Self:
|
|
flat_keys = flatten([*flatten([by]), *more_by])
|
|
df = self._native_frame
|
|
|
|
if isinstance(descending, bool):
|
|
order = "descending" if descending else "ascending"
|
|
sorting = [(key, order) for key in flat_keys]
|
|
else:
|
|
sorting = [
|
|
(key, "descending" if is_descending else "ascending")
|
|
for key, is_descending in zip(flat_keys, descending)
|
|
]
|
|
|
|
null_placement = "at_end" if nulls_last else "at_start"
|
|
|
|
return self._from_native_frame(df.sort_by(sorting, null_placement=null_placement))
|
|
|
|
def to_pandas(self) -> Any:
|
|
return self._native_frame.to_pandas()
|
|
|
|
def to_numpy(self) -> Any:
|
|
import numpy as np # ignore-banned-import
|
|
|
|
return np.column_stack([col.to_numpy() for col in self._native_frame.columns])
|
|
|
|
def to_dict(self, *, as_series: bool) -> Any:
|
|
df = self._native_frame
|
|
|
|
names_and_values = zip(df.column_names, df.columns)
|
|
if as_series:
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
return {
|
|
name: ArrowSeries(
|
|
col,
|
|
name=name,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
for name, col in names_and_values
|
|
}
|
|
else:
|
|
return {name: col.to_pylist() for name, col in names_and_values}
|
|
|
|
def with_row_index(self, name: str) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
df = self._native_frame
|
|
|
|
row_indices = pa.array(range(df.num_rows))
|
|
return self._from_native_frame(df.append_column(name, row_indices))
|
|
|
|
def filter(
|
|
self,
|
|
*predicates: IntoArrowExpr,
|
|
) -> Self:
|
|
if (
|
|
len(predicates) == 1
|
|
and isinstance(predicates[0], list)
|
|
and all(isinstance(x, bool) for x in predicates[0])
|
|
):
|
|
mask = predicates[0]
|
|
else:
|
|
plx = self.__narwhals_namespace__()
|
|
expr = plx.all_horizontal(*predicates)
|
|
# Safety: all_horizontal's expression only returns a single column.
|
|
mask = expr._call(self)[0]._native_series
|
|
return self._from_native_frame(self._native_frame.filter(mask))
|
|
|
|
def null_count(self) -> Self:
|
|
import pyarrow as pa # ignore-banned-import()
|
|
|
|
df = self._native_frame
|
|
names_and_values = zip(df.column_names, df.columns)
|
|
|
|
return self._from_native_frame(
|
|
pa.table({name: [col.null_count] for name, col in names_and_values})
|
|
)
|
|
|
|
def head(self, n: int) -> Self:
|
|
df = self._native_frame
|
|
if n >= 0:
|
|
return self._from_native_frame(df.slice(0, n))
|
|
else:
|
|
num_rows = df.num_rows
|
|
return self._from_native_frame(df.slice(0, max(0, num_rows + n)))
|
|
|
|
def tail(self, n: int) -> Self:
|
|
df = self._native_frame
|
|
if n >= 0:
|
|
num_rows = df.num_rows
|
|
return self._from_native_frame(df.slice(max(0, num_rows - n)))
|
|
else:
|
|
return self._from_native_frame(df.slice(abs(n)))
|
|
|
|
def lazy(self) -> Self:
|
|
return self
|
|
|
|
def collect(self) -> ArrowDataFrame:
|
|
return ArrowDataFrame(
|
|
self._native_frame,
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def clone(self) -> Self:
|
|
msg = "clone is not yet supported on PyArrow tables"
|
|
raise NotImplementedError(msg)
|
|
|
|
def is_empty(self: Self) -> bool:
|
|
return self.shape[0] == 0
|
|
|
|
def item(self: Self, row: int | None = None, column: int | str | None = None) -> Any:
|
|
if row is None and column is None:
|
|
if self.shape != (1, 1):
|
|
msg = (
|
|
"can only call `.item()` if the dataframe is of shape (1, 1),"
|
|
" or if explicit row/col values are provided;"
|
|
f" frame has shape {self.shape!r}"
|
|
)
|
|
raise ValueError(msg)
|
|
return self._native_frame[0][0]
|
|
|
|
elif row is None or column is None:
|
|
msg = "cannot call `.item()` with only one of `row` or `column`"
|
|
raise ValueError(msg)
|
|
|
|
_col = self.columns.index(column) if isinstance(column, str) else column
|
|
return self._native_frame[_col][row]
|
|
|
|
def rename(self, mapping: dict[str, str]) -> Self:
|
|
df = self._native_frame
|
|
new_cols = [mapping.get(c, c) for c in df.column_names]
|
|
return self._from_native_frame(df.rename_columns(new_cols))
|
|
|
|
def write_parquet(self, file: Any) -> Any:
|
|
import pyarrow.parquet as pp # ignore-banned-import
|
|
|
|
pp.write_table(self._native_frame, file)
|
|
|
|
def write_csv(self, file: Any) -> Any:
|
|
import pyarrow as pa # ignore-banned-import
|
|
import pyarrow.csv as pa_csv # ignore-banned-import
|
|
|
|
pa_table = self._native_frame
|
|
if file is None:
|
|
csv_buffer = pa.BufferOutputStream()
|
|
pa_csv.write_csv(pa_table, csv_buffer)
|
|
return csv_buffer.getvalue().to_pybytes().decode()
|
|
return pa_csv.write_csv(pa_table, file)
|
|
|
|
def is_duplicated(self: Self) -> ArrowSeries:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
df = self._native_frame
|
|
|
|
columns = self.columns
|
|
col_token = generate_temporary_column_name(n_bytes=8, columns=columns)
|
|
row_count = (
|
|
df.append_column(col_token, pa.array(np.arange(len(self))))
|
|
.group_by(columns)
|
|
.aggregate([(col_token, "count")])
|
|
)
|
|
is_duplicated = pc.greater(
|
|
df.join(
|
|
row_count, keys=columns, right_keys=columns, join_type="inner"
|
|
).column(f"{col_token}_count"),
|
|
1,
|
|
)
|
|
return ArrowSeries(
|
|
is_duplicated,
|
|
name="",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def is_unique(self: Self) -> ArrowSeries:
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
from narwhals._arrow.series import ArrowSeries
|
|
|
|
is_duplicated = self.is_duplicated()._native_series
|
|
|
|
return ArrowSeries(
|
|
pc.invert(is_duplicated),
|
|
name="",
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def unique(
|
|
self: Self,
|
|
subset: str | list[str] | None,
|
|
*,
|
|
keep: Literal["any", "first", "last", "none"] = "any",
|
|
maintain_order: bool = False,
|
|
) -> Self:
|
|
"""
|
|
NOTE:
|
|
The param `maintain_order` is only here for compatibility with the polars API
|
|
and has no effect on the output.
|
|
"""
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow as pa # ignore-banned-import()
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
df = self._native_frame
|
|
|
|
if isinstance(subset, str):
|
|
subset = [subset]
|
|
subset = subset or self.columns
|
|
|
|
if keep in {"any", "first", "last"}:
|
|
agg_func_map = {"any": "min", "first": "min", "last": "max"}
|
|
|
|
agg_func = agg_func_map[keep]
|
|
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
|
keep_idx = (
|
|
df.append_column(col_token, pa.array(np.arange(len(self))))
|
|
.group_by(subset)
|
|
.aggregate([(col_token, agg_func)])
|
|
.column(f"{col_token}_{agg_func}")
|
|
)
|
|
|
|
return self._from_native_frame(pc.take(df, keep_idx))
|
|
|
|
keep_idx = self.select(*subset).is_unique()
|
|
return self.filter(keep_idx)
|
|
|
|
def gather_every(self: Self, n: int, offset: int = 0) -> Self:
|
|
return self._from_native_frame(self._native_frame[offset::n])
|
|
|
|
def to_arrow(self: Self) -> Any:
|
|
return self._native_frame
|
|
|
|
def sample(
|
|
self: Self,
|
|
n: int | None = None,
|
|
*,
|
|
fraction: float | None = None,
|
|
with_replacement: bool = False,
|
|
seed: int | None = None,
|
|
) -> Self:
|
|
import numpy as np # ignore-banned-import
|
|
import pyarrow.compute as pc # ignore-banned-import()
|
|
|
|
frame = self._native_frame
|
|
num_rows = len(self)
|
|
if n is None and fraction is not None:
|
|
n = int(num_rows * fraction)
|
|
|
|
rng = np.random.default_rng(seed=seed)
|
|
idx = np.arange(0, num_rows)
|
|
mask = rng.choice(idx, size=n, replace=with_replacement)
|
|
|
|
return self._from_native_frame(pc.take(frame, mask))
|
|
|
|
def unpivot(
|
|
self: Self,
|
|
on: str | list[str] | None,
|
|
index: str | list[str] | None,
|
|
variable_name: str | None,
|
|
value_name: str | None,
|
|
) -> Self:
|
|
import pyarrow as pa # ignore-banned-import
|
|
|
|
native_frame = self._native_frame
|
|
variable_name = variable_name if variable_name is not None else "variable"
|
|
value_name = value_name if value_name is not None else "value"
|
|
|
|
index_: list[str] = (
|
|
[] if index is None else [index] if isinstance(index, str) else index
|
|
)
|
|
on_: list[str] = (
|
|
[c for c in self.columns if c not in index_]
|
|
if on is None
|
|
else [on]
|
|
if isinstance(on, str)
|
|
else on
|
|
)
|
|
|
|
n_rows = len(self)
|
|
|
|
promote_kwargs = (
|
|
{"promote_options": "permissive"}
|
|
if self._backend_version >= (14, 0, 0)
|
|
else {}
|
|
)
|
|
return self._from_native_frame(
|
|
pa.concat_tables(
|
|
[
|
|
pa.Table.from_arrays(
|
|
[
|
|
*[native_frame.column(idx_col) for idx_col in index_],
|
|
pa.array([on_col] * n_rows, pa.string()),
|
|
native_frame.column(on_col),
|
|
],
|
|
names=[*index_, variable_name, value_name],
|
|
)
|
|
for on_col in on_
|
|
],
|
|
**promote_kwargs,
|
|
)
|
|
)
|
|
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
|
|
# upcast numeric to non-numeric (e.g. string) datatypes
|