from __future__ import annotations from functools import lru_cache from typing import TYPE_CHECKING from typing import Any from typing import TypeVar from typing import overload import polars as pl from narwhals.exceptions import ColumnNotFoundError from narwhals.exceptions import ComputeError from narwhals.exceptions import DuplicateError from narwhals.exceptions import InvalidOperationError from narwhals.exceptions import NarwhalsError from narwhals.exceptions import ShapeError from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries from narwhals.dtypes import DType from narwhals.utils import Version T = TypeVar("T") @overload def extract_native(obj: PolarsDataFrame) -> pl.DataFrame: ... @overload def extract_native(obj: PolarsLazyFrame) -> pl.LazyFrame: ... @overload def extract_native(obj: PolarsSeries) -> pl.Series: ... @overload def extract_native(obj: PolarsExpr) -> pl.Expr: ... @overload def extract_native(obj: T) -> T: ... def extract_native( obj: PolarsDataFrame | PolarsLazyFrame | PolarsSeries | PolarsExpr | T, ) -> pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr | T: from narwhals._polars.dataframe import PolarsDataFrame from narwhals._polars.dataframe import PolarsLazyFrame from narwhals._polars.expr import PolarsExpr from narwhals._polars.series import PolarsSeries if isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr)): return obj.native return obj def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, Any]]: return [extract_native(arg) for arg in args], { k: extract_native(v) for k, v in kwargs.items() } @lru_cache(maxsize=16) def native_to_narwhals_dtype( dtype: pl.DataType, version: Version, backend_version: tuple[int, ...] ) -> DType: dtypes = import_dtypes_module(version) if dtype == pl.Float64: return dtypes.Float64() if dtype == pl.Float32: return dtypes.Float32() if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover # Not available for Polars pre 1.8.0 return dtypes.Int128() if dtype == pl.Int64: return dtypes.Int64() if dtype == pl.Int32: return dtypes.Int32() if dtype == pl.Int16: return dtypes.Int16() if dtype == pl.Int8: return dtypes.Int8() if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover # Not available for Polars pre 1.8.0 return dtypes.UInt128() if dtype == pl.UInt64: return dtypes.UInt64() if dtype == pl.UInt32: return dtypes.UInt32() if dtype == pl.UInt16: return dtypes.UInt16() if dtype == pl.UInt8: return dtypes.UInt8() if dtype == pl.String: return dtypes.String() if dtype == pl.Boolean: return dtypes.Boolean() if dtype == pl.Object: return dtypes.Object() if dtype == pl.Categorical: return dtypes.Categorical() if dtype == pl.Enum: return dtypes.Enum() if dtype == pl.Date: return dtypes.Date() if isinstance_or_issubclass(dtype, pl.Datetime): return ( dtypes.Datetime() if dtype is pl.Datetime else dtypes.Datetime(dtype.time_unit, dtype.time_zone) ) if isinstance_or_issubclass(dtype, pl.Duration): return ( dtypes.Duration() if dtype is pl.Duration else dtypes.Duration(dtype.time_unit) ) if isinstance_or_issubclass(dtype, pl.Struct): fields = [ dtypes.Field(name, native_to_narwhals_dtype(tp, version, backend_version)) for name, tp in dtype ] return dtypes.Struct(fields) if isinstance_or_issubclass(dtype, pl.List): return dtypes.List( native_to_narwhals_dtype(dtype.inner, version, backend_version) ) if isinstance_or_issubclass(dtype, pl.Array): outer_shape = dtype.width if backend_version < (0, 20, 30) else dtype.size return dtypes.Array( native_to_narwhals_dtype(dtype.inner, version, backend_version), outer_shape ) if dtype == pl.Decimal: return dtypes.Decimal() if dtype == pl.Time: return dtypes.Time() if dtype == pl.Binary: return dtypes.Binary() return dtypes.Unknown() def narwhals_to_native_dtype( dtype: DType | type[DType], version: Version, backend_version: tuple[int, ...] ) -> pl.DataType: dtypes = import_dtypes_module(version) if dtype == dtypes.Float64: return pl.Float64() if dtype == dtypes.Float32: return pl.Float32() if dtype == dtypes.Int128 and hasattr(pl, "Int128"): # Not available for Polars pre 1.8.0 return pl.Int128() if dtype == dtypes.Int64: return pl.Int64() if dtype == dtypes.Int32: return pl.Int32() if dtype == dtypes.Int16: return pl.Int16() if dtype == dtypes.Int8: return pl.Int8() if dtype == dtypes.UInt64: return pl.UInt64() if dtype == dtypes.UInt32: return pl.UInt32() if dtype == dtypes.UInt16: return pl.UInt16() if dtype == dtypes.UInt8: return pl.UInt8() if dtype == dtypes.String: return pl.String() if dtype == dtypes.Boolean: return pl.Boolean() if dtype == dtypes.Object: # pragma: no cover return pl.Object() if dtype == dtypes.Categorical: return pl.Categorical() if dtype == dtypes.Enum: msg = "Converting to Enum is not (yet) supported" raise NotImplementedError(msg) if dtype == dtypes.Date: return pl.Date() if dtype == dtypes.Time: return pl.Time() if dtype == dtypes.Binary: return pl.Binary() if dtype == dtypes.Decimal: msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Datetime): return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type] if isinstance_or_issubclass(dtype, dtypes.Duration): return pl.Duration(dtype.time_unit) # type: ignore[arg-type] if isinstance_or_issubclass(dtype, dtypes.List): return pl.List(narwhals_to_native_dtype(dtype.inner, version, backend_version)) if isinstance_or_issubclass(dtype, dtypes.Struct): fields = [ pl.Field( field.name, narwhals_to_native_dtype(field.dtype, version, backend_version), ) for field in dtype.fields ] return pl.Struct(fields) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover size = dtype.size kwargs = {"width": size} if backend_version < (0, 20, 30) else {"shape": size} return pl.Array( narwhals_to_native_dtype(dtype.inner, version, backend_version), **kwargs ) return pl.Unknown() # pragma: no cover def convert_str_slice_to_int_slice( str_slice: slice, columns: list[str] ) -> tuple[int | None, int | None, int | None]: # pragma: no cover start = columns.index(str_slice.start) if str_slice.start is not None else None stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None step = str_slice.step return (start, stop, step) def catch_polars_exception( exception: Exception, backend_version: tuple[int, ...] ) -> NarwhalsError | Exception: if isinstance(exception, pl.exceptions.ColumnNotFoundError): return ColumnNotFoundError(str(exception)) elif isinstance(exception, pl.exceptions.ShapeError): return ShapeError(str(exception)) elif isinstance(exception, pl.exceptions.InvalidOperationError): return InvalidOperationError(str(exception)) elif isinstance(exception, pl.exceptions.DuplicateError): return DuplicateError(str(exception)) elif isinstance(exception, pl.exceptions.ComputeError): return ComputeError(str(exception)) if backend_version >= (1,) and isinstance(exception, pl.exceptions.PolarsError): # Old versions of Polars didn't have PolarsError. return NarwhalsError(str(exception)) # pragma: no cover elif backend_version < (1,) and "polars.exceptions" in str( type(exception) ): # pragma: no cover # Last attempt, for old Polars versions. return NarwhalsError(str(exception)) # Just return exception as-is. return exception