Skip to content

Commit

Permalink
chronos models integration
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed May 24, 2024
1 parent 5216ba6 commit 882cbb5
Show file tree
Hide file tree
Showing 16 changed files with 408 additions and 118 deletions.
70 changes: 48 additions & 22 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,7 @@ def backtest_global_model(
pd.concat([train_df, val_df]),
start=train_df[self.conf["date_col"]].max(),
retrain=self.conf["backtest_retrain"],
spark=self.spark,
))

group_id_dtype = IntegerType() \
Expand Down Expand Up @@ -507,21 +508,23 @@ def evaluate_global_model(self, model_conf):
print(f"Champion alias assigned to the new model")

def evaluate_foundation_model(self, model_conf):
model_name = model_conf["name"]
model = self.model_registry.get_model(model_name)
hist_df, removed = self.prepare_data_for_global_model("evaluating")
train_df, val_df = self.split_df_train_val(hist_df)
metrics = self.backtest_global_model(
model=model,
train_df=train_df,
val_df=val_df,
model_uri="",
write=True,
)
mlflow.set_tag("action", "train")
mlflow.set_tag("candidate", "true")
mlflow.set_tag("model_name", model.params["name"])
print(f"Finished training {model_conf.get('name')}")
with mlflow.start_run(experiment_id=self.experiment_id) as run:
model_name = model_conf["name"]
model = self.model_registry.get_model(model_name)
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
train_df, val_df = self.split_df_train_val(hist_df)
metrics = self.backtest_global_model( # Reuse the same as global
model=model,
train_df=train_df,
val_df=val_df,
model_uri="",
write=True,
)
mlflow.log_metric(self.conf["metric"], metrics)
mlflow.set_tag("action", "evaluate")
mlflow.set_tag("candidate", "true")
mlflow.set_tag("model_name", model.params["name"])
print(f"Finished evaluating {model_conf.get('name')}")

def score_models(self):
print("Starting run_scoring")
Expand All @@ -532,6 +535,8 @@ def score_models(self):
self.score_global_model(model_conf)
elif model_conf["model_type"] == "local":
self.score_local_model(model_conf)
elif model_conf["model_type"] == "foundation":
self.score_foundation_model(model_conf)
print(f"Finished scoring with {model_name}")
print("Finished run_scoring")

Expand Down Expand Up @@ -627,13 +632,24 @@ def score_global_model(self, model_conf):
.saveAsTable(self.conf["scoring_output"])
)

def get_latest_model_version(self, mlflow_client, registered_name):
latest_version = 1
for mv in mlflow_client.search_model_versions(f"name='{registered_name}'"):
version_int = int(mv.version)
if version_int > latest_version:
latest_version = version_int
return latest_version
def score_foundation_model(self, model_conf):
print(f"Running scoring for {model_conf['name']}...")
model_name = model_conf["name"]
model = self.model_registry.get_model(model_name)
hist_df, removed = self.prepare_data_for_global_model("evaluating")
prediction_df, model_pretrained = model.forecast(hist_df, spark=self.spark)
sdf = self.spark.createDataFrame(prediction_df).drop('index')
(
sdf.withColumn(self.conf["group_id"], col(self.conf["group_id"]).cast(StringType()))
.withColumn("model", lit(model_conf["name"]))
.withColumn("run_id", lit(self.run_id))
.withColumn("run_date", lit(self.run_date))
.withColumn("use_case", lit(self.conf["use_case_name"]))
.withColumn("model_pickle", lit(b""))
.withColumn("model_uri", lit(""))
.write.mode("append")
.saveAsTable(self.conf["scoring_output"])
)

def get_model_for_scoring(self, model_conf):
mlflow_client = MlflowClient()
Expand All @@ -649,6 +665,7 @@ def get_model_for_scoring(self, model_conf):
else:
return self.model_registry.get_model(model_conf["name"]), None


def flatten_nested_parameters(d):
out = {}
for key, val in d.items():
Expand All @@ -661,3 +678,12 @@ def flatten_nested_parameters(d):
else:
out[key] = val
return out


def get_latest_model_version(self, mlflow_client, registered_name):
latest_version = 1
for mv in mlflow_client.search_model_versions(f"name='{registered_name}'"):
version_int = int(mv.version)
if version_int > latest_version:
latest_version = version_int
return latest_version
4 changes: 4 additions & 0 deletions mmf_sa/base_forecasting_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ active_models:
- NeuralForecastAutoNHITS
- NeuralForecastAutoTiDE
- NeuralForecastAutoPatchTST
- ChronosT5Tiny
- ChronosT5Mini
- ChronosT5Small
- ChronosT5Base
- ChronosT5Large

#Here we can override hyperparameters for built-in models
Expand Down
5 changes: 4 additions & 1 deletion mmf_sa/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def load_models_conf():
return conf

def get_model(
self, model_name: str, override_conf: DictConfig = None
self,
model_name: str,
override_conf: DictConfig = None,
spark=None,
) -> ForecastingRegressor:
model_conf = self.active_models.get(model_name)
if override_conf is not None:
Expand Down
9 changes: 7 additions & 2 deletions mmf_sa/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
import pandas as pd
import cloudpickle
from typing import Dict, Union
from transformers import pipeline
from sklearn.base import BaseEstimator, RegressorMixin
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
import mlflow
mlflow.set_registry_uri("databricks-uc")


class ForecastingRegressor(BaseEstimator, RegressorMixin):
Expand Down Expand Up @@ -45,6 +48,7 @@ def backtest(
group_id: Union[str, int] = None,
stride: int = None,
retrain: bool = True,
spark=None,
) -> pd.DataFrame:
if stride is None:
stride = int(self.params.get("stride", 7))
Expand Down Expand Up @@ -73,7 +77,7 @@ def backtest(
if retrain:
self.fit(_df)

metrics = self.calculate_metrics(_df, actuals_df, curr_date)
metrics = self.calculate_metrics(_df, actuals_df, curr_date, spark)

if isinstance(metrics, dict):
evaluation_results = [
Expand Down Expand Up @@ -103,10 +107,11 @@ def backtest(
"actual",
"model_pickle"],
)

return res_df

def calculate_metrics(
self, hist_df: pd.DataFrame, val_df: pd.DataFrame, curr_date
self, hist_df: pd.DataFrame, val_df: pd.DataFrame, curr_date, spark=None
) -> Dict[str, Union[str, float, bytes]]:
pred_df, model_fitted = self.predict(hist_df, val_df)
smape = mean_absolute_percentage_error(
Expand Down
Loading

0 comments on commit 882cbb5

Please sign in to comment.