Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_polars/dataframe.py
2025-05-08 21:10:14 -05:00

708 lines
25 KiB
Python

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterator
from typing import Literal
from typing import Mapping
from typing import Sequence
from typing import cast
from typing import overload
import polars as pl
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import catch_polars_exception
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import native_to_narwhals_dtype
from narwhals.exceptions import ColumnNotFoundError
from narwhals.utils import Implementation
from narwhals.utils import _into_arrow_table
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version
from narwhals.utils import validate_backend_version
if TYPE_CHECKING:
from types import ModuleType
from typing import Callable
from typing import TypeVar
import pandas as pd
import pyarrow as pa
from typing_extensions import Self
from typing_extensions import TypeAlias
from narwhals._polars.group_by import PolarsGroupBy
from narwhals._polars.group_by import PolarsLazyGroupBy
from narwhals._polars.series import PolarsSeries
from narwhals._translate import IntoArrowTable
from narwhals.dtypes import DType
from narwhals.schema import Schema
from narwhals.typing import CompliantDataFrame
from narwhals.typing import CompliantLazyFrame
from narwhals.typing import _2DArray
from narwhals.utils import Version
from narwhals.utils import _FullContext
T = TypeVar("T")
R = TypeVar("R")
Method: TypeAlias = "Callable[..., R]"
"""Generic alias representing all methods implemented via `__getattr__`.
Where `R` is the return type.
"""
class PolarsDataFrame:
clone: Method[Self]
collect: Method[CompliantDataFrame[Any, Any, Any]]
drop_nulls: Method[Self]
estimated_size: Method[int | float]
explode: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
item: Method[Any]
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
is_unique: Method[PolarsSeries]
join_asof: Method[Self]
rename: Method[Self]
row: Method[tuple[Any, ...]]
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
sample: Method[Self]
select: Method[Self]
sort: Method[Self]
to_arrow: Method[pa.Table]
to_numpy: Method[_2DArray]
to_pandas: Method[pd.DataFrame]
unique: Method[Self]
with_columns: Method[Self]
# NOTE: `write_csv` requires an `@overload` for `str | None`
# Can't do that here 😟
write_csv: Method[Any]
write_parquet: Method[None]
def __init__(
self: Self,
df: pl.DataFrame,
*,
backend_version: tuple[int, ...],
version: Version,
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._version = version
validate_backend_version(self._implementation, self._backend_version)
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _FullContext) -> Self:
backend_version = context._backend_version
if backend_version >= (1, 3):
native = pl.DataFrame(data)
else:
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
return cls(native, backend_version=backend_version, version=context._version)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _FullContext,
schema: Mapping[str, DType] | Schema | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = Schema(schema).to_polars() if schema is not None else schema
native = pl.from_dict(data, pl_schema)
return cls(
native, backend_version=context._backend_version, version=context._version
)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _FullContext, # NOTE: Maybe only `Implementation`?
schema: Mapping[str, DType] | Schema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = (
Schema(schema).to_polars()
if isinstance(schema, (Mapping, Schema))
else schema
)
native = pl.from_numpy(data, pl_schema)
return cls(
native, backend_version=context._backend_version, version=context._version
)
@property
def native(self) -> pl.DataFrame:
return self._native_frame
def __repr__(self: Self) -> str: # pragma: no cover
return "PolarsDataFrame"
def __narwhals_dataframe__(self: Self) -> Self:
return self
def __narwhals_namespace__(self: Self) -> PolarsNamespace:
return PolarsNamespace(
backend_version=self._backend_version, version=self._version
)
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_version(self: Self, version: Version) -> Self:
return self.__class__(
self.native, backend_version=self._backend_version, version=version
)
def _with_native(self: Self, df: pl.DataFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, version=self._version
)
@overload
def _from_native_object(self: Self, obj: pl.Series) -> PolarsSeries: ...
@overload
def _from_native_object(self: Self, obj: pl.DataFrame) -> Self: ...
@overload
def _from_native_object(self: Self, obj: T) -> T: ...
def _from_native_object(
self: Self, obj: pl.Series | pl.DataFrame | T
) -> Self | PolarsSeries | T:
if isinstance(obj, pl.Series):
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
obj, backend_version=self._backend_version, version=self._version
)
if isinstance(obj, pl.DataFrame):
return self._with_native(obj)
# scalar
return obj
def __len__(self) -> int:
return len(self.native)
def head(self, n: int) -> Self:
return self._with_native(self.native.head(n))
def tail(self, n: int) -> Self:
return self._with_native(self.native.tail(n))
def __getattr__(self: Self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
try:
return self._from_native_object(
getattr(self.native, attr)(*args, **kwargs)
)
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
return func
def __array__(
self: Self, dtype: Any | None = None, *, copy: bool | None = None
) -> _2DArray:
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28):
return self.native.__array__(dtype)
return self.native.__array__(dtype)
def collect_schema(self: Self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in self.native.schema.items()
}
else:
collected_schema = self.native.collect_schema()
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in collected_schema.items()
}
@property
def shape(self: Self) -> tuple[int, int]:
return self.native.shape
def __getitem__(self: Self, item: Any) -> Any:
if self._backend_version > (0, 20, 30):
return self._from_native_object(self.native.__getitem__(item))
else: # pragma: no cover
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
if isinstance(item, tuple):
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item)
columns = self.columns
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
if item[1] == slice(None):
if isinstance(item[0], Sequence) and not len(item[0]):
return self._with_native(self.native[0:0])
return self._with_native(self.native.__getitem__(item[0]))
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
return self._with_native(
self.native.select(columns[start:stop:step]).__getitem__(item[0])
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._with_native(
self.native.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
if (
isinstance(item, tuple)
and (len(item) == 2)
and is_sequence_but_not_str(item[1])
and (len(item[1]) == 0)
):
result = self.native.select(item[1])
elif isinstance(item, slice) and (
isinstance(item.start, str) or isinstance(item.stop, str)
):
start, stop, step = convert_str_slice_to_int_slice(item, columns)
return self._with_native(self.native.select(columns[start:stop:step]))
elif is_sequence_but_not_str(item) and (len(item) == 0):
result = self.native.slice(0, 0)
else:
result = self.native.__getitem__(item)
if isinstance(result, pl.Series):
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
result, backend_version=self._backend_version, version=self._version
)
return self._from_native_object(result)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self: Self, *exprs: Any) -> Self:
return self.select(*exprs)
def get_column(self: Self, name: str) -> PolarsSeries:
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
self.native.get_column(name),
backend_version=self._backend_version,
version=self._version,
)
def iter_columns(self) -> Iterator[PolarsSeries]:
from narwhals._polars.series import PolarsSeries
for series in self.native.iter_columns():
yield PolarsSeries(
series, backend_version=self._backend_version, version=self._version
)
@property
def columns(self: Self) -> list[str]:
return self.native.columns
@property
def schema(self: Self) -> dict[str, DType]:
return {
name: native_to_narwhals_dtype(dtype, self._version, self._backend_version)
for name, dtype in self.native.schema.items()
}
def lazy(
self: Self, *, backend: Implementation | None = None
) -> CompliantLazyFrame[Any, Any]:
if backend is None or backend is Implementation.POLARS:
return PolarsLazyFrame(
self.native.lazy(),
backend_version=self._backend_version,
version=self._version,
)
elif backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
# NOTE: (F841) is a false positive
df = self.native # noqa: F841
return DuckDBLazyFrame(
duckdb.table("df"),
backend_version=parse_version(duckdb),
version=self._version,
)
elif backend is Implementation.DASK:
import dask # ignore-banned-import
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
backend_version=parse_version(dask),
version=self._version,
)
raise AssertionError # pragma: no cover
@overload
def to_dict(self: Self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
@overload
def to_dict(self: Self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self: Self, *, as_series: bool
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
if as_series:
from narwhals._polars.series import PolarsSeries
return {
name: PolarsSeries(
col, backend_version=self._backend_version, version=self._version
)
for name, col in self.native.to_dict().items()
}
else:
return self.native.to_dict(as_series=False)
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsGroupBy:
from narwhals._polars.group_by import PolarsGroupBy
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
def with_row_index(self: Self, name: str) -> Self:
if self._backend_version < (0, 20, 4):
return self._with_native(self.native.with_row_count(name))
return self._with_native(self.native.with_row_index(name))
def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._with_native(self.native.drop(to_drop))
def unpivot(
self: Self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._with_native(
self.native.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
def pivot(
self: Self,
on: list[str],
*,
index: list[str] | None,
values: list[str] | None,
aggregate_function: Literal[
"min", "max", "first", "last", "sum", "mean", "median", "len"
]
| None,
sort_columns: bool,
separator: str,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
msg = "`pivot` is only supported for Polars>=1.0.0"
raise NotImplementedError(msg)
try:
result = self.native.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
return self._from_native_object(result)
def to_polars(self: Self) -> pl.DataFrame:
return self.native
def join(
self: Self,
other: Self,
*,
how: Literal["inner", "left", "full", "cross", "semi", "anti"],
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = (
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
)
try:
return self._with_native(
self.native.join(
other=other.native,
how=how_native, # type: ignore[arg-type]
left_on=left_on,
right_on=right_on,
suffix=suffix,
)
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
class PolarsLazyFrame:
drop_nulls: Method[Self]
explode: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
join_asof: Method[Self]
rename: Method[Self]
select: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
with_columns: Method[Self]
# NOTE: Temporary, just trying to factor out utils
_evaluate_expr: Any
def __init__(
self: Self,
df: pl.LazyFrame,
*,
backend_version: tuple[int, ...],
version: Version,
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._version = version
validate_backend_version(self._implementation, self._backend_version)
def __repr__(self: Self) -> str: # pragma: no cover
return "PolarsLazyFrame"
def __narwhals_lazyframe__(self: Self) -> Self:
return self
def __narwhals_namespace__(self: Self) -> PolarsNamespace:
return PolarsNamespace(
backend_version=self._backend_version, version=self._version
)
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_native(self: Self, df: pl.LazyFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, version=self._version
)
def _with_version(self: Self, version: Version) -> Self:
return self.__class__(
self.native, backend_version=self._backend_version, version=version
)
def __getattr__(self: Self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
try:
return self._with_native(getattr(self.native, attr)(*args, **kwargs))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
raise ColumnNotFoundError(str(e)) from e
return func
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
yield from self.collect(self._implementation).iter_columns()
@property
def native(self) -> pl.LazyFrame:
return self._native_frame
@property
def columns(self: Self) -> list[str]:
return self.native.columns
@property
def schema(self: Self) -> dict[str, DType]:
schema = self.native.schema
return {
name: native_to_narwhals_dtype(dtype, self._version, self._backend_version)
for name, dtype in schema.items()
}
def collect_schema(self: Self) -> dict[str, DType]:
if self._backend_version < (1,):
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in self.native.schema.items()
}
else:
try:
collected_schema = self.native.collect_schema()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
return {
name: native_to_narwhals_dtype(
dtype, self._version, self._backend_version
)
for name, dtype in collected_schema.items()
}
def collect(
self: Self,
backend: Implementation | None,
**kwargs: Any,
) -> CompliantDataFrame[Any, Any, Any]:
try:
result = self.native.collect(**kwargs)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e, self._backend_version) from None
if backend is None or backend is Implementation.POLARS:
return PolarsDataFrame(
result, backend_version=self._backend_version, version=self._version
)
if backend is Implementation.PANDAS:
import pandas as pd # ignore-banned-import
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result.to_pandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd),
version=self._version,
validate_column_names=False,
)
if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
result.to_arrow(),
backend_version=parse_version(pa),
version=self._version,
validate_column_names=False,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def group_by(self: Self, *keys: str, drop_null_keys: bool) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def with_row_index(self: Self, name: str) -> Self:
if self._backend_version < (0, 20, 4):
return self._with_native(self.native.with_row_count(name))
return self._with_native(self.native.with_row_index(name))
def drop(self: Self, columns: Sequence[str], *, strict: bool) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(self.native.drop(columns))
return self._with_native(self.native.drop(columns, strict=strict))
def unpivot(
self: Self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._with_native(
self.native.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self: Self, *exprs: Any) -> Self:
return self.select(*exprs)
def join(
self: Self,
other: Self,
*,
how: Literal["inner", "left", "full", "cross", "semi", "anti"],
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = (
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
)
return self._with_native(
self.native.join(
other=other.native,
how=how_native, # type: ignore[arg-type]
left_on=left_on,
right_on=right_on,
suffix=suffix,
)
)