Skip to content

Commit

Permalink
Merge pull request #36 from databricks-industry-solutions/integrate-m…
Browse files Browse the repository at this point in the history
…oment

moment (foundation model) integration
  • Loading branch information
ryuta-yoshimatsu authored May 26, 2024
2 parents fef34ce + b251794 commit 2fddf6c
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 6 deletions.
1 change: 1 addition & 0 deletions mmf_sa/base_forecasting_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ active_models:
- ChronosT5Small
- ChronosT5Base
- ChronosT5Large
- Moment1Large

#Here we can override hyperparameters for built-in models
models:
Expand Down
18 changes: 13 additions & 5 deletions mmf_sa/models/chronosforecast/ChronosPipeline.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from abc import ABC
import subprocess
import sys
import pandas as pd
import numpy as np
import torch
from chronos import ChronosPipeline
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
from pyspark.sql import DataFrame
import mlflow
from mmf_sa.models.abstract_model import ForecastingRegressor
mlflow.set_registry_uri("databricks-uc")


class ChronosForecaster(ForecastingRegressor):
Expand All @@ -18,6 +17,10 @@ def __init__(self, params):
self.params = params
self.device = None
self.model = None
self.install("git+https://github.com/amazon-science/chronos-forecasting.git")

def install(self, package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def create_horizon_timestamps_udf(self):
@pandas_udf('array<timestamp>')
Expand Down Expand Up @@ -57,7 +60,7 @@ def predict(self,

horizon_timestamps_udf = self.create_horizon_timestamps_udf()

# Todo figure out the distribution
# Todo figure out the distribution strategy
forecast_df = (
hist_df.repartition(4)
.select(
Expand All @@ -74,6 +77,7 @@ def predict(self,
}
)

# Todo
#forecast_df[self.params.target] = forecast_df[self.params.target].clip(0.01)

return forecast_df, self.model
Expand Down Expand Up @@ -117,6 +121,7 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from chronos import ChronosPipeline
self.model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-tiny",
device_map=self.device, # use "cuda" for GPU and "cpu" for CPU inference
Expand All @@ -129,6 +134,7 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from chronos import ChronosPipeline
self.model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-mini",
device_map=self.device, # use "cuda" for GPU and "cpu" for CPU inference
Expand All @@ -141,6 +147,7 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from chronos import ChronosPipeline
self.model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map=self.device, # use "cuda" for GPU and "cpu" for CPU inference
Expand All @@ -153,6 +160,7 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from chronos import ChronosPipeline
self.model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-base",
device_map=self.device, # use "cuda" for GPU and "cpu" for CPU inference
Expand All @@ -165,6 +173,7 @@ def __init__(self, params):
super().__init__(params)
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
from chronos import ChronosPipeline
self.model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-large",
device_map=self.device, # use "cuda" for GPU and "cpu" for CPU inference
Expand All @@ -173,7 +182,6 @@ def __init__(self, params):


def create_predict_udf(prediction_length: int, num_samples: int):

@pandas_udf('array<double>')
def predict_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
# initialization step
Expand Down
8 changes: 8 additions & 0 deletions mmf_sa/models/models_conf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -407,3 +407,11 @@ models:
model_type: foundation
trainable: false
num_samples: 20

Moment1Large:
module: mmf_sa.models.momentforecast.MomentPipeline
model_class: Moment1Large
framework: Moment
model_type: foundation
trainable: false
num_samples: 20
157 changes: 157 additions & 0 deletions mmf_sa/models/momentforecast/MomentPipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
from abc import ABC
import subprocess
import sys
import pandas as pd
import numpy as np
import torch
from sktime.performance_metrics.forecasting import mean_absolute_percentage_error
from typing import Iterator
from pyspark.sql.functions import collect_list, pandas_udf
from pyspark.sql import DataFrame
from mmf_sa.models.abstract_model import ForecastingRegressor


class MomentForecaster(ForecastingRegressor):
def __init__(self, params):
super().__init__(params)
self.params = params
self.device = None
self.model = None
self.install("git+https://github.com/moment-timeseries-foundation-model/moment.git")

def install(self, package: str):
subprocess.check_call([sys.executable, "-m", "pip", "install", package, "--quiet"])

def create_horizon_timestamps_udf(self):
@pandas_udf('array<timestamp>')
def horizon_timestamps_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
batch_horizon_timestamps = []
for batch in batch_iterator:
for series in batch:
last = series.max()
horizon_timestamps = []
for i in range(self.params["prediction_length"]):
last = last + self.one_ts_offset
horizon_timestamps.append(last)
batch_horizon_timestamps.append(np.array(horizon_timestamps))
yield pd.Series(batch_horizon_timestamps)
return horizon_timestamps_udf

def prepare_data(self, df: pd.DataFrame, future: bool = False, spark=None) -> DataFrame:
df = spark.createDataFrame(df)
df = (
df.groupBy(self.params.group_id)
.agg(
collect_list(self.params.date_col).alias('ds'),
collect_list(self.params.target).alias('y'),
))
return df

def predict(self,
hist_df: pd.DataFrame,
val_df: pd.DataFrame = None,
curr_date=None,
spark=None):
hist_df = self.prepare_data(hist_df, spark=spark)
forecast_udf = self.create_predict_udf()
horizon_timestamps_udf = self.create_horizon_timestamps_udf()
# Todo figure out the distribution strategy
forecast_df = (
hist_df.repartition(4)
.select(
hist_df.unique_id,
horizon_timestamps_udf(hist_df.ds).alias("ds"),
forecast_udf(hist_df.y).alias("y"))
).toPandas()

forecast_df = forecast_df.reset_index(drop=False).rename(
columns={
"unique_id": self.params.group_id,
"ds": self.params.date_col,
"y": self.params.target,
}
)

# Todo
#forecast_df[self.params.target] = forecast_df[self.params.target].clip(0.01)

return forecast_df, self.model

def forecast(self, df: pd.DataFrame, spark=None):
return self.predict(df, spark=spark)

def calculate_metrics(
self, hist_df: pd.DataFrame, val_df: pd.DataFrame, curr_date, spark=None
) -> list:
pred_df, model_pretrained = self.predict(hist_df, val_df, curr_date, spark)
keys = pred_df[self.params["group_id"]].unique()
metrics = []
if self.params["metric"] == "smape":
metric_name = "smape"
else:
raise Exception(f"Metric {self.params['metric']} not supported!")
for key in keys:
actual = val_df[val_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()
forecast = pred_df[pred_df[self.params["group_id"]] == key][self.params["target"]].to_numpy()[0]
try:
if metric_name == "smape":
metric_value = mean_absolute_percentage_error(actual, forecast, symmetric=True)
metrics.extend(
[(
key,
curr_date,
metric_name,
metric_value,
actual,
forecast,
b'',
)])
except:
pass
return metrics

def create_predict_udf(self):
@pandas_udf('array<double>')
def predict_udf(batch_iterator: Iterator[pd.Series]) -> Iterator[pd.Series]:
import torch
import pandas as pd
for batch in batch_iterator:
batch_forecast = []
for series in batch:
# takes in tensor of shape [batchsize, n_channels, context_length]
context = list(series)
if len(context) < 512:
input_mask = [1] * len(context) + [0] * (512 - len(context))
context = context + [0] * (512 - len(context))
else:
input_mask = [1] * 512
context = context[-512:]
input_mask = torch.reshape(torch.tensor(input_mask), (1, 512))
context = torch.reshape(torch.tensor(context), (1, 1, 512)).to(dtype=torch.float32)
output = self.model(context, input_mask=input_mask)
forecast = output.forecast.squeeze().tolist()
batch_forecast.append(forecast)
yield pd.Series(batch_forecast)
return predict_udf


class Moment1Large(MomentForecaster):
def __init__(self, params):
super().__init__(params)
from momentfm import MOMENTPipeline
self.params = params
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.model = MOMENTPipeline.from_pretrained(
"AutonLab/MOMENT-1-large",
model_kwargs={
'task_name': 'forecasting',
'forecast_horizon': self.params["prediction_length"],
'head_dropout': 0.1,
'weight_decay': 0,
'freeze_encoder': True, # Freeze the patch embedding layer
'freeze_embedder': True, # Freeze the transformer encoder
'freeze_head': False, # The linear forecasting head must be trained
},
)
self.model.init()

Empty file.
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ sktime==0.29.0
lightgbm==4.3.0
datasetsforecast==0.0.8
fugue==0.9.0
git+https://github.com/amazon-science/chronos-forecasting.git

0 comments on commit 2fddf6c

Please sign in to comment.