Skip to content

Commit

Permalink
feat: semi-join for duckdb (#1767)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Jan 9, 2025
1 parent 36dacf9 commit 145e4de
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 12 deletions.
17 changes: 10 additions & 7 deletions narwhals/_duckdb/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def join(
if isinstance(right_on, str):
right_on = [right_on]

if how not in ("inner", "left"):
if how not in ("inner", "left", "semi"):
msg = "Only inner and left join is implemented for DuckDB"
raise NotImplementedError(msg)

Expand All @@ -242,12 +242,15 @@ def join(
other._native_frame.set_alias("rhs"), condition=condition, how=how
)

select = [f"lhs.{x}" for x in self._native_frame.columns]
for col in other._native_frame.columns:
if col in self._native_frame.columns and col not in right_on:
select.append(f"rhs.{col} as {col}{suffix}")
elif col not in right_on:
select.append(col)
if how in ("inner", "left"):
select = [f"lhs.{x}" for x in self._native_frame.columns]
for col in other._native_frame.columns:
if col in self._native_frame.columns and col not in right_on:
select.append(f"rhs.{col} as {col}{suffix}")
elif col not in right_on:
select.append(col)
else: # semi
select = [f"lhs.{x}" for x in self._native_frame.columns]

res = rel.select(", ".join(select)).set_alias(original_alias)
return self._from_native_frame(res)
Expand Down
3 changes: 0 additions & 3 deletions tests/frame/join_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,7 @@ def test_semi_join(
join_key: list[str],
filter_expr: nw.Expr,
expected: dict[str, list[Any]],
request: pytest.FixtureRequest,
) -> None:
if "duckdb" in str(constructor):
request.applymarker(pytest.mark.xfail)
data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]}
df = nw.from_native(constructor(data))
other = df.filter(filter_expr)
Expand Down
4 changes: 2 additions & 2 deletions tpch/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
"dask": lambda x: x.compute(),
}

DUCKDB_XFAILS = ["q11", "q14", "q15", "q18", "q22"]
DUCKDB_SKIPS = ["q11", "q14", "q15", "q22"]

QUERY_DATA_PATH_MAP = {
"q1": (LINEITEM_PATH,),
Expand Down Expand Up @@ -95,7 +95,7 @@ def execute_query(query_id: str) -> None:
data_paths = QUERY_DATA_PATH_MAP[query_id]

for backend, (native_namespace, kwargs) in BACKEND_NAMESPACE_KWARGS_MAP.items():
if backend == "duckdb" and query_id in DUCKDB_XFAILS:
if backend == "duckdb" and query_id in DUCKDB_SKIPS:
print(f"\nSkipping {query_id} for DuckDB") # noqa: T201
continue

Expand Down

0 comments on commit 145e4de

Please sign in to comment.