Skip to content

Commit

Permalink
Merge pull request #51 from MarcoGorelli/to-numpy
Browse files Browse the repository at this point in the history
To numpy
  • Loading branch information
MarcoGorelli authored Apr 25, 2024
2 parents c18bd13 + 47f905b commit 8f2de4d
Show file tree
Hide file tree
Showing 12 changed files with 826 additions and 36 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
- name: install-modin
run: python -m pip install pandas==2.0.0 polars==0.20.13 modin[dask]
- name: Run pytest
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow
- name: Run doctests
run: pytest narwhals --doctest-modules

Expand Down Expand Up @@ -64,7 +64,7 @@ jobs:
- name: install-pandas-nightly
run: python -m pip install --pre --extra-index https://pypi.anaconda.org/scientific-python-nightly-wheels/simple pandas
- name: Run pytest
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow
- name: Run doctests
run: pytest narwhals --doctest-modules

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,6 @@ jobs:
- name: install-modin
run: python -m pip install --upgrade modin[dask]
- name: Run pytest
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=100
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=100 --runslow
- name: Run doctests
run: pytest narwhals --doctest-modules
13 changes: 12 additions & 1 deletion narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def select(
new_series = evaluate_into_exprs(self, *exprs, **named_exprs)
new_series = validate_indices(new_series)
df = horizontal_concat(
[series._series for series in new_series],
new_series,
implementation=self._implementation,
)
return self._from_dataframe(df)
Expand Down Expand Up @@ -227,6 +227,17 @@ def to_dict(self, *, as_series: bool = False) -> dict[str, Any]:
return self._dataframe.to_dict(orient="list") # type: ignore[no-any-return]

def to_numpy(self) -> Any:
from narwhals._pandas_like.series import PANDAS_TO_NUMPY_DTYPE_MISSING

# pandas return `object` dtype for nullable dtypes, so we cast each
# Series to numpy and let numpy find a common dtype.
# If there aren't any dtypes where `to_numpy()` is "broken" (i.e. it
# returns Object) then we just call `to_numpy()` on the DataFrame.
for dtype in self._dataframe.dtypes:
if str(dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
import numpy as np

return np.hstack([self[col].to_numpy()[:, None] for col in self.columns])
return self._dataframe.to_numpy()

def to_pandas(self) -> Any:
Expand Down
68 changes: 67 additions & 1 deletion narwhals/_pandas_like/series.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING
from typing import Any
from typing import Sequence
Expand All @@ -16,6 +17,51 @@
from narwhals._pandas_like.namespace import PandasNamespace
from narwhals.dtypes import DType

PANDAS_TO_NUMPY_DTYPE_NO_MISSING = {
"Int64": "int64",
"int64[pyarrow]": "int64",
"Int32": "int32",
"int32[pyarrow]": "int32",
"Int16": "int16",
"int16[pyarrow]": "int16",
"Int8": "int8",
"int8[pyarrow]": "int8",
"UInt64": "uint64",
"uint64[pyarrow]": "uint64",
"UInt32": "uint32",
"uint32[pyarrow]": "uint32",
"UInt16": "uint16",
"uint16[pyarrow]": "uint16",
"UInt8": "uint8",
"uint8[pyarrow]": "uint8",
"Float64": "float64",
"float64[pyarrow]": "float64",
"Float32": "float32",
"float32[pyarrow]": "float32",
}
PANDAS_TO_NUMPY_DTYPE_MISSING = {
"Int64": "float64",
"int64[pyarrow]": "float64",
"Int32": "float64",
"int32[pyarrow]": "float64",
"Int16": "float64",
"int16[pyarrow]": "float64",
"Int8": "float64",
"int8[pyarrow]": "float64",
"UInt64": "float64",
"uint64[pyarrow]": "float64",
"UInt32": "float64",
"uint32[pyarrow]": "float64",
"UInt16": "float64",
"uint16[pyarrow]": "float64",
"UInt8": "float64",
"uint8[pyarrow]": "float64",
"Float64": "float64",
"float64[pyarrow]": "float64",
"Float32": "float32",
"float32[pyarrow]": "float32",
}


class PandasSeries:
def __init__(
Expand Down Expand Up @@ -102,7 +148,14 @@ def is_in(self, other: Any) -> PandasSeries:
import pandas as pd

ser = self._series
res = ser.isin(other).convert_dtypes()
with warnings.catch_warnings():
# np.find_common_type is deprecated. Please use `np.result_type` or `np.promote_types`
warnings.filterwarnings(
"ignore",
message="np.find_common_type is deprecated.*",
category=DeprecationWarning,
)
res = ser.isin(other).convert_dtypes()
res[ser.isna()] = pd.NA
return self._from_series(res)

Expand Down Expand Up @@ -317,6 +370,19 @@ def alias(self, name: str) -> Self:
return self._from_series(self._rename(ser, name))

def to_numpy(self) -> Any:
has_missing = self._series.isna().any()
if has_missing and str(self._series.dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
return self._series.to_numpy(
dtype=PANDAS_TO_NUMPY_DTYPE_MISSING[str(self._series.dtype)],
na_value=float("nan"),
)
if (
not has_missing
and str(self._series.dtype) in PANDAS_TO_NUMPY_DTYPE_NO_MISSING
):
return self._series.to_numpy(
dtype=PANDAS_TO_NUMPY_DTYPE_NO_MISSING[str(self._series.dtype)]
)
return self._series.to_numpy()

def to_pandas(self) -> Any:
Expand Down
25 changes: 8 additions & 17 deletions narwhals/_pandas_like/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ def validate_column_comparand(index: Any, other: Any) -> Any:
# broadcast
return other.item()
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"You may need to do a join before this operation."
)
raise ValueError(msg)
return other._series.set_axis(index, axis=0)
return other._series
return other

Expand All @@ -74,11 +70,7 @@ def validate_dataframe_comparand(index: Any, other: Any) -> Any:
# broadcast
return item(other._series)
if other._series.index is not index and not (other._series.index == index).all():
msg = (
"Narwhals does not support automated index alignment. "
"You may need to do a join before this operation."
)
raise ValueError(msg)
return other._series.set_axis(index, axis=0)
return other._series
raise AssertionError("Please report a bug")

Expand Down Expand Up @@ -362,13 +354,12 @@ def reverse_translate_dtype(dtype: DType | type[DType]) -> Any:
raise AssertionError(msg)


def validate_indices(series: list[PandasSeries]) -> list[PandasSeries]:
def validate_indices(series: list[PandasSeries]) -> list[Any]:
idx = series[0]._series.index
reindexed = [series[0]._series]
for s in series[1:]:
if s._series.index is not idx and not (s._series.index == idx).all():
msg = (
"Narwhals does not support automated index alignment. "
"You may need to do a join before this operation."
)
raise RuntimeError(msg)
return series
reindexed.append(s._series.set_axis(idx.rename(s._series.index.name), axis=0))
else:
reindexed.append(s._series)
return reindexed
2 changes: 1 addition & 1 deletion narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def get_pandas() -> Any:


@functools.lru_cache
def get_modin() -> Any:
def get_modin() -> Any: # pragma: no cover
try:
import modin.pandas as mpd
except ImportError: # pragma: no cover
Expand Down
23 changes: 23 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Any

import pytest


def pytest_addoption(parser: Any) -> None:
parser.addoption(
"--runslow", action="store_true", default=False, help="run slow tests"
)


def pytest_configure(config: Any) -> None:
config.addinivalue_line("markers", "slow: mark test as slow to run")


def pytest_collection_modifyitems(config: Any, items: Any) -> Any: # pragma: no cover
if config.getoption("--runslow"):
# --runslow given in cli: do not skip slow tests
return
skip_slow = pytest.mark.skip(reason="need --runslow option to run")
for item in items:
if "slow" in item.keywords:
item.add_marker(skip_slow)
2 changes: 2 additions & 0 deletions tests/hypothesis/test_join.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import pandas as pd
import polars as pl
import pytest
from hypothesis import example
from hypothesis import given
from hypothesis import strategies as st
Expand Down Expand Up @@ -34,6 +35,7 @@
unique=True,
),
) # type: ignore[misc]
@pytest.mark.slow()
def test_join(
integers: st.SearchStrategy[list[int]],
other_integers: st.SearchStrategy[list[int]],
Expand Down
25 changes: 16 additions & 9 deletions tests/test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,14 +322,17 @@ def test_convert_pandas(df_raw: Any) -> None:
pd.testing.assert_frame_equal(result, expected)


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd])
@pytest.mark.parametrize(
"df_raw", [df_polars, df_pandas, df_mpd, df_pandas_nullable, df_pandas_pyarrow]
)
@pytest.mark.filterwarnings(
r"ignore:np\.find_common_type is deprecated\.:DeprecationWarning"
)
def test_convert_numpy(df_raw: Any) -> None:
result = nw.DataFrame(df_raw).to_numpy()
expected = np.array([[1, 3, 2], [4, 4, 6], [7.0, 8, 9]]).T
np.testing.assert_array_equal(result, expected)
assert result.dtype == "float64"


