Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package][PySpark] Expose Training and Validation Metrics #11133

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
HasFeaturesCols,
HasQueryIdCol,
)
from .summary import XGBoostTrainingSummary
from .utils import (
CommunicatorContext,
_get_default_params_from_func,
Expand Down Expand Up @@ -704,8 +705,10 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
"""
raise NotImplementedError()

def _create_pyspark_model(self, xgb_model: XGBModel) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model)
def _create_pyspark_model(
self, xgb_model: XGBModel, training_summary: XGBoostTrainingSummary
) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model, training_summary)

def _convert_to_sklearn_model(self, booster: bytearray, config: str) -> XGBModel:
xgb_sklearn_params = self._gen_xgb_params_dict(
Expand Down Expand Up @@ -1148,7 +1151,7 @@ def _train_booster(
if dvalid is not None:
dval = [(dtrain, "training"), (dvalid, "validation")]
else:
dval = None
dval = [(dtrain, "training")]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@trivialfis, Could you check this is ok by enabling it by default?

booster = worker_train(
params=booster_params,
dtrain=dtrain,
Expand All @@ -1159,6 +1162,7 @@ def _train_booster(
context.barrier()

if context.partitionId() == 0:
yield pd.DataFrame({"data": [json.dumps(dict(evals_result))]})
config = booster.save_config()
yield pd.DataFrame({"data": [config]})
booster_json = booster.save_raw("json").decode("utf-8")
Expand All @@ -1167,7 +1171,7 @@ def _train_booster(
booster_chunk = booster_json[offset : offset + _MODEL_CHUNK_SIZE]
yield pd.DataFrame({"data": [booster_chunk]})

def _run_job() -> Tuple[str, str]:
def _run_job() -> Tuple[str, str, str]:
rdd = (
dataset.mapInPandas(
_train_booster, # type: ignore
Expand All @@ -1179,7 +1183,7 @@ def _run_job() -> Tuple[str, str]:
rdd_with_resource = self._try_stage_level_scheduling(rdd)
ret = rdd_with_resource.collect()
data = [v[0] for v in ret]
return data[0], "".join(data[1:])
return data[0], data[1], "".join(data[2:])

get_logger(_LOG_TAG).info(
"Running xgboost-%s on %s workers with"
Expand All @@ -1192,13 +1196,14 @@ def _run_job() -> Tuple[str, str]:
train_call_kwargs_params,
dmatrix_kwargs,
)
(config, booster) = _run_job()
(evals_result, config, booster) = _run_job()
get_logger(_LOG_TAG).info("Finished xgboost training!")

result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config
)
spark_model = self._create_pyspark_model(result_xgb_model)
training_summary = XGBoostTrainingSummary.from_metrics(json.loads(evals_result))
spark_model = self._create_pyspark_model(result_xgb_model, training_summary)
# According to pyspark ML convention, the model uid should be the same
# with estimator uid.
spark_model._resetUid(self.uid)
Expand All @@ -1219,9 +1224,14 @@ def read(cls) -> "SparkXGBReader":


class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
def __init__(self, xgb_sklearn_model: Optional[XGBModel] = None) -> None:
def __init__(
self,
xgb_sklearn_model: Optional[XGBModel] = None,
training_summary: Optional[XGBoostTrainingSummary] = None,
) -> None:
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model
self.training_summary = training_summary

@classmethod
def _xgb_cls(cls) -> Type[XGBModel]:
Expand Down
43 changes: 43 additions & 0 deletions python-package/xgboost/spark/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Xgboost training summary integration submodule."""

from dataclasses import dataclass, field
from typing import Dict, List


@dataclass
class XGBoostTrainingSummary:
"""
A class that holds the training and validation objective history
of an XGBoost model during its training process.
"""

train_objective_history: Dict[str, List[float]] = field(default_factory=dict)
validation_objective_history: Dict[str, List[float]] = field(default_factory=dict)

@staticmethod
def from_metrics(
metrics: Dict[str, Dict[str, List[float]]]
) -> "XGBoostTrainingSummary":
"""
Create an XGBoostTrainingSummary instance from a nested dictionary of metrics.

Parameters
----------
metrics : dict of str to dict of str to list of float
A dictionary containing training and validation metrics.
Example format:
{
"training": {"logloss": [0.1, 0.08]},
"validation": {"logloss": [0.12, 0.1]}
}

Returns
-------
A new instance of XGBoostTrainingSummary.

"""
train_objective_history = metrics.get("training", {})
validation_objective_history = metrics.get("validation", {})
return XGBoostTrainingSummary(
train_objective_history, validation_objective_history
)
233 changes: 233 additions & 0 deletions tests/test_distributed/test_with_spark/test_xgboost_summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,233 @@
import logging
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering, if we could put the tests in this file into the existing test_spark_local.py and reuse the existing test data?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can move them there without much effort. Let me know if you'd like me to proceed with that.
However, in my humble opinion, it’s better to keep them in this separate file, and here’s the rationale :
The test_spark_local.py file already exceeds 1800 lines of code, which makes it increasingly difficult to read, maintain and navigate. As new features are added to PySpark XGBoost, this file will only continue to grow, compounding the problem.
I think refactoring the tests to organize them by key features, rather than bundling everything under the TestPySparkLocal class would be a better long-term approach.

If we decide to keep the tests in this file, I can either leave the examples here as they are, or, as you suggested, for better modularity and data reuse, we could import them from test_spark_local.py. Another option is to store all shared data in a separate file, allowing both test_spark_local.py and test_xgboost_summary.py to import what they need from it.

Let me know what you think, I have no strong opinion on this.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, That's good point. Originally, I would like to separate the tests per the estimators. like XGBoostClassifier/Regressor/Ranker, instead of per features. So you can share the same dataset for different features.

from typing import Union

import pytest
from pyspark.ml.linalg import Vectors
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.functions import lit

from xgboost import testing as tm
from xgboost.spark import (
SparkXGBClassifier,
SparkXGBClassifierModel,
SparkXGBRanker,
SparkXGBRankerModel,
SparkXGBRegressor,
SparkXGBRegressorModel,
)

from .test_spark_local import spark as spark_local

logging.getLogger("py4j").setLevel(logging.INFO)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this for debug?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, and since it was also set in test_spark_local.py, I kept it. Do you prefer that we remove it ?


pytestmark = [tm.timeout(60), pytest.mark.skipif(**tm.no_spark())]


@pytest.fixture
def clf_and_reg_df(spark_local: SparkSession) -> DataFrame:
"""
Fixture to create a DataFrame with example data.
"""
data = [
(Vectors.dense([1.0, 2.0, 3.0]), 1),
(Vectors.dense([4.0, 5.0, 6.0]), 1),
(Vectors.dense([9.0, 4.0, 8.0]), 0),
(Vectors.dense([6.0, 2.0, 2.0]), 1),
(Vectors.dense([5.0, 4.0, 3.0]), 0),
]
columns = ["features", "label"]
return spark_local.createDataFrame(data, schema=columns)


@pytest.fixture
def clf_and_reg_df_with_validation(clf_and_reg_df: DataFrame) -> DataFrame:
"""
Fixture to create a DataFrame with example data.
"""
# split data into training and validation sets
train_df, validation_df = clf_and_reg_df.randomSplit([0.8, 0.2], seed=42)

# Add a column to indicate validation rows
train_df = train_df.withColumn("validation_indicator_col", lit(False))
validation_df = validation_df.withColumn("validation_indicator_col", lit(True))
return train_df.union(validation_df)


@pytest.fixture
def ranker_df(spark_local: SparkSession) -> DataFrame:
"""
Fixture to create a DataFrame with sample data for ranking tasks.
"""
data = [
(Vectors.dense([1.0, 2.0, 3.0]), 0, 0),
(Vectors.dense([4.0, 5.0, 6.0]), 1, 0),
(Vectors.dense([9.0, 4.0, 8.0]), 0, 0),
(Vectors.dense([6.0, 2.0, 2.0]), 1, 0),
(Vectors.dense([5.0, 4.0, 3.0]), 0, 0),
]
columns = ["features", "label", "qid"]
return spark_local.createDataFrame(data, schema=columns)


@pytest.fixture
def ranker_df_with_validation(ranker_df: DataFrame) -> DataFrame:
"""
Fixture to split the ranking DataFrame into training and validation sets,
add validation indicator, and merge them back into a single DataFrame.
"""
# Split the data into training and validation sets (80-20 split)
train_df, validation_df = ranker_df.randomSplit([0.8, 0.2], seed=42)

# Add a column to indicate whether the row is from the validation set
train_df = train_df.withColumn("validation_indicator_col", lit(False))
validation_df = validation_df.withColumn("validation_indicator_col", lit(True))

# Union the training and validation DataFrames
return train_df.union(validation_df)


class TestXGBoostTrainingSummary:
@staticmethod
def assert_empty_validation_objective_history(
xgb_model: Union[
SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel
]
) -> None:
assert hasattr(xgb_model.training_summary, "validation_objective_history")
assert isinstance(xgb_model.training_summary.validation_objective_history, dict)
assert not xgb_model.training_summary.validation_objective_history

@staticmethod
def assert_non_empty_training_objective_history(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if we could get the evaluate_results from xgboost itself and the training summary from xgboost-pyspark on the same dataset, and then check if they are equal? You can see some tests in test_spark_local.py are doing same comparison.

Copy link
Contributor Author

@ayoub317 ayoub317 Jan 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, absolutely. Thank you for pointing this out ! I tested this on a simple DataFrame locally, and the results matched perfectly. We should definitely add such tests, I’ll take care of that !

xgb_model: Union[
SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel
],
metric: str,
n_estimators: int,
) -> None:
assert hasattr(xgb_model.training_summary, "train_objective_history")
assert isinstance(xgb_model.training_summary.train_objective_history, dict)

assert metric in xgb_model.training_summary.train_objective_history
assert (
len(xgb_model.training_summary.train_objective_history[metric])
== n_estimators
)

for (
training_metric,
loss_evolution,
) in xgb_model.training_summary.train_objective_history.items():
assert isinstance(training_metric, str)
assert len(loss_evolution) == n_estimators
assert all(isinstance(value, float) for value in loss_evolution)

@staticmethod
def assert_non_empty_validation_objective_history(
xgb_model: Union[
SparkXGBClassifierModel, SparkXGBRankerModel, SparkXGBRegressorModel
],
metric: str,
n_estimators: int,
) -> None:
assert hasattr(xgb_model.training_summary, "validation_objective_history")
assert isinstance(xgb_model.training_summary.validation_objective_history, dict)

assert metric in xgb_model.training_summary.validation_objective_history
assert (
len(xgb_model.training_summary.validation_objective_history[metric])
== n_estimators
)

for (
validation_metric,
loss_evolution,
) in xgb_model.training_summary.validation_objective_history.items():
assert isinstance(validation_metric, str)
assert len(loss_evolution) == n_estimators
assert all(isinstance(value, float) for value in loss_evolution)

@pytest.mark.parametrize(
"spark_xgb_estimator, metric",
[
(SparkXGBClassifier, "logloss"),
(SparkXGBClassifier, "error"),
(SparkXGBRegressor, "rmse"),
(SparkXGBRegressor, "mae"),
],
)
def test_xgb_summary_classification_regression(
self,
clf_and_reg_df: DataFrame,
spark_xgb_estimator: Union[SparkXGBClassifier, SparkXGBRegressor],
metric: str,
) -> None:
n_estimators = 10
spark_xgb_model = spark_xgb_estimator(
eval_metric=metric, n_estimators=n_estimators
).fit(clf_and_reg_df)
self.assert_non_empty_training_objective_history(
spark_xgb_model, metric, n_estimators
)
self.assert_empty_validation_objective_history(spark_xgb_model)

@pytest.mark.parametrize(
"spark_xgb_estimator, metric",
[
(SparkXGBClassifier, "logloss"),
(SparkXGBClassifier, "error"),
(SparkXGBRegressor, "rmse"),
(SparkXGBRegressor, "mae"),
],
)
def test_xgb_summary_classification_regression_with_validation(
self,
clf_and_reg_df_with_validation: DataFrame,
spark_xgb_estimator: Union[SparkXGBClassifier, SparkXGBRegressor],
metric: str,
) -> None:
n_estimators = 10
spark_xgb_model = spark_xgb_estimator(
eval_metric=metric,
validation_indicator_col="validation_indicator_col",
n_estimators=n_estimators,
).fit(clf_and_reg_df_with_validation)

self.assert_non_empty_training_objective_history(
spark_xgb_model, metric, n_estimators
)
self.assert_non_empty_validation_objective_history(
spark_xgb_model, metric, n_estimators
)

@pytest.mark.parametrize("metric", ["ndcg", "map"])
def test_xgb_summary_ranker(self, ranker_df: DataFrame, metric: str) -> None:
n_estimators = 10
xgb_ranker = SparkXGBRanker(
qid_col="qid", eval_metric=metric, n_estimators=n_estimators
)
xgb_ranker_model = xgb_ranker.fit(ranker_df)

self.assert_non_empty_training_objective_history(
xgb_ranker_model, metric, n_estimators
)
self.assert_empty_validation_objective_history(xgb_ranker_model)

@pytest.mark.parametrize("metric", ["ndcg", "map"])
def test_xgb_summary_ranker_with_validation(
self, ranker_df_with_validation: DataFrame, metric: str
) -> None:
n_estimators = 10
xgb_ranker_model = SparkXGBRanker(
qid_col="qid",
validation_indicator_col="validation_indicator_col",
eval_metric=metric,
n_estimators=n_estimators,
).fit(ranker_df_with_validation)

self.assert_non_empty_training_objective_history(
xgb_ranker_model, metric, n_estimators
)
self.assert_non_empty_validation_objective_history(
xgb_ranker_model, metric, n_estimators
)
Loading