diff --git a/examples/foundation-model-examples/chronos/01_chronos_load_inference.py b/examples/foundation-model-examples/chronos/01_chronos_load_inference.py index 71e10d5..31de779 100644 --- a/examples/foundation-model-examples/chronos/01_chronos_load_inference.py +++ b/examples/foundation-model-examples/chronos/01_chronos_load_inference.py @@ -435,7 +435,3 @@ def forecast(input_data, url=endpoint_url, databricks_token=token): # Delete the serving endpoint func_delete_model_serving_endpoint(model_serving_endpoint_name) - -# COMMAND ---------- - - diff --git a/examples/foundation_daily.py b/examples/foundation_daily.py index 3217f10..981658e 100644 --- a/examples/foundation_daily.py +++ b/examples/foundation_daily.py @@ -195,7 +195,3 @@ def transform_group(df): # COMMAND ---------- display(spark.sql(f"delete from {catalog}.{db}.daily_scoring_output")) - -# COMMAND ---------- - - diff --git a/examples/local_univariate_daily.py b/examples/local_univariate_daily.py index 3d4990c..6ed05e5 100644 --- a/examples/local_univariate_daily.py +++ b/examples/local_univariate_daily.py @@ -234,7 +234,3 @@ def transform_group(df): # COMMAND ---------- display(spark.sql(f"delete from {catalog}.{db}.daily_scoring_output")) - -# COMMAND ---------- - - diff --git a/examples/local_univariate_external_regressors_daily.py b/examples/local_univariate_external_regressors_daily.py index 6c18072..1921266 100644 --- a/examples/local_univariate_external_regressors_daily.py +++ b/examples/local_univariate_external_regressors_daily.py @@ -40,7 +40,7 @@ catalog = "mmf" # Name of the catalog we use to manage our assets db = "rossmann" # Name of the schema we use to manage our assets (e.g. datasets) -volume = "csv" # Name of the volume where you have your rossmann dataset csv sotred +volume = "csv" # Name of the volume where you have your rossmann dataset csv stored user = spark.sql('select current_user() as user').collect()[0]['user'] # User email address # COMMAND ---------- diff --git a/mmf_sa/Forecaster.py b/mmf_sa/Forecaster.py index 99c0a7d..77b564a 100644 --- a/mmf_sa/Forecaster.py +++ b/mmf_sa/Forecaster.py @@ -222,11 +222,12 @@ def evaluate_local_model(self, model_conf): evaluate_one_local_model_fn = functools.partial( Forecaster.evaluate_one_local_model, model=model ) + res_sdf = ( src_df.groupby(self.conf["group_id"]) .applyInPandas(evaluate_one_local_model_fn, schema=output_schema) ) - + # Write evaluation result to a delta table if self.conf.get("evaluation_output", None) is not None: ( diff --git a/mmf_sa/models/abstract_model.py b/mmf_sa/models/abstract_model.py index 1b20791..0d74d03 100644 --- a/mmf_sa/models/abstract_model.py +++ b/mmf_sa/models/abstract_model.py @@ -137,37 +137,25 @@ def calculate_metrics( Returns: metrics (Dict[str, Union[str, float, bytes]]): A dictionary specifying the metrics. """ pred_df, model_fitted = self.predict(hist_df, val_df) + + actual = val_df[self.params["target"]].to_numpy() + forecast = pred_df[self.params["target"]].to_numpy() if self.params["metric"] == "smape": smape = MeanAbsolutePercentageError(symmetric=True) - metric_value = smape( - val_df[self.params["target"]], - pred_df[self.params["target"]], - ) + metric_value = smape(actual, forecast) elif self.params["metric"] == "mape": mape = MeanAbsolutePercentageError(symmetric=False) - metric_value = mape( - val_df[self.params["target"]], - pred_df[self.params["target"]], - ) + metric_value = mape(actual, forecast) elif self.params["metric"] == "mae": mae = MeanAbsoluteError() - metric_value = mae( - val_df[self.params["target"]], - pred_df[self.params["target"]], - ) + metric_value = mae(actual, forecast) elif self.params["metric"] == "mse": mse = MeanSquaredError(square_root=False) - metric_value = mse( - val_df[self.params["target"]], - pred_df[self.params["target"]], - ) + metric_value = mse(actual, forecast) elif self.params["metric"] == "rmse": rmse = MeanSquaredError(square_root=True) - metric_value = rmse( - val_df[self.params["target"]], - pred_df[self.params["target"]], - ) + metric_value = rmse(actual, forecast) else: raise Exception(f"Metric {self.params['metric']} not supported!")