234 lines
8.0 KiB
Python
Executable File
234 lines
8.0 KiB
Python
Executable File
from __future__ import annotations
|
|
|
|
from typing import TYPE_CHECKING
|
|
from typing import Any
|
|
from typing import Iterable
|
|
from typing import Literal
|
|
from typing import Sequence
|
|
|
|
from narwhals._expression_parsing import parse_into_exprs
|
|
from narwhals._polars.utils import extract_args_kwargs
|
|
from narwhals._polars.utils import narwhals_to_native_dtype
|
|
from narwhals.utils import Implementation
|
|
|
|
if TYPE_CHECKING:
|
|
from narwhals._polars.dataframe import PolarsDataFrame
|
|
from narwhals._polars.dataframe import PolarsLazyFrame
|
|
from narwhals._polars.expr import PolarsExpr
|
|
from narwhals._polars.typing import IntoPolarsExpr
|
|
from narwhals.dtypes import DType
|
|
from narwhals.typing import DTypes
|
|
|
|
|
|
class PolarsNamespace:
|
|
def __init__(self, *, backend_version: tuple[int, ...], dtypes: DTypes) -> None:
|
|
self._backend_version = backend_version
|
|
self._implementation = Implementation.POLARS
|
|
self._dtypes = dtypes
|
|
|
|
def __getattr__(self, attr: str) -> Any:
|
|
import polars as pl # ignore-banned-import
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
def func(*args: Any, **kwargs: Any) -> Any:
|
|
args, kwargs = extract_args_kwargs(args, kwargs) # type: ignore[assignment]
|
|
return PolarsExpr(getattr(pl, attr)(*args, **kwargs), dtypes=self._dtypes)
|
|
|
|
return func
|
|
|
|
def nth(self, *indices: int) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
if self._backend_version < (1, 0, 0): # pragma: no cover
|
|
msg = "`nth` is only supported for Polars>=1.0.0. Please use `col` for columns selection instead."
|
|
raise AttributeError(msg)
|
|
return PolarsExpr(pl.nth(*indices), dtypes=self._dtypes)
|
|
|
|
def len(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
if self._backend_version < (0, 20, 5): # pragma: no cover
|
|
return PolarsExpr(pl.count().alias("len"), dtypes=self._dtypes)
|
|
return PolarsExpr(pl.len(), dtypes=self._dtypes)
|
|
|
|
def concat(
|
|
self,
|
|
items: Sequence[PolarsDataFrame | PolarsLazyFrame],
|
|
*,
|
|
how: Literal["vertical", "horizontal"],
|
|
) -> PolarsDataFrame | PolarsLazyFrame:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.dataframe import PolarsDataFrame
|
|
from narwhals._polars.dataframe import PolarsLazyFrame
|
|
|
|
dfs: list[Any] = [item._native_frame for item in items]
|
|
result = pl.concat(dfs, how=how)
|
|
if isinstance(result, pl.DataFrame):
|
|
return PolarsDataFrame(
|
|
result, backend_version=items[0]._backend_version, dtypes=items[0]._dtypes
|
|
)
|
|
return PolarsLazyFrame(
|
|
result, backend_version=items[0]._backend_version, dtypes=items[0]._dtypes
|
|
)
|
|
|
|
def lit(self, value: Any, dtype: DType | None = None) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
if dtype is not None:
|
|
return PolarsExpr(
|
|
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._dtypes)),
|
|
dtypes=self._dtypes,
|
|
)
|
|
return PolarsExpr(pl.lit(value), dtypes=self._dtypes)
|
|
|
|
def mean(self, *column_names: str) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
if self._backend_version < (0, 20, 4): # pragma: no cover
|
|
return PolarsExpr(pl.mean([*column_names]), dtypes=self._dtypes) # type: ignore[arg-type]
|
|
return PolarsExpr(pl.mean(*column_names), dtypes=self._dtypes)
|
|
|
|
def mean_horizontal(self, *exprs: IntoPolarsExpr) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
polars_exprs = parse_into_exprs(*exprs, namespace=self)
|
|
|
|
if self._backend_version < (0, 20, 8): # pragma: no cover
|
|
return PolarsExpr(
|
|
pl.sum_horizontal(e._native_expr for e in polars_exprs)
|
|
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in polars_exprs),
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
return PolarsExpr(
|
|
pl.mean_horizontal(e._native_expr for e in polars_exprs),
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def concat_str(
|
|
self,
|
|
exprs: Iterable[IntoPolarsExpr],
|
|
*more_exprs: IntoPolarsExpr,
|
|
separator: str = "",
|
|
ignore_nulls: bool = False,
|
|
) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
pl_exprs: list[pl.Expr] = [
|
|
expr._native_expr
|
|
for expr in (
|
|
*parse_into_exprs(*exprs, namespace=self),
|
|
*parse_into_exprs(*more_exprs, namespace=self),
|
|
)
|
|
]
|
|
|
|
if self._backend_version < (0, 20, 6): # pragma: no cover
|
|
null_mask = [expr.is_null() for expr in pl_exprs]
|
|
sep = pl.lit(separator)
|
|
|
|
if not ignore_nulls:
|
|
null_mask_result = pl.any_horizontal(*null_mask)
|
|
output_expr = pl.reduce(
|
|
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
|
|
pl_exprs,
|
|
)
|
|
result = pl.when(~null_mask_result).then(output_expr)
|
|
else:
|
|
init_value, *values = [
|
|
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
|
|
for expr, nm in zip(pl_exprs, null_mask)
|
|
]
|
|
separators = [
|
|
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
|
|
]
|
|
|
|
result = pl.fold( # type: ignore[assignment]
|
|
acc=init_value,
|
|
function=lambda x, y: x + y,
|
|
exprs=[s + v for s, v in zip(separators, values)],
|
|
)
|
|
|
|
return PolarsExpr(
|
|
result,
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
return PolarsExpr(
|
|
pl.concat_str(
|
|
pl_exprs,
|
|
separator=separator,
|
|
ignore_nulls=ignore_nulls,
|
|
),
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
@property
|
|
def selectors(self) -> PolarsSelectors:
|
|
return PolarsSelectors(self._dtypes)
|
|
|
|
|
|
class PolarsSelectors:
|
|
def __init__(self, dtypes: DTypes) -> None:
|
|
self._dtypes = dtypes
|
|
|
|
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(
|
|
pl.selectors.by_dtype(
|
|
[narwhals_to_native_dtype(dtype, self._dtypes) for dtype in dtypes]
|
|
),
|
|
dtypes=self._dtypes,
|
|
)
|
|
|
|
def numeric(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(pl.selectors.numeric(), dtypes=self._dtypes)
|
|
|
|
def boolean(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(pl.selectors.boolean(), dtypes=self._dtypes)
|
|
|
|
def string(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(pl.selectors.string(), dtypes=self._dtypes)
|
|
|
|
def categorical(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(pl.selectors.categorical(), dtypes=self._dtypes)
|
|
|
|
def all(self) -> PolarsExpr:
|
|
import polars as pl # ignore-banned-import()
|
|
|
|
from narwhals._polars.expr import PolarsExpr
|
|
|
|
return PolarsExpr(pl.selectors.all(), dtypes=self._dtypes)
|