Files
Buffteks-Website/streamlit-venv/lib/python3.10/site-packages/narwhals/_polars/utils.py
2025-01-10 21:40:35 +00:00

159 lines
5.6 KiB
Python
Executable File

from __future__ import annotations
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
if TYPE_CHECKING:
from narwhals.dtypes import DType
from narwhals.typing import DTypes
from narwhals.utils import parse_version
def extract_native(obj: Any) -> Any:
from narwhals._polars.dataframe import PolarsDataFrame
from narwhals._polars.dataframe import PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
if isinstance(obj, (PolarsDataFrame, PolarsLazyFrame)):
return obj._native_frame
if isinstance(obj, PolarsSeries):
return obj._native_series
if isinstance(obj, PolarsExpr):
return obj._native_expr
return obj
def extract_args_kwargs(args: Any, kwargs: Any) -> tuple[list[Any], dict[str, Any]]:
args = [extract_native(arg) for arg in args]
kwargs = {k: extract_native(v) for k, v in kwargs.items()}
return args, kwargs
def native_to_narwhals_dtype(dtype: Any, dtypes: DTypes) -> DType:
import polars as pl # ignore-banned-import()
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if dtype == pl.Enum:
return dtypes.Enum()
if dtype == pl.Date:
return dtypes.Date()
if dtype == pl.Datetime or isinstance(dtype, pl.Datetime):
dt_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return dtypes.Datetime(time_unit=dt_time_unit, time_zone=dt_time_zone)
if dtype == pl.Duration or isinstance(dtype, pl.Duration):
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return dtypes.Duration(time_unit=du_time_unit)
if dtype == pl.Struct:
return dtypes.Struct(
[
dtypes.Field(field_name, native_to_narwhals_dtype(field_type, dtypes))
for field_name, field_type in dtype
]
)
if dtype == pl.List:
return dtypes.List(native_to_narwhals_dtype(dtype.inner, dtypes))
if dtype == pl.Array:
if parse_version(pl.__version__) < (0, 20, 30): # pragma: no cover
return dtypes.Array(
native_to_narwhals_dtype(dtype.inner, dtypes), dtype.width
)
else:
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, dtypes), dtype.size)
return dtypes.Unknown()
def narwhals_to_native_dtype(dtype: DType | type[DType], dtypes: DTypes) -> Any:
import polars as pl # ignore-banned-import()
if dtype == dtypes.Float64:
return pl.Float64()
if dtype == dtypes.Float32:
return pl.Float32()
if dtype == dtypes.Int64:
return pl.Int64()
if dtype == dtypes.Int32:
return pl.Int32()
if dtype == dtypes.Int16:
return pl.Int16()
if dtype == dtypes.Int8:
return pl.Int8()
if dtype == dtypes.UInt64:
return pl.UInt64()
if dtype == dtypes.UInt32:
return pl.UInt32()
if dtype == dtypes.UInt16:
return pl.UInt16()
if dtype == dtypes.UInt8:
return pl.UInt8()
if dtype == dtypes.String:
return pl.String()
if dtype == dtypes.Boolean:
return pl.Boolean()
if dtype == dtypes.Object: # pragma: no cover
return pl.Object()
if dtype == dtypes.Categorical:
return pl.Categorical()
if dtype == dtypes.Enum:
msg = "Converting to Enum is not (yet) supported"
raise NotImplementedError(msg)
if dtype == dtypes.Date:
return pl.Date()
if dtype == dtypes.Datetime or isinstance(dtype, dtypes.Datetime):
dt_time_unit = getattr(dtype, "time_unit", "us")
dt_time_zone = getattr(dtype, "time_zone", None)
return pl.Datetime(dt_time_unit, dt_time_zone) # type: ignore[arg-type]
if dtype == dtypes.Duration or isinstance(dtype, dtypes.Duration):
du_time_unit: Literal["us", "ns", "ms"] = getattr(dtype, "time_unit", "us")
return pl.Duration(time_unit=du_time_unit)
if dtype == dtypes.List: # pragma: no cover
msg = "Converting to List dtype is not supported yet"
return NotImplementedError(msg)
if dtype == dtypes.Struct: # pragma: no cover
msg = "Converting to Struct dtype is not supported yet"
return NotImplementedError(msg)
if dtype == dtypes.Array: # pragma: no cover
msg = "Converting to Array dtype is not supported yet"
return NotImplementedError(msg)
return pl.Unknown() # pragma: no cover
def convert_str_slice_to_int_slice(
str_slice: slice, columns: list[str]
) -> tuple[int | None, int | None, int | None]: # pragma: no cover
start = columns.index(str_slice.start) if str_slice.start is not None else None
stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None
step = str_slice.step
return (start, stop, step)