Skip to content

Commit

Permalink
Renamed xgboost_training_summary.py to summary.py and _XGBoostTrainin…
Browse files Browse the repository at this point in the history
…gSummary to XGBoostTrainingSummary
  • Loading branch information
a.cherkaoui committed Jan 3, 2025
1 parent 25a06b8 commit 984bc8e
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 9 deletions.
10 changes: 4 additions & 6 deletions python-package/xgboost/spark/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@
HasFeaturesCols,
HasQueryIdCol,
)
from .summary import XGBoostTrainingSummary
from .utils import (
CommunicatorContext,
_get_default_params_from_func,
Expand All @@ -100,7 +101,6 @@
serialize_booster,
use_cuda,
)
from .xgboost_training_summary import _XGBoostTrainingSummary

# Put pyspark specific params here, they won't be passed to XGBoost.
# like `validationIndicatorCol`, `base_margin_col`
Expand Down Expand Up @@ -706,7 +706,7 @@ def _pyspark_model_cls(cls) -> Type["_SparkXGBModel"]:
raise NotImplementedError()

def _create_pyspark_model(
self, xgb_model: XGBModel, training_summary: _XGBoostTrainingSummary
self, xgb_model: XGBModel, training_summary: XGBoostTrainingSummary
) -> "_SparkXGBModel":
return self._pyspark_model_cls()(xgb_model, training_summary)

Expand Down Expand Up @@ -1202,9 +1202,7 @@ def _run_job() -> Tuple[str, str, str]:
result_xgb_model = self._convert_to_sklearn_model(
bytearray(booster, "utf-8"), config
)
training_summary = _XGBoostTrainingSummary.from_metrics(
json.loads(evals_result)
)
training_summary = XGBoostTrainingSummary.from_metrics(json.loads(evals_result))
spark_model = self._create_pyspark_model(result_xgb_model, training_summary)
# According to pyspark ML convention, the model uid should be the same
# with estimator uid.
Expand All @@ -1229,7 +1227,7 @@ class _SparkXGBModel(Model, _SparkXGBParams, MLReadable, MLWritable):
def __init__(
self,
xgb_sklearn_model: Optional[XGBModel] = None,
training_summary: Optional[_XGBoostTrainingSummary] = None,
training_summary: Optional[XGBoostTrainingSummary] = None,
) -> None:
super().__init__()
self._xgb_sklearn_model = xgb_sklearn_model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@


@dataclass
class _XGBoostTrainingSummary:
class XGBoostTrainingSummary:
"""
A class that holds the training and validation objective history
of an XGBoost model during its training process.
Expand All @@ -17,7 +17,7 @@ class _XGBoostTrainingSummary:
@staticmethod
def from_metrics(
metrics: Dict[str, Dict[str, List[float]]]
) -> "_XGBoostTrainingSummary":
) -> "XGBoostTrainingSummary":
"""
Create an XGBoostTrainingSummary instance from a nested dictionary of metrics.
Expand All @@ -38,6 +38,6 @@ def from_metrics(
"""
train_objective_history = metrics.get("training", {})
validation_objective_history = metrics.get("validation", {})
return _XGBoostTrainingSummary(
return XGBoostTrainingSummary(
train_objective_history, validation_objective_history
)

0 comments on commit 984bc8e

Please sign in to comment.