From deee14ce4182e76041fd4daa34de8cc679a85249 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Thu, 9 Jan 2025 14:56:24 +0000 Subject: [PATCH] feat: implement cross-join for duckdb (#1773) --- .github/workflows/check_tpch_queries.yml | 2 +- .github/workflows/extremes.yml | 16 +++++---- narwhals/_duckdb/dataframe.py | 42 +++++++++++++++--------- tests/frame/join_test.py | 5 +-- tests/utils.py | 1 + tpch/execute.py | 2 +- 6 files changed, 41 insertions(+), 27 deletions(-) diff --git a/.github/workflows/check_tpch_queries.yml b/.github/workflows/check_tpch_queries.yml index 723fa6e80..ce7da6f8e 100644 --- a/.github/workflows/check_tpch_queries.yml +++ b/.github/workflows/check_tpch_queries.yml @@ -25,7 +25,7 @@ jobs: cache-suffix: ${{ matrix.python-version }} cache-dependency-glob: "pyproject.toml" - name: local-install - run: uv pip install -e ".[dev, core, dask]" --system + run: uv pip install -U --pre -e ".[dev, core, dask]" --system - name: generate-data run: cd tpch && python generate_data.py - name: tpch-tests diff --git a/.github/workflows/extremes.yml b/.github/workflows/extremes.yml index 47ebc85ea..0e7e6a205 100644 --- a/.github/workflows/extremes.yml +++ b/.github/workflows/extremes.yml @@ -158,8 +158,6 @@ jobs: run: | uv pip uninstall pyarrow --system uv pip install --extra-index-url https://pypi.fury.io/arrow-nightlies/ --pre pyarrow --system - - name: show-deps - run: uv pip freeze - name: install numpy nightly run: | uv pip uninstall numpy --system @@ -167,18 +165,22 @@ jobs: - name: install dask run: | uv pip uninstall dask dask-expr --system - python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask git+https://github.com/dask/dask-expr + python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask + - name: install duckdb + run: | + python -m pip install -U --pre duckdb - name: show-deps run: uv pip freeze - name: Assert nightlies dependencies run: | DEPS=$(uv pip freeze) - echo "$DEPS" | grep 'polars' + echo "$DEPS" | grep 'polars.*@' echo "$DEPS" | grep 'pandas.*dev' echo "$DEPS" | grep 'pyarrow.*dev' - echo "$DEPS" | grep 'numpy' - echo "$DEPS" | grep 'dask' + echo "$DEPS" | grep 'numpy.*dev' + echo "$DEPS" | grep 'dask.*@' + echo "$DEPS" | grep 'duckdb.*dev' - name: Run pytest run: | pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow \ - --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,polars[eager],polars[lazy],dask + --constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,polars[eager],polars[lazy],dask,duckdb diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 2ff0e085a..e1c0f994c 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -224,30 +224,40 @@ def join( left_on = [left_on] if isinstance(right_on, str): right_on = [right_on] + original_alias = self._native_frame.alias - if how not in ("inner", "left", "semi"): + if how not in ("inner", "left", "semi", "cross"): msg = "Only inner and left join is implemented for DuckDB" raise NotImplementedError(msg) - # help mypy - assert left_on is not None # noqa: S101 - assert right_on is not None # noqa: S101 - - conditions = [ - f"lhs.{left} = rhs.{right}" for left, right in zip(left_on, right_on) - ] - original_alias = self._native_frame.alias - condition = " and ".join(conditions) - rel = self._native_frame.set_alias("lhs").join( - other._native_frame.set_alias("rhs"), condition=condition, how=how - ) + if how == "cross": + if self._backend_version < (1, 1, 4): + msg = f"DuckDB>=1.1.4 is required for cross-join, found version: {self._backend_version}" + raise NotImplementedError(msg) + rel = self._native_frame.set_alias("lhs").cross( # pragma: no cover + other._native_frame.set_alias("rhs") + ) + else: + # help mypy + assert left_on is not None # noqa: S101 + assert right_on is not None # noqa: S101 + + conditions = [ + f"lhs.{left} = rhs.{right}" for left, right in zip(left_on, right_on) + ] + condition = " and ".join(conditions) + rel = self._native_frame.set_alias("lhs").join( + other._native_frame.set_alias("rhs"), condition=condition, how=how + ) - if how in ("inner", "left"): + if how in ("inner", "left", "cross"): select = [f"lhs.{x}" for x in self._native_frame.columns] for col in other._native_frame.columns: - if col in self._native_frame.columns and col not in right_on: + if col in self._native_frame.columns and ( + right_on is None or col not in right_on + ): select.append(f"rhs.{col} as {col}{suffix}") - elif col not in right_on: + elif right_on is None or col not in right_on: select.append(col) else: # semi select = [f"lhs.{x}" for x in self._native_frame.columns] diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 242696394..5bf5c91f0 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -10,6 +10,7 @@ import narwhals.stable.v1 as nw from narwhals.utils import Implementation +from tests.utils import DUCKDB_VERSION from tests.utils import PANDAS_VERSION from tests.utils import Constructor from tests.utils import assert_equal_data @@ -75,7 +76,7 @@ def test_inner_join_single_key(constructor: Constructor) -> None: def test_cross_join(constructor: Constructor, request: pytest.FixtureRequest) -> None: - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 1, 4): request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) @@ -117,7 +118,7 @@ def test_suffix(constructor: Constructor, how: str, suffix: str) -> None: def test_cross_join_suffix( constructor: Constructor, suffix: str, request: pytest.FixtureRequest ) -> None: - if "duckdb" in str(constructor): + if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 1, 4): request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2]} df = nw.from_native(constructor(data)) diff --git a/tests/utils.py b/tests/utils.py index 005b4eee2..2d41d6782 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -33,6 +33,7 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]: IBIS_VERSION: tuple[int, ...] = get_module_version_as_tuple("ibis") NUMPY_VERSION: tuple[int, ...] = get_module_version_as_tuple("numpy") PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas") +DUCKDB_VERSION: tuple[int, ...] = get_module_version_as_tuple("duckdb") POLARS_VERSION: tuple[int, ...] = get_module_version_as_tuple("polars") DASK_VERSION: tuple[int, ...] = get_module_version_as_tuple("dask") PYARROW_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyarrow") diff --git a/tpch/execute.py b/tpch/execute.py index ea4cc3a8a..5209ad48e 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_SKIPS = ["q11", "q14", "q15", "q22"] +DUCKDB_SKIPS = ["q14", "q15"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,),