Skip to content

Commit

Permalink
Merge pull request #65 from databricks-industry-solutions/fix-r-models
Browse files Browse the repository at this point in the history
fixed r model metric calculation
  • Loading branch information
ryuta-yoshimatsu authored Jul 8, 2024
2 parents 6201e40 + bdec305 commit e30596e
Show file tree
Hide file tree
Showing 6 changed files with 11 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,3 @@ def forecast(input_data, url=endpoint_url, databricks_token=token):

# Delete the serving endpoint
func_delete_model_serving_endpoint(model_serving_endpoint_name)

# COMMAND ----------


4 changes: 0 additions & 4 deletions examples/foundation_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,3 @@ def transform_group(df):
# COMMAND ----------

display(spark.sql(f"delete from {catalog}.{db}.daily_scoring_output"))

# COMMAND ----------


4 changes: 0 additions & 4 deletions examples/local_univariate_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,3 @@ def transform_group(df):
# COMMAND ----------

display(spark.sql(f"delete from {catalog}.{db}.daily_scoring_output"))

# COMMAND ----------


2 changes: 1 addition & 1 deletion examples/local_univariate_external_regressors_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@

catalog = "mmf" # Name of the catalog we use to manage our assets
db = "rossmann" # Name of the schema we use to manage our assets (e.g. datasets)
volume = "csv" # Name of the volume where you have your rossmann dataset csv sotred
volume = "csv" # Name of the volume where you have your rossmann dataset csv stored
user = spark.sql('select current_user() as user').collect()[0]['user'] # User email address

# COMMAND ----------
Expand Down
3 changes: 2 additions & 1 deletion mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,12 @@ def evaluate_local_model(self, model_conf):
evaluate_one_local_model_fn = functools.partial(
Forecaster.evaluate_one_local_model, model=model
)

res_sdf = (
src_df.groupby(self.conf["group_id"])
.applyInPandas(evaluate_one_local_model_fn, schema=output_schema)
)

# Write evaluation result to a delta table
if self.conf.get("evaluation_output", None) is not None:
(
Expand Down
28 changes: 8 additions & 20 deletions mmf_sa/models/abstract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,37 +137,25 @@ def calculate_metrics(
Returns: metrics (Dict[str, Union[str, float, bytes]]): A dictionary specifying the metrics.
"""
pred_df, model_fitted = self.predict(hist_df, val_df)

actual = val_df[self.params["target"]].to_numpy()
forecast = pred_df[self.params["target"]].to_numpy()

if self.params["metric"] == "smape":
smape = MeanAbsolutePercentageError(symmetric=True)
metric_value = smape(
val_df[self.params["target"]],
pred_df[self.params["target"]],
)
metric_value = smape(actual, forecast)
elif self.params["metric"] == "mape":
mape = MeanAbsolutePercentageError(symmetric=False)
metric_value = mape(
val_df[self.params["target"]],
pred_df[self.params["target"]],
)
metric_value = mape(actual, forecast)
elif self.params["metric"] == "mae":
mae = MeanAbsoluteError()
metric_value = mae(
val_df[self.params["target"]],
pred_df[self.params["target"]],
)
metric_value = mae(actual, forecast)
elif self.params["metric"] == "mse":
mse = MeanSquaredError(square_root=False)
metric_value = mse(
val_df[self.params["target"]],
pred_df[self.params["target"]],
)
metric_value = mse(actual, forecast)
elif self.params["metric"] == "rmse":
rmse = MeanSquaredError(square_root=True)
metric_value = rmse(
val_df[self.params["target"]],
pred_df[self.params["target"]],
)
metric_value = rmse(actual, forecast)
else:
raise Exception(f"Metric {self.params['metric']} not supported!")

Expand Down

0 comments on commit e30596e

Please sign in to comment.