Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_polars/utils.py
2025-05-08 21:10:14 -05:00

252 lines
8.6 KiB
Python

from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import TypeVar
from typing import overload
import polars as pl
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import ComputeError
from narwhals.exceptions import DuplicateError
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import NarwhalsError
from narwhals.exceptions import ShapeError
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
if TYPE_CHECKING:
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals.dtypes import DType
from narwhals.utils import Version
T = TypeVar("T")
@overload
def extract_native(obj: PolarsDataFrame) -> pl.DataFrame: ...
@overload
def extract_native(obj: PolarsLazyFrame) -> pl.LazyFrame: ...
@overload
def extract_native(obj: PolarsSeries) -> pl.Series: ...
@overload
def extract_native(obj: PolarsExpr) -> pl.Expr: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(
obj: PolarsDataFrame | PolarsLazyFrame | PolarsSeries | PolarsExpr | T,
) -> pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr | T:
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
if isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)):
return obj.native
return obj
def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, Any]]:
return [extract_native(arg) for arg in args], {
k: extract_native(v) for k, v in kwargs.items()
}
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(
dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
) -> DType:
dtypes = import_dtypes_module(version)
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if dtype == pl.Enum:
return dtypes.Enum()
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(
native_to_narwhals_dtype(dtype.inner, version, backend_version)
)
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
)
if dtype == pl.Decimal:
return dtypes.Decimal()
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
return dtypes.Binary()
return dtypes.Unknown()
def narwhals_to_native_dtype(
dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...]
) -> pl.DataType:
dtypes = import_dtypes_module(version)
if dtype == dtypes.Float64:
return pl.Float64()
if dtype == dtypes.Float32:
return pl.Float32()
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if dtype == dtypes.Int64:
return pl.Int64()
if dtype == dtypes.Int32:
return pl.Int32()
if dtype == dtypes.Int16:
return pl.Int16()
if dtype == dtypes.Int8:
return pl.Int8()
if dtype == dtypes.UInt64:
return pl.UInt64()
if dtype == dtypes.UInt32:
return pl.UInt32()
if dtype == dtypes.UInt16:
return pl.UInt16()
if dtype == dtypes.UInt8:
return pl.UInt8()
if dtype == dtypes.String:
return pl.String()
if dtype == dtypes.Boolean:
return pl.Boolean()
if dtype == dtypes.Object: # pragma: no cover
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if dtype == dtypes.Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.Time:
return pl.Time()
if dtype == dtypes.Binary:
return pl.Binary()
if dtype == dtypes.Decimal:
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(
field.name,
narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
return pl.Array(
narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
)
return pl.Unknown() # pragma: no cover
def convert_str_slice_to_int_slice(
str_slice: slice, columns: list[str]
) -> tuple[int | None, int | None, int | None]: # pragma: no cover
start = columns.index(str_slice.start) if str_slice.start is not None else None
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)
def catch_polars_exception(
exception: Exception, backend_version: tuple[int, ...]
) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
elif isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
elif isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
elif isinstance(exception, pl.exceptions.DuplicateError):
return DuplicateError(str(exception))
elif isinstance(exception, pl.exceptions.ComputeError):
return ComputeError(str(exception))
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
# Old versions of Polars didn't have PolarsError.
return NarwhalsError(str(exception)) # pragma: no cover
elif backend_version < (1,) and "polars.exceptions" in str(
type(exception)
): # pragma: no cover
# Last attempt, for old Polars versions.
return NarwhalsError(str(exception))
# Just return exception as-is.
return exception