From 25a06b8b4cab541649b6efab4c4188f9f5194f41 Mon Sep 17 00:00:00 2001 From: "a.cherkaoui" Date: Sun, 29 Dec 2024 03:03:21 +0100 Subject: [PATCH 1/3] [python-package][PySpark] Expose Training and Validation Metrics --- python-package/xgboost/spark/core.py | 28 ++++++++---- .../xgboost/spark/xgboost_training_summary.py | 43 +++++++++++++++++++ 2 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 python-package/xgboost/spark/xgboost_training_summary.py diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 689e747e8a5c..1e06ef9c514c 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -100,6 +100,7 @@ serialize_booster, use_cuda, ) +from .xgboost_training_summary import _XGBoostTrainingSummary # Put pyspark specific params here, they won't be passed to XGBoost. # like `validationIndicatorCol`, `base_margin_col` @@ -704,8 +705,10 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]: """ raise NotImplementedError() - def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel": - return self._pyspark_model_cls()(xgb_model) + def _create_pyspark_model( + self, xgb_model: XGBModel, training_summary: _XGBoostTrainingSummary + ) -> "_SparkXGBModel": + return self._pyspark_model_cls()(xgb_model, training_summary) def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel: xgb_sklearn_params = self._gen_xgb_params_dict( @@ -1148,7 +1151,7 @@ def _train_booster( if dvalid is not None: dval = [(dtrain, "training"), (dvalid, "validation")] else: - dval = None + dval = [(dtrain, "training")] booster = worker_train( params=booster_params, dtrain=dtrain, @@ -1159,6 +1162,7 @@ def _train_booster( context.barrier() if context.partitionId() == 0: + yield pd.DataFrame({"data": [json.dumps(dict(evals_result))]}) config = booster.save_config() yield pd.DataFrame({"data": [config]}) booster_json = booster.save_raw("json").decode("utf-8") @@ -1167,7 +1171,7 @@ def _train_booster( booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE] yield pd.DataFrame({"data": [booster_chunk]}) - def _run_job() -> Tuple[str, str]: + def _run_job() -> Tuple[str, str, str]: rdd = ( dataset.mapInPandas( _train_booster, # type: ignore @@ -1179,7 +1183,7 @@ def _run_job() -> Tuple[str, str]: rdd_with_resource = self._try_stage_level_scheduling(rdd) ret = rdd_with_resource.collect() data = [v[0] for v in ret] - return data[0], "".join(data[1:]) + return data[0], data[1], "".join(data[2:]) get_logger(_LOG_TAG).info( "Running xgboost-%s on %s workers with" @@ -1192,13 +1196,16 @@ def _run_job() -> Tuple[str, str]: train_call_kwargs_params, dmatrix_kwargs, ) - (config, booster) = _run_job() + (evals_result, config, booster) = _run_job() get_logger(_LOG_TAG).info("Finished xgboost training!") result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config ) - spark_model = self._create_pyspark_model(result_xgb_model) + training_summary = _XGBoostTrainingSummary.from_metrics( + json.loads(evals_result) + ) + spark_model = self._create_pyspark_model(result_xgb_model, training_summary) # According to pyspark ML convention, the model uid should be the same # with estimator uid. spark_model._resetUid(self.uid) @@ -1219,9 +1226,14 @@ def read(cls) -> "SparkXGBReader": class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): - def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None: + def __init__( + self, + xgb_sklearn_model: Optional[XGBModel] = None, + training_summary: Optional[_XGBoostTrainingSummary] = None, + ) -> None: super().__init__() self._xgb_sklearn_model = xgb_sklearn_model + self.training_summary = training_summary @classmethod def _xgb_cls(cls) -> Type[XGBModel]: diff --git a/python-package/xgboost/spark/xgboost_training_summary.py b/python-package/xgboost/spark/xgboost_training_summary.py new file mode 100644 index 000000000000..7c3c6f5093d5 --- /dev/null +++ b/python-package/xgboost/spark/xgboost_training_summary.py @@ -0,0 +1,43 @@ +"""Xgboost training summary integration submodule.""" + +from dataclasses import dataclass, field +from typing import Dict, List + + +@dataclass +class _XGBoostTrainingSummary: + """ + A class that holds the training and validation objective history + of an XGBoost model during its training process. + """ + + train_objective_history: Dict[str, List[float]] = field(default_factory=dict) + validation_objective_history: Dict[str, List[float]] = field(default_factory=dict) + + @staticmethod + def from_metrics( + metrics: Dict[str, Dict[str, List[float]]] + ) -> "_XGBoostTrainingSummary": + """ + Create an XGBoostTrainingSummary instance from a nested dictionary of metrics. + + Parameters + ---------- + metrics : dict of str to dict of str to list of float + A dictionary containing training and validation metrics. + Example format: + { + "training": {"logloss": [0.1, 0.08]}, + "validation": {"logloss": [0.12, 0.1]} + } + + Returns + ------- + A new instance of XGBoostTrainingSummary. + + """ + train_objective_history = metrics.get("training", {}) + validation_objective_history = metrics.get("validation", {}) + return _XGBoostTrainingSummary( + train_objective_history, validation_objective_history + ) From 984bc8ea9bbe6ddb24adb41fa67f56a517613586 Mon Sep 17 00:00:00 2001 From: "a.cherkaoui" Date: Thu, 2 Jan 2025 07:25:07 +0100 Subject: [PATCH 2/3] Renamed xgboost_training_summary.py to summary.py and _XGBoostTrainingSummary to XGBoostTrainingSummary --- python-package/xgboost/spark/core.py | 10 ++++------ .../spark/{xgboost_training_summary.py => summary.py} | 6 +++--- 2 files changed, 7 insertions(+), 9 deletions(-) rename python-package/xgboost/spark/{xgboost_training_summary.py => summary.py} (92%) diff --git a/python-package/xgboost/spark/core.py b/python-package/xgboost/spark/core.py index 1e06ef9c514c..df9a57ba8428 100644 --- a/python-package/xgboost/spark/core.py +++ b/python-package/xgboost/spark/core.py @@ -82,6 +82,7 @@ HasFeaturesCols, HasQueryIdCol, ) +from .summary import XGBoostTrainingSummary from .utils import ( CommunicatorContext, _get_default_params_from_func, @@ -100,7 +101,6 @@ serialize_booster, use_cuda, ) -from .xgboost_training_summary import _XGBoostTrainingSummary # Put pyspark specific params here, they won't be passed to XGBoost. # like `validationIndicatorCol`, `base_margin_col` @@ -706,7 +706,7 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]: raise NotImplementedError() def _create_pyspark_model( - self, xgb_model: XGBModel, training_summary: _XGBoostTrainingSummary + self, xgb_model: XGBModel, training_summary: XGBoostTrainingSummary ) -> "_SparkXGBModel": return self._pyspark_model_cls()(xgb_model, training_summary) @@ -1202,9 +1202,7 @@ def _run_job() -> Tuple[str, str, str]: result_xgb_model = self._convert_to_sklearn_model( bytearray(booster, "utf-8"), config ) - training_summary = _XGBoostTrainingSummary.from_metrics( - json.loads(evals_result) - ) + training_summary = XGBoostTrainingSummary.from_metrics(json.loads(evals_result)) spark_model = self._create_pyspark_model(result_xgb_model, training_summary) # According to pyspark ML convention, the model uid should be the same # with estimator uid. @@ -1229,7 +1227,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable): def __init__( self, xgb_sklearn_model: Optional[XGBModel] = None, - training_summary: Optional[_XGBoostTrainingSummary] = None, + training_summary: Optional[XGBoostTrainingSummary] = None, ) -> None: super().__init__() self._xgb_sklearn_model = xgb_sklearn_model diff --git a/python-package/xgboost/spark/xgboost_training_summary.py b/python-package/xgboost/spark/summary.py similarity index 92% rename from python-package/xgboost/spark/xgboost_training_summary.py rename to python-package/xgboost/spark/summary.py index 7c3c6f5093d5..eca5f6b128b7 100644 --- a/python-package/xgboost/spark/xgboost_training_summary.py +++ b/python-package/xgboost/spark/summary.py @@ -5,7 +5,7 @@ @dataclass -class _XGBoostTrainingSummary: +class XGBoostTrainingSummary: """ A class that holds the training and validation objective history of an XGBoost model during its training process. @@ -17,7 +17,7 @@ class _XGBoostTrainingSummary: @staticmethod def from_metrics( metrics: Dict[str, Dict[str, List[float]]] - ) -> "_XGBoostTrainingSummary": + ) -> "XGBoostTrainingSummary": """ Create an XGBoostTrainingSummary instance from a nested dictionary of metrics. @@ -38,6 +38,6 @@ def from_metrics( """ train_objective_history = metrics.get("training", {}) validation_objective_history = metrics.get("validation", {}) - return _XGBoostTrainingSummary( + return XGBoostTrainingSummary( train_objective_history, validation_objective_history ) From 3e60eec33c049d11aaa25305f5547d139f6d7d7a Mon Sep 17 00:00:00 2001 From: "a.cherkaoui" Date: Sat, 4 Jan 2025 14:12:02 +0100 Subject: [PATCH 3/3] Add tests for the PySpark XGBoost summary --- .../test_with_spark/test_xgboost_summary.py | 233 ++++++++++++++++++ 1 file changed, 233 insertions(+) create mode 100644 tests/test_distributed/test_with_spark/test_xgboost_summary.py diff --git a/tests/test_distributed/test_with_spark/test_xgboost_summary.py b/tests/test_distributed/test_with_spark/test_xgboost_summary.py new file mode 100644 index 000000000000..759ccea98aaa --- /dev/null +++ b/tests/test_distributed/test_with_spark/test_xgboost_summary.py @@ -0,0 +1,233 @@ +import logging +from typing import Union + +import pytest +from pyspark.ml.linalg import Vectors +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.functions import lit + +from xgboost import testing as tm +from xgboost.spark import ( + SparkXGBClassifier, + SparkXGBClassifierModel, + SparkXGBRanker, + SparkXGBRankerModel, + SparkXGBRegressor, + SparkXGBRegressorModel, +) + +from .test_spark_local import spark as spark_local + +logging.getLogger("py4j").setLevel(logging.INFO) + +pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())] + + +@pytest.fixture +def clf_and_reg_df(spark_local: SparkSession) -> DataFrame: + """ + Fixture to create a DataFrame with example data. + """ + data = [ + (Vectors.dense([1.0, 2.0, 3.0]), 1), + (Vectors.dense([4.0, 5.0, 6.0]), 1), + (Vectors.dense([9.0, 4.0, 8.0]), 0), + (Vectors.dense([6.0, 2.0, 2.0]), 1), + (Vectors.dense([5.0, 4.0, 3.0]), 0), + ] + columns = ["features", "label"] + return spark_local.createDataFrame(data, schema=columns) + + +@pytest.fixture +def clf_and_reg_df_with_validation(clf_and_reg_df: DataFrame) -> DataFrame: + """ + Fixture to create a DataFrame with example data. + """ + # split data into training and validation sets + train_df, validation_df = clf_and_reg_df.randomSplit([0.8, 0.2], seed=42) + + # Add a column to indicate validation rows + train_df = train_df.withColumn("validation_indicator_col", lit(False)) + validation_df = validation_df.withColumn("validation_indicator_col", lit(True)) + return train_df.union(validation_df) + + +@pytest.fixture +def ranker_df(spark_local: SparkSession) -> DataFrame: + """ + Fixture to create a DataFrame with sample data for ranking tasks. + """ + data = [ + (Vectors.dense([1.0, 2.0, 3.0]), 0, 0), + (Vectors.dense([4.0, 5.0, 6.0]), 1, 0), + (Vectors.dense([9.0, 4.0, 8.0]), 0, 0), + (Vectors.dense([6.0, 2.0, 2.0]), 1, 0), + (Vectors.dense([5.0, 4.0, 3.0]), 0, 0), + ] + columns = ["features", "label", "qid"] + return spark_local.createDataFrame(data, schema=columns) + + +@pytest.fixture +def ranker_df_with_validation(ranker_df: DataFrame) -> DataFrame: + """ + Fixture to split the ranking DataFrame into training and validation sets, + add validation indicator, and merge them back into a single DataFrame. + """ + # Split the data into training and validation sets (80-20 split) + train_df, validation_df = ranker_df.randomSplit([0.8, 0.2], seed=42) + + # Add a column to indicate whether the row is from the validation set + train_df = train_df.withColumn("validation_indicator_col", lit(False)) + validation_df = validation_df.withColumn("validation_indicator_col", lit(True)) + + # Union the training and validation DataFrames + return train_df.union(validation_df) + + +class TestXGBoostTrainingSummary: + @staticmethod + def assert_empty_validation_objective_history( + xgb_model: Union[ + SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel + ] + ) -> None: + assert hasattr(xgb_model.training_summary, "validation_objective_history") + assert isinstance(xgb_model.training_summary.validation_objective_history, dict) + assert not xgb_model.training_summary.validation_objective_history + + @staticmethod + def assert_non_empty_training_objective_history( + xgb_model: Union[ + SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel + ], + metric: str, + n_estimators: int, + ) -> None: + assert hasattr(xgb_model.training_summary, "train_objective_history") + assert isinstance(xgb_model.training_summary.train_objective_history, dict) + + assert metric in xgb_model.training_summary.train_objective_history + assert ( + len(xgb_model.training_summary.train_objective_history[metric]) + == n_estimators + ) + + for ( + training_metric, + loss_evolution, + ) in xgb_model.training_summary.train_objective_history.items(): + assert isinstance(training_metric, str) + assert len(loss_evolution) == n_estimators + assert all(isinstance(value, float) for value in loss_evolution) + + @staticmethod + def assert_non_empty_validation_objective_history( + xgb_model: Union[ + SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel + ], + metric: str, + n_estimators: int, + ) -> None: + assert hasattr(xgb_model.training_summary, "validation_objective_history") + assert isinstance(xgb_model.training_summary.validation_objective_history, dict) + + assert metric in xgb_model.training_summary.validation_objective_history + assert ( + len(xgb_model.training_summary.validation_objective_history[metric]) + == n_estimators + ) + + for ( + validation_metric, + loss_evolution, + ) in xgb_model.training_summary.validation_objective_history.items(): + assert isinstance(validation_metric, str) + assert len(loss_evolution) == n_estimators + assert all(isinstance(value, float) for value in loss_evolution) + + @pytest.mark.parametrize( + "spark_xgb_estimator, metric", + [ + (SparkXGBClassifier, "logloss"), + (SparkXGBClassifier, "error"), + (SparkXGBRegressor, "rmse"), + (SparkXGBRegressor, "mae"), + ], + ) + def test_xgb_summary_classification_regression( + self, + clf_and_reg_df: DataFrame, + spark_xgb_estimator: Union[SparkXGBClassifier, SparkXGBRegressor], + metric: str, + ) -> None: + n_estimators = 10 + spark_xgb_model = spark_xgb_estimator( + eval_metric=metric, n_estimators=n_estimators + ).fit(clf_and_reg_df) + self.assert_non_empty_training_objective_history( + spark_xgb_model, metric, n_estimators + ) + self.assert_empty_validation_objective_history(spark_xgb_model) + + @pytest.mark.parametrize( + "spark_xgb_estimator, metric", + [ + (SparkXGBClassifier, "logloss"), + (SparkXGBClassifier, "error"), + (SparkXGBRegressor, "rmse"), + (SparkXGBRegressor, "mae"), + ], + ) + def test_xgb_summary_classification_regression_with_validation( + self, + clf_and_reg_df_with_validation: DataFrame, + spark_xgb_estimator: Union[SparkXGBClassifier, SparkXGBRegressor], + metric: str, + ) -> None: + n_estimators = 10 + spark_xgb_model = spark_xgb_estimator( + eval_metric=metric, + validation_indicator_col="validation_indicator_col", + n_estimators=n_estimators, + ).fit(clf_and_reg_df_with_validation) + + self.assert_non_empty_training_objective_history( + spark_xgb_model, metric, n_estimators + ) + self.assert_non_empty_validation_objective_history( + spark_xgb_model, metric, n_estimators + ) + + @pytest.mark.parametrize("metric", ["ndcg", "map"]) + def test_xgb_summary_ranker(self, ranker_df: DataFrame, metric: str) -> None: + n_estimators = 10 + xgb_ranker = SparkXGBRanker( + qid_col="qid", eval_metric=metric, n_estimators=n_estimators + ) + xgb_ranker_model = xgb_ranker.fit(ranker_df) + + self.assert_non_empty_training_objective_history( + xgb_ranker_model, metric, n_estimators + ) + self.assert_empty_validation_objective_history(xgb_ranker_model) + + @pytest.mark.parametrize("metric", ["ndcg", "map"]) + def test_xgb_summary_ranker_with_validation( + self, ranker_df_with_validation: DataFrame, metric: str + ) -> None: + n_estimators = 10 + xgb_ranker_model = SparkXGBRanker( + qid_col="qid", + validation_indicator_col="validation_indicator_col", + eval_metric=metric, + n_estimators=n_estimators, + ).fit(ranker_df_with_validation) + + self.assert_non_empty_training_objective_history( + xgb_ranker_model, metric, n_estimators + ) + self.assert_non_empty_validation_objective_history( + xgb_ranker_model, metric, n_estimators + )