From 50037a55992d422c5dc732bba828347683a8621f Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Wed, 1 Jan 2025 14:42:50 +0000 Subject: [PATCH] test: dont convert everything to nan --- tests/expr_and_series/fill_null_test.py | 8 ++--- tests/expr_and_series/max_horizontal_test.py | 2 +- tests/expr_and_series/mean_horizontal_test.py | 2 +- tests/expr_and_series/min_horizontal_test.py | 2 +- tests/expr_and_series/rolling_mean_test.py | 13 ++++---- tests/expr_and_series/rolling_std_test.py | 31 ++++++++++++++----- tests/expr_and_series/rolling_sum_test.py | 12 +++---- tests/expr_and_series/rolling_var_test.py | 8 ++--- tests/expr_and_series/skew_test.py | 4 +-- tests/expr_and_series/unary_test.py | 20 ++++++------ tests/expr_and_series/when_test.py | 11 +++---- tests/frame/drop_nulls_test.py | 4 +-- tests/frame/join_test.py | 18 +++++------ tests/frame/pivot_test.py | 4 +-- tests/frame/sort_test.py | 4 +-- tests/group_by_test.py | 2 +- tests/hypothesis/join_test.py | 4 +-- tests/spark_like_test.py | 16 +++++----- tests/utils.py | 8 +++-- 19 files changed, 96 insertions(+), 77 deletions(-) diff --git a/tests/expr_and_series/fill_null_test.py b/tests/expr_and_series/fill_null_test.py index 32e4b9cdd..57f767d4d 100644 --- a/tests/expr_and_series/fill_null_test.py +++ b/tests/expr_and_series/fill_null_test.py @@ -136,7 +136,7 @@ def test_fill_null_limits(constructor: Constructor) -> None: nw.col("a", "b").fill_null(strategy="forward", limit=2) ) expected_forward = { - "a": [1, 1, 1, float("nan"), 5, 6, 6, 6, float("nan"), 10], + "a": [1, 1, 1, None, 5, 6, 6, 6, None, 10], "b": ["a", "a", "a", None, "b", "c", "c", "c", None, "d"], } assert_equal_data(result_forward, expected_forward) @@ -146,7 +146,7 @@ def test_fill_null_limits(constructor: Constructor) -> None: ) expected_backward = { - "a": [1, float("nan"), 5, 5, 5, 6, float("nan"), 10, 10, 10], + "a": [1, None, 5, 5, 5, 6, None, 10, 10, 10], "b": ["a", None, "b", "b", "b", "c", None, "d", "d", "d"], } assert_equal_data(result_backward, expected_backward) @@ -203,7 +203,7 @@ def test_fill_null_series_limits(constructor_eager: ConstructorEager) -> None: "ignore", message="The 'downcast' keyword in fillna is deprecated" ) expected_forward = { - "a_forward": [0.0, 1, 1, float("nan"), 2, 2, float("nan"), 3], + "a_forward": [0.0, 1, 1, None, 2, 2, None, 3], "b_forward": ["", "a", "a", None, "c", "c", None, "e"], } result_forward = df.select( @@ -214,7 +214,7 @@ def test_fill_null_series_limits(constructor_eager: ConstructorEager) -> None: assert_equal_data(result_forward, expected_forward) expected_backward = { - "a_backward": [0.0, 1, float("nan"), 2, 2, float("nan"), 3, 3], + "a_backward": [0.0, 1, None, 2, 2, None, 3, 3], "b_backward": ["", "a", None, "c", "c", None, "e", "e"], } diff --git a/tests/expr_and_series/max_horizontal_test.py b/tests/expr_and_series/max_horizontal_test.py index 3becb36be..c86e11318 100644 --- a/tests/expr_and_series/max_horizontal_test.py +++ b/tests/expr_and_series/max_horizontal_test.py @@ -9,7 +9,7 @@ from tests.utils import assert_equal_data data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]} -expected_values = [4, 3, 6, float("nan")] +expected_values = [4, 3, 6, None] @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) diff --git a/tests/expr_and_series/mean_horizontal_test.py b/tests/expr_and_series/mean_horizontal_test.py index 31b4b2109..485bf1750 100644 --- a/tests/expr_and_series/mean_horizontal_test.py +++ b/tests/expr_and_series/mean_horizontal_test.py @@ -14,7 +14,7 @@ def test_meanh(constructor: Constructor, col_expr: Any) -> None: data = {"a": [1, 3, None, None], "b": [4, None, 6, None]} df = nw.from_native(constructor(data)) result = df.select(horizontal_mean=nw.mean_horizontal(col_expr, nw.col("b"))) - expected = {"horizontal_mean": [2.5, 3.0, 6.0, float("nan")]} + expected = {"horizontal_mean": [2.5, 3.0, 6.0, None]} assert_equal_data(result, expected) diff --git a/tests/expr_and_series/min_horizontal_test.py b/tests/expr_and_series/min_horizontal_test.py index 5fb7fce97..787e3e2a4 100644 --- a/tests/expr_and_series/min_horizontal_test.py +++ b/tests/expr_and_series/min_horizontal_test.py @@ -9,7 +9,7 @@ from tests.utils import assert_equal_data data = {"a": [1, 3, None, None], "b": [4, None, 6, None], "z": [3, 1, None, None]} -expected_values = [1, 1, 6, float("nan")] +expected_values = [1, 1, 6, None] @pytest.mark.parametrize("col_expr", [nw.col("a"), "a"]) diff --git a/tests/expr_and_series/rolling_mean_test.py b/tests/expr_and_series/rolling_mean_test.py index 33c817bf3..a6dd41935 100644 --- a/tests/expr_and_series/rolling_mean_test.py +++ b/tests/expr_and_series/rolling_mean_test.py @@ -1,6 +1,7 @@ from __future__ import annotations import random +from typing import Any import hypothesis.strategies as st import pandas as pd @@ -16,15 +17,15 @@ data = {"a": [None, 1, 2, None, 4, 6, 11]} -kwargs_and_expected = { - "x1": {"kwargs": {"window_size": 3}, "expected": [float("nan")] * 6 + [7.0]}, +kwargs_and_expected: dict[str, dict[str, Any]] = { + "x1": {"kwargs": {"window_size": 3}, "expected": [None] * 6 + [7.0]}, "x2": { "kwargs": {"window_size": 3, "min_periods": 1}, - "expected": [float("nan"), 1.0, 1.5, 1.5, 3.0, 5.0, 7.0], + "expected": [None, 1.0, 1.5, 1.5, 3.0, 5.0, 7.0], }, "x3": { "kwargs": {"window_size": 2, "min_periods": 1}, - "expected": [float("nan"), 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], + "expected": [None, 1.0, 1.5, 2.0, 4.0, 5.0, 8.5], }, "x4": { "kwargs": {"window_size": 5, "min_periods": 1, "center": True}, @@ -52,7 +53,7 @@ def test_rolling_mean_expr( df = nw.from_native(constructor(data)) result = df.select( **{ - name: nw.col("a").rolling_mean(**values["kwargs"]) # type: ignore[arg-type] + name: nw.col("a").rolling_mean(**values["kwargs"]) for name, values in kwargs_and_expected.items() } ) @@ -69,7 +70,7 @@ def test_rolling_mean_series(constructor_eager: ConstructorEager) -> None: result = df.select( **{ - name: df["a"].rolling_mean(**values["kwargs"]) # type: ignore[arg-type] + name: df["a"].rolling_mean(**values["kwargs"]) for name, values in kwargs_and_expected.items() } ) diff --git a/tests/expr_and_series/rolling_std_test.py b/tests/expr_and_series/rolling_std_test.py index 3fdba9493..b937f8430 100644 --- a/tests/expr_and_series/rolling_std_test.py +++ b/tests/expr_and_series/rolling_std_test.py @@ -1,8 +1,8 @@ from __future__ import annotations +from math import sqrt from typing import Any -import numpy as np import pytest import narwhals.stable.v1 as nw @@ -17,32 +17,49 @@ { "name": "x1", "kwargs": {"window_size": 3}, - "expected": np.sqrt([float("nan"), float("nan"), 1 / 3, 1, 4 / 3, 7 / 3, 3]), + "expected": [ + sqrt(x) if x is not None else x + for x in [None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3] + ], }, { "name": "x2", "kwargs": {"window_size": 3, "min_periods": 1}, - "expected": np.sqrt([float("nan"), 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3]), + "expected": [ + sqrt(x) if x is not None else x + for x in [None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3] + ], }, { "name": "x3", "kwargs": {"window_size": 2, "min_periods": 1}, - "expected": np.sqrt([float("nan"), 0.5, 0.5, 2.0, 2.0, 4.5, 4.5]), + "expected": [ + sqrt(x) if x is not None else x for x in [None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5] + ], }, { "name": "x4", "kwargs": {"window_size": 5, "min_periods": 1, "center": True}, - "expected": np.sqrt([1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3]), + "expected": [ + sqrt(x) if x is not None else x + for x in [1 / 3, 11 / 12, 4 / 5, 17 / 10, 2.0, 2.25, 3] + ], }, { "name": "x5", "kwargs": {"window_size": 4, "min_periods": 1, "center": True}, - "expected": np.sqrt([0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3]), + "expected": [ + sqrt(x) if x is not None else x + for x in [0.5, 1 / 3, 11 / 12, 11 / 12, 2.25, 2.25, 3] + ], }, { "name": "x6", "kwargs": {"window_size": 3, "ddof": 2}, - "expected": np.sqrt([float("nan"), float("nan"), 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0]), + "expected": [ + sqrt(x) if x is not None else x + for x in [None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0] + ], }, ) diff --git a/tests/expr_and_series/rolling_sum_test.py b/tests/expr_and_series/rolling_sum_test.py index fae22552b..8c4537e49 100644 --- a/tests/expr_and_series/rolling_sum_test.py +++ b/tests/expr_and_series/rolling_sum_test.py @@ -18,15 +18,15 @@ data = {"a": [None, 1, 2, None, 4, 6, 11]} -kwargs_and_expected = { - "x1": {"kwargs": {"window_size": 3}, "expected": [float("nan")] * 6 + [21]}, +kwargs_and_expected: dict[str, dict[str, Any]] = { + "x1": {"kwargs": {"window_size": 3}, "expected": [None] * 6 + [21]}, "x2": { "kwargs": {"window_size": 3, "min_periods": 1}, - "expected": [float("nan"), 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], + "expected": [None, 1.0, 3.0, 3.0, 6.0, 10.0, 21.0], }, "x3": { "kwargs": {"window_size": 2, "min_periods": 1}, - "expected": [float("nan"), 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], + "expected": [None, 1.0, 3.0, 2.0, 4.0, 10.0, 17.0], }, "x4": { "kwargs": {"window_size": 5, "min_periods": 1, "center": True}, @@ -54,7 +54,7 @@ def test_rolling_sum_expr( df = nw.from_native(constructor(data)) result = df.select( **{ - name: nw.col("a").rolling_sum(**values["kwargs"]) # type: ignore[arg-type] + name: nw.col("a").rolling_sum(**values["kwargs"]) for name, values in kwargs_and_expected.items() } ) @@ -71,7 +71,7 @@ def test_rolling_sum_series(constructor_eager: ConstructorEager) -> None: result = df.select( **{ - name: df["a"].rolling_sum(**values["kwargs"]) # type: ignore[arg-type] + name: df["a"].rolling_sum(**values["kwargs"]) for name, values in kwargs_and_expected.items() } ) diff --git a/tests/expr_and_series/rolling_var_test.py b/tests/expr_and_series/rolling_var_test.py index 32767c990..37475e76a 100644 --- a/tests/expr_and_series/rolling_var_test.py +++ b/tests/expr_and_series/rolling_var_test.py @@ -23,17 +23,17 @@ { "name": "x1", "kwargs": {"window_size": 3}, - "expected": [float("nan"), float("nan"), 1 / 3, 1, 4 / 3, 7 / 3, 3], + "expected": [None, None, 1 / 3, 1, 4 / 3, 7 / 3, 3], }, { "name": "x2", "kwargs": {"window_size": 3, "min_periods": 1}, - "expected": [float("nan"), 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3], + "expected": [None, 0.5, 1 / 3, 1.0, 4 / 3, 7 / 3, 3], }, { "name": "x3", "kwargs": {"window_size": 2, "min_periods": 1}, - "expected": [float("nan"), 0.5, 0.5, 2.0, 2.0, 4.5, 4.5], + "expected": [None, 0.5, 0.5, 2.0, 2.0, 4.5, 4.5], }, { "name": "x4", @@ -48,7 +48,7 @@ { "name": "x6", "kwargs": {"window_size": 3, "ddof": 2}, - "expected": [float("nan"), float("nan"), 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0], + "expected": [None, None, 2 / 3, 2.0, 8 / 3, 14 / 3, 6.0], }, ) diff --git a/tests/expr_and_series/skew_test.py b/tests/expr_and_series/skew_test.py index b2029d08e..849496807 100644 --- a/tests/expr_and_series/skew_test.py +++ b/tests/expr_and_series/skew_test.py @@ -13,9 +13,9 @@ ("data", "expected"), [ ([], None), - ([1], float("nan")), + ([1], None), ([1, 2], 0.0), - ([0.0, 0.0, 0.0], float("nan")), + ([0.0, 0.0, 0.0], None), ([1, 2, 3, 2, 1], 0.343622), ], ) diff --git a/tests/expr_and_series/unary_test.py b/tests/expr_and_series/unary_test.py index 3a580b726..f6b9affc1 100644 --- a/tests/expr_and_series/unary_test.py +++ b/tests/expr_and_series/unary_test.py @@ -93,7 +93,7 @@ def test_unary_two_elements(constructor: Constructor) -> None: "b_nunique": [2], "b_skew": [0.0], "c_nunique": [2], - "c_skew": [float("nan")], + "c_skew": [None], } assert_equal_data(result, expected) @@ -115,13 +115,13 @@ def test_unary_two_elements_series(constructor_eager: ConstructorEager) -> None: "b_nunique": [2], "b_skew": [0.0], "c_nunique": [2], - "c_skew": [float("nan")], + "c_skew": [None], } assert_equal_data(result, expected) def test_unary_one_element(constructor: Constructor) -> None: - data = {"a": [1], "b": [2], "c": [float("nan")]} + data = {"a": [1], "b": [2], "c": [None]} # Dask runs into a divide by zero RuntimeWarning for 1 element skew. context = ( pytest.warns(RuntimeWarning, match="invalid value encountered in scalar divide") @@ -139,17 +139,17 @@ def test_unary_one_element(constructor: Constructor) -> None: ) expected = { "a_nunique": [1], - "a_skew": [float("nan")], + "a_skew": [None], "b_nunique": [1], - "b_skew": [float("nan")], + "b_skew": [None], "c_nunique": [1], - "c_skew": [float("nan")], + "c_skew": [None], } assert_equal_data(result, expected) def test_unary_one_element_series(constructor_eager: ConstructorEager) -> None: - data = {"a": [1], "b": [2], "c": [float("nan")]} + data = {"a": [1], "b": [2], "c": [None]} df = nw.from_native(constructor_eager(data)) result = { "a_nunique": [df["a"].n_unique()], @@ -161,10 +161,10 @@ def test_unary_one_element_series(constructor_eager: ConstructorEager) -> None: } expected = { "a_nunique": [1], - "a_skew": [float("nan")], + "a_skew": [None], "b_nunique": [1], - "b_skew": [float("nan")], + "b_skew": [None], "c_nunique": [1], - "c_skew": [float("nan")], + "c_skew": [None], } assert_equal_data(result, expected) diff --git a/tests/expr_and_series/when_test.py b/tests/expr_and_series/when_test.py index 3cef177fa..8648ae4fb 100644 --- a/tests/expr_and_series/when_test.py +++ b/tests/expr_and_series/when_test.py @@ -21,7 +21,7 @@ def test_when(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(value=3).alias("a_when")) expected = { - "a_when": [3, np.nan, np.nan], + "a_when": [3, None, None], } assert_equal_data(result, expected) @@ -41,7 +41,7 @@ def test_multiple_conditions(constructor: Constructor) -> None: nw.when(nw.col("a") < 3, nw.col("c") < 5.0).then(3).alias("a_when") ) expected = { - "a_when": [3, np.nan, np.nan], + "a_when": [3, None, None], } assert_equal_data(result, expected) @@ -65,7 +65,7 @@ def test_value_numpy_array( nw.when(nw.col("a") == 1).then(np.asanyarray([3, 4, 5])).alias("a_when") ) expected = { - "a_when": [3, np.nan, np.nan], + "a_when": [3, None, None], } assert_equal_data(result, expected) @@ -77,7 +77,7 @@ def test_value_series(constructor_eager: ConstructorEager) -> None: assert isinstance(s, nw.Series) result = df.select(nw.when(nw.col("a") == 1).then(s).alias("a_when")) expected = { - "a_when": [3, np.nan, np.nan], + "a_when": [3, None, None], } assert_equal_data(result, expected) @@ -86,7 +86,7 @@ def test_value_expression(constructor: Constructor) -> None: df = nw.from_native(constructor(data)) result = df.select(nw.when(nw.col("a") == 1).then(nw.col("a") + 9).alias("a_when")) expected = { - "a_when": [10, np.nan, np.nan], + "a_when": [10, None, None], } assert_equal_data(result, expected) @@ -98,7 +98,6 @@ def test_otherwise_numpy_array( request.applymarker(pytest.mark.xfail) df = nw.from_native(constructor(data)) - import numpy as np result = df.select( nw.when(nw.col("a") == 1).then(-1).otherwise(np.array([0, 9, 10])).alias("a_when") diff --git a/tests/frame/drop_nulls_test.py b/tests/frame/drop_nulls_test.py index 680cbd4c4..bb55439eb 100644 --- a/tests/frame/drop_nulls_test.py +++ b/tests/frame/drop_nulls_test.py @@ -24,8 +24,8 @@ def test_drop_nulls(constructor: Constructor) -> None: @pytest.mark.parametrize( ("subset", "expected"), [ - ("a", {"a": [1, 2.0, 4.0], "b": [float("nan"), 3.0, 5.0]}), - (["a"], {"a": [1, 2.0, 4.0], "b": [float("nan"), 3.0, 5.0]}), + ("a", {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), + (["a"], {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), (["a", "b"], {"a": [2.0, 4.0], "b": [3.0, 5.0]}), ], ) diff --git a/tests/frame/join_test.py b/tests/frame/join_test.py index c743893d0..1abe2b90f 100644 --- a/tests/frame/join_test.py +++ b/tests/frame/join_test.py @@ -235,22 +235,20 @@ def test_left_join(constructor: Constructor) -> None: } df_left = nw.from_native(constructor(data_left)) df_right = nw.from_native(constructor(data_right)) - result = df_left.join(df_right, left_on="bob", right_on="co", how="left").select( # type: ignore[arg-type] - nw.all().fill_null(float("nan")) - ) + result = df_left.join(df_right, left_on="bob", right_on="co", how="left") # type: ignore[arg-type] result = result.sort("index") result = result.drop("index_right") expected = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], - "antananarivo_right": [1, 2, float("nan")], + "antananarivo_right": [1, 2, None], "index": [0, 1, 2], } result_on_list = df_left.join( df_right, # type: ignore[arg-type] on=["antananarivo", "index"], how="left", - ).select(nw.all().fill_null(float("nan"))) + ) result_on_list = result_on_list.sort("index") expected_on_list = { "antananarivo": [1, 2, 3], @@ -312,15 +310,15 @@ def test_left_join_overlapping_column(constructor: Constructor) -> None: left_on="antananarivo", right_on="d", how="left", - ).select(nw.all().fill_null(float("nan"))) + ) result = result.sort("index") result = result.drop("index_right") expected = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], "d": [1, 4, 2], - "antananarivo_right": [1.0, 3.0, float("nan")], - "c": [4.0, 6.0, float("nan")], + "antananarivo_right": [1.0, 3.0, None], + "c": [4.0, 6.0, None], "index": [0, 1, 2], } assert_equal_data(result, expected) @@ -397,7 +395,7 @@ def test_joinasof_numeric( expected_forward = { "antananarivo": [1, 5, 10], "val": ["a", "b", "c"], - "val_right": [1, 6, float("nan")], + "val_right": [1, 6, None], } expected_nearest = { "antananarivo": [1, 5, 10], @@ -523,7 +521,7 @@ def test_joinasof_by( "antananarivo": [1, 5, 7, 10], "bob": ["D", "D", "C", "A"], "c": [9, 2, 1, 1], - "d": [1, 3, float("nan"), 4], + "d": [1, 3, None, 4], } assert_equal_data(result, expected) assert_equal_data(result_by, expected) diff --git a/tests/frame/pivot_test.py b/tests/frame/pivot_test.py index 98ef7466f..0e3860292 100644 --- a/tests/frame/pivot_test.py +++ b/tests/frame/pivot_test.py @@ -271,7 +271,7 @@ def test_pivot_no_index( expected = { "ix": [1, 1, 2, 2], "bar": ["x", "y", "w", "z"], - "a": [1.0, float("nan"), float("nan"), 3.0], - "b": [float("nan"), 2.0, 4.0, float("nan")], + "a": [1.0, None, None, 3.0], + "b": [None, 2.0, 4.0, None], } assert_equal_data(result, expected) diff --git a/tests/frame/sort_test.py b/tests/frame/sort_test.py index 4e12cc95a..5147c6f56 100644 --- a/tests/frame/sort_test.py +++ b/tests/frame/sort_test.py @@ -29,8 +29,8 @@ def test_sort(constructor: Constructor) -> None: @pytest.mark.parametrize( ("nulls_last", "expected"), [ - (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, float("nan")]}), - (False, {"a": [-1, 0, 2, 0], "b": [float("nan"), 3, 2, 1]}), + (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, None]}), + (False, {"a": [-1, 0, 2, 0], "b": [None, 3, 2, 1]}), ], ) def test_sort_nulls( diff --git a/tests/group_by_test.py b/tests/group_by_test.py index 188c17c76..cd350f6ea 100644 --- a/tests/group_by_test.py +++ b/tests/group_by_test.py @@ -293,7 +293,7 @@ def test_key_with_nulls( .sort("a") .with_columns(nw.col("b").cast(nw.Float64)) ) - expected = {"b": [4.0, 5, float("nan")], "len": [1, 1, 1], "a": [1, 2, 3]} + expected = {"b": [4.0, 5, None], "len": [1, 1, 1], "a": [1, 2, 3]} assert_equal_data(result, expected) diff --git a/tests/hypothesis/join_test.py b/tests/hypothesis/join_test.py index 5b498db65..7f1cd8103 100644 --- a/tests/hypothesis/join_test.py +++ b/tests/hypothesis/join_test.py @@ -161,7 +161,7 @@ def test_left_join( # pragma: no cover left_on=left_key, right_on=right_key, ) - ).select(pl.all().fill_null(float("nan"))) + ) assert_equal_data( result_pd.to_dict(as_series=False), result_pl.to_dict(as_series=False) ) @@ -174,7 +174,7 @@ def test_left_join( # pragma: no cover left_on=left_key, right_on=right_key, ) - .select(nw.all().cast(nw.Float64).fill_null(float("nan"))) + .select(nw.all().cast(nw.Float64)) .pipe(lambda df: df.sort(df.columns)) ) assert_equal_data( diff --git a/tests/spark_like_test.py b/tests/spark_like_test.py index 99682b8f7..eb6f89fdb 100644 --- a/tests/spark_like_test.py +++ b/tests/spark_like_test.py @@ -26,7 +26,7 @@ def _pyspark_constructor_with_session(obj: Any, spark_session: SparkSession) -> IntoFrame: # NaN and NULL are not the same in PySpark - pd_df = pd.DataFrame(obj).replace({float("nan"): None}).reset_index() + pd_df = pd.DataFrame(obj).replace({None: None}).reset_index() return ( # type: ignore[no-any-return] spark_session.createDataFrame(pd_df).repartition(2).orderBy("index").drop("index") ) @@ -235,8 +235,8 @@ def test_sort(pyspark_constructor: Constructor) -> None: @pytest.mark.parametrize( ("nulls_last", "expected"), [ - (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, float("nan")]}), - (False, {"a": [-1, 0, 2, 0], "b": [float("nan"), 3, 2, 1]}), + (True, {"a": [0, 2, 0, -1], "b": [3, 2, 1, None]}), + (False, {"a": [-1, 0, 2, 0], "b": [None, 3, 2, 1]}), ], ) def test_sort_nulls( @@ -511,8 +511,8 @@ def test_drop_nulls(pyspark_constructor: Constructor) -> None: @pytest.mark.parametrize( ("subset", "expected"), [ - ("a", {"a": [1, 2.0, 4.0], "b": [float("nan"), 3.0, 5.0]}), - (["a"], {"a": [1, 2.0, 4.0], "b": [float("nan"), 3.0, 5.0]}), + ("a", {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), + (["a"], {"a": [1, 2.0, 4.0], "b": [None, 3.0, 5.0]}), (["a", "b"], {"a": [2.0, 4.0], "b": [3.0, 5.0]}), ], ) @@ -782,7 +782,7 @@ def test_left_join(pyspark_constructor: Constructor) -> None: expected = { "antananarivo": [1, 2, 3], "bob": [4, 5, 6], - "antananarivo_right": [1, 2, float("nan")], + "antananarivo_right": [1, 2, None], "idx": [0, 1, 2], } result_on_list = df_left.join( @@ -863,8 +863,8 @@ def test_left_join_overlapping_column(pyspark_constructor: Constructor) -> None: "antananarivo": [1, 2, 3], "bob": [4, 5, 6], "d": [1, 4, 2], - "antananarivo_right": [1.0, 3.0, float("nan")], - "c": [4.0, 6.0, float("nan")], + "antananarivo_right": [1.0, 3.0, None], + "c": [4.0, 6.0, None], "idx": [0, 1, 2], } assert_equal_data(result, expected) diff --git a/tests/utils.py b/tests/utils.py index 60933046b..6c1caad07 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -90,8 +90,12 @@ def assert_equal_data(result: Any, expected: dict[str, Any]) -> None: for i, (lhs, rhs) in enumerate(zip_strict(result_value, expected_value)): if isinstance(lhs, float) and not math.isnan(lhs): are_equivalent_values = math.isclose(lhs, rhs, rel_tol=0, abs_tol=1e-6) - elif isinstance(lhs, float) and math.isnan(lhs) and rhs is not None: - are_equivalent_values = math.isnan(rhs) # pragma: no cover + elif isinstance(lhs, float) and math.isnan(lhs): + are_equivalent_values = rhs is None or math.isnan(rhs) + elif isinstance(rhs, float) and math.isnan(rhs): + are_equivalent_values = lhs is None or math.isnan(lhs) + elif lhs is None: + are_equivalent_values = rhs is None elif pd.isna(lhs): are_equivalent_values = pd.isna(rhs) else: