Skip to content

Commit

Permalink
chore: import overhaul (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 14, 2024
1 parent 350fe7d commit 6fbfb77
Show file tree
Hide file tree
Showing 17 changed files with 382 additions and 214 deletions.
20 changes: 20 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,26 @@ Please adhere to the following guidelines:
If Narwhals looks like underwater unicorn magic to you, then please read
[how it works](https://narwhals-dev.github.io/narwhals/how-it-works/).

## Imports

In Narwhals, we are very particular about imports. When it comes to importing
heavy third-party libraries (pandas, NumPy, Polars, etc...) please follow these rules:

- Never import anything to do `isinstance` checks. Instead, just use the functions
in `narwhals.dependencies` (such as `is_pandas_dataframe`);
- If you need to import anything, do it in a place where you know that the import
is definitely available. For example, NumPy is a required dependency of PyArrow,
so it's OK to import NumPy to implement a PyArrow function - however, NumPy
should never be imported to implement a Polars function. The only exception is
for when there's simply no way around it by definition - for example, `Series.to_numpy`
always requires NumPy to be installed.
- Don't place a third-party import at the top of a file. Instead, place it in the
function where it's used, so that we minimise the chances of it being imported
unnecessarily.

We're trying to be really lightweight and minimal-overhead, and
unnecessary imports can slow things down.

## Happy contributing!

Please remember to abide by the code of conduct, else you'll be conducted away from this project.
Expand Down
27 changes: 14 additions & 13 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
from narwhals._arrow.utils import validate_dataframe_comparand
from narwhals._expression_parsing import evaluate_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_compute
from narwhals.dependencies import get_pyarrow_parquet
from narwhals.dependencies import is_numpy_array
from narwhals.utils import Implementation
from narwhals.utils import flatten
Expand Down Expand Up @@ -182,12 +180,13 @@ def select(
*exprs: IntoArrowExpr,
**named_exprs: IntoArrowExpr,
) -> Self:
import pyarrow as pa # ignore-banned-import()

new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._from_native_frame(self._native_frame.__class__.from_arrays([]))
names = [s.name for s in new_series]
pa = get_pyarrow()
df = pa.Table.from_arrays(
broadcast_series(new_series),
names=names,
Expand Down Expand Up @@ -337,7 +336,8 @@ def to_dict(self, *, as_series: bool) -> Any:
return {name: col.to_pylist() for name, col in names_and_values}

def with_row_index(self, name: str) -> Self:
pa = get_pyarrow()
import pyarrow as pa # ignore-banned-import()

df = self._native_frame

row_indices = pa.array(range(df.num_rows))
Expand All @@ -354,7 +354,8 @@ def filter(
return self._from_native_frame(self._native_frame.filter(mask._native_series))

def null_count(self) -> Self:
pa = get_pyarrow()
import pyarrow as pa # ignore-banned-import()

df = self._native_frame
names_and_values = zip(df.column_names, df.columns)

Expand Down Expand Up @@ -415,16 +416,17 @@ def rename(self, mapping: dict[str, str]) -> Self:
return self._from_native_frame(df.rename_columns(new_cols))

def write_parquet(self, file: Any) -> Any:
pp = get_pyarrow_parquet()
import pyarrow.parquet as pp # ignore-banned-import

pp.write_table(self._native_frame, file)

def is_duplicated(self: Self) -> ArrowSeries:
import numpy as np # ignore-banned-import
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

from narwhals._arrow.series import ArrowSeries

pa = get_pyarrow()
pc = get_pyarrow_compute()
df = self._native_frame

columns = self.columns
Expand All @@ -443,9 +445,10 @@ def is_duplicated(self: Self) -> ArrowSeries:
return ArrowSeries(is_duplicated, name="", backend_version=self._backend_version)

def is_unique(self: Self) -> ArrowSeries:
import pyarrow.compute as pc # ignore-banned-import()

from narwhals._arrow.series import ArrowSeries

pc = get_pyarrow_compute()
is_duplicated = self.is_duplicated()._native_series

return ArrowSeries(
Expand All @@ -464,11 +467,9 @@ def unique(
The param `maintain_order` is only here for compatibility with the polars API
and has no effect on the output.
"""

import numpy as np # ignore-banned-import

pa = get_pyarrow()
pc = get_pyarrow_compute()
import pyarrow as pa # ignore-banned-import()
import pyarrow.compute as pc # ignore-banned-import()

df = self._native_frame

Expand Down
8 changes: 4 additions & 4 deletions narwhals/_arrow/group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@

from narwhals._expression_parsing import is_simple_aggregation
from narwhals._expression_parsing import parse_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.dependencies import get_pyarrow_compute
from narwhals.utils import remove_prefix

if TYPE_CHECKING:
Expand All @@ -20,7 +18,8 @@

class ArrowGroupBy:
def __init__(self, df: ArrowDataFrame, keys: list[str]) -> None:
pa = get_pyarrow()
import pyarrow as pa # ignore-banned-import()

self._df = df
self._keys = list(keys)
self._grouped = pa.TableGroupBy(self._df._native_frame, list(self._keys))
Expand Down Expand Up @@ -79,7 +78,8 @@ def agg_arrow(
output_names: list[str],
from_dataframe: Callable[[Any], ArrowDataFrame],
) -> ArrowDataFrame:
pc = get_pyarrow_compute()
import pyarrow.compute as pc # ignore-banned-import()

all_simple_aggs = True
for expr in exprs:
if not is_simple_aggregation(expr):
Expand Down
4 changes: 2 additions & 2 deletions narwhals/_arrow/namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from narwhals._arrow.utils import horizontal_concat
from narwhals._arrow.utils import vertical_concat
from narwhals._expression_parsing import parse_into_exprs
from narwhals.dependencies import get_pyarrow
from narwhals.utils import Implementation

if TYPE_CHECKING:
Expand Down Expand Up @@ -87,9 +86,10 @@ def _create_series_from_scalar(self, value: Any, series: ArrowSeries) -> ArrowSe
)

def _create_compliant_series(self, value: Any) -> ArrowSeries:
import pyarrow as pa # ignore-banned-import()

from narwhals._arrow.series import ArrowSeries

pa = get_pyarrow()
return ArrowSeries(
native_series=pa.chunked_array([value]),
name="",
Expand Down
Loading

0 comments on commit 6fbfb77

Please sign in to comment.