Skip to content

Commit

Permalink
model logging and registry for foundation models
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed Jun 3, 2024
1 parent 1ce9e0e commit d087995
Show file tree
Hide file tree
Showing 14 changed files with 40 additions and 131 deletions.
45 changes: 0 additions & 45 deletions .github/workflows/integration-test-gcp-pr.yml

This file was deleted.

49 changes: 0 additions & 49 deletions .github/workflows/integration-test-gcp-push.yml

This file was deleted.

25 changes: 13 additions & 12 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def resolve_source(self, key: str) -> DataFrame:
else:
return self.spark.read.table(self.conf[key])

def prepare_data_for_global_model(self, mode: str):
def prepare_data_for_global_model(self, mode: str = None):
src_df = self.resolve_source("train_data")
src_df, removed = DataQualityChecks(src_df, self.conf, self.spark).run()
if (mode == "scoring") \
Expand Down Expand Up @@ -284,7 +284,7 @@ def evaluate_global_model(self, model_conf):
model=model,
train_df=train_df,
val_df=val_df,
model_uri=model_info.model_uri, # This model_uri is from the final model
model_uri=model_info.model_uri, # This model_uri is from the final model
write=True,
)
print(f"Finished training {model_conf.get('name')}")
Expand Down Expand Up @@ -360,17 +360,15 @@ def evaluate_foundation_model(self, model_conf):
)
hist_df, removed = self.prepare_data_for_global_model("evaluating") # Reuse the same as global
train_df, val_df = self.split_df_train_val(hist_df)
model_uri = f"runs:/{run.info.run_id}/model"
metrics = self.backtest_global_model( # Reuse the same as global
model=model,
train_df=train_df,
val_df=val_df,
model_uri="",
model_uri=model_uri,
write=True,
)

mlflow.log_metric(self.conf["metric"], metrics)
mlflow.set_tag("action", "evaluate")
mlflow.set_tag("candidate", "true")
mlflow.set_tag("model_name", model.params["name"])
mlflow.set_tag("run_id", self.run_id)
mlflow.log_params(model.get_params())
Expand Down Expand Up @@ -485,8 +483,9 @@ def score_global_model(self, model_conf):
def score_foundation_model(self, model_conf):
print(f"Running scoring for {model_conf['name']}...")
model_name = model_conf["name"]
_, model_uri = self.get_model_for_scoring(model_conf)
model = self.model_registry.get_model(model_name)
hist_df, removed = self.prepare_data_for_global_model("evaluating")
hist_df, removed = self.prepare_data_for_global_model()
prediction_df, model_pretrained = model.forecast(hist_df, spark=self.spark)
sdf = self.spark.createDataFrame(prediction_df).drop('index')
(
Expand All @@ -496,20 +495,22 @@ def score_foundation_model(self, model_conf):
.withColumn("run_date", lit(self.run_date))
.withColumn("use_case", lit(self.conf["use_case_name"]))
.withColumn("model_pickle", lit(b""))
.withColumn("model_uri", lit(""))
.withColumn("model_uri", lit(model_uri))
.write.mode("append")
.saveAsTable(self.conf["scoring_output"])
)

def get_model_for_scoring(self, model_conf):
client = MlflowClient()
registered_name = f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
model_info = self.get_latest_model_info(client, registered_name)
model_version = model_info.version
model_uri = f"runs:/{model_info.run_id}/model"
if model_conf.get("model_type", None) == "global":
registered_name = f"{self.conf['model_output']}.{model_conf['name']}_{self.conf['use_case_name']}"
model_info = Forecaster.get_latest_model_info(client, registered_name)
model_version = model_info.version
model_uri = f"runs:/{model_info.run_id}/model"
model = mlflow.sklearn.load_model(f"models:/{registered_name}/{model_version}")
return model, model_uri
elif model_conf.get("model_type", None) == "foundation":
return None, model_uri
else:
return self.model_registry.get_model(model_conf["name"]), None

Expand Down
2 changes: 1 addition & 1 deletion mmf_sa/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def predict(self, x, y=None):
pass

@abstractmethod
def forecast(self, x):
def forecast(self, x, spark=None):
# TODO Shouldn't X be optional if we have a trainable model and provide a prediction length
pass

Expand Down
1 change: 1 addition & 0 deletions mmf_sa/models/chronosforecast/ChronosPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def register(self, registered_model_name: str):
pip_requirements=[
"git+https://github.com/amazon-science/chronos-forecasting.git",
"git+https://github.com/databricks-industry-solutions/many-model-forecasting.git",
"pyspark==3.5.0",
],
)

Expand Down
6 changes: 3 additions & 3 deletions notebooks/demo_foundation_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,12 +149,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.daily_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.daily_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.daily_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_foundation_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_global_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,12 +150,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.daily_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.daily_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.daily_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_global_external_regressors_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,12 @@

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_global_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,12 +160,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_local_univariate_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,12 +180,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.daily_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.daily_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.daily_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.daily_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_local_univariate_external_regressors_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,12 +131,12 @@

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.rossmann_daily_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.rossmann_daily_ensemble_output
6 changes: 3 additions & 3 deletions notebooks/demo_local_univariate_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,12 @@ def transform_group(df):

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_evaluation_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_evaluation_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_scoring_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_scoring_output

# COMMAND ----------

# MAGIC #%sql delete from solacc_uc.mmf.monthly_ensemble_output
# MAGIC %sql delete from solacc_uc.mmf.monthly_ensemble_output
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
name = "mmf_sa"
version = "0.0.1"
dependencies = [
#"rpy2==3.5.16", # causes issue when deploying the model to Model Serving
"kaleido==0.2.1",
"Jinja2",
"omegaconf==2.3.0",
Expand Down

0 comments on commit d087995

Please sign in to comment.