Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add minimal PySpark support #908

Merged
merged 98 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 55 commits
Commits
Show all changes
98 commits
Select commit Hold shift + click to select a range
72c1b49
first pyspark draft
EdAbati Sep 3, 2024
e67140a
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Sep 3, 2024
3316460
added schema
EdAbati Sep 4, 2024
12f62c1
add methods needed for compliant types
EdAbati Sep 4, 2024
2b114eb
fix all_horizontal
EdAbati Sep 7, 2024
378b421
add xfail to some tests
EdAbati Sep 8, 2024
b5957dc
draft with sql
EdAbati Sep 8, 2024
9f8f944
merge upstream
EdAbati Sep 10, 2024
b2aee0e
making all frame tests pass
EdAbati Sep 11, 2024
0e4b2f2
group by
EdAbati Sep 12, 2024
741cdde
skipping tests
EdAbati Sep 12, 2024
2bdfe31
restore type
EdAbati Sep 12, 2024
c0b1a18
smaller diff + mypy fix
EdAbati Sep 12, 2024
ec0b26f
remove print
EdAbati Sep 12, 2024
32b87a3
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Sep 12, 2024
a053b07
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Sep 12, 2024
a415bd0
smaller diff
EdAbati Sep 12, 2024
6065eb2
reenable pyspark
EdAbati Sep 12, 2024
1688f7d
count without window
EdAbati Oct 6, 2024
191dcb7
revert expr series tests
EdAbati Oct 6, 2024
41368ef
revert rest of tests
EdAbati Oct 6, 2024
b0dffad
placeholder pyspark test
EdAbati Oct 6, 2024
37ecc70
merge main
EdAbati Oct 6, 2024
1c76b0b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Oct 6, 2024
9802fdc
moved test_column
EdAbati Oct 6, 2024
267f2ff
moved select filter and with_columns
EdAbati Oct 6, 2024
8adee30
add schema head sort tests
EdAbati Oct 6, 2024
9186687
add test add
EdAbati Oct 6, 2024
38d326d
fix rename
EdAbati Oct 8, 2024
223ea88
added more tests
EdAbati Oct 8, 2024
3337014
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Oct 13, 2024
d7b2752
fix all_horizontal
EdAbati Oct 13, 2024
95b8395
fixing all tests πŸŽ‰πŸŽ‰
EdAbati Oct 13, 2024
734c140
rename test
EdAbati Oct 13, 2024
1b9a7e7
add backend_version
EdAbati Oct 13, 2024
9d326a4
added group by tests
EdAbati Oct 13, 2024
1a2e804
add pyspark in requirement dev
EdAbati Oct 13, 2024
411f67d
use pyspark.sql to create empty df
EdAbati Oct 13, 2024
3a59240
stddev for older pyspark
EdAbati Oct 13, 2024
08120da
coverage up
EdAbati Oct 13, 2024
177ec5e
min pyspark version test
EdAbati Oct 14, 2024
77e6687
fix for pyspark 3.2
EdAbati Oct 14, 2024
9ccab80
pyspark 3.3 as minimum
EdAbati Oct 14, 2024
ef1944c
trying debugging windows
EdAbati Oct 14, 2024
a8b228f
no test pyspark with pandas <1.0.5
EdAbati Oct 14, 2024
c74772d
removing debug windows
EdAbati Oct 14, 2024
dd0dd39
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Oct 14, 2024
d00a2da
testing 3.3.0
EdAbati Oct 14, 2024
6b25971
trying with repartition 2
EdAbati Oct 14, 2024
3713a6d
remove unused data
EdAbati Oct 15, 2024
eb0a2ce
trying to fix sorting problems in tests
EdAbati Oct 15, 2024
df1a37f
no pyspark in minimum_versions
EdAbati Oct 15, 2024
ce503fa
trying to make windows happy
EdAbati Oct 15, 2024
94656b3
fix repartition
EdAbati Oct 15, 2024
33739de
exclude pyspark for python 3.12
EdAbati Oct 15, 2024
5808d71
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Oct 27, 2024
5d4b02f
use assert_equal_data
EdAbati Oct 27, 2024
92617f1
only use self._native_frame.sparkSession
EdAbati Oct 27, 2024
5733069
add drop_null_keys in groupby
EdAbati Oct 27, 2024
9b6c4e0
rename _spark
EdAbati Oct 27, 2024
e2344c7
rename spark_test
EdAbati Oct 27, 2024
bb1de48
use PYSPARK_VERSION
EdAbati Oct 28, 2024
36d0886
rename PySpark... classes to Spark...
EdAbati Oct 28, 2024
24676d0
_ in func signature
EdAbati Oct 28, 2024
3defa39
make coverage happy
EdAbati Oct 28, 2024
a8946f2
exception public
EdAbati Nov 17, 2024
86c459d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 17, 2024
dc7fb71
fix docs
EdAbati Nov 17, 2024
d720b58
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Nov 17, 2024
a1141f7
rename to _spark_like
EdAbati Nov 17, 2024
94b6777
rename exceptions
EdAbati Nov 17, 2024
10c1b11
update coverage to ignore `_spark_like`
EdAbati Nov 17, 2024
5193bca
better comment
EdAbati Nov 18, 2024
7b513f4
invalidintoexpr error
EdAbati Nov 18, 2024
f25969b
fix pytest warning error
EdAbati Nov 21, 2024
08987cf
small comment
EdAbati Nov 21, 2024
0fb0478
fix F.std for ddof more than 1
EdAbati Nov 21, 2024
50e2e4d
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Nov 21, 2024
e86fc5c
fix stddev imports for py <3.5
EdAbati Nov 21, 2024
522a1aa
use F
EdAbati Nov 21, 2024
84d5b6a
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Dec 3, 2024
b9c21df
update to latest changes
EdAbati Dec 3, 2024
bb82020
add implementation to expr
EdAbati Dec 3, 2024
010a362
rename SparkLike...
EdAbati Dec 3, 2024
dac5901
rename native_to_narwhals_dtype
EdAbati Dec 3, 2024
6d67b0c
dtype unknown for decimal
EdAbati Dec 3, 2024
15ca58e
simplify return unknown
EdAbati Dec 3, 2024
ce4e2fb
update no_imports_tests
EdAbati Dec 4, 2024
d841ec5
level lazy for spark
EdAbati Dec 4, 2024
ac68a7e
add _change_dtypes
EdAbati Dec 4, 2024
9a1f741
Merge remote-tracking branch 'upstream/main' into pyspark
EdAbati Dec 4, 2024
2121c40
_change_version is back
EdAbati Dec 4, 2024
c0f44b6
fix no imports tests
EdAbati Dec 4, 2024
4b7895f
rename spark_like tests
EdAbati Dec 4, 2024
638c402
same error message as dask
EdAbati Dec 4, 2024
b46f1b5
remove extra expr._call
EdAbati Dec 5, 2024
a3e3dba
update coverage
EdAbati Dec 5, 2024
d8e6064
extract _columns_from_expr
EdAbati Dec 5, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions .github/workflows/extremes.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ jobs:
- name: install-minimum-versions
run: uv pip install tox virtualenv setuptools pandas==0.25.3 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
run: |
uv pip install -r requirements-dev.txt --system
: # pyspark >= 3.3.0 is not compatible with pandas==0.25.3
uv pip uninstall pyspark --system
- name: show-deps
run: uv pip freeze
- name: Run pytest
Expand All @@ -52,7 +55,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "**requirements*.txt"
- name: install-minimum-versions
run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==1.1.5 polars==0.20.3 numpy==1.17.5 pyarrow==11.0.0 pyspark==3.3.0 scipy==1.5.0 scikit-learn==1.1.0 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
- name: show-deps
Expand Down Expand Up @@ -81,7 +84,7 @@ jobs:
cache-suffix: ${{ matrix.python-version }}
cache-dependency-glob: "**requirements*.txt"
- name: install-minimum-versions
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system
run: uv pip install tox virtualenv setuptools pandas==2.0.3 polars==0.20.8 numpy==1.24.4 pyarrow==14.0.0 pyspark==3.4.0 scipy==1.8.0 scikit-learn==1.3.0 dask[dataframe]==2024.7 tzdata --system
- name: install-reqs
run: uv pip install -r requirements-dev.txt --system
- name: show-deps
Expand Down
32 changes: 27 additions & 5 deletions narwhals/_expression_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,21 @@
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.typing import IntoPolarsExpr
from narwhals._pyspark.dataframe import PySparkLazyFrame
from narwhals._pyspark.expr import PySparkExpr
from narwhals._pyspark.namespace import PySparkNamespace
from narwhals._pyspark.typing import IntoPySparkExpr
EdAbati marked this conversation as resolved.
Show resolved Hide resolved

