Skip to content

Commit

Permalink
feat: implement cross-join for duckdb (#1773)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 9, 2025
1 parent 8f0cf50 commit deee14c
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 27 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_tpch_queries.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "pyproject.toml"
- name: local-install
run: uv pip install -e ".[dev, core, dask]" --system
run: uv pip install -U --pre -e ".[dev, core, dask]" --system
- name: generate-data
run: cd tpch && python generate_data.py
- name: tpch-tests
Expand Down
16 changes: 9 additions & 7 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -158,27 +158,29 @@ jobs:
run: |
uv pip uninstall pyarrow --system
uv pip install --extra-index-url https://pypi.fury.io/arrow-nightlies/ --pre pyarrow --system
- name: show-deps
run: uv pip freeze
- name: install numpy nightly
run: |
uv pip uninstall numpy --system
uv pip install --pre --extra-index-url https://pypi.anaconda.org/scientific-python-nightly-wheels/simple numpy --system
- name: install dask
run: |
uv pip uninstall dask dask-expr --system
python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask git+https://github.com/dask/dask-expr
python -m pip install git+https://github.com/dask/distributed git+https://github.com/dask/dask
- name: install duckdb
run: |
python -m pip install -U --pre duckdb
- name: show-deps
run: uv pip freeze
- name: Assert nightlies dependencies
run: |
DEPS=$(uv pip freeze)
echo "$DEPS" | grep 'polars'
echo "$DEPS" | grep 'polars.*@'
echo "$DEPS" | grep 'pandas.*dev'
echo "$DEPS" | grep 'pyarrow.*dev'
echo "$DEPS" | grep 'numpy'
echo "$DEPS" | grep 'dask'
echo "$DEPS" | grep 'numpy.*dev'
echo "$DEPS" | grep 'dask.*@'
echo "$DEPS" | grep 'duckdb.*dev'
- name: Run pytest
run: |
pytest tests --cov=narwhals --cov=tests --cov-fail-under=50 --runslow \
--constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,polars[eager],polars[lazy],dask
--constructors=pandas,pandas[nullable],pandas[pyarrow],pyarrow,polars[eager],polars[lazy],dask,duckdb
42 changes: 26 additions & 16 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,30 +224,40 @@ def join(
left_on = [left_on]
if isinstance(right_on, str):
right_on = [right_on]
original_alias = self._native_frame.alias

if how not in ("inner", "left", "semi"):
if how not in ("inner", "left", "semi", "cross"):
msg = "Only inner and left join is implemented for DuckDB"
raise NotImplementedError(msg)

# help mypy
assert left_on is not None # noqa: S101
assert right_on is not None # noqa: S101

conditions = [
f"lhs.{left} = rhs.{right}" for left, right in zip(left_on, right_on)
]
original_alias = self._native_frame.alias
condition = " and ".join(conditions)
rel = self._native_frame.set_alias("lhs").join(
other._native_frame.set_alias("rhs"), condition=condition, how=how
)
if how == "cross":
if self._backend_version < (1, 1, 4):
msg = f"DuckDB>=1.1.4 is required for cross-join, found version: {self._backend_version}"
raise NotImplementedError(msg)
rel = self._native_frame.set_alias("lhs").cross( # pragma: no cover
other._native_frame.set_alias("rhs")
)
else:
# help mypy
assert left_on is not None # noqa: S101
assert right_on is not None # noqa: S101

conditions = [
f"lhs.{left} = rhs.{right}" for left, right in zip(left_on, right_on)
]
condition = " and ".join(conditions)
rel = self._native_frame.set_alias("lhs").join(
other._native_frame.set_alias("rhs"), condition=condition, how=how
)

if how in ("inner", "left"):
if how in ("inner", "left", "cross"):
select = [f"lhs.{x}" for x in self._native_frame.columns]
for col in other._native_frame.columns:
if col in self._native_frame.columns and col not in right_on:
if col in self._native_frame.columns and (
right_on is None or col not in right_on
):
select.append(f"rhs.{col} as {col}{suffix}")
elif col not in right_on:
elif right_on is None or col not in right_on:
select.append(col)
else: # semi
select = [f"lhs.{x}" for x in self._native_frame.columns]
Expand Down
5 changes: 3 additions & 2 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import narwhals.stable.v1 as nw
from narwhals.utils import Implementation
from tests.utils import DUCKDB_VERSION
from tests.utils import PANDAS_VERSION
from tests.utils import Constructor
from tests.utils import assert_equal_data
Expand Down Expand Up @@ -75,7 +76,7 @@ def test_inner_join_single_key(constructor: Constructor) -> None:


def test_cross_join(constructor: Constructor, request: pytest.FixtureRequest) -> None:
if "duckdb" in str(constructor):
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 1, 4):
request.applymarker(pytest.mark.xfail)
data = {"antananarivo": [1, 3, 2]}
df = nw.from_native(constructor(data))
Expand Down Expand Up @@ -117,7 +118,7 @@ def test_suffix(constructor: Constructor, how: str, suffix: str) -> None:
def test_cross_join_suffix(
constructor: Constructor, suffix: str, request: pytest.FixtureRequest
) -> None:
if "duckdb" in str(constructor):
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 1, 4):
request.applymarker(pytest.mark.xfail)
data = {"antananarivo": [1, 3, 2]}
df = nw.from_native(constructor(data))
Expand Down
1 change: 1 addition & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def get_module_version_as_tuple(module_name: str) -> tuple[int, ...]:
IBIS_VERSION: tuple[int, ...] = get_module_version_as_tuple("ibis")
NUMPY_VERSION: tuple[int, ...] = get_module_version_as_tuple("numpy")
PANDAS_VERSION: tuple[int, ...] = get_module_version_as_tuple("pandas")
DUCKDB_VERSION: tuple[int, ...] = get_module_version_as_tuple("duckdb")
POLARS_VERSION: tuple[int, ...] = get_module_version_as_tuple("polars")
DASK_VERSION: tuple[int, ...] = get_module_version_as_tuple("dask")
PYARROW_VERSION: tuple[int, ...] = get_module_version_as_tuple("pyarrow")
Expand Down
2 changes: 1 addition & 1 deletion tpch/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"dask": lambda x: x.compute(),
}

DUCKDB_SKIPS = ["q11", "q14", "q15", "q22"]
DUCKDB_SKIPS = ["q14", "q15"]

QUERY_DATA_PATH_MAP = {
"q1": (LINEITEM_PATH,),
Expand Down

0 comments on commit deee14c

Please sign in to comment.