Files
Buffteks-Website/venv/lib/python3.12/site-packages/narwhals/_polars/utils.py
2025-05-08 21:10:14 -05:00

248 lines
8.8 KiB
Python

from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Iterator
from typing import Mapping
from typing import TypeVar
from typing import cast
from typing import overload
import polars as pl
from narwhals.exceptions import ColumnNotFoundError
from narwhals.exceptions import ComputeError
from narwhals.exceptions import DuplicateError
from narwhals.exceptions import InvalidOperationError
from narwhals.exceptions import NarwhalsError
from narwhals.exceptions import ShapeError
from narwhals.utils import Version
from narwhals.utils import _DeferredIterable
from narwhals.utils import isinstance_or_issubclass
if TYPE_CHECKING:
from typing_extensions import TypeIs
from narwhals.dtypes import DType
from narwhals.utils import _StoresNative
T = TypeVar("T")
NativeT = TypeVar(
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
)
@overload
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
return obj.native if _is_compliant_polars(obj) else obj
def _is_compliant_polars(
obj: _StoresNative[NativeT] | Any,
) -> TypeIs[_StoresNative[NativeT]]:
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
def extract_args_kwargs(
args: Iterable[Any], kwds: Mapping[str, Any], /
) -> tuple[Iterator[Any], dict[str, Any]]:
it_args = (extract_native(arg) for arg in args)
return it_args, {k: extract_native(v) for k, v in kwds.items()}
@lru_cache(maxsize=16)
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: pl.DataType, version: Version, backend_version: tuple[int, ...]
) -> DType:
dtypes = version.dtypes
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if isinstance_or_issubclass(dtype, pl.Enum):
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
categories = _DeferredIterable(
dtype.categories.to_list
if backend_version >= (0, 20, 4)
else lambda: cast("list[str]", dtype.categories)
)
return dtypes.Enum(categories)
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(
native_to_narwhals_dtype(dtype.inner, version, backend_version)
)
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape
)
if dtype == pl.Decimal:
return dtypes.Decimal()
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
return dtypes.Binary()
return dtypes.Unknown()
def narwhals_to_native_dtype( # noqa: C901, PLR0912
dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...]
) -> pl.DataType:
dtypes = version.dtypes
if dtype == dtypes.Float64:
return pl.Float64()
if dtype == dtypes.Float32:
return pl.Float32()
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if dtype == dtypes.Int64:
return pl.Int64()
if dtype == dtypes.Int32:
return pl.Int32()
if dtype == dtypes.Int16:
return pl.Int16()
if dtype == dtypes.Int8:
return pl.Int8()
if dtype == dtypes.UInt64:
return pl.UInt64()
if dtype == dtypes.UInt32:
return pl.UInt32()
if dtype == dtypes.UInt16:
return pl.UInt16()
if dtype == dtypes.UInt8:
return pl.UInt8()
if dtype == dtypes.String:
return pl.String()
if dtype == dtypes.Boolean:
return pl.Boolean()
if dtype == dtypes.Object: # pragma: no cover
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return pl.Enum(dtype.categories)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.Time:
return pl.Time()
if dtype == dtypes.Binary:
return pl.Binary()
if dtype == dtypes.Decimal:
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(
field.name,
narwhals_to_native_dtype(field.dtype, version, backend_version),
)
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size}
return pl.Array(
narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs
)
return pl.Unknown() # pragma: no cover
def catch_polars_exception(
exception: Exception, backend_version: tuple[int, ...]
) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
elif isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
elif isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
elif isinstance(exception, pl.exceptions.DuplicateError):
return DuplicateError(str(exception))
elif isinstance(exception, pl.exceptions.ComputeError):
return ComputeError(str(exception))
if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError):
# Old versions of Polars didn't have PolarsError.
return NarwhalsError(str(exception)) # pragma: no cover
elif backend_version < (1,) and "polars.exceptions" in str(
type(exception)
): # pragma: no cover
# Last attempt, for old Polars versions.
return NarwhalsError(str(exception))
# Just return exception as-is.
return exception