@pytest.mark.parametrize("df_raw", [df_polars, df_pandas, df_mpd])
Expand Down Expand Up @@ -569,15 +572,19 @@ def test_invalid() -> None:
@pytest.mark.parametrize("df_raw", [df_pandas])
def test_reindex(df_raw: Any) -> None:
df = nw.DataFrame(df_raw)
with pytest.raises(RuntimeError, match="automated index alignment"):
df.select("a", df["b"].sort(descending=True))
with pytest.raises(RuntimeError, match="automated index alignment"):
df.select("a", nw.col("b").sort(descending=True))
result = df.select("b", df["a"].sort(descending=True))
expected = {"b": [4, 4, 6], "a": [3, 2, 1]}
compare_dicts(result, expected)
result = df.select("b", nw.col("a").sort(descending=True))
compare_dicts(result, expected)

s = df["a"]
with pytest.raises(ValueError, match="index alignment"):
nw.to_native(s > s.sort())
with pytest.raises(ValueError, match="index alignment"):
nw.to_native(df.with_columns(s.sort()))
result_s = s > s.sort()
assert not result_s[0]
assert result_s[1]
assert not result_s[2]
result = df.with_columns(s.sort())
expected = {"a": [1, 2, 3], "b": [4, 4, 6], "z": [7.0, 8.0, 9.0]} # type: ignore[list-item]
compare_dicts(result, expected)
with pytest.raises(ValueError, match="Multi-output expressions are not supported"):
nw.to_native(df.with_columns(nw.all() + nw.all()))
15 changes: 11 additions & 4 deletions tests/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,14 @@
def test_len(df_raw: Any) -> None:
result = len(nw.Series(df_raw["a"]))
assert result == 3
result = len(nw.to_native(nw.LazyFrame(df_raw).collect()["a"]))
result = len(nw.LazyFrame(df_raw).collect()["a"])
assert result == 3


