diff --git a/narwhals/_duckdb/dataframe.py b/narwhals/_duckdb/dataframe.py index 76ff68ae0..2ff0e085a 100644 --- a/narwhals/_duckdb/dataframe.py +++ b/narwhals/_duckdb/dataframe.py @@ -225,7 +225,7 @@ def join( if isinstance(right_on, str): right_on = [right_on] - if how not in ("inner", "left"): + if how not in ("inner", "left", "semi"): msg = "Only inner and left join is implemented for DuckDB" raise NotImplementedError(msg) @@ -242,12 +242,15 @@ def join( other._native_frame.set_alias("rhs"), condition=condition, how=how ) - select = [f"lhs.{x}" for x in self._native_frame.columns] - for col in other._native_frame.columns: - if col in self._native_frame.columns and col not in right_on: - select.append(f"rhs.{col} as {col}{suffix}") - elif col not in right_on: - select.append(col) + if how in ("inner", "left"): + select = [f"lhs.{x}" for x in self._native_frame.columns] + for col in other._native_frame.columns: + if col in self._native_frame.columns and col not in right_on: + select.append(f"rhs.{col} as {col}{suffix}") + elif col not in right_on: + select.append(col) + else: # semi + select = [f"lhs.{x}" for x in self._native_frame.columns] res = rel.select(", ".join(select)).set_alias(original_alias) return self._from_native_frame(res) diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index 7332cb254..242696394 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -206,10 +206,7 @@ def test_semi_join( join_key: list[str], filter_expr: nw.Expr, expected: dict[str, list[Any]], - request: pytest.FixtureRequest, ) -> None: - if "duckdb" in str(constructor): - request.applymarker(pytest.mark.xfail) data = {"antananarivo": [1, 3, 2], "bob": [4, 4, 6], "zorro": [7.0, 8, 9]} df = nw.from_native(constructor(data)) other = df.filter(filter_expr) diff --git a/tpch/execute.py b/tpch/execute.py index f2f3041df..ea4cc3a8a 100644 --- a/tpch/execute.py +++ b/tpch/execute.py @@ -40,7 +40,7 @@ "dask": lambda x: x.compute(), } -DUCKDB_XFAILS = ["q11", "q14", "q15", "q18", "q22"] +DUCKDB_SKIPS = ["q11", "q14", "q15", "q22"] QUERY_DATA_PATH_MAP = { "q1": (LINEITEM_PATH,), @@ -95,7 +95,7 @@ def execute_query(query_id: str) -> None: data_paths = QUERY_DATA_PATH_MAP[query_id] for backend, (native_namespace, kwargs) in BACKEND_NAMESPACE_KWARGS_MAP.items(): - if backend == "duckdb" and query_id in DUCKDB_XFAILS: + if backend == "duckdb" and query_id in DUCKDB_SKIPS: print(f"\nSkipping {query_id} for DuckDB") # noqa: T201 continue