Skip to content

Commit

Permalink
bulk update
Browse files Browse the repository at this point in the history
  • Loading branch information
ryuta-yoshimatsu committed May 16, 2024
1 parent 6763e15 commit 2b83572
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 15 deletions.
1 change: 0 additions & 1 deletion RUNME.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
},
"num_workers": 0,
"node_type_id": {"AWS": "i3.xlarge", "MSA": "Standard_DS3_v2", "GCP": "n1-highmem-4"},
"driver_node_type_id": {"AWS": "i3.xlarge", "MSA": "Standard_DS3_v2", "GCP": "n1-highmem-4"},
"custom_tags": {
"ResourceClass": "SingleNode",
"usage": "solacc_testing"
Expand Down
23 changes: 9 additions & 14 deletions forecasting_sa/models/sktime/SKTimeForecastingPipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,13 @@ def __init__(self, params):
self.model = None
self.param_grid = self.create_param_grid()

@abstractmethod
def create_model(self) -> BaseForecaster:
pass

def create_param_grid(self) -> Dict[str, Any]:
return {}

def prepare_data(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.copy().fillna(0.1)
df[self.params.target] = df[self.params.target].clip(0.1)
Expand All @@ -36,10 +43,7 @@ def prepare_data(self, df: pd.DataFrame) -> pd.DataFrame:
df = df.set_index(self.params.date_col)
df = df.reindex(date_idx, method="backfill")
df = df.sort_index()
df = pd.DataFrame(
{"y": df[self.params.target].values},
index=df.index.to_period(self.params.freq),
)
df = pd.DataFrame({"y": df[self.params.target].values}, index=df.index.to_period(self.params.freq))
return df

def fit(self, x, y=None):
Expand Down Expand Up @@ -82,13 +86,6 @@ def predict(self, hist_df: pd.DataFrame, val_df: pd.DataFrame = None):
def forecast(self, x):
return self.predict(x)

@abstractmethod
def create_model(self) -> BaseForecaster:
pass

def create_param_grid(self) -> Dict[str, Any]:
return {}


class SKTimeLgbmDsDt(SKTimeForecastingPipeline):
def __init__(self, params):
Expand All @@ -97,9 +94,7 @@ def __init__(self, params):
def create_model(self) -> BaseForecaster:
model = TransformedTargetForecaster(
[
(
"deseasonalise",
ConditionalDeseasonalizer(
("deseasonalise", ConditionalDeseasonalizer(
model=self.model_spec.get("deseasonalise_model", "additive"),
sp=int(self.model_spec.get("season_length", 1)),
),
Expand Down

0 comments on commit 2b83572

Please sign in to comment.