Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/_polars/dataframe.py
2025-01-10 21:40:35 +00:00

351 lines
13 KiB
Python
Executable File

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.utils import convert_str_slice_to_int_slice
from narwhals._polars.utils import extract_args_kwargs
from narwhals._polars.utils import native_to_narwhals_dtype
from narwhals.utils import Implementation
from narwhals.utils import is_sequence_but_not_str
from narwhals.utils import parse_columns_to_drop
if TYPE_CHECKING:
from types import ModuleType
import numpy as np
from typing_extensions import Self
from narwhals.typing import DTypes
class PolarsDataFrame:
def __init__(
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._dtypes = dtypes
def __repr__(self) -> str: # pragma: no cover
return "PolarsDataFrame"
def __narwhals_dataframe__(self) -> Self:
return self
def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _from_native_frame(self, df: Any) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)
def _from_native_object(self, obj: Any) -> Any:
import polars as pl # ignore-banned-import()
if isinstance(obj, pl.Series):
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
obj, backend_version=self._backend_version, dtypes=self._dtypes
)
if isinstance(obj, pl.DataFrame):
return self._from_native_frame(obj)
# scalar
return obj
def __getattr__(self, attr: str) -> Any:
if attr == "collect": # pragma: no cover
raise AttributeError
if attr == "schema":
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
return self._from_native_object(
getattr(self._native_frame, attr)(*args, **kwargs)
)
return func
def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
if self._backend_version < (0, 20, 28) and copy is not None: # pragma: no cover
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28): # pragma: no cover
return self._native_frame.__array__(dtype)
return self._native_frame.__array__(dtype)
def collect_schema(self) -> dict[str, Any]:
if self._backend_version < (1,): # pragma: no cover
schema = self._native_frame.schema
else:
schema = dict(self._native_frame.collect_schema())
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
@property
def shape(self) -> tuple[int, int]:
return self._native_frame.shape # type: ignore[no-any-return]
def __getitem__(self, item: Any) -> Any:
if self._backend_version > (0, 20, 30):
return self._from_native_object(self._native_frame.__getitem__(item))
else: # pragma: no cover
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
if isinstance(item, tuple):
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item)
columns = self.columns
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
if item[1] == slice(None):
if isinstance(item[0], Sequence) and not len(item[0]):
return self._from_native_frame(self._native_frame[0:0])
return self._from_native_frame(
self._native_frame.__getitem__(item[0])
)
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step]).__getitem__(
item[0]
)
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
import polars as pl # ignore-banned-import()
if (
isinstance(item, tuple)
and (len(item) == 2)
and is_sequence_but_not_str(item[1])
and (len(item[1]) == 0)
):
result = self._native_frame.select(item[1])
elif isinstance(item, slice) and (
isinstance(item.start, str) or isinstance(item.stop, str)
):
start, stop, step = convert_str_slice_to_int_slice(item, columns)
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step])
)
elif is_sequence_but_not_str(item) and (len(item) == 0):
result = self._native_frame.slice(0, 0)
else:
result = self._native_frame.__getitem__(item)
if isinstance(result, pl.Series):
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
result, backend_version=self._backend_version, dtypes=self._dtypes
)
return self._from_native_object(result)
def get_column(self, name: str) -> Any:
from narwhals._polars.series import PolarsSeries
return PolarsSeries(
self._native_frame.get_column(name),
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def is_empty(self) -> bool:
return len(self._native_frame) == 0
@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]
def lazy(self) -> PolarsLazyFrame:
return PolarsLazyFrame(
self._native_frame.lazy(),
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def to_dict(self, *, as_series: bool) -> Any:
df = self._native_frame
if as_series:
from narwhals._polars.series import PolarsSeries
return {
name: PolarsSeries(
col, backend_version=self._backend_version, dtypes=self._dtypes
)
for name, col in df.to_dict(as_series=True).items()
}
else:
return df.to_dict(as_series=False)
def group_by(self, *by: str, drop_null_keys: bool) -> Any:
from narwhals._polars.group_by import PolarsGroupBy
return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys)
def with_row_index(self, name: str) -> Any:
if self._backend_version < (0, 20, 4): # pragma: no cover
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0): # pragma: no cover
to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(to_drop))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._from_native_frame(
self._native_frame.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
class PolarsLazyFrame:
def __init__(
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
) -> None:
self._native_frame = df
self._backend_version = backend_version
self._implementation = Implementation.POLARS
self._dtypes = dtypes
def __repr__(self) -> str: # pragma: no cover
return "PolarsLazyFrame"
def __narwhals_lazyframe__(self) -> Self:
return self
def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
def __native_namespace__(self: Self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _from_native_frame(self, df: Any) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
return self._from_native_frame(
getattr(self._native_frame, attr)(*args, **kwargs)
)
return func
@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]
@property
def schema(self) -> dict[str, Any]:
schema = self._native_frame.schema
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
def collect_schema(self) -> dict[str, Any]:
if self._backend_version < (1,): # pragma: no cover
schema = self._native_frame.schema
else:
schema = dict(self._native_frame.collect_schema())
return {
name: native_to_narwhals_dtype(dtype, self._dtypes)
for name, dtype in schema.items()
}
def collect(self) -> PolarsDataFrame:
return PolarsDataFrame(
self._native_frame.collect(),
backend_version=self._backend_version,
dtypes=self._dtypes,
)
def group_by(self, *by: str, drop_null_keys: bool) -> Any:
from narwhals._polars.group_by import PolarsLazyGroupBy
return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)
def with_row_index(self, name: str) -> Any:
if self._backend_version < (0, 20, 4): # pragma: no cover
return self._from_native_frame(self._native_frame.with_row_count(name))
return self._from_native_frame(self._native_frame.with_row_index(name))
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
if self._backend_version < (1, 0, 0): # pragma: no cover
return self._from_native_frame(self._native_frame.drop(columns))
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
def unpivot(
self: Self,
on: str | list[str] | None,
index: str | list[str] | None,
variable_name: str | None,
value_name: str | None,
) -> Self:
if self._backend_version < (1, 0, 0): # pragma: no cover
return self._from_native_frame(
self._native_frame.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._from_native_frame(
self._native_frame.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)