Skip to content

Commit

Permalink
feat: Duckdb interchange (#902)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Sep 3, 2024
1 parent 3d246b7 commit 780e2bc
Show file tree
Hide file tree
Showing 7 changed files with 199 additions and 0 deletions.
Empty file added narwhals/_duckdb/__init__.py
Empty file.
80 changes: 80 additions & 0 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any

from narwhals import dtypes

if TYPE_CHECKING:
from narwhals._duckdb.series import DuckDBInterchangeSeries


def map_duckdb_dtype_to_narwhals_dtype(
duckdb_dtype: Any,
) -> dtypes.DType:
if duckdb_dtype == "BIGINT":
return dtypes.Int64()
if duckdb_dtype == "INTEGER":
return dtypes.Int32()
if duckdb_dtype == "SMALLINT":
return dtypes.Int16()
if duckdb_dtype == "TINYINT":
return dtypes.Int8()
if duckdb_dtype == "UBIGINT":
return dtypes.UInt64()
if duckdb_dtype == "UINTEGER":
return dtypes.UInt32()
if duckdb_dtype == "USMALLINT":
return dtypes.UInt16()
if duckdb_dtype == "UTINYINT":
return dtypes.UInt8()
if duckdb_dtype == "DOUBLE":
return dtypes.Float64()
if duckdb_dtype == "FLOAT":
return dtypes.Float32()
if duckdb_dtype == "VARCHAR":
return dtypes.String()
if duckdb_dtype == "DATE":
return dtypes.Date()
if duckdb_dtype == "TIMESTAMP":
return dtypes.Datetime()
if duckdb_dtype == "BOOLEAN":
return dtypes.Boolean()
if duckdb_dtype == "INTERVAL":
return dtypes.Duration()
msg = ( # pragma: no cover
f"Invalid dtype, got: {duckdb_dtype}.\n\n"
"If you believe this dtype should be supported in Narwhals, "
"please report an issue at https://github.com/narwhals-dev/narwhals."
)
raise AssertionError(msg)


class DuckDBInterchangeFrame:
def __init__(self, df: Any) -> None:
self._native_frame = df

def __narwhals_dataframe__(self) -> Any:
return self

def __getitem__(self, item: str) -> DuckDBInterchangeSeries:
from narwhals._duckdb.series import DuckDBInterchangeSeries

return DuckDBInterchangeSeries(self._native_frame.select(item))

def __getattr__(self, attr: str) -> Any:
if attr == "schema":
return {
column_name: map_duckdb_dtype_to_narwhals_dtype(duckdb_dtype)
for column_name, duckdb_dtype in zip(
self._native_frame.columns, self._native_frame.types
)
}

msg = ( # pragma: no cover
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg) # pragma: no cover
24 changes: 24 additions & 0 deletions narwhals/_duckdb/series.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Any

from narwhals._duckdb.dataframe import map_duckdb_dtype_to_narwhals_dtype


class DuckDBInterchangeSeries:
def __init__(self, df: Any) -> None:
self._native_series = df

def __narwhals_series__(self) -> Any:
return self

def __getattr__(self, attr: str) -> Any:
if attr == "dtype":
return map_duckdb_dtype_to_narwhals_dtype(self._native_series.types[0])
msg = ( # pragma: no cover
f"Attribute {attr} is not supported for metadata-only dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg) # pragma: no cover
13 changes: 13 additions & 0 deletions narwhals/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing_extensions import TypeGuard
import cudf
import dask.dataframe as dd
import duckdb
import ibis
import modin.pandas as mpd
import pandas as pd
Expand Down Expand Up @@ -65,6 +66,11 @@ def get_dask_dataframe() -> Any:
return sys.modules.get("dask.dataframe", None)


def get_duckdb() -> Any:
"""Get duckdb module (if already imported - else return None)."""
return sys.modules.get("duckdb", None)


def get_dask_expr() -> Any:
"""Get dask_expr module (if already imported - else return None)."""
return sys.modules.get("dask_expr", None)
Expand Down Expand Up @@ -110,6 +116,13 @@ def is_dask_dataframe(df: Any) -> TypeGuard[dd.DataFrame]:
return bool((dd := get_dask_dataframe()) is not None and isinstance(df, dd.DataFrame))


def is_duckdb_relation(df: Any) -> TypeGuard[duckdb.DuckDBPyRelation]:
"""Check whether `df` is a DuckDB Relation without importing DuckDB."""
return bool(
(duckdb := get_duckdb()) is not None and isinstance(df, duckdb.DuckDBPyRelation)
)


def is_ibis_table(df: Any) -> TypeGuard[ibis.Table]:
"""Check whether `df` is a Ibis Table without importing Ibis."""
return bool((ibis := get_ibis()) is not None and isinstance(df, ibis.Table))
Expand Down
15 changes: 15 additions & 0 deletions narwhals/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from narwhals.dependencies import is_cudf_dataframe
from narwhals.dependencies import is_cudf_series
from narwhals.dependencies import is_dask_dataframe
from narwhals.dependencies import is_duckdb_relation
from narwhals.dependencies import is_ibis_table
from narwhals.dependencies import is_modin_dataframe
from narwhals.dependencies import is_modin_series
Expand Down Expand Up @@ -332,6 +333,7 @@ def from_native( # noqa: PLR0915
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.series import ArrowSeries
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._duckdb.dataframe import DuckDBInterchangeFrame
from narwhals._ibis.dataframe import IbisInterchangeFrame
from narwhals._interchange.dataframe import InterchangeFrame
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
Expand Down Expand Up @@ -548,6 +550,19 @@ def from_native( # noqa: PLR0915
level="full",
)

# DuckDB
elif is_duckdb_relation(native_object):
if eager_only or series_only: # pragma: no cover
msg = (
"Cannot only use `series_only=True` or `eager_only=False` "
"with DuckDB Relation"
)
raise TypeError(msg)
return DataFrame(
DuckDBInterchangeFrame(native_object),
level="interchange",
)

# Ibis
elif is_ibis_table(native_object): # pragma: no cover
if eager_only or series_only:
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
covdefaults
duckdb
pandas
polars
pre-commit
Expand Down
66 changes: 66 additions & 0 deletions tests/frame/interchange_schema_test.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from datetime import date
from datetime import datetime
from datetime import timedelta

import duckdb
import polars as pl
import pytest

Expand Down Expand Up @@ -126,6 +128,70 @@ def test_interchange_schema_ibis() -> None: # pragma: no cover
assert df["a"].dtype == nw.Int64


def test_interchange_schema_duckdb() -> None:
df_pl = pl.DataFrame( # noqa: F841
{
"a": [1, 1, 2],
"b": [4, 5, 6],
"c": [4, 5, 6],
"d": [4, 5, 6],
"e": [4, 5, 6],
"f": [4, 5, 6],
"g": [4, 5, 6],
"h": [4, 5, 6],
"i": [4, 5, 6],
"j": [4, 5, 6],
"k": ["fdafsd", "fdas", "ad"],
"l": ["fdafsd", "fdas", "ad"],
"m": [date(2021, 1, 1), date(2021, 1, 1), date(2021, 1, 1)],
"n": [datetime(2021, 1, 1), datetime(2021, 1, 1), datetime(2021, 1, 1)],
"o": [timedelta(1)] * 3,
"p": [True, True, False],
},
schema={
"a": pl.Int64,
"b": pl.Int32,
"c": pl.Int16,
"d": pl.Int8,
"e": pl.UInt64,
"f": pl.UInt32,
"g": pl.UInt16,
"h": pl.UInt8,
"i": pl.Float64,
"j": pl.Float32,
"k": pl.String,
"l": pl.Categorical,
"m": pl.Date,
"n": pl.Datetime,
"o": pl.Duration,
"p": pl.Boolean,
},
)
rel = duckdb.sql("select * from df_pl")
df = nw.from_native(rel, eager_or_interchange_only=True)
result = df.schema
expected = {
"a": nw.Int64,
"b": nw.Int32,
"c": nw.Int16,
"d": nw.Int8,
"e": nw.UInt64,
"f": nw.UInt32,
"g": nw.UInt16,
"h": nw.UInt8,
"i": nw.Float64,
"j": nw.Float32,
"k": nw.String,
"l": nw.String,
"m": nw.Date,
"n": nw.Datetime,
"o": nw.Duration,
"p": nw.Boolean,
}
assert result == expected
assert df["a"].dtype == nw.Int64


def test_invalid() -> None:
df = pl.DataFrame({"a": [1, 2, 3]}).__dataframe__()
with pytest.raises(
Expand Down

0 comments on commit 780e2bc

Please sign in to comment.