From 6fbfb7783e510b5b27b5989b20738a4fe629f8ef Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Wed, 14 Aug 2024 08:22:42 +0100 Subject: [PATCH] chore: import overhaul (#788) --- CONTRIBUTING.md | 20 +++ narwhals/_arrow/dataframe.py | 27 +-- narwhals/_arrow/group_by.py | 8 +- narwhals/_arrow/namespace.py | 4 +- narwhals/_arrow/series.py | 266 ++++++++++++++++++----------- narwhals/_arrow/utils.py | 28 +-- narwhals/_dask/dataframe.py | 5 +- narwhals/_dask/namespace.py | 6 +- narwhals/_pandas_like/dataframe.py | 5 +- narwhals/_pandas_like/series.py | 8 +- narwhals/dataframe.py | 3 - narwhals/dependencies.py | 67 +++++--- narwhals/expr.py | 9 - narwhals/series.py | 6 - narwhals/translate.py | 42 ++--- narwhals/utils.py | 24 +-- tests/no_imports_test.py | 68 ++++++++ 17 files changed, 382 insertions(+), 214 deletions(-) create mode 100644 tests/no_imports_test.py diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 03d80fec9..b1eb91b0d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -67,6 +67,26 @@ Please adhere to the following guidelines: If Narwhals looks like underwater unicorn magic to you, then please read [how it works](https://narwhals-dev.github.io/narwhals/how-it-works/). +## Imports + +In Narwhals, we are very particular about imports. When it comes to importing +heavy third-party libraries (pandas, NumPy, Polars, etc...) please follow these rules: + +- Never import anything to do `isinstance` checks. Instead, just use the functions + in `narwhals.dependencies` (such as `is_pandas_dataframe`); +- If you need to import anything, do it in a place where you know that the import + is definitely available. For example, NumPy is a required dependency of PyArrow, + so it's OK to import NumPy to implement a PyArrow function - however, NumPy + should never be imported to implement a Polars function. The only exception is + for when there's simply no way around it by definition - for example, `Series.to_numpy` + always requires NumPy to be installed. +- Don't place a third-party import at the top of a file. Instead, place it in the + function where it's used, so that we minimise the chances of it being imported + unnecessarily. + +We're trying to be really lightweight and minimal-overhead, and +unnecessary imports can slow things down. + ## Happy contributing! Please remember to abide by the code of conduct, else you'll be conducted away from this project. diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index 75b9068a0..865d17098 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -13,8 +13,6 @@ from narwhals._arrow.utils import validate_dataframe_comparand from narwhals._expression_parsing import evaluate_into_exprs from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyarrow_compute -from narwhals.dependencies import get_pyarrow_parquet from narwhals.dependencies import is_numpy_array from narwhals.utils import Implementation from narwhals.utils import flatten @@ -182,12 +180,13 @@ def select( *exprs: IntoArrowExpr, **named_exprs: IntoArrowExpr, ) -> Self: + import pyarrow as pa # ignore-banned-import() + new_series = evaluate_into_exprs(self, *exprs, **named_exprs) if not new_series: # return empty dataframe, like Polars does return self._from_native_frame(self._native_frame.__class__.from_arrays([])) names = [s.name for s in new_series] - pa = get_pyarrow() df = pa.Table.from_arrays( broadcast_series(new_series), names=names, @@ -337,7 +336,8 @@ def to_dict(self, *, as_series: bool) -> Any: return {name: col.to_pylist() for name, col in names_and_values} def with_row_index(self, name: str) -> Self: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + df = self._native_frame row_indices = pa.array(range(df.num_rows)) @@ -354,7 +354,8 @@ def filter( return self._from_native_frame(self._native_frame.filter(mask._native_series)) def null_count(self) -> Self: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + df = self._native_frame names_and_values = zip(df.column_names, df.columns) @@ -415,16 +416,17 @@ def rename(self, mapping: dict[str, str]) -> Self: return self._from_native_frame(df.rename_columns(new_cols)) def write_parquet(self, file: Any) -> Any: - pp = get_pyarrow_parquet() + import pyarrow.parquet as pp # ignore-banned-import + pp.write_table(self._native_frame, file) def is_duplicated(self: Self) -> ArrowSeries: import numpy as np # ignore-banned-import + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() from narwhals._arrow.series import ArrowSeries - pa = get_pyarrow() - pc = get_pyarrow_compute() df = self._native_frame columns = self.columns @@ -443,9 +445,10 @@ def is_duplicated(self: Self) -> ArrowSeries: return ArrowSeries(is_duplicated, name="", backend_version=self._backend_version) def is_unique(self: Self) -> ArrowSeries: + import pyarrow.compute as pc # ignore-banned-import() + from narwhals._arrow.series import ArrowSeries - pc = get_pyarrow_compute() is_duplicated = self.is_duplicated()._native_series return ArrowSeries( @@ -464,11 +467,9 @@ def unique( The param `maintain_order` is only here for compatibility with the polars API and has no effect on the output. """ - import numpy as np # ignore-banned-import - - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() df = self._native_frame diff --git a/narwhals/_arrow/group_by.py b/narwhals/_arrow/group_by.py index ecdfc02a6..27c7ff368 100644 --- a/narwhals/_arrow/group_by.py +++ b/narwhals/_arrow/group_by.py @@ -8,8 +8,6 @@ from narwhals._expression_parsing import is_simple_aggregation from narwhals._expression_parsing import parse_into_exprs -from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyarrow_compute from narwhals.utils import remove_prefix if TYPE_CHECKING: @@ -20,7 +18,8 @@ class ArrowGroupBy: def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + self._df = df self._keys = list(keys) self._grouped = pa.TableGroupBy(self._df._native_frame, list(self._keys)) @@ -79,7 +78,8 @@ def agg_arrow( output_names: list[str], from_dataframe: Callable[[Any], ArrowDataFrame], ) -> ArrowDataFrame: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + all_simple_aggs = True for expr in exprs: if not is_simple_aggregation(expr): diff --git a/narwhals/_arrow/namespace.py b/narwhals/_arrow/namespace.py index ffb3f2d15..bb90b1792 100644 --- a/narwhals/_arrow/namespace.py +++ b/narwhals/_arrow/namespace.py @@ -13,7 +13,6 @@ from narwhals._arrow.utils import horizontal_concat from narwhals._arrow.utils import vertical_concat from narwhals._expression_parsing import parse_into_exprs -from narwhals.dependencies import get_pyarrow from narwhals.utils import Implementation if TYPE_CHECKING: @@ -87,9 +86,10 @@ def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSe ) def _create_compliant_series(self, value: Any) -> ArrowSeries: + import pyarrow as pa # ignore-banned-import() + from narwhals._arrow.series import ArrowSeries - pa = get_pyarrow() return ArrowSeries( native_series=pa.chunked_array([value]), name="", diff --git a/narwhals/_arrow/series.py b/narwhals/_arrow/series.py index 20513c39e..fb15f3aaf 100644 --- a/narwhals/_arrow/series.py +++ b/narwhals/_arrow/series.py @@ -14,7 +14,6 @@ from narwhals._arrow.utils import validate_column_comparand from narwhals.dependencies import get_pandas from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyarrow_compute from narwhals.utils import Implementation from narwhals.utils import generate_unique_token @@ -35,7 +34,8 @@ def __init__( self._backend_version = backend_version def _from_native_series(self, series: Any) -> Self: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + if isinstance(series, pa.Array): series = pa.chunked_array([series]) return self.__class__( @@ -52,7 +52,8 @@ def _from_iterable( *, backend_version: tuple[int, ...], ) -> Self: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + return cls( pa.chunked_array([data]), name=name, @@ -63,67 +64,78 @@ def __len__(self) -> int: return len(self._native_series) def __eq__(self, other: object) -> Self: # type: ignore[override] - pc = get_pyarrow_compute() + import pyarrow.compute as pc + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.equal(ser, other)) def __ne__(self, other: object) -> Self: # type: ignore[override] - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.not_equal(ser, other)) def __ge__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.greater_equal(ser, other)) def __gt__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.greater(ser, other)) def __le__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.less_equal(ser, other)) def __lt__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.less(ser, other)) def __and__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.and_kleene(ser, other)) def __rand__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.and_kleene(other, ser)) def __or__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.or_kleene(ser, other)) def __ror__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.or_kleene(other, ser)) def __add__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + other = validate_column_comparand(other) return self._from_native_series(pc.add(self._native_series, other)) @@ -131,7 +143,8 @@ def __radd__(self, other: Any) -> Self: return self + other # type: ignore[no-any-return] def __sub__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + other = validate_column_comparand(other) return self._from_native_series(pc.subtract(self._native_series, other)) @@ -139,7 +152,8 @@ def __rsub__(self, other: Any) -> Self: return (self - other) * (-1) # type: ignore[no-any-return] def __mul__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + other = validate_column_comparand(other) return self._from_native_series(pc.multiply(self._native_series, other)) @@ -147,13 +161,15 @@ def __rmul__(self, other: Any) -> Self: return self * other # type: ignore[no-any-return] def __pow__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.power(ser, other)) def __rpow__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) return self._from_native_series(pc.power(other, ser)) @@ -169,8 +185,9 @@ def __rfloordiv__(self, other: Any) -> Self: return self._from_native_series(floordiv_compat(other, ser)) def __truediv__(self, other: Any) -> Self: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) if not isinstance(other, (pa.Array, pa.ChunkedArray)): @@ -179,8 +196,9 @@ def __truediv__(self, other: Any) -> Self: return self._from_native_series(pc.divide(*cast_for_truediv(ser, other))) def __rtruediv__(self, other: Any) -> Self: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) if not isinstance(other, (pa.Array, pa.ChunkedArray)): @@ -189,7 +207,8 @@ def __rtruediv__(self, other: Any) -> Self: return self._from_native_series(pc.divide(*cast_for_truediv(other, ser))) def __mod__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) floor_div = (self // other)._native_series @@ -197,7 +216,8 @@ def __mod__(self, other: Any) -> Self: return self._from_native_series(res) def __rmod__(self, other: Any) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series other = validate_column_comparand(other) floor_div = (other // self)._native_series @@ -205,7 +225,8 @@ def __rmod__(self, other: Any) -> Self: return self._from_native_series(res) def __invert__(self) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series(pc.invert(self._native_series)) def len(self) -> int: @@ -216,27 +237,33 @@ def filter(self, other: Any) -> Self: return self._from_native_series(self._native_series.filter(other)) def mean(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.mean(self._native_series) # type: ignore[no-any-return] def min(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.min(self._native_series) # type: ignore[no-any-return] def max(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.max(self._native_series) # type: ignore[no-any-return] def sum(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.sum(self._native_series) # type: ignore[no-any-return] def drop_nulls(self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series(pc.drop_null(self._native_series)) def shift(self, n: int) -> Self: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + ca = self._native_series if n > 0: @@ -248,15 +275,18 @@ def shift(self, n: int) -> Self: return self._from_native_series(result) def std(self, ddof: int = 1) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.stddev(self._native_series, ddof=ddof) # type: ignore[no-any-return] def count(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.count(self._native_series) # type: ignore[no-any-return] def n_unique(self) -> int: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + unique_values = pc.unique(self._native_series) return pc.count(unique_values, mode="all") # type: ignore[no-any-return] @@ -302,35 +332,42 @@ def dtype(self) -> DType: return translate_dtype(self._native_series.type) def abs(self) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series(pc.abs(self._native_series)) def cum_sum(self) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series(pc.cumulative_sum(self._native_series)) def round(self, decimals: int) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series( pc.round(self._native_series, decimals, round_mode="half_towards_infinity") ) def diff(self) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series( pc.pairwise_diff(self._native_series.combine_chunks()) ) def any(self) -> bool: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.any(self._native_series) # type: ignore[no-any-return] def all(self) -> bool: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.all(self._native_series) # type: ignore[no-any-return] def is_between(self, lower_bound: Any, upper_bound: Any, closed: str = "both") -> Any: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series if closed == "left": ge = pc.greater_equal(ser, lower_bound) @@ -360,7 +397,8 @@ def is_null(self) -> Self: return self._from_native_series(ser.is_null()) def cast(self, dtype: DType) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series dtype = narwhals_to_native_dtype(dtype) return self._from_native_series(pc.cast(ser, dtype)) @@ -385,8 +423,9 @@ def tail(self, n: int) -> Self: return self._from_native_series(ser.slice(abs(n))) def is_in(self, other: Any) -> Self: - pc = get_pyarrow_compute() - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + value_set = pa.array(other) ser = self._native_series return self._from_native_series(pc.is_in(ser, value_set=value_set)) @@ -420,10 +459,10 @@ def value_counts( normalize: bool = False, ) -> ArrowDataFrame: """Parallel is unused, exists for compatibility""" - from narwhals._arrow.dataframe import ArrowDataFrame + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() - pc = get_pyarrow_compute() - pa = get_pyarrow() + from narwhals._arrow.dataframe import ArrowDataFrame index_name_ = "index" if self._name is None else self._name value_name_ = name or ("proportion" if normalize else "count") @@ -448,7 +487,7 @@ def value_counts( ) def zip_with(self: Self, mask: Self, other: Self) -> Self: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() return self._from_native_series( pc.replace_with_mask( @@ -466,8 +505,8 @@ def sample( with_replacement: bool = False, ) -> Self: import numpy as np # ignore-banned-import + import pyarrow.compute as pc # ignore-banned-import() - pc = get_pyarrow_compute() ser = self._native_series num_rows = len(self) @@ -479,17 +518,19 @@ def sample( return self._from_native_series(pc.take(ser, mask)) def fill_null(self: Self, value: Any) -> Self: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series dtype = ser.type return self._from_native_series(pc.fill_null(ser, pa.scalar(value, dtype))) def to_frame(self: Self) -> ArrowDataFrame: + import pyarrow as pa # ignore-banned-import() + from narwhals._arrow.dataframe import ArrowDataFrame - pa = get_pyarrow() df = pa.Table.from_arrays([self._native_series], names=[self.name]) return ArrowDataFrame(df, backend_version=self._backend_version) @@ -505,9 +546,8 @@ def is_unique(self: Self) -> ArrowSeries: def is_first_distinct(self: Self) -> Self: import numpy as np # ignore-banned-import - - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() row_number = pa.array(np.arange(len(self))) col_token = generate_unique_token(n_bytes=8, columns=[self.name]) @@ -523,9 +563,8 @@ def is_first_distinct(self: Self) -> Self: def is_last_distinct(self: Self) -> Self: import numpy as np # ignore-banned-import - - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() row_number = pa.array(np.arange(len(self))) col_token = generate_unique_token(n_bytes=8, columns=[self.name]) @@ -543,7 +582,8 @@ def is_sorted(self: Self, *, descending: bool = False) -> bool: if not isinstance(descending, bool): msg = f"argument 'descending' should be boolean, found {type(descending)}" raise TypeError(msg) - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + ser = self._native_series if descending: return pc.all(pc.greater_equal(ser[:-1], ser[1:])) # type: ignore[no-any-return] @@ -551,13 +591,15 @@ def is_sorted(self: Self, *, descending: bool = False) -> bool: return pc.all(pc.less_equal(ser[:-1], ser[1:])) # type: ignore[no-any-return] def unique(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._from_native_series(pc.unique(self._native_series)) def sort( self: Self, *, descending: bool = False, nulls_last: bool = False ) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + series = self._native_series order = "descending" if descending else "ascending" null_placement = "at_end" if nulls_last else "at_start" @@ -571,11 +613,10 @@ def to_dummies( self: Self, *, separator: str = "_", drop_first: bool = False ) -> ArrowDataFrame: import numpy as np # ignore-banned-import + import pyarrow as pa # ignore-banned-import() from narwhals._arrow.dataframe import ArrowDataFrame - pa = get_pyarrow() - series = self._native_series da = series.dictionary_encode().combine_chunks() @@ -593,7 +634,8 @@ def quantile( quantile: float, interpolation: Literal["nearest", "higher", "lower", "midpoint", "linear"], ) -> Any: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return pc.quantile(self._native_series, q=quantile, interpolation=interpolation)[ 0 ] @@ -604,8 +646,8 @@ def gather_every(self: Self, n: int, offset: int = 0) -> Self: def clip( self: Self, lower_bound: Any | None = None, upper_bound: Any | None = None ) -> Self: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() arr = self._native_series arr = pc.max_element_wise(arr, pa.scalar(lower_bound, type=arr.type)) @@ -638,7 +680,8 @@ def __init__(self: Self, series: ArrowSeries) -> None: self._arrow_series = series def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + # PyArrow differs from other libraries in that %S also prints out # the fractional part of the second...:'( # https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html @@ -648,63 +691,72 @@ def to_string(self: Self, format: str) -> ArrowSeries: # noqa: A002 ) def date(self: Self) -> ArrowSeries: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + return self._arrow_series._from_native_series( self._arrow_series._native_series.cast(pa.date64()) ) def year(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.year(self._arrow_series._native_series) ) def month(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.month(self._arrow_series._native_series) ) def day(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.day(self._arrow_series._native_series) ) def hour(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.hour(self._arrow_series._native_series) ) def minute(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.minute(self._arrow_series._native_series) ) def second(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.second(self._arrow_series._native_series) ) def millisecond(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.millisecond(self._arrow_series._native_series) ) def microsecond(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series result = pc.add(pc.multiply(pc.millisecond(arr), 1000), pc.microsecond(arr)) return self._arrow_series._from_native_series(result) def nanosecond(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series result = pc.add( pc.multiply(self.microsecond()._native_series, 1000), pc.nanosecond(arr) @@ -712,14 +764,16 @@ def nanosecond(self: Self) -> ArrowSeries: return self._arrow_series._from_native_series(result) def ordinal_day(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.day_of_year(self._arrow_series._native_series) ) def total_minutes(self: Self) -> ArrowSeries: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series unit = arr.type.unit @@ -736,8 +790,9 @@ def total_minutes(self: Self) -> ArrowSeries: ) def total_seconds(self: Self) -> ArrowSeries: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series unit = arr.type.unit @@ -754,8 +809,9 @@ def total_seconds(self: Self) -> ArrowSeries: ) def total_milliseconds(self: Self) -> ArrowSeries: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series unit = arr.type.unit @@ -778,8 +834,9 @@ def total_milliseconds(self: Self) -> ArrowSeries: ) def total_microseconds(self: Self) -> ArrowSeries: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series unit = arr.type.unit @@ -801,8 +858,9 @@ def total_microseconds(self: Self) -> ArrowSeries: ) def total_nanoseconds(self: Self) -> ArrowSeries: - pa = get_pyarrow() - pc = get_pyarrow_compute() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() + arr = self._arrow_series._native_series unit = arr.type.unit @@ -825,7 +883,8 @@ def __init__(self, series: ArrowSeries) -> None: self._arrow_series = series def get_categories(self) -> ArrowSeries: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + ca = self._arrow_series._native_series # TODO(Unassigned): this looks potentially expensive - is there no better way? out = pa.chunked_array( @@ -841,7 +900,8 @@ def __init__(self: Self, series: ArrowSeries) -> None: def replace( self, pattern: str, value: str, *, literal: bool = False, n: int = 1 ) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + method = "replace_substring" if literal else "replace_substring_regex" return self._arrow_series._from_native_series( getattr(pc, method)( @@ -858,7 +918,8 @@ def replace_all( return self.replace(pattern, value, literal=literal, n=-1) def strip_chars(self: Self, characters: str | None = None) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + whitespace = " \t\n\r\v\f" return self._arrow_series._from_native_series( pc.utf8_trim( @@ -868,26 +929,30 @@ def strip_chars(self: Self, characters: str | None = None) -> ArrowSeries: ) def starts_with(self: Self, prefix: str) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.equal(self.slice(0, len(prefix))._native_series, prefix) ) def ends_with(self: Self, suffix: str) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.equal(self.slice(-len(suffix))._native_series, suffix) ) def contains(self: Self, pattern: str, *, literal: bool = False) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + check_func = pc.match_substring if literal else pc.match_substring_regex return self._arrow_series._from_native_series( check_func(self._arrow_series._native_series, pattern) ) def slice(self: Self, offset: int, length: int | None = None) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + stop = offset + length if length else None return self._arrow_series._from_native_series( pc.utf8_slice_codeunits( @@ -896,19 +961,22 @@ def slice(self: Self, offset: int, length: int | None = None) -> ArrowSeries: ) def to_datetime(self: Self, format: str | None = None) -> ArrowSeries: # noqa: A002 - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.strptime(self._arrow_series._native_series, format=format, unit="us") ) def to_uppercase(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.utf8_upper(self._arrow_series._native_series), ) def to_lowercase(self: Self) -> ArrowSeries: - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + return self._arrow_series._from_native_series( pc.utf8_lower(self._arrow_series._native_series), ) diff --git a/narwhals/_arrow/utils.py b/narwhals/_arrow/utils.py index a6b56a355..6f7517aeb 100644 --- a/narwhals/_arrow/utils.py +++ b/narwhals/_arrow/utils.py @@ -4,8 +4,6 @@ from typing import Any from narwhals import dtypes -from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyarrow_compute from narwhals.utils import isinstance_or_issubclass if TYPE_CHECKING: @@ -13,7 +11,8 @@ def translate_dtype(dtype: Any) -> dtypes.DType: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + if pa.types.is_int64(dtype): return dtypes.Int64() if pa.types.is_int32(dtype): @@ -56,9 +55,9 @@ def translate_dtype(dtype: Any) -> dtypes.DType: def narwhals_to_native_dtype(dtype: dtypes.DType | type[dtypes.DType]) -> Any: - from narwhals import dtypes + import pyarrow as pa # ignore-banned-import() - pa = get_pyarrow() + from narwhals import dtypes if isinstance_or_issubclass(dtype, dtypes.Float64): return pa.float64() @@ -143,7 +142,8 @@ def validate_dataframe_comparand( return NotImplemented if isinstance(other, ArrowSeries): if len(other) == 1: - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + value = other.item() if backend_version < (13,) and hasattr(value, "as_py"): # pragma: no cover value = value.as_py() @@ -159,7 +159,8 @@ def horizontal_concat(dfs: list[Any]) -> Any: Should be in namespace. """ - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + if not dfs: msg = "No dataframes to concatenate" # pragma: no cover raise AssertionError(msg) @@ -191,15 +192,16 @@ def vertical_concat(dfs: list[Any]) -> Any: msg = "unable to vstack, column names don't match" raise TypeError(msg) - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + return pa.concat_tables(dfs).combine_chunks() def floordiv_compat(left: Any, right: Any) -> Any: # The following lines are adapted from pandas' pyarrow implementation. # Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154 - pc = get_pyarrow_compute() - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() if isinstance(left, (int, float)): left = pa.scalar(left) @@ -237,8 +239,8 @@ def floordiv_compat(left: Any, right: Any) -> Any: def cast_for_truediv(arrow_array: Any, pa_object: Any) -> tuple[Any, Any]: # Lifted from: # https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122 - pc = get_pyarrow_compute() - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + import pyarrow.compute as pc # ignore-banned-import() # Ensure int / int -> float mirroring Python/Numpy behavior # as pc.divide_checked(int, int) -> int @@ -260,7 +262,7 @@ def broadcast_series(series: list[ArrowSeries]) -> list[Any]: if fast_path: return [s._native_series for s in series] - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() reshaped = [] for s, length in zip(series, lengths): diff --git a/narwhals/_dask/dataframe.py b/narwhals/_dask/dataframe.py index cf3c6cc12..99ed430a9 100644 --- a/narwhals/_dask/dataframe.py +++ b/narwhals/_dask/dataframe.py @@ -87,7 +87,7 @@ def select( *exprs: IntoDaskExpr, **named_exprs: IntoDaskExpr, ) -> Self: - dd = get_dask_dataframe() + import dask.dataframe as dd # ignore-banned-import if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs: # This is a simple slice => fastpath! @@ -97,7 +97,8 @@ def select( if not new_series: # return empty dataframe, like Polars does - pd = get_pandas() + import pandas as pd # ignore-banned-import + return self._from_native_frame( dd.from_pandas(pd.DataFrame(), npartitions=self._native_frame.npartitions) ) diff --git a/narwhals/_dask/namespace.py b/narwhals/_dask/namespace.py index afff9fee5..2baf1cf3f 100644 --- a/narwhals/_dask/namespace.py +++ b/narwhals/_dask/namespace.py @@ -10,8 +10,6 @@ from narwhals._dask.expr import DaskExpr from narwhals._dask.selectors import DaskSelectorNamespace from narwhals._expression_parsing import parse_into_exprs -from narwhals.dependencies import get_dask_dataframe -from narwhals.dependencies import get_pandas if TYPE_CHECKING: from narwhals._dask.dataframe import DaskLazyFrame @@ -104,8 +102,8 @@ def sum(self, *column_names: str) -> DaskExpr: ).sum() def len(self) -> DaskExpr: - pd = get_pandas() - dd = get_dask_dataframe() + import dask.dataframe as dd # ignore-banned-import + import pandas as pd # ignore-banned-import def func(df: DaskLazyFrame) -> list[Any]: if not df.columns: diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index b2a819a0e..c815b2a0c 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -19,7 +19,6 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_pyarrow from narwhals.dependencies import is_numpy_array from narwhals.utils import Implementation from narwhals.utils import flatten @@ -491,7 +490,6 @@ def unique( The param `maintain_order` is only here for compatibility with the polars API and has no effect on the output. """ - mapped_keep = {"none": False, "any": "first"}.get(keep, keep) subset = flatten(subset) if subset else None return self._from_native_frame( @@ -603,5 +601,6 @@ def to_arrow(self: Self) -> Any: msg = "`to_arrow` is not implemented for CuDF backend." raise NotImplementedError(msg) - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + return pa.Table.from_pandas(self._native_frame) diff --git a/narwhals/_pandas_like/series.py b/narwhals/_pandas_like/series.py index 51b4cbd72..3db0fb73a 100644 --- a/narwhals/_pandas_like/series.py +++ b/narwhals/_pandas_like/series.py @@ -16,8 +16,6 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas -from narwhals.dependencies import get_pyarrow -from narwhals.dependencies import get_pyarrow_compute from narwhals.utils import Implementation if TYPE_CHECKING: @@ -638,7 +636,8 @@ def to_arrow(self: Self) -> Any: msg = "`to_arrow` is not implemented for CuDF backend." raise NotImplementedError(msg) - pa = get_pyarrow() + import pyarrow as pa # ignore-banned-import() + return pa.Array.from_pandas(self._native_series) @property @@ -786,7 +785,8 @@ def microsecond(self) -> PandasLikeSeries: self._pandas_series._native_series.dtype ): # crazy workaround for https://github.com/pandas-dev/pandas/issues/59154 - pc = get_pyarrow_compute() + import pyarrow.compute as pc # ignore-banned-import() + native_series = self._pandas_series._native_series arr = native_series.array.__arrow_array__() result_arr = pc.add( diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index dfcdce87b..9aa7bb64c 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -1261,7 +1261,6 @@ def head(self, n: int = 5) -> Self: │ 3 ┆ 8 ┆ c │ └─────┴─────┴─────┘ """ - return super().head(n) def tail(self, n: int = 5) -> Self: @@ -1833,7 +1832,6 @@ def is_empty(self: Self) -> bool: >>> func(df_pd), func(df_pl) (False, False) """ - return self._compliant_frame.is_empty() # type: ignore[no-any-return] def is_unique(self: Self) -> Series: @@ -1939,7 +1937,6 @@ def null_count(self: Self) -> Self: │ 1 ┆ 1 ┆ 0 │ └─────┴─────┴─────┘ """ - return self._from_compliant_dataframe(self._compliant_frame.null_count()) def item(self: Self, row: int | None = None, column: int | str | None = None) -> Any: diff --git a/narwhals/dependencies.py b/narwhals/dependencies.py index 663b5dde6..e2d67f03c 100644 --- a/narwhals/dependencies.py +++ b/narwhals/dependencies.py @@ -14,8 +14,12 @@ from typing import TypeGuard else: from typing_extensions import TypeGuard + import cudf + import dask.dataframe as dd + import modin.pandas as mpd import pandas as pd import polars as pl + import pyarrow as pa def get_polars() -> Any: @@ -45,26 +49,6 @@ def get_pyarrow() -> Any: # pragma: no cover return sys.modules.get("pyarrow", None) -def get_pyarrow_compute() -> Any: # pragma: no cover - """Get pyarrow.compute module (if pyarrow has already been imported - else return None).""" - # TODO(marco): remove this one, as it's at odds with the others, as it imports - # something new - if "pyarrow" in sys.modules: - import pyarrow.compute as pc - - return pc - return None - - -def get_pyarrow_parquet() -> Any: # pragma: no cover - """Get pyarrow.parquet module (if pyarrow has already been imported - else return None).""" - if "pyarrow" in sys.modules: - import pyarrow.parquet as pp - - return pp - return None - - def get_numpy() -> Any: """Get numpy module (if already imported - else return None).""" return sys.modules.get("numpy", None) @@ -91,10 +75,35 @@ def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]: def is_pandas_series(ser: Any) -> TypeGuard[pd.Series[Any]]: - """Check whether `df` is a pandas Series without importing pandas.""" + """Check whether `ser` is a pandas Series without importing pandas.""" return bool((pd := get_pandas()) is not None and isinstance(ser, pd.Series)) +def is_modin_dataframe(df: Any) -> TypeGuard[mpd.DataFrame]: + """Check whether `df` is a modin DataFrame without importing modin.""" + return bool((pd := get_modin()) is not None and isinstance(df, pd.DataFrame)) + + +def is_modin_series(ser: Any) -> TypeGuard[mpd.Series]: + """Check whether `ser` is a modin Series without importing modin.""" + return bool((pd := get_modin()) is not None and isinstance(ser, pd.Series)) + + +def is_cudf_dataframe(df: Any) -> TypeGuard[cudf.DataFrame]: + """Check whether `df` is a cudf DataFrame without importing cudf.""" + return bool((pd := get_cudf()) is not None and isinstance(df, pd.DataFrame)) + + +def is_cudf_series(ser: Any) -> TypeGuard[pd.Series[Any]]: + """Check whether `ser` is a cudf Series without importing cudf.""" + return bool((pd := get_cudf()) is not None and isinstance(ser, pd.Series)) + + +def is_dask_dataframe(df: Any) -> TypeGuard[dd.DataFrame]: + """Check whether `df` is a Dask DataFrame without importing Dask.""" + return bool((dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame)) + + def is_polars_dataframe(df: Any) -> TypeGuard[pl.DataFrame]: """Check whether `df` is a Polars DataFrame without importing Polars.""" return bool((pl := get_polars()) is not None and isinstance(df, pl.DataFrame)) @@ -105,6 +114,21 @@ def is_polars_lazyframe(df: Any) -> TypeGuard[pl.LazyFrame]: return bool((pl := get_polars()) is not None and isinstance(df, pl.LazyFrame)) +def is_polars_series(ser: Any) -> TypeGuard[pl.Series]: + """Check whether `ser` is a Polars Series without importing Polars.""" + return bool((pl := get_polars()) is not None and isinstance(ser, pl.Series)) + + +def is_pyarrow_chunked_array(ser: Any) -> TypeGuard[pa.ChunkedArray]: + """Check whether `ser` is a PyArrow ChunkedArray without importing PyArrow.""" + return bool((pa := get_pyarrow()) is not None and isinstance(ser, pa.ChunkedArray)) + + +def is_pyarrow_table(df: Any) -> TypeGuard[pa.Table]: + """Check whether `df` is a PyArrow Table without importing PyArrow.""" + return bool((pa := get_pyarrow()) is not None and isinstance(df, pa.Table)) + + def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]: """Check whether `arr` is a NumPy Array without importing NumPy.""" return bool((np := get_numpy()) is not None and isinstance(arr, np.ndarray)) @@ -116,7 +140,6 @@ def is_numpy_array(arr: Any) -> TypeGuard[np.ndarray]: "get_modin", "get_cudf", "get_pyarrow", - "get_pyarrow_compute", "get_numpy", "is_pandas_dataframe", ] diff --git a/narwhals/expr.py b/narwhals/expr.py index 8d0f4956a..b89fb3d5a 100644 --- a/narwhals/expr.py +++ b/narwhals/expr.py @@ -162,7 +162,6 @@ def cast( │ 3.0 ┆ 8 │ └─────┴─────┘ """ - return self.__class__( lambda plx: self._call(plx).cast(dtype), ) @@ -499,7 +498,6 @@ def min(self) -> Self: │ 1 ┆ 3 │ └─────┴─────┘ """ - return self.__class__(lambda plx: self._call(plx).min()) def max(self) -> Self: @@ -1409,7 +1407,6 @@ def is_unique(self) -> Self: │ false ┆ true │ └───────┴───────┘ """ - return self.__class__(lambda plx: self._call(plx).is_unique()) def null_count(self) -> Self: @@ -1623,7 +1620,6 @@ def head(self, n: int = 10) -> Self: │ 2 │ └─────┘ """ - return self.__class__(lambda plx: self._call(plx).head(n)) def tail(self, n: int = 10) -> Self: @@ -1667,7 +1663,6 @@ def tail(self, n: int = 10) -> Self: │ 9 │ └─────┘ """ - return self.__class__(lambda plx: self._call(plx).tail(n)) def round(self, decimals: int = 0) -> Self: @@ -1719,7 +1714,6 @@ def round(self, decimals: int = 0) -> Self: │ 3.9 │ └─────┘ """ - return self.__class__(lambda plx: self._call(plx).round(decimals)) def len(self) -> Self: @@ -2242,7 +2236,6 @@ def contains(self, pattern: str, *, literal: bool = False) -> Expr: │ null ┆ null ┆ null ┆ null │ └───────────────────┴───────────────┴────────────────────────┴───────────────┘ """ - return self._expr.__class__( lambda plx: self._expr._call(plx).str.contains(pattern, literal=literal) ) @@ -3445,7 +3438,6 @@ def keep(self: Self) -> Expr: >>> func(df_pl).columns ['foo'] """ - return self._expr.__class__(lambda plx: self._expr._call(plx).name.keep()) def map(self: Self, function: Callable[[str], str]) -> Expr: @@ -3482,7 +3474,6 @@ def map(self: Self, function: Callable[[str], str]) -> Expr: >>> func(df_pl).columns ['oof', 'RAB'] """ - return self._expr.__class__(lambda plx: self._expr._call(plx).name.map(function)) def prefix(self: Self, prefix: str) -> Expr: diff --git a/narwhals/series.py b/narwhals/series.py index 3c79024c8..a1bcae18b 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -1698,7 +1698,6 @@ def null_count(self: Self) -> int: >>> func(s_pl) 2 """ - return self._compliant_series.null_count() # type: ignore[no-any-return] def is_first_distinct(self: Self) -> Self: @@ -1969,7 +1968,6 @@ def zip_with(self: Self, mask: Self, other: Self) -> Self: 4 5 dtype: int64 """ - return self._from_compliant_series( self._compliant_series.zip_with( self._extract_native(mask), self._extract_native(other) @@ -2043,7 +2041,6 @@ def head(self: Self, n: int = 10) -> Self: 2 ] """ - return self._from_compliant_series(self._compliant_series.head(n)) def tail(self: Self, n: int = 10) -> Self: @@ -2084,7 +2081,6 @@ def tail(self: Self, n: int = 10) -> Self: 9 ] """ - return self._from_compliant_series(self._compliant_series.tail(n)) def round(self: Self, decimals: int = 0) -> Self: @@ -2200,7 +2196,6 @@ def to_dummies( │ 0 ┆ 1 │ └─────┴─────┘ """ - from narwhals.dataframe import DataFrame return DataFrame( @@ -2284,7 +2279,6 @@ def to_arrow(self: Self) -> pa.Array: 4 ] """ - return self._compliant_series.to_arrow() @property diff --git a/narwhals/translate.py b/narwhals/translate.py index ed33b376b..f396a8982 100644 --- a/narwhals/translate.py +++ b/narwhals/translate.py @@ -10,16 +10,23 @@ from narwhals.dependencies import get_cudf from narwhals.dependencies import get_dask -from narwhals.dependencies import get_dask_dataframe from narwhals.dependencies import get_dask_expr from narwhals.dependencies import get_modin from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow +from narwhals.dependencies import is_cudf_dataframe +from narwhals.dependencies import is_cudf_series +from narwhals.dependencies import is_dask_dataframe +from narwhals.dependencies import is_modin_dataframe +from narwhals.dependencies import is_modin_series from narwhals.dependencies import is_pandas_dataframe from narwhals.dependencies import is_pandas_series from narwhals.dependencies import is_polars_dataframe from narwhals.dependencies import is_polars_lazyframe +from narwhals.dependencies import is_polars_series +from narwhals.dependencies import is_pyarrow_chunked_array +from narwhals.dependencies import is_pyarrow_table if TYPE_CHECKING: from narwhals.dataframe import DataFrame @@ -369,7 +376,8 @@ def from_native( # noqa: PLR0915 PolarsLazyFrame(native_object, backend_version=parse_version(pl.__version__)), level="full", ) - elif (pl := get_polars()) is not None and isinstance(native_object, pl.Series): + elif is_polars_series(native_object): + pl = get_polars() if not allow_series: msg = "Please set `allow_series=True`" raise TypeError(msg) @@ -407,9 +415,8 @@ def from_native( # noqa: PLR0915 ) # Modin - elif (mpd := get_modin()) is not None and isinstance( - native_object, mpd.DataFrame - ): # pragma: no cover + elif is_modin_dataframe(native_object): # pragma: no cover + mpd = get_modin() if series_only: msg = "Cannot only use `series_only` with modin.DataFrame" raise TypeError(msg) @@ -421,9 +428,8 @@ def from_native( # noqa: PLR0915 ), level="full", ) - elif (mpd := get_modin()) is not None and isinstance( - native_object, mpd.Series - ): # pragma: no cover + elif is_modin_series(native_object): # pragma: no cover + mpd = get_modin() if not allow_series: msg = "Please set `allow_series=True`" raise TypeError(msg) @@ -437,9 +443,8 @@ def from_native( # noqa: PLR0915 ) # cuDF - elif (cudf := get_cudf()) is not None and isinstance( # pragma: no cover - native_object, cudf.DataFrame - ): + elif is_cudf_dataframe(native_object): # pragma: no cover + cudf = get_cudf() if series_only: msg = "Cannot only use `series_only` with cudf.DataFrame" raise TypeError(msg) @@ -451,9 +456,8 @@ def from_native( # noqa: PLR0915 ), level="full", ) - elif (cudf := get_cudf()) is not None and isinstance( - native_object, cudf.Series - ): # pragma: no cover + elif is_cudf_series(native_object): # pragma: no cover + cudf = get_cudf() if not allow_series: msg = "Please set `allow_series=True`" raise TypeError(msg) @@ -467,7 +471,8 @@ def from_native( # noqa: PLR0915 ) # PyArrow - elif (pa := get_pyarrow()) is not None and isinstance(native_object, pa.Table): + elif is_pyarrow_table(native_object): + pa = get_pyarrow() if series_only: msg = "Cannot only use `series_only` with arrow table" raise TypeError(msg) @@ -475,7 +480,8 @@ def from_native( # noqa: PLR0915 ArrowDataFrame(native_object, backend_version=parse_version(pa.__version__)), level="full", ) - elif (pa := get_pyarrow()) is not None and isinstance(native_object, pa.ChunkedArray): + elif is_pyarrow_chunked_array(native_object): + pa = get_pyarrow() if not allow_series: msg = "Please set `allow_series=True`" raise TypeError(msg) @@ -487,9 +493,7 @@ def from_native( # noqa: PLR0915 ) # Dask - elif (dd := get_dask_dataframe()) is not None and isinstance( - native_object, dd.DataFrame - ): + elif is_dask_dataframe(native_object): if series_only: msg = "Cannot only use `series_only` with dask DataFrame" raise TypeError(msg) diff --git a/narwhals/utils.py b/narwhals/utils.py index 512099cc5..1a0b752d9 100644 --- a/narwhals/utils.py +++ b/narwhals/utils.py @@ -18,6 +18,12 @@ from narwhals.dependencies import get_pandas from narwhals.dependencies import get_polars from narwhals.dependencies import get_pyarrow +from narwhals.dependencies import is_cudf_series +from narwhals.dependencies import is_modin_series +from narwhals.dependencies import is_pandas_dataframe +from narwhals.dependencies import is_pandas_series +from narwhals.dependencies import is_polars_series +from narwhals.dependencies import is_pyarrow_chunked_array from narwhals.translate import to_native if TYPE_CHECKING: @@ -95,7 +101,7 @@ def tupleify(arg: Any) -> Any: def _is_iterable(arg: Any | Iterable[Any]) -> bool: from narwhals.series import Series - if (pd := get_pandas()) is not None and isinstance(arg, (pd.Series, pd.DataFrame)): + if is_pandas_dataframe(arg) or is_pandas_series(arg): msg = f"Expected Narwhals class or scalar, got: {type(arg)}. Perhaps you forgot a `nw.from_native` somewhere?" raise TypeError(msg) if (pl := get_polars()) is not None and isinstance( @@ -352,19 +358,15 @@ def is_ordered_categorical(series: Series) -> bool: if series.dtype != dtypes.Categorical: return False native_series = to_native(series) - if (pl := get_polars()) is not None and isinstance(native_series, pl.Series): - return native_series.dtype.ordering == "physical" # type: ignore[no-any-return] - if (pd := get_pandas()) is not None and isinstance(native_series, pd.Series): + if is_polars_series(native_series): + return native_series.dtype.ordering == "physical" # type: ignore[attr-defined, no-any-return] + if is_pandas_series(native_series): return native_series.cat.ordered # type: ignore[no-any-return] - if (mpd := get_modin()) is not None and isinstance( - native_series, mpd.Series - ): # pragma: no cover + if is_modin_series(native_series): # pragma: no cover return native_series.cat.ordered # type: ignore[no-any-return] - if (cudf := get_cudf()) is not None and isinstance( - native_series, cudf.Series - ): # pragma: no cover + if is_cudf_series(native_series): # pragma: no cover return native_series.cat.ordered # type: ignore[no-any-return] - if (pa := get_pyarrow()) is not None and isinstance(native_series, pa.ChunkedArray): + if is_pyarrow_chunked_array(native_series): return native_series.type.ordered # type: ignore[no-any-return] # If it doesn't match any of the above, let's just play it safe and return False. return False # pragma: no cover diff --git a/tests/no_imports_test.py b/tests/no_imports_test.py new file mode 100644 index 000000000..a89ed0ed8 --- /dev/null +++ b/tests/no_imports_test.py @@ -0,0 +1,68 @@ +import sys + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +import narwhals.stable.v1 as nw + + +def test_polars(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delitem(sys.modules, "pandas") + monkeypatch.delitem(sys.modules, "numpy") + monkeypatch.delitem(sys.modules, "pyarrow") + monkeypatch.delitem(sys.modules, "dask", raising=False) + df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( + nw.col("a") > 1 + ) + assert "polars" in sys.modules + assert "pandas" not in sys.modules + assert "numpy" not in sys.modules + assert "pyarrow" not in sys.modules + assert "dask" not in sys.modules + + +def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delitem(sys.modules, "polars") + monkeypatch.delitem(sys.modules, "pyarrow") + monkeypatch.delitem(sys.modules, "dask", raising=False) + df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]}) + nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter( + nw.col("a") > 1 + ) + assert "polars" not in sys.modules + assert "pandas" in sys.modules + assert "numpy" in sys.modules + assert "pyarrow" not in sys.modules + assert "dask" not in sys.modules + + +def test_dask(monkeypatch: pytest.MonkeyPatch) -> None: + pytest.importorskip("dask") + pytest.importorskip("dask_expr", exc_type=ImportError) + import dask.dataframe as dd + + monkeypatch.delitem(sys.modules, "polars") + monkeypatch.delitem(sys.modules, "pyarrow") + df = dd.from_pandas(pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})) + nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1) + assert "polars" not in sys.modules + assert "pandas" in sys.modules + assert "numpy" in sys.modules + assert "pyarrow" not in sys.modules + assert "dask" in sys.modules + + +def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.delitem(sys.modules, "polars") + monkeypatch.delitem(sys.modules, "pandas") + monkeypatch.delitem(sys.modules, "dask", raising=False) + df = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]}) + nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1) + assert "polars" not in sys.modules + assert "pandas" not in sys.modules + assert "numpy" in sys.modules + assert "pyarrow" in sys.modules + assert "dask" not in sys.modules