351 lines
13 KiB
Python
Executable File
351 lines
13 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Sequence
|
|
|
|
from narwhals._polars.namespace import PolarsNamespace
|
|
from narwhals._polars.utils import convert_str_slice_to_int_slice
|
|
from narwhals._polars.utils import extract_args_kwargs
|
|
from narwhals._polars.utils import native_to_narwhals_dtype
|
|
from narwhals.utils import Implementation
|
|
from narwhals.utils import is_sequence_but_not_str
|
|
from narwhals.utils import parse_columns_to_drop
|
|
|
|
if TYPE_CHECKING:
|
|
from types import ModuleType
|
|
|
|
import numpy as np
|
|
from typing_extensions import Self
|
|
|
|
from narwhals.typing import DTypes
|
|
|
|
|
|
class PolarsDataFrame:
|
|
def __init__(
|
|
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
|
|
) -> None:
|
|
self._native_frame = df
|
|
self._backend_version = backend_version
|
|
self._implementation = Implementation.POLARS
|
|
self._dtypes = dtypes
|
|
|
|
def __repr__(self) -> str: # pragma: no cover
|
|
return "PolarsDataFrame"
|
|
|
|
def __narwhals_dataframe__(self) -> Self:
|
|
return self
|
|
|
|
def __narwhals_namespace__(self) -> PolarsNamespace:
|
|
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
|
|
def __native_namespace__(self: Self) -> ModuleType:
|
|
if self._implementation is Implementation.POLARS:
|
|
return self._implementation.to_native_namespace()
|
|
|
|
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
|
raise AssertionError(msg)
|
|
|
|
def _from_native_frame(self, df: Any) -> Self:
|
|
return self.__class__(
|
|
df, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def _from_native_object(self, obj: Any) -> Any:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
if isinstance(obj, pl.Series):
|
|
from narwhals._polars.series import PolarsSeries
|
|
|
|
return PolarsSeries(
|
|
obj, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
if isinstance(obj, pl.DataFrame):
|
|
return self._from_native_frame(obj)
|
|
# scalar
|
|
return obj
|
|
|
|
def __getattr__(self, attr: str) -> Any:
|
|
if attr == "collect": # pragma: no cover
|
|
raise AttributeError
|
|
if attr == "schema":
|
|
schema = self._native_frame.schema
|
|
return {
|
|
name: native_to_narwhals_dtype(dtype, self._dtypes)
|
|
for name, dtype in schema.items()
|
|
}
|
|
|
|
def func(*args: Any, **kwargs: Any) -> Any:
|
|
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
|
|
return self._from_native_object(
|
|
getattr(self._native_frame, attr)(*args, **kwargs)
|
|
)
|
|
|
|
return func
|
|
|
|
def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
|
|
if self._backend_version < (0, 20, 28) and copy is not None: # pragma: no cover
|
|
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
|
|
raise NotImplementedError(msg)
|
|
if self._backend_version < (0, 20, 28): # pragma: no cover
|
|
return self._native_frame.__array__(dtype)
|
|
return self._native_frame.__array__(dtype)
|
|
|
|
def collect_schema(self) -> dict[str, Any]:
|
|
if self._backend_version < (1,): # pragma: no cover
|
|
schema = self._native_frame.schema
|
|
else:
|
|
schema = dict(self._native_frame.collect_schema())
|
|
return {
|
|
name: native_to_narwhals_dtype(dtype, self._dtypes)
|
|
for name, dtype in schema.items()
|
|
}
|
|
|
|
@property
|
|
def shape(self) -> tuple[int, int]:
|
|
return self._native_frame.shape # type: ignore[no-any-return]
|
|
|
|
def __getitem__(self, item: Any) -> Any:
|
|
if self._backend_version > (0, 20, 30):
|
|
return self._from_native_object(self._native_frame.__getitem__(item))
|
|
else: # pragma: no cover
|
|
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
|
|
# Polars version we support
|
|
if isinstance(item, tuple):
|
|
item = tuple(list(i) if is_sequence_but_not_str(i) else i for i in item)
|
|
|
|
columns = self.columns
|
|
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
|
|
if item[1] == slice(None):
|
|
if isinstance(item[0], Sequence) and not len(item[0]):
|
|
return self._from_native_frame(self._native_frame[0:0])
|
|
return self._from_native_frame(
|
|
self._native_frame.__getitem__(item[0])
|
|
)
|
|
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
|
|
start, stop, step = convert_str_slice_to_int_slice(item[1], columns)
|
|
return self._from_native_frame(
|
|
self._native_frame.select(columns[start:stop:step]).__getitem__(
|
|
item[0]
|
|
)
|
|
)
|
|
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
|
|
return self._from_native_frame(
|
|
self._native_frame.select(
|
|
columns[item[1].start : item[1].stop : item[1].step]
|
|
).__getitem__(item[0])
|
|
)
|
|
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
|
|
raise TypeError(msg) # pragma: no cover
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
if (
|
|
isinstance(item, tuple)
|
|
and (len(item) == 2)
|
|
and is_sequence_but_not_str(item[1])
|
|
and (len(item[1]) == 0)
|
|
):
|
|
result = self._native_frame.select(item[1])
|
|
elif isinstance(item, slice) and (
|
|
isinstance(item.start, str) or isinstance(item.stop, str)
|
|
):
|
|
start, stop, step = convert_str_slice_to_int_slice(item, columns)
|
|
return self._from_native_frame(
|
|
self._native_frame.select(columns[start:stop:step])
|
|
)
|
|
elif is_sequence_but_not_str(item) and (len(item) == 0):
|
|
result = self._native_frame.slice(0, 0)
|
|
else:
|
|
result = self._native_frame.__getitem__(item)
|
|
if isinstance(result, pl.Series):
|
|
from narwhals._polars.series import PolarsSeries
|
|
|
|
return PolarsSeries(
|
|
result, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
return self._from_native_object(result)
|
|
|
|
def get_column(self, name: str) -> Any:
|
|
from narwhals._polars.series import PolarsSeries
|
|
|
|
return PolarsSeries(
|
|
self._native_frame.get_column(name),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def is_empty(self) -> bool:
|
|
return len(self._native_frame) == 0
|
|
|
|
@property
|
|
def columns(self) -> list[str]:
|
|
return self._native_frame.columns # type: ignore[no-any-return]
|
|
|
|
def lazy(self) -> PolarsLazyFrame:
|
|
return PolarsLazyFrame(
|
|
self._native_frame.lazy(),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def to_dict(self, *, as_series: bool) -> Any:
|
|
df = self._native_frame
|
|
|
|
if as_series:
|
|
from narwhals._polars.series import PolarsSeries
|
|
|
|
return {
|
|
name: PolarsSeries(
|
|
col, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
for name, col in df.to_dict(as_series=True).items()
|
|
}
|
|
else:
|
|
return df.to_dict(as_series=False)
|
|
|
|
def group_by(self, *by: str, drop_null_keys: bool) -> Any:
|
|
from narwhals._polars.group_by import PolarsGroupBy
|
|
|
|
return PolarsGroupBy(self, list(by), drop_null_keys=drop_null_keys)
|
|
|
|
def with_row_index(self, name: str) -> Any:
|
|
if self._backend_version < (0, 20, 4): # pragma: no cover
|
|
return self._from_native_frame(self._native_frame.with_row_count(name))
|
|
return self._from_native_frame(self._native_frame.with_row_index(name))
|
|
|
|
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
|
|
if self._backend_version < (1, 0, 0): # pragma: no cover
|
|
to_drop = parse_columns_to_drop(
|
|
compliant_frame=self, columns=columns, strict=strict
|
|
)
|
|
return self._from_native_frame(self._native_frame.drop(to_drop))
|
|
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
|
|
|
|
def unpivot(
|
|
self: Self,
|
|
on: str | list[str] | None,
|
|
index: str | list[str] | None,
|
|
variable_name: str | None,
|
|
value_name: str | None,
|
|
) -> Self:
|
|
if self._backend_version < (1, 0, 0): # pragma: no cover
|
|
return self._from_native_frame(
|
|
self._native_frame.melt(
|
|
id_vars=index,
|
|
value_vars=on,
|
|
variable_name=variable_name,
|
|
value_name=value_name,
|
|
)
|
|
)
|
|
return self._from_native_frame(
|
|
self._native_frame.unpivot(
|
|
on=on, index=index, variable_name=variable_name, value_name=value_name
|
|
)
|
|
)
|
|
|
|
|
|
class PolarsLazyFrame:
|
|
def __init__(
|
|
self, df: Any, *, backend_version: tuple[int, ...], dtypes: DTypes
|
|
) -> None:
|
|
self._native_frame = df
|
|
self._backend_version = backend_version
|
|
self._implementation = Implementation.POLARS
|
|
self._dtypes = dtypes
|
|
|
|
def __repr__(self) -> str: # pragma: no cover
|
|
return "PolarsLazyFrame"
|
|
|
|
def __narwhals_lazyframe__(self) -> Self:
|
|
return self
|
|
|
|
def __narwhals_namespace__(self) -> PolarsNamespace:
|
|
return PolarsNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
|
|
|
|
def __native_namespace__(self: Self) -> ModuleType:
|
|
if self._implementation is Implementation.POLARS:
|
|
return self._implementation.to_native_namespace()
|
|
|
|
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
|
raise AssertionError(msg)
|
|
|
|
def _from_native_frame(self, df: Any) -> Self:
|
|
return self.__class__(
|
|
df, backend_version=self._backend_version, dtypes=self._dtypes
|
|
)
|
|
|
|
def __getattr__(self, attr: str) -> Any:
|
|
def func(*args: Any, **kwargs: Any) -> Any:
|
|
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
|
|
return self._from_native_frame(
|
|
getattr(self._native_frame, attr)(*args, **kwargs)
|
|
)
|
|
|
|
return func
|
|
|
|
@property
|
|
def columns(self) -> list[str]:
|
|
return self._native_frame.columns # type: ignore[no-any-return]
|
|
|
|
@property
|
|
def schema(self) -> dict[str, Any]:
|
|
schema = self._native_frame.schema
|
|
return {
|
|
name: native_to_narwhals_dtype(dtype, self._dtypes)
|
|
for name, dtype in schema.items()
|
|
}
|
|
|
|
def collect_schema(self) -> dict[str, Any]:
|
|
if self._backend_version < (1,): # pragma: no cover
|
|
schema = self._native_frame.schema
|
|
else:
|
|
schema = dict(self._native_frame.collect_schema())
|
|
return {
|
|
name: native_to_narwhals_dtype(dtype, self._dtypes)
|
|
for name, dtype in schema.items()
|
|
}
|
|
|
|
def collect(self) -> PolarsDataFrame:
|
|
return PolarsDataFrame(
|
|
self._native_frame.collect(),
|
|
backend_version=self._backend_version,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def group_by(self, *by: str, drop_null_keys: bool) -> Any:
|
|
from narwhals._polars.group_by import PolarsLazyGroupBy
|
|
|
|
return PolarsLazyGroupBy(self, list(by), drop_null_keys=drop_null_keys)
|
|
|
|
def with_row_index(self, name: str) -> Any:
|
|
if self._backend_version < (0, 20, 4): # pragma: no cover
|
|
return self._from_native_frame(self._native_frame.with_row_count(name))
|
|
return self._from_native_frame(self._native_frame.with_row_index(name))
|
|
|
|
def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
|
|
if self._backend_version < (1, 0, 0): # pragma: no cover
|
|
return self._from_native_frame(self._native_frame.drop(columns))
|
|
return self._from_native_frame(self._native_frame.drop(columns, strict=strict))
|
|
|
|
def unpivot(
|
|
self: Self,
|
|
on: str | list[str] | None,
|
|
index: str | list[str] | None,
|
|
variable_name: str | None,
|
|
value_name: str | None,
|
|
) -> Self:
|
|
if self._backend_version < (1, 0, 0): # pragma: no cover
|
|
return self._from_native_frame(
|
|
self._native_frame.melt(
|
|
id_vars=index,
|
|
value_vars=on,
|
|
variable_name=variable_name,
|
|
value_name=value_name,
|
|
)
|
|
)
|
|
return self._from_native_frame(
|
|
self._native_frame.unpivot(
|
|
on=on, index=index, variable_name=variable_name, value_name=value_name
|
|
)
|
|
)
|