From 102671da3c10ac5062c356dd0279083107927357 Mon Sep 17 00:00:00 2001 From: Ryuta Yoshimatsu Date: Tue, 28 May 2024 09:00:34 +0200 Subject: [PATCH] fixed model signature logging for global models --- mmf_sa/Forecaster.py | 21 +++++++++++++------ .../neuralforecast/NeuralForecastPipeline.py | 1 - 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/mmf_sa/Forecaster.py b/mmf_sa/Forecaster.py index 7c4d52e..5a937cc 100644 --- a/mmf_sa/Forecaster.py +++ b/mmf_sa/Forecaster.py @@ -12,7 +12,8 @@ import mlflow from mlflow.exceptions import MlflowException from mlflow.tracking import MlflowClient -from mlflow.models import infer_signature +from mlflow.models import ModelSignature, infer_signature +from mlflow.types.schema import Schema, ColSpec from omegaconf import OmegaConf, DictConfig from omegaconf.basecontainer import BaseContainer from pyspark.sql import SparkSession, DataFrame @@ -240,13 +241,21 @@ def train_global_model( model: ForecastingRegressor, ): print(f"Started training {model_conf['name']}") + # Todo fix model.fit(pd.concat([train_df, val_df])) - # TODO fix - signature = infer_signature( - model_input=train_df, - model_output=train_df, + + input_example = train_df[train_df[self.conf['group_id']] == train_df[self.conf['group_id']]\ + .unique()[0]].sort_values(by=[self.conf['date_col']]) + input_schema = infer_signature(model_input=input_example).inputs + output_schema = Schema( + [ + ColSpec("integer", "index"), + ColSpec("string", self.conf['group_id']), + ColSpec("datetime", self.conf['date_col']), + ColSpec("float", self.conf['target']), + ] ) - input_example = train_df + signature = ModelSignature(inputs=input_schema, outputs=output_schema) model_info = mlflow.sklearn.log_model( model, "model", diff --git a/mmf_sa/models/neuralforecast/NeuralForecastPipeline.py b/mmf_sa/models/neuralforecast/NeuralForecastPipeline.py index b81df36..4933f58 100644 --- a/mmf_sa/models/neuralforecast/NeuralForecastPipeline.py +++ b/mmf_sa/models/neuralforecast/NeuralForecastPipeline.py @@ -135,7 +135,6 @@ def forecast(self, df: pd.DataFrame): & (df[self.params["date_col"]] <= np.datetime64(_last_date + self.prediction_length_offset)) ] - _dynamic_future = self.prepare_data(_future_df, future=True) _dynamic_future = None if _dynamic_future.empty else _dynamic_future _static_df = self.prepare_static_features(_future_df)