Skip to content

Commit

Permalink
feat: Better ibis interchange (#901)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Sep 3, 2024
1 parent 1091072 commit 3d246b7
Show file tree
Hide file tree
Showing 11 changed files with 213 additions and 5 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,11 @@ jobs:
run: uv pip install --upgrade modin[dask] --system
- name: show-deps
run: uv pip freeze
- name: install ibis
run: uv pip install ibis-framework[duckdb] --system
# Ibis puts upper bounds on dependencies, and requires Python3.10+,
# which messes with other dependencies on lower Python versions
if: matrix.python-version == '3.12'
- name: Run pytest
run: pytest tests --cov=narwhals --cov=tests --cov-fail-under=100 --runslow
- name: Run doctests
Expand Down
2 changes: 2 additions & 0 deletions docs/api-reference/dependencies.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
options:
members:
- get_cudf
- get_ibis
- get_modin
- get_pandas
- get_polars
- get_pyarrow
- is_cudf_dataframe
- is_cudf_series
- is_dask_dataframe
- is_ibis_table
- is_modin_dataframe
- is_modin_series
- is_numpy_array
Expand Down
Empty file added narwhals/_ibis/__init__.py
Empty file.
75 changes: 75 additions & 0 deletions narwhals/_ibis/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals import dtypes

if TYPE_CHECKING:
from narwhals._ibis.series import IbisInterchangeSeries


def map_ibis_dtype_to_narwhals_dtype(
ibis_dtype: Any,
) -> dtypes.DType:
if ibis_dtype.is_int64():
return dtypes.Int64()
if ibis_dtype.is_int32():
return dtypes.Int32()
if ibis_dtype.is_int16():
return dtypes.Int16()
if ibis_dtype.is_int8():
return dtypes.Int8()
if ibis_dtype.is_uint64():
return dtypes.UInt64()
if ibis_dtype.is_uint32():
return dtypes.UInt32()
if ibis_dtype.is_uint16():
return dtypes.UInt16()
if ibis_dtype.is_uint8():
return dtypes.UInt8()
if ibis_dtype.is_boolean():
return dtypes.Boolean()
if ibis_dtype.is_float64():
return dtypes.Float64()
if ibis_dtype.is_float32():
return dtypes.Float32()
if ibis_dtype.is_string():
return dtypes.String()
if ibis_dtype.is_date():
return dtypes.Date()
if ibis_dtype.is_timestamp():
return dtypes.Datetime()
msg = ( # pragma: no cover
f"Invalid dtype, got: {ibis_dtype}.\n\n"
"If you believe this dtype should be supported in Narwhals, "
"please report an issue at https://github.com/narwhals-dev/narwhals."
)
raise AssertionError(msg)


class IbisInterchangeFrame:
def __init__(self, df: Any) -> None:
self._native_frame = df

def __narwhals_dataframe__(self) -> Any:
return self

def __getitem__(self, item: str) -> IbisInterchangeSeries:
from narwhals._ibis.series import IbisInterchangeSeries

return IbisInterchangeSeries(self._native_frame[item])

def __getattr__(self, attr: str) -> Any:
if attr == "schema":
return {
column_name: map_ibis_dtype_to_narwhals_dtype(ibis_dtype)
for column_name, ibis_dtype in self._native_frame.schema().items()
}
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
24 changes: 24 additions & 0 deletions narwhals/_ibis/series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Any

from narwhals._ibis.dataframe import map_ibis_dtype_to_narwhals_dtype


class IbisInterchangeSeries:
def __init__(self, df: Any) -> None:
self._native_series = df

def __narwhals_series__(self) -> Any:
return self

def __getattr__(self, attr: str) -> Any:
if attr == "dtype":
return map_ibis_dtype_to_narwhals_dtype(self._native_series.type())
msg = (
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
7 changes: 4 additions & 3 deletions narwhals/_interchange/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,22 +70,23 @@ def map_interchange_dtype_to_narwhals_dtype(
class InterchangeFrame:
def __init__(self, df: Any) -> None:
self._native_frame = df
self._interchange_frame = df.__dataframe__()

def __narwhals_dataframe__(self) -> Any:
return self

def __getitem__(self, item: str) -> InterchangeSeries:
from narwhals._interchange.series import InterchangeSeries

return InterchangeSeries(self._native_frame.get_column_by_name(item))
return InterchangeSeries(self._interchange_frame.get_column_by_name(item))

@property
def schema(self) -> dict[str, dtypes.DType]:
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._native_frame.get_column_by_name(column_name).dtype
self._interchange_frame.get_column_by_name(column_name).dtype
)
for column_name in self._native_frame.column_names()
for column_name in self._interchange_frame.column_names()
}

def __getattr__(self, attr: str) -> NoReturn:
Expand Down
13 changes: 13 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing_extensions import TypeGuard
import cudf
import dask.dataframe as dd
import ibis
import modin.pandas as mpd
import pandas as pd
import polars as pl
Expand Down Expand Up @@ -69,6 +70,11 @@ def get_dask_expr() -> Any:
return sys.modules.get("dask_expr", None)


def get_ibis() -> Any:
"""Get ibis module (if already imported - else return None)."""
return sys.modules.get("ibis", None)


def is_pandas_dataframe(df: Any) -> TypeGuard[pd.DataFrame]:
"""Check whether `df` is a pandas DataFrame without importing pandas."""
return bool((pd := get_pandas()) is not None and isinstance(df, pd.DataFrame))
Expand Down Expand Up @@ -104,6 +110,11 @@ def is_dask_dataframe(df: Any) -> TypeGuard[dd.DataFrame]:
return bool((dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame))


def is_ibis_table(df: Any) -> TypeGuard[ibis.Table]:
"""Check whether `df` is a Ibis Table without importing Ibis."""
return bool((ibis := get_ibis()) is not None and isinstance(df, ibis.Table))


def is_polars_dataframe(df: Any) -> TypeGuard[pl.DataFrame]:
"""Check whether `df` is a Polars DataFrame without importing Polars."""
return bool((pl := get_polars()) is not None and isinstance(df, pl.DataFrame))
Expand Down Expand Up @@ -159,6 +170,8 @@ def is_pandas_like_series(arr: Any) -> bool:
"get_cudf",
"get_pyarrow",
"get_numpy",
"get_ibis",
"is_ibis_table",
"is_pandas_dataframe",
"is_pandas_series",
"is_polars_dataframe",
Expand Down
17 changes: 16 additions & 1 deletion narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals.dependencies import is_cudf_dataframe
from narwhals.dependencies import is_cudf_series
from narwhals.dependencies import is_dask_dataframe
from narwhals.dependencies import is_ibis_table
from narwhals.dependencies import is_modin_dataframe
from narwhals.dependencies import is_modin_series
from narwhals.dependencies import is_pandas_dataframe
Expand Down Expand Up @@ -331,6 +332,7 @@ def from_native( # noqa: PLR0915
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._ibis.dataframe import IbisInterchangeFrame
from narwhals._interchange.dataframe import InterchangeFrame
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.series import PandasLikeSeries
Expand Down Expand Up @@ -546,6 +548,19 @@ def from_native( # noqa: PLR0915
level="full",
)

# Ibis
elif is_ibis_table(native_object): # pragma: no cover
if eager_only or series_only:
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with Ibis table"
)
raise TypeError(msg)
return DataFrame(
IbisInterchangeFrame(native_object),
level="interchange",
)

# Interchange protocol
elif hasattr(native_object, "__dataframe__"):
if eager_only or series_only:
Expand All @@ -555,7 +570,7 @@ def from_native( # noqa: PLR0915
)
raise TypeError(msg)
return DataFrame(
InterchangeFrame(native_object.__dataframe__()),
InterchangeFrame(native_object),
level="interchange",
)

Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,11 @@ env = [
plugins = ["covdefaults"]

[tool.coverage.report]
omit = ['narwhals/typing.py']
omit = [
'narwhals/typing.py',
# we can run this in every environment that we measure coverage on due to upper-bound constraits
'narwhals/_ibis/*',
]
exclude_also = [
"> POLARS_VERSION",
"if sys.version_info() <",
Expand Down
63 changes: 63 additions & 0 deletions tests/frame/interchange_schema_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from datetime import date
from datetime import datetime

import polars as pl
import pytest
Expand Down Expand Up @@ -63,6 +64,68 @@ def test_interchange_schema() -> None:
assert df["a"].dtype == nw.Int64


def test_interchange_schema_ibis() -> None: # pragma: no cover
ibis = pytest.importorskip("ibis")
df_pl = pl.DataFrame(
{
"a": [1, 1, 2],
"b": [4, 5, 6],
"c": [4, 5, 6],
"d": [4, 5, 6],
"e": [4, 5, 6],
"f": [4, 5, 6],
"g": [4, 5, 6],
"h": [4, 5, 6],
"i": [4, 5, 6],
"j": [4, 5, 6],
"k": ["fdafsd", "fdas", "ad"],
"l": ["fdafsd", "fdas", "ad"],
"m": [date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)],
"n": [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1)],
"o": [True, True, False],
},
schema={
"a": pl.Int64,
"b": pl.Int32,
"c": pl.Int16,
"d": pl.Int8,
"e": pl.UInt64,
"f": pl.UInt32,
"g": pl.UInt16,
"h": pl.UInt8,
"i": pl.Float64,
"j": pl.Float32,
"k": pl.String,
"l": pl.Categorical,
"m": pl.Date,
"n": pl.Datetime,
"o": pl.Boolean,
},
)
tbl = ibis.memtable(df_pl)
df = nw.from_native(tbl, eager_or_interchange_only=True)
result = df.schema
expected = {
"a": nw.Int64,
"b": nw.Int32,
"c": nw.Int16,
"d": nw.Int8,
"e": nw.UInt64,
"f": nw.UInt32,
"g": nw.UInt16,
"h": nw.UInt8,
"i": nw.Float64,
"j": nw.Float32,
"k": nw.String,
"l": nw.String,
"m": nw.Date,
"n": nw.Datetime,
"o": nw.Boolean,
}
assert result == expected
assert df["a"].dtype == nw.Int64


def test_invalid() -> None:
df = pl.DataFrame({"a": [1, 2, 3]}).__dataframe__()
with pytest.raises(
Expand Down
6 changes: 6 additions & 0 deletions tests/no_imports_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delitem(sys.modules, "numpy")
monkeypatch.delitem(sys.modules, "pyarrow")
monkeypatch.delitem(sys.modules, "dask", raising=False)
monkeypatch.delitem(sys.modules, "ibis", raising=False)
df = pl.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})
nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter(
nw.col("a") > 1
Expand All @@ -22,12 +23,14 @@ def test_polars(monkeypatch: pytest.MonkeyPatch) -> None:
assert "numpy" not in sys.modules
assert "pyarrow" not in sys.modules
assert "dask" not in sys.modules
assert "ibis" not in sys.modules


def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delitem(sys.modules, "polars")
monkeypatch.delitem(sys.modules, "pyarrow")
monkeypatch.delitem(sys.modules, "dask", raising=False)
monkeypatch.delitem(sys.modules, "ibis", raising=False)
df = pd.DataFrame({"a": [1, 1, 2], "b": [4, 5, 6]})
nw.from_native(df, eager_only=True).group_by("a").agg(nw.col("b").mean()).filter(
nw.col("a") > 1
Expand All @@ -37,6 +40,7 @@ def test_pandas(monkeypatch: pytest.MonkeyPatch) -> None:
assert "numpy" in sys.modules
assert "pyarrow" not in sys.modules
assert "dask" not in sys.modules
assert "ibis" not in sys.modules


def test_dask(monkeypatch: pytest.MonkeyPatch) -> None:
Expand All @@ -59,10 +63,12 @@ def test_pyarrow(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delitem(sys.modules, "polars")
monkeypatch.delitem(sys.modules, "pandas")
monkeypatch.delitem(sys.modules, "dask", raising=False)
monkeypatch.delitem(sys.modules, "ibis", raising=False)
df = pa.table({"a": [1, 2, 3], "b": [4, 5, 6]})
nw.from_native(df).group_by("a").agg(nw.col("b").mean()).filter(nw.col("a") > 1)
assert "polars" not in sys.modules
assert "pandas" not in sys.modules
assert "numpy" in sys.modules
assert "pyarrow" in sys.modules
assert "dask" not in sys.modules
assert "ibis" not in sys.modules

0 comments on commit 3d246b7

Please sign in to comment.