from __future__ import annotations from functools import lru_cache from itertools import chain from typing import TYPE_CHECKING from typing import Any from typing import Iterable from typing import Iterator from typing import Sequence from typing import cast from typing import overload import pyarrow as pa import pyarrow.compute as pc from narwhals._compliant.series import _SeriesNamespace from narwhals.exceptions import ShapeError from narwhals.utils import import_dtypes_module from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: from typing import TypeVar from typing_extensions import Self from typing_extensions import TypeAlias from typing_extensions import TypeIs from narwhals._arrow.series import ArrowSeries from narwhals._arrow.typing import ArrowArray from narwhals._arrow.typing import ArrowChunkedArray from narwhals.dtypes import DType from narwhals.typing import _AnyDArray from narwhals.utils import Version # NOTE: stubs don't allow for `ChunkedArray[StructArray]` # Intended to represent the `.chunks` property storing `list[pa.StructArray]` ChunkedArrayStructArray: TypeAlias = ArrowChunkedArray ArrayAny: TypeAlias = "ArrowArray | ArrowChunkedArray" _T = TypeVar("_T") def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ... def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ... def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ... def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ... def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ... def is_dictionary( t: Any, ) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ... def extract_regex( strings: ArrowChunkedArray, /, pattern: str, *, options: Any = None, memory_pool: Any = None, ) -> ChunkedArrayStructArray: ... else: from pyarrow.compute import extract_regex from pyarrow.types import is_dictionary # noqa: F401 from pyarrow.types import is_duration from pyarrow.types import is_fixed_size_list from pyarrow.types import is_large_list from pyarrow.types import is_list from pyarrow.types import is_timestamp lit = pa.scalar """Alias for `pyarrow.scalar`.""" def extract_py_scalar(value: Any, /) -> Any: from narwhals._arrow.series import maybe_extract_py_scalar return maybe_extract_py_scalar(value, return_py_scalar=True) def chunked_array( arr: ArrayAny | list[Iterable[Any]], dtype: pa.DataType | None = None, / ) -> ArrowChunkedArray: if isinstance(arr, pa.ChunkedArray): return arr if isinstance(arr, list): return pa.chunked_array(arr, dtype) else: return pa.chunked_array([arr], arr.type) def nulls_like(n: int, series: ArrowSeries) -> ArrowArray: """Create a strongly-typed Array instance with all elements null. Uses the type of `series`, without upseting `mypy`. """ return pa.nulls(n, series.native.type) @lru_cache(maxsize=16) def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: dtypes = import_dtypes_module(version) if pa.types.is_int64(dtype): return dtypes.Int64() if pa.types.is_int32(dtype): return dtypes.Int32() if pa.types.is_int16(dtype): return dtypes.Int16() if pa.types.is_int8(dtype): return dtypes.Int8() if pa.types.is_uint64(dtype): return dtypes.UInt64() if pa.types.is_uint32(dtype): return dtypes.UInt32() if pa.types.is_uint16(dtype): return dtypes.UInt16() if pa.types.is_uint8(dtype): return dtypes.UInt8() if pa.types.is_boolean(dtype): return dtypes.Boolean() if pa.types.is_float64(dtype): return dtypes.Float64() if pa.types.is_float32(dtype): return dtypes.Float32() # bug in coverage? it shows `31->exit` (where `31` is currently the line number of # the next line), even though both when the if condition is true and false are covered if ( # pragma: no cover pa.types.is_string(dtype) or pa.types.is_large_string(dtype) or getattr(pa.types, "is_string_view", lambda _: False)(dtype) ): return dtypes.String() if pa.types.is_date32(dtype): return dtypes.Date() if is_timestamp(dtype): return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz) if is_duration(dtype): return dtypes.Duration(time_unit=dtype.unit) if pa.types.is_dictionary(dtype): return dtypes.Categorical() if pa.types.is_struct(dtype): return dtypes.Struct( [ dtypes.Field( dtype.field(i).name, native_to_narwhals_dtype(dtype.field(i).type, version), ) for i in range(dtype.num_fields) ] ) if is_list(dtype) or is_large_list(dtype): return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version)) if is_fixed_size_list(dtype): return dtypes.Array( native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size ) if pa.types.is_decimal(dtype): return dtypes.Decimal() if pa.types.is_time32(dtype) or pa.types.is_time64(dtype): return dtypes.Time() if pa.types.is_binary(dtype): return dtypes.Binary() return dtypes.Unknown() # pragma: no cover def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa.DataType: dtypes = import_dtypes_module(version) if isinstance_or_issubclass(dtype, dtypes.Decimal): msg = "Casting to Decimal is not supported yet." raise NotImplementedError(msg) if isinstance_or_issubclass(dtype, dtypes.Float64): return pa.float64() if isinstance_or_issubclass(dtype, dtypes.Float32): return pa.float32() if isinstance_or_issubclass(dtype, dtypes.Int64): return pa.int64() if isinstance_or_issubclass(dtype, dtypes.Int32): return pa.int32() if isinstance_or_issubclass(dtype, dtypes.Int16): return pa.int16() if isinstance_or_issubclass(dtype, dtypes.Int8): return pa.int8() if isinstance_or_issubclass(dtype, dtypes.UInt64): return pa.uint64() if isinstance_or_issubclass(dtype, dtypes.UInt32): return pa.uint32() if isinstance_or_issubclass(dtype, dtypes.UInt16): return pa.uint16() if isinstance_or_issubclass(dtype, dtypes.UInt8): return pa.uint8() if isinstance_or_issubclass(dtype, dtypes.String): return pa.string() if isinstance_or_issubclass(dtype, dtypes.Boolean): return pa.bool_() if isinstance_or_issubclass(dtype, dtypes.Categorical): return pa.dictionary(pa.uint32(), pa.string()) if isinstance_or_issubclass(dtype, dtypes.Datetime): return pa.timestamp(dtype.time_unit, tz=dtype.time_zone) # pyright: ignore[reportArgumentType] if isinstance_or_issubclass(dtype, dtypes.Duration): return pa.duration(dtype.time_unit) if isinstance_or_issubclass(dtype, dtypes.Date): return pa.date32() if isinstance_or_issubclass(dtype, dtypes.List): return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version)) if isinstance_or_issubclass(dtype, dtypes.Struct): return pa.struct( [ (field.name, narwhals_to_native_dtype(field.dtype, version=version)) for field in dtype.fields ] ) if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover inner = narwhals_to_native_dtype(dtype.inner, version=version) list_size = dtype.size return pa.list_(inner, list_size=list_size) if isinstance_or_issubclass(dtype, dtypes.Time): return pa.time64("ns") if isinstance_or_issubclass(dtype, dtypes.Binary): return pa.binary() msg = f"Unknown dtype: {dtype}" # pragma: no cover raise AssertionError(msg) def extract_native( lhs: ArrowSeries, rhs: ArrowSeries | object ) -> tuple[ ArrowChunkedArray | pa.Scalar[Any], ArrowChunkedArray | pa.Scalar[Any] | object ]: """Extract native objects in binary operation. If the comparison isn't supported, return `NotImplemented` so that the "right-hand-side" operation (e.g. `__radd__`) can be tried. If one of the two sides has a `_broadcast` flag, then extract the scalar underneath it so that PyArrow can do its own broadcasting. """ from narwhals._arrow.dataframe import ArrowDataFrame from narwhals._arrow.series import ArrowSeries if rhs is None: # pragma: no cover return lhs.native, lit(None, type=lhs._type) if isinstance(rhs, ArrowDataFrame): return NotImplemented if isinstance(rhs, ArrowSeries): if lhs._broadcast and not rhs._broadcast: return lhs.native[0], rhs.native if rhs._broadcast: return lhs.native, rhs.native[0] return lhs.native, rhs.native if isinstance(rhs, list): msg = "Expected Series or scalar, got list." raise TypeError(msg) return lhs.native, rhs def align_series_full_broadcast(*series: ArrowSeries) -> Sequence[ArrowSeries]: # Ensure all of `series` are of the same length. lengths = [len(s) for s in series] max_length = max(lengths) fast_path = all(_len == max_length for _len in lengths) if fast_path: return series reshaped = [] for s in series: if s._broadcast: value = s.native[0] if s._backend_version < (13,) and hasattr(value, "as_py"): value = value.as_py() reshaped.append(s._with_native(pa.array([value] * max_length, type=s._type))) else: if (actual_len := len(s)) != max_length: msg = f"Expected object of length {max_length}, got {actual_len}." raise ShapeError(msg) reshaped.append(s) return reshaped def horizontal_concat(dfs: list[pa.Table]) -> pa.Table: """Concatenate (native) DataFrames horizontally. Should be in namespace. """ names = [name for df in dfs for name in df.column_names] if len(set(names)) < len(names): # pragma: no cover msg = "Expected unique column names" raise ValueError(msg) arrays = list(chain.from_iterable(df.itercolumns() for df in dfs)) return pa.Table.from_arrays(arrays, names=names) def vertical_concat(dfs: list[pa.Table]) -> pa.Table: """Concatenate (native) DataFrames vertically. Should be in namespace. """ cols_0 = dfs[0].column_names for i, df in enumerate(dfs[1:], start=1): cols_current = df.column_names if cols_current != cols_0: msg = ( "unable to vstack, column names don't match:\n" f" - dataframe 0: {cols_0}\n" f" - dataframe {i}: {cols_current}\n" ) raise TypeError(msg) return pa.concat_tables(dfs) def diagonal_concat(dfs: list[pa.Table], backend_version: tuple[int, ...]) -> pa.Table: """Concatenate (native) DataFrames diagonally. Should be in namespace. """ kwargs: dict[str, Any] = ( {"promote": True} if backend_version < (14, 0, 0) else {"promote_options": "default"} ) return pa.concat_tables(dfs, **kwargs) def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 if isinstance(left, (int, float)): left = lit(left) if isinstance(right, (int, float)): right = lit(right) if pa.types.is_integer(left.type) and pa.types.is_integer(right.type): divided = pc.divide_checked(left, right) if pa.types.is_signed_integer(divided.type): # GH 56676 has_remainder = pc.not_equal(pc.multiply(divided, right), left) has_one_negative_operand = pc.less( pc.bit_wise_xor(left, right), lit(0, type=divided.type) ) result = pc.if_else( pc.and_(has_remainder, has_one_negative_operand), pc.subtract(divided, lit(1, type=divided.type)), divided, ) else: result = divided # pragma: no cover result = result.cast(left.type) else: divided = pc.divide(left, right) result = pc.floor(divided) return result def cast_for_truediv( arrow_array: ArrowChunkedArray | pa.Scalar[Any], pa_object: ArrowChunkedArray | ArrowArray | pa.Scalar[Any], ) -> tuple[ ArrowChunkedArray | pa.Scalar[Any], ArrowChunkedArray | ArrowArray | pa.Scalar[Any], ]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type): # GH: 56645. # noqa: ERA001 # https://github.com/apache/arrow/issues/35563 return pc.cast(arrow_array, pa.float64(), safe=False), pc.cast( pa_object, pa.float64(), safe=False ) return arrow_array, pa_object @overload def convert_slice_to_nparray(num_rows: int, rows_slice: slice) -> _AnyDArray: ... @overload def convert_slice_to_nparray(num_rows: int, rows_slice: _T) -> _T: ... def convert_slice_to_nparray(num_rows: int, rows_slice: slice | _T) -> _AnyDArray | _T: if isinstance(rows_slice, slice): import numpy as np # ignore-banned-import return np.arange(num_rows)[rows_slice] else: return rows_slice def select_rows( table: pa.Table, rows: slice | int | Sequence[int] | _AnyDArray ) -> pa.Table: if isinstance(rows, slice) and rows == slice(None): selected_rows = table elif isinstance(rows, Sequence) and not rows: selected_rows = table.slice(0, 0) else: range_ = convert_slice_to_nparray(num_rows=len(table), rows_slice=rows) selected_rows = table.take(cast("list[int]", range_)) return selected_rows def convert_str_slice_to_int_slice( str_slice: slice, columns: list[str] ) -> tuple[int | None, int | None, int | None]: start = columns.index(str_slice.start) if str_slice.start is not None else None stop = columns.index(str_slice.stop) + 1 if str_slice.stop is not None else None step = str_slice.step return (start, stop, step) # Regex for date, time, separator and timezone components DATE_RE = r"(?P\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})" SEP_RE = r"(?P\s|T)" TIME_RE = r"(?P