@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
@pytest.mark.filterwarnings("ignore:np.find_common_type is deprecated:DeprecationWarning")
def test_is_in(df_raw: Any) -> None:
result = nw.to_native(nw.Series(df_raw["a"]).is_in([1, 2]))
result = nw.Series(df_raw["a"]).is_in([1, 2])
assert result[0]
assert not result[1]
assert result[2]
Expand All @@ -55,11 +56,11 @@ def test_is_in(df_raw: Any) -> None:
@pytest.mark.parametrize("df_raw", [df_pandas, df_polars])
def test_gt(df_raw: Any) -> None:
s = nw.Series(df_raw["a"])
result = nw.to_native(s > s) # noqa: PLR0124
result = s > s # noqa: PLR0124
assert not result[0]
assert not result[1]
assert not result[2]
result = nw.to_native(s > 1)
result = s > 1
assert not result[0]
assert result[1]
assert result[2]
Expand Down Expand Up @@ -285,3 +286,9 @@ def test_cast() -> None:
n=df["m"].cast(nw.Boolean),
).schema
assert result == expected


def test_to_numpy() -> None:
s = pd.Series([1, 2, None], dtype="Int64")
result = nw.Series(s).to_numpy()
assert result.dtype == "float64"
Loading

0 comments on commit 8f2de4d

Please sign in to comment.