CompliantNamespace = Union[
PandasLikeNamespace, ArrowNamespace, DaskNamespace, PolarsNamespace
PandasLikeNamespace,
ArrowNamespace,
DaskNamespace,
PolarsNamespace,
PySparkNamespace,
]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr]
CompliantExpr = Union[PandasLikeExpr, ArrowExpr, DaskExpr, PolarsExpr, PySparkExpr]
IntoCompliantExpr = Union[
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr
IntoPandasLikeExpr, IntoArrowExpr, IntoDaskExpr, IntoPolarsExpr, IntoPySparkExpr
]
IntoCompliantExprT = TypeVar("IntoCompliantExprT", bound=IntoCompliantExpr)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExpr)
Expand All @@ -48,9 +56,15 @@
list[PandasLikeSeries], list[ArrowSeries], list[DaskExpr], list[PolarsSeries]
]
ListOfCompliantExpr = Union[
list[PandasLikeExpr], list[ArrowExpr], list[DaskExpr], list[PolarsExpr]
list[PandasLikeExpr],
list[ArrowExpr],
list[DaskExpr],
list[PolarsExpr],
list[PySparkExpr],
]
CompliantDataFrame = Union[
PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame, PySparkLazyFrame
]
CompliantDataFrame = Union[PandasLikeDataFrame, ArrowDataFrame, DaskLazyFrame]

