Skip to content

Commit

Permalink
Merge pull request #77 from databricks-industry-solutions/chronos-bug…
Browse files Browse the repository at this point in the history
…-fix

fixed a bug on the chronos pipeline
  • Loading branch information
ryuta-yoshimatsu authored Jan 16, 2025
2 parents 6d4b537 + c0b0581 commit 6e65bf3
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 34 deletions.
1 change: 0 additions & 1 deletion examples/foundation_daily.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,6 @@ def transform_group(df):
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
1 change: 0 additions & 1 deletion examples/foundation_monthly.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def transform_group(df):
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
1 change: 0 additions & 1 deletion examples/m5-examples/foundation_daily_m5.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
"MoiraiLarge",
"MoiraiMoESmall",
"MoiraiMoEBase",
"MoiraiMoELarge",
"TimesFM_1_0_200m",
"TimesFM_2_0_500m",
]
Expand Down
3 changes: 0 additions & 3 deletions mmf_sa/Forecaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def backtest_global_model(
spark=self.spark,
# backtest_retrain=self.conf["backtest_retrain"],
))

group_id_dtype = IntegerType() \
if train_df[self.conf["group_id"]].dtype == 'int' else StringType()

Expand All @@ -399,7 +398,6 @@ def backtest_global_model(
]
)
res_sdf = self.spark.createDataFrame(res_pdf, schema)

# Write evaluation results to a delta table
if write:
if self.conf.get("evaluation_output", None):
Expand All @@ -413,7 +411,6 @@ def backtest_global_model(
.write.mode("append")
.saveAsTable(self.conf.get("evaluation_output"))
)

# Compute aggregated metrics
res_df = (
res_sdf.groupby(["metric_name"])
Expand Down
41 changes: 13 additions & 28 deletions mmf_sa/models/chronosforecast/ChronosPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,15 +90,13 @@ def predict(self,
horizon_timestamps_udf(hist_df.ds).alias("ds"),
forecast_udf(hist_df.y).alias("y"))
).toPandas()

forecast_df = forecast_df.reset_index(drop=False).rename(
columns={
"unique_id": self.params.group_id,
"ds": self.params.date_col,
"y": self.params.target,
}
)

# Todo
# forecast_df[self.params.target] = forecast_df[self.params.target].clip(0.01)
return forecast_df, self.model
Expand Down Expand Up @@ -165,19 +163,13 @@ def predict_udf(bulk_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
import numpy as np
import pandas as pd
# Initialize the ChronosPipeline with a pretrained model from the specified repository
from chronos import BaseChronosPipeline, ChronosBoltPipeline
if "bolt" in self.repo:
pipeline = ChronosBoltPipeline.from_pretrained(
self.repo,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
else:
pipeline = BaseChronosPipeline.from_pretrained(
self.repo,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
from chronos import BaseChronosPipeline
pipeline = BaseChronosPipeline.from_pretrained(
self.repo,
device_map='cuda',
torch_dtype=torch.bfloat16,
)

# inference
for bulk in bulk_iterator:
median = []
Expand Down Expand Up @@ -262,19 +254,12 @@ def __init__(self, repository, prediction_length):
self.prediction_length = prediction_length
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# Initialize the ChronosPipeline with a pretrained model from the specified repository
from chronos import BaseChronosPipeline, ChronosBoltPipeline
if "bolt" in self.repository:
self.pipeline = ChronosBoltPipeline.from_pretrained(
self.repository,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
else:
self.pipeline = BaseChronosPipeline.from_pretrained(
self.repository,
device_map=self.device,
torch_dtype=torch.bfloat16,
)
from chronos import BaseChronosPipeline
self.pipeline = BaseChronosPipeline.from_pretrained(
self.repository,
device_map='cuda',
torch_dtype=torch.bfloat16,
)

def predict(self, context, input_data, params=None):
history = [torch.tensor(list(series)) for series in input_data]
Expand Down

0 comments on commit 6e65bf3

Please sign in to comment.