From 25b03f9d3471ea57ba3d18a7d7fbe0be05a306eb Mon Sep 17 00:00:00 2001 From: Haejoon Lee Date: Sat, 19 Oct 2024 19:56:37 +0900 Subject: [PATCH] [SPARK-50023][PYTHON][CONNECT] API compatibility check for Functions ### What changes were proposed in this pull request? This PR proposes to add API compatibility check for Spark SQL Functions ### Why are the changes needed? To guarantee of the same behavior between Spark Classic and Spark Connect ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added UTs ### Was this patch authored or co-authored using generative AI tooling? No Closes #48536 from itholic/SPARK-50023. Authored-by: Haejoon Lee Signed-off-by: Hyukjin Kwon --- .../pyspark/sql/connect/functions/builtin.py | 7 ++++--- python/pyspark/sql/functions/builtin.py | 8 ++++---- .../sql/tests/test_connect_compatibility.py | 18 ++++++++++++++++++ 3 files changed, 26 insertions(+), 7 deletions(-) diff --git a/python/pyspark/sql/connect/functions/builtin.py b/python/pyspark/sql/connect/functions/builtin.py index 9341442a1733b..1e3d41825f06c 100644 --- a/python/pyspark/sql/connect/functions/builtin.py +++ b/python/pyspark/sql/connect/functions/builtin.py @@ -1498,7 +1498,7 @@ def lead(col: "ColumnOrName", offset: int = 1, default: Optional[Any] = None) -> lead.__doc__ = pysparkfuncs.lead.__doc__ -def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = None) -> Column: +def nth_value(col: "ColumnOrName", offset: int, ignoreNulls: Optional[bool] = False) -> Column: if ignoreNulls is None: return _invoke_function("nth_value", _to_col(col), lit(offset)) else: @@ -2236,7 +2236,7 @@ def size(col: "ColumnOrName") -> Column: def slice( - col: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] + x: "ColumnOrName", start: Union["ColumnOrName", int], length: Union["ColumnOrName", int] ) -> Column: start = _enum_to_value(start) if isinstance(start, (Column, str)): @@ -2260,7 +2260,7 @@ def slice( messageParameters={"arg_name": "length", "arg_type": type(length).__name__}, ) - return _invoke_function_over_columns("slice", col, _start, _length) + return _invoke_function_over_columns("slice", x, _start, _length) slice.__doc__ = pysparkfuncs.slice.__doc__ @@ -4195,6 +4195,7 @@ def unwrap_udt(col: "ColumnOrName") -> Column: def udf( f: Optional[Union[Callable[..., Any], "DataTypeOrString"]] = None, returnType: "DataTypeOrString" = StringType(), + *, useArrow: Optional[bool] = None, ) -> Union["UserDefinedFunctionLike", Callable[[Callable[..., Any]], "UserDefinedFunctionLike"]]: if f is None or isinstance(f, (str, DataType)): diff --git a/python/pyspark/sql/functions/builtin.py b/python/pyspark/sql/functions/builtin.py index 55da50fd4a5a5..dbc66cab3f9b3 100644 --- a/python/pyspark/sql/functions/builtin.py +++ b/python/pyspark/sql/functions/builtin.py @@ -41,7 +41,7 @@ from pyspark.errors import PySparkTypeError, PySparkValueError from pyspark.sql.column import Column -from pyspark.sql.dataframe import DataFrame +from pyspark.sql.dataframe import DataFrame as ParentDataFrame from pyspark.sql.types import ArrayType, DataType, StringType, StructType, _from_numpy_type # Keep UserDefinedFunction import for backwards compatible import; moved in SPARK-22409 @@ -5590,7 +5590,7 @@ def approx_count_distinct(col: "ColumnOrName", rsd: Optional[float] = None) -> C @_try_remote_functions -def broadcast(df: DataFrame) -> DataFrame: +def broadcast(df: "ParentDataFrame") -> "ParentDataFrame": """ Marks a DataFrame as small enough for use in broadcast joins. @@ -5621,7 +5621,7 @@ def broadcast(df: DataFrame) -> DataFrame: from py4j.java_gateway import JVMView sc = _get_active_spark_context() - return DataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) + return ParentDataFrame(cast(JVMView, sc._jvm).functions.broadcast(df._jdf), df.sparkSession) @_try_remote_functions @@ -9678,7 +9678,7 @@ def from_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Col @_try_remote_functions -def to_utc_timestamp(timestamp: "ColumnOrName", tz: "ColumnOrName") -> Column: +def to_utc_timestamp(timestamp: "ColumnOrName", tz: Union[Column, str]) -> Column: """ This is a common function for databases supporting TIMESTAMP WITHOUT TIMEZONE. This function takes a timestamp which is timezone-agnostic, and interprets it as a timestamp in the given diff --git a/python/pyspark/sql/tests/test_connect_compatibility.py b/python/pyspark/sql/tests/test_connect_compatibility.py index f081385f44894..3ebb6b7aea7d0 100644 --- a/python/pyspark/sql/tests/test_connect_compatibility.py +++ b/python/pyspark/sql/tests/test_connect_compatibility.py @@ -30,6 +30,7 @@ from pyspark.sql.readwriter import DataFrameWriterV2 as ClassicDataFrameWriterV2 from pyspark.sql.window import Window as ClassicWindow from pyspark.sql.window import WindowSpec as ClassicWindowSpec +import pyspark.sql.functions as ClassicFunctions if should_test_connect: from pyspark.sql.connect.dataframe import DataFrame as ConnectDataFrame @@ -41,6 +42,7 @@ from pyspark.sql.connect.readwriter import DataFrameWriterV2 as ConnectDataFrameWriterV2 from pyspark.sql.connect.window import Window as ConnectWindow from pyspark.sql.connect.window import WindowSpec as ConnectWindowSpec + import pyspark.sql.connect.functions as ConnectFunctions class ConnectCompatibilityTestsMixin: @@ -339,6 +341,22 @@ def test_window_spec_compatibility(self): expected_missing_classic_methods, ) + def test_functions_compatibility(self): + """Test Functions compatibility between classic and connect.""" + expected_missing_connect_properties = set() + expected_missing_classic_properties = set() + expected_missing_connect_methods = set() + expected_missing_classic_methods = {"check_dependencies"} + self.check_compatibility( + ClassicFunctions, + ConnectFunctions, + "Functions", + expected_missing_connect_properties, + expected_missing_classic_properties, + expected_missing_connect_methods, + expected_missing_classic_methods, + ) + @unittest.skipIf(not should_test_connect, connect_requirement_message) class ConnectCompatibilityTests(ConnectCompatibilityTestsMixin, ReusedSQLTestCase):