T = TypeVar("T")

Expand Down Expand Up @@ -150,6 +164,14 @@ def parse_into_exprs(
) -> list[PolarsExpr]: ...


@overload
def parse_into_exprs(
*exprs: IntoPySparkExpr,
namespace: PySparkNamespace,
**named_exprs: IntoPySparkExpr,
) -> list[PySparkExpr]: ...


def parse_into_exprs(
*exprs: IntoCompliantExpr,
namespace: CompliantNamespace,
Expand Down
Empty file added narwhals/_pyspark/__init__.py
Empty file.
183 changes: 183 additions & 0 deletions narwhals/_pyspark/dataframe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Any
from typing import Iterable
from typing import Sequence

from narwhals._pyspark.utils import parse_exprs_and_named_exprs
from narwhals._pyspark.utils import translate_sql_api_dtype
from narwhals.utils import Implementation
from narwhals.utils import flatten
from narwhals.utils import parse_columns_to_drop
from narwhals.utils import parse_version

if TYPE_CHECKING:
from pyspark.sql import DataFrame
from typing_extensions import Self

from narwhals._pyspark.expr import PySparkExpr
from narwhals._pyspark.group_by import PySparkLazyGroupBy
from narwhals._pyspark.namespace import PySparkNamespace
from narwhals._pyspark.typing import IntoPySparkExpr
from narwhals.dtypes import DType
from narwhals.typing import DTypes


class PySparkLazyFrame:
def __init__(
self,
native_dataframe: DataFrame,
*,
backend_version: tuple[int, ...],
dtypes: DTypes,
) -> None:
self._native_frame = native_dataframe
self._backend_version = backend_version
self._implementation = Implementation.PYSPARK
self._dtypes = dtypes

def __native_namespace__(self) -> Any: # pragma: no cover
if self._implementation is Implementation.PYSPARK:
return self._implementation.to_native_namespace()

msg = f"Expected pyspark, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)

def __narwhals_namespace__(self) -> PySparkNamespace:
from narwhals._pyspark.namespace import PySparkNamespace

return PySparkNamespace(
backend_version=self._backend_version, dtypes=self._dtypes
)

def __narwhals_lazyframe__(self) -> Self:
return self

def _from_native_frame(self, df: DataFrame) -> Self:
return self.__class__(
df, backend_version=self._backend_version, dtypes=self._dtypes
)

@property
def columns(self) -> list[str]:
return self._native_frame.columns # type: ignore[no-any-return]

def collect(self) -> Any:
import pandas as pd # ignore-banned-import()

from narwhals._pandas_like.dataframe import PandasLikeDataFrame

return PandasLikeDataFrame(
native_dataframe=self._native_frame.toPandas(),
implementation=Implementation.PANDAS,
backend_version=parse_version(pd.__version__),
dtypes=self._dtypes,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A similar discussion was happening when I opened #1042 with Marco's concern for how to collect duckdb.
My opinion is that we should let the use decide to which eager backend collect (maybe we one as default).

Now I am not using pyspark in a couple of years, but if pandas is not a dependency, then this collect may also fail.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pandas (as well as pyarrow and numpy) is still a dependency of PySpark https://github.com/apache/spark/blob/master/python/packaging/classic/setup.py#L350-L359

toPandas() will not fail, but I agree that it would be nice to give the user a choice.

To me, this sounds like a followup what do you think?


def select(
self: Self,
*exprs: IntoPySparkExpr,
**named_exprs: IntoPySparkExpr,
) -> Self:
if exprs and all(isinstance(x, str) for x in exprs) and not named_exprs:
# This is a simple select
return self._from_native_frame(self._native_frame.select(*exprs))

new_columns = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)

if not new_columns:
# return empty dataframe, like Polars does
from pyspark.sql.types import StructType

if self._backend_version >= (3, 3, 0):
spark_session = self._native_frame.sparkSession
else: # pragma: no cover
from pyspark.sql import SparkSession

spark_session = SparkSession.builder.getOrCreate()
EdAbati marked this conversation as resolved.
Show resolved Hide resolved

spark_df = spark_session.createDataFrame([], StructType([]))

return self._from_native_frame(spark_df)

new_columns_list = [col.alias(col_name) for col_name, col in new_columns.items()]
return self._from_native_frame(self._native_frame.select(*new_columns_list))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice πŸ‘Œ how do aggregations/reductions behave?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make an example of what you mean?

are you referring to something like df.select(nw.max("a")) ?


def filter(self, *predicates: PySparkExpr) -> Self:
from narwhals._pyspark.namespace import PySparkNamespace

if (
len(predicates) == 1
and isinstance(predicates[0], list)
and all(isinstance(x, bool) for x in predicates[0])
):
msg = "`LazyFrame.filter` is not supported for PySpark backend with boolean masks."
raise NotImplementedError(msg)
plx = PySparkNamespace(backend_version=self._backend_version, dtypes=self._dtypes)
expr = plx.all_horizontal(*predicates)
# Safety: all_horizontal's expression only returns a single column.
condition = expr._call(self)[0]
spark_df = self._native_frame.where(condition)
return self._from_native_frame(spark_df)

@property
def schema(self) -> dict[str, DType]:
return {
field.name: translate_sql_api_dtype(field.dataType)
for field in self._native_frame.schema
}

def collect_schema(self) -> dict[str, DType]:
return self.schema

def with_columns(
self: Self,
*exprs: IntoPySparkExpr,
**named_exprs: IntoPySparkExpr,
) -> Self:
new_columns_map = parse_exprs_and_named_exprs(self, *exprs, **named_exprs)
return self._from_native_frame(self._native_frame.withColumns(new_columns_map))

def drop(self: Self, columns: list[str], strict: bool) -> Self: # noqa: FBT001
columns_to_drop = parse_columns_to_drop(
compliant_frame=self, columns=columns, strict=strict
)
return self._from_native_frame(self._native_frame.drop(*columns_to_drop))

def head(self: Self, n: int) -> Self:
spark_session = self._native_frame.sparkSession

return self._from_native_frame(
spark_session.createDataFrame(self._native_frame.take(num=n))
)

def group_by(self: Self, *by: str) -> PySparkLazyGroupBy:
from narwhals._pyspark.group_by import PySparkLazyGroupBy

return PySparkLazyGroupBy(df=self, keys=list(by))

def sort(
self: Self,
by: str | Iterable[str],
*more_by: str,
descending: bool | Sequence[bool] = False,
nulls_last: bool = False,
) -> Self:
import pyspark.sql.functions as F # noqa: N812

flat_by = flatten([*flatten([by]), *more_by])
if isinstance(descending, bool):
descending = [descending]

if nulls_last:
sort_funcs = [
F.desc_nulls_last if d else F.asc_nulls_last for d in descending
]
else:
sort_funcs = [
F.desc_nulls_first if d else F.asc_nulls_first for d in descending
]

sort_cols = [sort_f(col) for col, sort_f in zip(flat_by, sort_funcs)]
return self._from_native_frame(self._native_frame.sort(*sort_cols))
Loading
Loading