Files
Buffteks-Website/buffteks/lib/python3.12/site-packages/narwhals/_duckdb/utils.py
2025-05-08 21:10:14 -05:00

238 lines
8.5 KiB
Python

from __future__ import annotations
import re
from functools import lru_cache
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
import duckdb
from narwhals.utils import import_dtypes_module
from narwhals.utils import isinstance_or_issubclass
if TYPE_CHECKING:
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.expr import DuckDBExpr
from narwhals.dtypes import DType
from narwhals.utils import Version
col = duckdb.ColumnExpression
"""Alias for `duckdb.ColumnExpression`."""
lit = duckdb.ConstantExpression
"""Alias for `duckdb.ConstantExpression`."""
when = duckdb.CaseExpression
"""Alias for `duckdb.CaseExpression`."""
class WindowInputs:
__slots__ = ("expr", "order_by", "partition_by")
def __init__(
self,
expr: duckdb.Expression,
partition_by: Sequence[str],
order_by: Sequence[str],
) -> None:
self.expr = expr
self.partition_by = partition_by
self.order_by = order_by
def concat_str(*exprs: duckdb.Expression, separator: str = "") -> duckdb.Expression:
"""Concatenate many strings, NULL inputs are skipped.
Wraps [concat] and [concat_ws] `FunctionExpression`(s).
Arguments:
exprs: Native columns.
separator: String that will be used to separate the values of each column.
Returns:
A new native expression.
[concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
[concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
"""
return (
duckdb.FunctionExpression("concat_ws", lit(separator), *exprs)
if separator
else duckdb.FunctionExpression("concat", *exprs)
)
def evaluate_exprs(
df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
) -> list[tuple[str, duckdb.Expression]]:
native_results: list[tuple[str, duckdb.Expression]] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(output_names, native_series_list))
return native_results
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(duckdb_dtype: str, version: Version) -> DType:
dtypes = import_dtypes_module(version)
if duckdb_dtype == "HUGEINT":
return dtypes.Int128()
if duckdb_dtype == "BIGINT":
return dtypes.Int64()
if duckdb_dtype == "INTEGER":
return dtypes.Int32()
if duckdb_dtype == "SMALLINT":
return dtypes.Int16()
if duckdb_dtype == "TINYINT":
return dtypes.Int8()
if duckdb_dtype == "UHUGEINT":
return dtypes.UInt128()
if duckdb_dtype == "UBIGINT":
return dtypes.UInt64()
if duckdb_dtype == "UINTEGER":
return dtypes.UInt32()
if duckdb_dtype == "USMALLINT":
return dtypes.UInt16()
if duckdb_dtype == "UTINYINT":
return dtypes.UInt8()
if duckdb_dtype == "DOUBLE":
return dtypes.Float64()
if duckdb_dtype == "FLOAT":
return dtypes.Float32()
if duckdb_dtype == "VARCHAR":
return dtypes.String()
if duckdb_dtype == "DATE":
return dtypes.Date()
if duckdb_dtype == "TIMESTAMP":
return dtypes.Datetime()
if duckdb_dtype == "TIMESTAMP WITH TIME ZONE":
# TODO(marco): is UTC correct, or should we be getting the connection timezone?
# https://github.com/narwhals-dev/narwhals/issues/2165
return dtypes.Datetime(time_zone="UTC")
if duckdb_dtype == "BOOLEAN":
return dtypes.Boolean()
if duckdb_dtype == "INTERVAL":
return dtypes.Duration()
if duckdb_dtype.startswith("STRUCT"):
matchstruc_ = re.findall(r"(\w+)\s+(\w+)", duckdb_dtype)
return dtypes.Struct(
[
dtypes.Field(
matchstruc_[i][0],
native_to_narwhals_dtype(matchstruc_[i][1], version),
)
for i in range(len(matchstruc_))
]
)
if match_ := re.match(r"(.*)\[\]$", duckdb_dtype):
return dtypes.List(native_to_narwhals_dtype(match_.group(1), version))
if match_ := re.match(r"(\w+)((?:\[\d+\])+)", duckdb_dtype):
duckdb_inner_type = match_.group(1)
duckdb_shape = match_.group(2)
shape = tuple(int(value) for value in re.findall(r"\[(\d+)\]", duckdb_shape))
return dtypes.Array(
inner=native_to_narwhals_dtype(duckdb_inner_type, version),
shape=shape,
)
if duckdb_dtype.startswith("DECIMAL("):
return dtypes.Decimal()
if duckdb_dtype == "TIME":
return dtypes.Time()
if duckdb_dtype == "BLOB":
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> str:
dtypes = import_dtypes_module(version)
if isinstance_or_issubclass(dtype, dtypes.Decimal):
msg = "Casting to Decimal is not supported yet."
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Float64):
return "DOUBLE"
if isinstance_or_issubclass(dtype, dtypes.Float32):
return "FLOAT"
if isinstance_or_issubclass(dtype, dtypes.Int128):
return "INT128"
if isinstance_or_issubclass(dtype, dtypes.Int64):
return "BIGINT"
if isinstance_or_issubclass(dtype, dtypes.Int32):
return "INTEGER"
if isinstance_or_issubclass(dtype, dtypes.Int16):
return "SMALLINT"
if isinstance_or_issubclass(dtype, dtypes.Int8):
return "TINYINT"
if isinstance_or_issubclass(dtype, dtypes.UInt128):
return "UINT128"
if isinstance_or_issubclass(dtype, dtypes.UInt64):
return "UBIGINT"
if isinstance_or_issubclass(dtype, dtypes.UInt32):
return "UINTEGER"
if isinstance_or_issubclass(dtype, dtypes.UInt16): # pragma: no cover
return "USMALLINT"
if isinstance_or_issubclass(dtype, dtypes.UInt8): # pragma: no cover
return "UTINYINT"
if isinstance_or_issubclass(dtype, dtypes.String):
return "VARCHAR"
if isinstance_or_issubclass(dtype, dtypes.Boolean): # pragma: no cover
return "BOOLEAN"
if isinstance_or_issubclass(dtype, dtypes.Time):
return "TIME"
if isinstance_or_issubclass(dtype, dtypes.Binary):
return "BLOB"
if isinstance_or_issubclass(dtype, dtypes.Categorical):
msg = "Categorical not supported by DuckDB"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
_time_unit = dtype.time_unit
_time_zone = dtype.time_zone
msg = "todo"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Duration): # pragma: no cover
_time_unit = dtype.time_unit
msg = "todo"
raise NotImplementedError(msg)
if isinstance_or_issubclass(dtype, dtypes.Date): # pragma: no cover
return "DATE"
if isinstance_or_issubclass(dtype, dtypes.List):
inner = narwhals_to_native_dtype(dtype.inner, version)
return f"{inner}[]"
if isinstance_or_issubclass(dtype, dtypes.Struct): # pragma: no cover
inner = ", ".join(
f'"{field.name}" {narwhals_to_native_dtype(field.dtype, version)}'
for field in dtype.fields
)
return f"STRUCT({inner})"
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
shape = dtype.shape
duckdb_shape_fmt = "".join(f"[{item}]" for item in shape)
inner_dtype: Any = dtype
for _ in shape:
inner_dtype = inner_dtype.inner
duckdb_inner = narwhals_to_native_dtype(inner_dtype, version)
return f"{duckdb_inner}{duckdb_shape_fmt}"
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def generate_partition_by_sql(*partition_by: str) -> str:
if not partition_by:
return ""
by_sql = ", ".join([f'"{x}"' for x in partition_by])
return f"partition by {by_sql}"
def generate_order_by_sql(*order_by: str, ascending: bool) -> str:
if ascending:
by_sql = ", ".join([f'"{x}" asc nulls first' for x in order_by])
else:
by_sql = ", ".join([f'"{x}" desc nulls last' for x in order_by])
return f"order by {by_sql}"