Skip to content

Commit

Permalink
[SPARK-50023][PYTHON][CONNECT] API compatibility check for Functions
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?

This PR proposes to add API compatibility check for Spark SQL Functions

### Why are the changes needed?

To guarantee of the same behavior between Spark Classic and Spark Connect

### Does this PR introduce _any_ user-facing change?

No

### How was this patch tested?

Added UTs

### Was this patch authored or co-authored using generative AI tooling?

No

Closes apache#48536 from itholic/SPARK-50023.

Authored-by: Haejoon Lee <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
itholic authored and HyukjinKwon committed Oct 19, 2024
1 parent 14ed86e commit 25b03f9
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 7 deletions.
7 changes: 4 additions & 3 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1498,7 +1498,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) ->
lead.__doc__ = pysparkfuncs.lead.__doc__


def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = None) -> Column:
def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column:
if ignoreNulls is None:
return _invoke_function("nth_value", _to_col(col), lit(offset))
else:
Expand Down Expand Up @@ -2236,7 +2236,7 @@ def size(col: "ColumnOrName") -> Column:


def slice(
col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]
x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int]
) -> Column:
start = _enum_to_value(start)
if isinstance(start, (Column, str)):
Expand All @@ -2260,7 +2260,7 @@ def slice(
messageParameters={"arg_name": "length", "arg_type": type(length).__name__},
)

return _invoke_function_over_columns("slice", col, _start, _length)
return _invoke_function_over_columns("slice", x, _start, _length)


slice.__doc__ = pysparkfuncs.slice.__doc__
Expand Down Expand Up @@ -4195,6 +4195,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column:
def udf(
f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None,
returnType: "DataTypeOrString" = StringType(),
*,
useArrow: Optional[bool] = None,
) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]:
if f is None or isinstance(f, (str, DataType)):
Expand Down
8 changes: 4 additions & 4 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@

from pyspark.errors import PySparkTypeError, PySparkValueError
from pyspark.sql.column import Column
from pyspark.sql.dataframe import DataFrame
from pyspark.sql.dataframe import DataFrame as ParentDataFrame
from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from_numpy_type

# Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409
Expand Down Expand Up @@ -5590,7 +5590,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C


@_try_remote_functions
def broadcast(df: DataFrame) -> DataFrame:
def broadcast(df: "ParentDataFrame") -> "ParentDataFrame":
"""
Marks a DataFrame as small enough for use in broadcast joins.

Expand Down Expand Up @@ -5621,7 +5621,7 @@ def broadcast(df: DataFrame) -> DataFrame:
from py4j.java_gateway import JVMView

sc = _get_active_spark_context()
return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession)
return ParentDataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession)


@_try_remote_functions
Expand Down Expand Up @@ -9678,7 +9678,7 @@ def from_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Col


@_try_remote_functions
def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column:
def to_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Column:
"""
This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function
takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given
Expand Down
18 changes: 18 additions & 0 deletions python/pyspark/sql/tests/test_connect_compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2
from pyspark.sql.window import Window as ClassicWindow
from pyspark.sql.window import WindowSpec as ClassicWindowSpec
import pyspark.sql.functions as ClassicFunctions

if should_test_connect:
from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame
Expand All @@ -41,6 +42,7 @@
from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2
from pyspark.sql.connect.window import Window as ConnectWindow
from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec
import pyspark.sql.connect.functions as ConnectFunctions


class ConnectCompatibilityTestsMixin:
Expand Down Expand Up @@ -339,6 +341,22 @@ def test_window_spec_compatibility(self):
expected_missing_classic_methods,
)

def test_functions_compatibility(self):
"""Test Functions compatibility between classic and connect."""
expected_missing_connect_properties = set()
expected_missing_classic_properties = set()
expected_missing_connect_methods = set()
expected_missing_classic_methods = {"check_dependencies"}
self.check_compatibility(
ClassicFunctions,
ConnectFunctions,
"Functions",
expected_missing_connect_properties,
expected_missing_classic_properties,
expected_missing_connect_methods,
expected_missing_classic_methods,
)


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):
Expand Down

0 comments on commit 25b03f9

Please sign in to comment.