Skip to content

Commit

Permalink
fixed model signature logging for global models
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed May 28, 2024
1 parent ce9a29b commit 102671d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
21 changes: 15 additions & 6 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import mlflow
from mlflow.exceptions import MlflowException
from mlflow.tracking import MlflowClient
from mlflow.models import infer_signature
from mlflow.models import ModelSignature, infer_signature
from mlflow.types.schema import Schema, ColSpec
from omegaconf import OmegaConf, DictConfig
from omegaconf.basecontainer import BaseContainer
from pyspark.sql import SparkSession, DataFrame
Expand Down Expand Up @@ -240,13 +241,21 @@ def train_global_model(
model: ForecastingRegressor,
):
print(f"Started training {model_conf['name']}")
# Todo fix
model.fit(pd.concat([train_df, val_df]))
# TODO fix
signature = infer_signature(
model_input=train_df,
model_output=train_df,

input_example = train_df[train_df[self.conf['group_id']] == train_df[self.conf['group_id']]\
.unique()[0]].sort_values(by=[self.conf['date_col']])
input_schema = infer_signature(model_input=input_example).inputs
output_schema = Schema(
[
ColSpec("integer", "index"),
ColSpec("string", self.conf['group_id']),
ColSpec("datetime", self.conf['date_col']),
ColSpec("float", self.conf['target']),
]
)
input_example = train_df
signature = ModelSignature(inputs=input_schema, outputs=output_schema)
model_info = mlflow.sklearn.log_model(
model,
"model",
Expand Down
1 change: 0 additions & 1 deletion mmf_sa/models/neuralforecast/NeuralForecastPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def forecast(self, df: pd.DataFrame):
& (df[self.params["date_col"]]
<= np.datetime64(_last_date + self.prediction_length_offset))
]

_dynamic_future = self.prepare_data(_future_df, future=True)
_dynamic_future = None if _dynamic_future.empty else _dynamic_future
_static_df = self.prepare_static_features(_future_df)
Expand Down

0 comments on commit 102671d

Please sign in to comment.