From bb70cc546c8ca937355a5ffb6473cb802a02f7d7 Mon Sep 17 00:00:00 2001 From: Oleksandr Shchur Date: Fri, 13 Oct 2023 14:20:13 +0000 Subject: [PATCH] Move from pytorch_lightning to lightning --- .../advanced_topics/howto_pytorch_lightning.md.template | 2 +- docs/tutorials/advanced_topics/index.md | 2 +- requirements/requirements-pytorch.txt | 3 ++- src/gluonts/dataset/repository/_lstnet.py | 5 +---- src/gluonts/torch/model/d_linear/estimator.py | 2 +- src/gluonts/torch/model/d_linear/lightning_module.py | 2 +- src/gluonts/torch/model/deepar/lightning_module.py | 2 +- src/gluonts/torch/model/estimator.py | 4 ++-- src/gluonts/torch/model/lag_tst/estimator.py | 2 +- src/gluonts/torch/model/lag_tst/lightning_module.py | 2 +- src/gluonts/torch/model/lightning_util.py | 2 +- src/gluonts/torch/model/mqf2/lightning_module.py | 2 +- src/gluonts/torch/model/patch_tst/estimator.py | 2 +- src/gluonts/torch/model/patch_tst/lightning_module.py | 2 +- src/gluonts/torch/model/simple_feedforward/estimator.py | 2 +- .../torch/model/simple_feedforward/lightning_module.py | 2 +- src/gluonts/torch/model/tft/lightning_module.py | 2 +- src/gluonts/torch/model/wavenet/estimator.py | 2 +- src/gluonts/torch/model/wavenet/lightning_module.py | 2 +- 19 files changed, 21 insertions(+), 23 deletions(-) diff --git a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template index 693f070dcf..34e6937b19 100644 --- a/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template +++ b/docs/tutorials/advanced_topics/howto_pytorch_lightning.md.template @@ -133,7 +133,7 @@ To train the model using PyTorch Lightning, we only need to extend the class wit ```python -import pytorch_lightning as pl +import lightning.pytorch as pl ``` diff --git a/docs/tutorials/advanced_topics/index.md b/docs/tutorials/advanced_topics/index.md index c09840d7e8..303baf2d98 100644 --- a/docs/tutorials/advanced_topics/index.md +++ b/docs/tutorials/advanced_topics/index.md @@ -2,7 +2,7 @@ ```{toctree} :maxdepth: 1 -howto_pytorch_lightning +howto_lightning.pytorch hp_tuning_with_optuna trainer_callbacks ``` diff --git a/requirements/requirements-pytorch.txt b/requirements/requirements-pytorch.txt index 16a40f64a3..9947b4e9ae 100644 --- a/requirements/requirements-pytorch.txt +++ b/requirements/requirements-pytorch.txt @@ -1,5 +1,6 @@ torch>=1.9,<3 -pytorch-lightning>=1.5,<3 +lightning>=1.8,<2.2 +lightning.pytorch>=1.8,<2.2 # Need to pin protobuf (for now) # See: https://github.com/PyTorchLightning/pytorch-lightning/issues/13159 protobuf~=3.19.0 diff --git a/src/gluonts/dataset/repository/_lstnet.py b/src/gluonts/dataset/repository/_lstnet.py index ba46cf03ae..ed38aa90b5 100644 --- a/src/gluonts/dataset/repository/_lstnet.py +++ b/src/gluonts/dataset/repository/_lstnet.py @@ -141,10 +141,7 @@ def generate_lstnet_dataset( pd.read_csv(ds_info.url, header=None), # type: ignore ) - assert df.shape == ( - ds_info.num_time_steps, - ds_info.num_series, - ), ( + assert df.shape == (ds_info.num_time_steps, ds_info.num_series,), ( "expected num_time_steps/num_series" f" {(ds_info.num_time_steps, ds_info.num_series)} but got {df.shape}" ) diff --git a/src/gluonts/torch/model/d_linear/estimator.py b/src/gluonts/torch/model/d_linear/estimator.py index f62429d162..f8dc0453c1 100644 --- a/src/gluonts/torch/model/d_linear/estimator.py +++ b/src/gluonts/torch/model/d_linear/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/d_linear/lightning_module.py b/src/gluonts/torch/model/d_linear/lightning_module.py index 28dccf1b97..bd081b45dd 100644 --- a/src/gluonts/torch/model/d_linear/lightning_module.py +++ b/src/gluonts/torch/model/d_linear/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/deepar/lightning_module.py b/src/gluonts/torch/model/deepar/lightning_module.py index fc676dfab3..8d190e2329 100644 --- a/src/gluonts/torch/model/deepar/lightning_module.py +++ b/src/gluonts/torch/model/deepar/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/estimator.py b/src/gluonts/torch/model/estimator.py index 003d88f7fa..ba91f0d725 100644 --- a/src/gluonts/torch/model/estimator.py +++ b/src/gluonts/torch/model/estimator.py @@ -15,7 +15,7 @@ import logging import numpy as np -import pytorch_lightning as pl +import lightning.pytorch as pl import torch.nn as nn from gluonts.core.component import validated @@ -217,7 +217,7 @@ def train_model( logger.info( f"Loading best model from {checkpoint.best_model_path}" ) - best_model = training_network.load_from_checkpoint( + best_model = training_network.__class__.load_from_checkpoint( checkpoint.best_model_path ) else: diff --git a/src/gluonts/torch/model/lag_tst/estimator.py b/src/gluonts/torch/model/lag_tst/estimator.py index 330a1ff4b9..96f2b2a603 100644 --- a/src/gluonts/torch/model/lag_tst/estimator.py +++ b/src/gluonts/torch/model/lag_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any, List import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/lag_tst/lightning_module.py b/src/gluonts/torch/model/lag_tst/lightning_module.py index 2510944cfa..5c9e70e9e4 100644 --- a/src/gluonts/torch/model/lag_tst/lightning_module.py +++ b/src/gluonts/torch/model/lag_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/lightning_util.py b/src/gluonts/torch/model/lightning_util.py index 6742c8c7cf..73e2396140 100644 --- a/src/gluonts/torch/model/lightning_util.py +++ b/src/gluonts/torch/model/lightning_util.py @@ -13,7 +13,7 @@ from packaging import version -import pytorch_lightning as pl +import lightning.pytorch as pl def has_validation_loop(trainer: pl.Trainer): diff --git a/src/gluonts/torch/model/mqf2/lightning_module.py b/src/gluonts/torch/model/mqf2/lightning_module.py index 6dc824beb4..16916c3c41 100644 --- a/src/gluonts/torch/model/mqf2/lightning_module.py +++ b/src/gluonts/torch/model/mqf2/lightning_module.py @@ -13,7 +13,7 @@ from typing import Dict -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from torch.optim.lr_scheduler import ReduceLROnPlateau diff --git a/src/gluonts/torch/model/patch_tst/estimator.py b/src/gluonts/torch/model/patch_tst/estimator.py index de7b880f36..34dfa5dacb 100644 --- a/src/gluonts/torch/model/patch_tst/estimator.py +++ b/src/gluonts/torch/model/patch_tst/estimator.py @@ -14,7 +14,7 @@ from typing import Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/patch_tst/lightning_module.py b/src/gluonts/torch/model/patch_tst/lightning_module.py index f5e95158b2..d80137ae05 100644 --- a/src/gluonts/torch/model/patch_tst/lightning_module.py +++ b/src/gluonts/torch/model/patch_tst/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/simple_feedforward/estimator.py b/src/gluonts/torch/model/simple_feedforward/estimator.py index e43956d1ad..a909d4ee59 100644 --- a/src/gluonts/torch/model/simple_feedforward/estimator.py +++ b/src/gluonts/torch/model/simple_feedforward/estimator.py @@ -14,7 +14,7 @@ from typing import List, Optional, Iterable, Dict, Any import torch -import pytorch_lightning as pl +import lightning.pytorch as pl from gluonts.core.component import validated from gluonts.dataset.common import Dataset diff --git a/src/gluonts/torch/model/simple_feedforward/lightning_module.py b/src/gluonts/torch/model/simple_feedforward/lightning_module.py index b7cf9a529a..f03473e78d 100644 --- a/src/gluonts/torch/model/simple_feedforward/lightning_module.py +++ b/src/gluonts/torch/model/simple_feedforward/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated diff --git a/src/gluonts/torch/model/tft/lightning_module.py b/src/gluonts/torch/model/tft/lightning_module.py index f6f7daa335..4647d740fd 100644 --- a/src/gluonts/torch/model/tft/lightning_module.py +++ b/src/gluonts/torch/model/tft/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated from gluonts.itertools import select diff --git a/src/gluonts/torch/model/wavenet/estimator.py b/src/gluonts/torch/model/wavenet/estimator.py index ab9fc9db54..e7fea4b0d7 100644 --- a/src/gluonts/torch/model/wavenet/estimator.py +++ b/src/gluonts/torch/model/wavenet/estimator.py @@ -13,7 +13,7 @@ from typing import Any, Dict, List, Optional, Iterable -import pytorch_lightning as pl +import lightning.pytorch as pl import torch import numpy as np diff --git a/src/gluonts/torch/model/wavenet/lightning_module.py b/src/gluonts/torch/model/wavenet/lightning_module.py index daf78e451c..85dd0a6671 100644 --- a/src/gluonts/torch/model/wavenet/lightning_module.py +++ b/src/gluonts/torch/model/wavenet/lightning_module.py @@ -11,7 +11,7 @@ # express or implied. See the License for the specific language governing # permissions and limitations under the License. -import pytorch_lightning as pl +import lightning.pytorch as pl import torch from gluonts.core.component import validated