From dade5ae61fd61d47ffb7fc08d3d7adaeed05bd3b Mon Sep 17 00:00:00 2001 From: salvaRC Date: Thu, 5 Dec 2024 03:32:06 -0800 Subject: [PATCH] add code --- .github/workflows/ci.yaml | 46 + .gitignore | 221 +++ .pre-commit-config.yaml | 15 + CONTRIBUTING.md | 91 ++ LICENSE | 4 +- Makefile | 46 + README.md | 145 +- environment/README.md | 30 + environment/install_dependencies.sh | 13 + pyproject.toml | 18 + run_inference.py | 14 + scripts/.gitkeep | 0 setup.cfg | 7 + setup.py | 237 +++ src/__init__.py | 1 + src/ace_inference/LICENSE | 201 +++ src/ace_inference/README.md | 4 + src/ace_inference/__init__.py | 1 + src/ace_inference/core/__init__.py | 0 src/ace_inference/core/aggregator/__init__.py | 1 + .../core/aggregator/climate_data.py | 233 +++ .../core/aggregator/inference/__init__.py | 0 .../core/aggregator/inference/main.py | 219 +++ .../core/aggregator/inference/reduced.py | 293 ++++ .../core/aggregator/inference/time_mean.py | 226 +++ .../aggregator/inference/time_mean_salva.py | 150 ++ .../core/aggregator/inference/video.py | 448 ++++++ .../core/aggregator/inference/zonal_mean.py | 129 ++ src/ace_inference/core/aggregator/null.py | 32 + .../core/aggregator/one_step/__init__.py | 0 .../core/aggregator/one_step/derived.py | 132 ++ .../core/aggregator/one_step/main.py | 97 ++ .../core/aggregator/one_step/reduced.py | 156 ++ .../aggregator/one_step/reduced_metrics.py | 78 + .../core/aggregator/one_step/reduced_salva.py | 136 ++ .../core/aggregator/one_step/snapshot.py | 161 ++ src/ace_inference/core/aggregator/plotting.py | 33 + .../core/aggregator/reduced_metrics.py | 118 ++ src/ace_inference/core/aggregator/train.py | 43 + src/ace_inference/core/constants.py | 6 + src/ace_inference/core/corrector.py | 296 ++++ .../core/data_loading/__init__.py | 0 .../core/data_loading/_xarray.py | 328 +++++ .../core/data_loading/data_typing.py | 110 ++ .../core/data_loading/get_loader.py | 119 ++ .../core/data_loading/getters.py | 173 +++ .../core/data_loading/inference.py | 175 +++ src/ace_inference/core/data_loading/params.py | 77 + .../core/data_loading/requirements.py | 11 + src/ace_inference/core/data_loading/utils.py | 107 ++ src/ace_inference/core/device.py | 12 + src/ace_inference/core/dicts.py | 41 + src/ace_inference/core/distributed.py | 107 ++ src/ace_inference/core/ema.py | 143 ++ src/ace_inference/core/histogram.py | 99 ++ src/ace_inference/core/loss.py | 255 ++++ src/ace_inference/core/metrics.py | 367 +++++ src/ace_inference/core/normalizer.py | 126 ++ src/ace_inference/core/ocean.py | 146 ++ src/ace_inference/core/optimization.py | 190 +++ src/ace_inference/core/packer.py | 70 + src/ace_inference/core/parameter_init.py | 115 ++ src/ace_inference/core/prescriber.py | 134 ++ src/ace_inference/core/registry.py | 194 +++ src/ace_inference/core/scheduler.py | 29 + src/ace_inference/core/stepper.py | 591 ++++++++ src/ace_inference/core/stepper_multistep.py | 463 ++++++ src/ace_inference/core/wandb.py | 189 +++ src/ace_inference/core/weight_ops.py | 166 +++ src/ace_inference/core/wildcard.py | 40 + src/ace_inference/core/winds.py | 170 +++ src/ace_inference/inference/__init__.py | 0 .../inference/data_writer/__init__.py | 0 .../inference/data_writer/histograms.py | 148 ++ .../inference/data_writer/main.py | 187 +++ .../inference/data_writer/prediction.py | 131 ++ .../inference/data_writer/time_coarsen.py | 141 ++ .../inference/data_writer/video.py | 73 + .../inference/derived_variables.py | 132 ++ src/ace_inference/inference/gcs_utils.py | 20 + src/ace_inference/inference/inference.py | 333 +++++ src/ace_inference/inference/logging_utils.py | 150 ++ src/ace_inference/inference/loop.py | 326 +++++ src/ace_inference/training/__init__.py | 0 src/ace_inference/training/registry.py | 197 +++ src/ace_inference/training/train.py | 418 ++++++ src/ace_inference/training/train_config.py | 253 ++++ src/ace_inference/training/utils/__init__.py | 0 .../training/utils/darcy_loss.py | 350 +++++ .../training/utils/data_loader_fv3gfs.py | 252 ++++ .../training/utils/data_loader_multifiles.py | 174 +++ .../training/utils/data_loader_params.py | 40 + .../training/utils/data_requirements.py | 11 + src/ace_inference/training/utils/img_utils.py | 66 + .../ckpts_from_huggingface_10years.yaml | 51 + .../ckpts_from_huggingface_debug.yaml | 51 + src/datamodules/__init__.py | 0 src/datamodules/_dataset_dimensions.py | 27 + src/datamodules/abstract_datamodule.py | 281 ++++ src/datamodules/debug_datamodule.py | 114 ++ src/datamodules/fv3gfs_ensemble.py | 280 ++++ src/dependency_versions_table.py | 34 + src/diffusion/__init__.py | 0 src/diffusion/_base_diffusion.py | 80 + src/diffusion/dyffusion.py | 738 ++++++++++ src/evaluation/__init__.py | 0 src/evaluation/aggregators/__init__.py | 0 .../aggregators/_abstract_aggregator.py | 68 + src/evaluation/aggregators/main.py | 153 ++ src/evaluation/aggregators/snapshot.py | 208 +++ src/evaluation/aggregators/time_mean.py | 116 ++ src/evaluation/aggregators/timestepwise.py | 214 +++ src/evaluation/metrics.py | 456 ++++++ src/evaluation/reduced_metrics.py | 122 ++ src/experiment_types/__init__.py | 0 src/experiment_types/_base_experiment.py | 1275 ++++++++++++++++ .../forecasting_multi_horizon.py | 680 +++++++++ src/experiment_types/interpolation.py | 183 +++ src/interface.py | 313 ++++ src/losses/__init__.py | 0 src/losses/losses.py | 79 + src/models/__init__.py | 0 src/models/_base_model.py | 300 ++++ src/models/modules/__init__.py | 0 src/models/modules/attention.py | 116 ++ src/models/modules/convs.py | 30 + src/models/modules/drop_path.py | 36 + src/models/modules/ema.py | 91 ++ src/models/modules/misc.py | 148 ++ src/models/modules/net_norm.py | 37 + src/models/sfno/__init__.py | 0 src/models/sfno/activations.py | 110 ++ src/models/sfno/contractions.py | 193 +++ src/models/sfno/distributed/__init__.py | 13 + src/models/sfno/distributed/comm.py | 314 ++++ src/models/sfno/distributed/helpers.py | 194 +++ src/models/sfno/distributed/layer_norm.py | 133 ++ src/models/sfno/distributed/layers.py | 539 +++++++ src/models/sfno/distributed/mappings.py | 340 +++++ src/models/sfno/factorizations.py | 225 +++ src/models/sfno/initialization.py | 73 + src/models/sfno/layers.py | 511 +++++++ src/models/sfno/preprocessor.py | 252 ++++ src/models/sfno/s2convolutions.py | 548 +++++++ src/models/sfno/sfnonet.py | 841 +++++++++++ src/models/unet.py | 376 +++++ src/train.py | 196 +++ src/utilities/__init__.py | 0 src/utilities/checkpointing.py | 154 ++ src/utilities/config_utils.py | 916 ++++++++++++ src/utilities/lr_scheduler.py | 201 +++ src/utilities/naming.py | 509 +++++++ src/utilities/normalization.py | 117 ++ src/utilities/packer.py | 77 + src/utilities/s3utils.py | 383 +++++ src/utilities/utils.py | 967 ++++++++++++ src/utilities/wandb_api.py | 1296 +++++++++++++++++ src/utilities/wandb_callbacks.py | 355 +++++ utils/check_copies.py | 203 +++ utils/get_modified_files.py | 34 + utils/release.py | 134 ++ 161 files changed, 28712 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/ci.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 CONTRIBUTING.md create mode 100644 Makefile create mode 100644 environment/README.md create mode 100644 environment/install_dependencies.sh create mode 100644 pyproject.toml create mode 100644 run_inference.py create mode 100644 scripts/.gitkeep create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 src/__init__.py create mode 100644 src/ace_inference/LICENSE create mode 100644 src/ace_inference/README.md create mode 100644 src/ace_inference/__init__.py create mode 100644 src/ace_inference/core/__init__.py create mode 100644 src/ace_inference/core/aggregator/__init__.py create mode 100644 src/ace_inference/core/aggregator/climate_data.py create mode 100644 src/ace_inference/core/aggregator/inference/__init__.py create mode 100644 src/ace_inference/core/aggregator/inference/main.py create mode 100644 src/ace_inference/core/aggregator/inference/reduced.py create mode 100644 src/ace_inference/core/aggregator/inference/time_mean.py create mode 100644 src/ace_inference/core/aggregator/inference/time_mean_salva.py create mode 100644 src/ace_inference/core/aggregator/inference/video.py create mode 100644 src/ace_inference/core/aggregator/inference/zonal_mean.py create mode 100644 src/ace_inference/core/aggregator/null.py create mode 100644 src/ace_inference/core/aggregator/one_step/__init__.py create mode 100644 src/ace_inference/core/aggregator/one_step/derived.py create mode 100644 src/ace_inference/core/aggregator/one_step/main.py create mode 100644 src/ace_inference/core/aggregator/one_step/reduced.py create mode 100644 src/ace_inference/core/aggregator/one_step/reduced_metrics.py create mode 100644 src/ace_inference/core/aggregator/one_step/reduced_salva.py create mode 100644 src/ace_inference/core/aggregator/one_step/snapshot.py create mode 100644 src/ace_inference/core/aggregator/plotting.py create mode 100644 src/ace_inference/core/aggregator/reduced_metrics.py create mode 100644 src/ace_inference/core/aggregator/train.py create mode 100644 src/ace_inference/core/constants.py create mode 100644 src/ace_inference/core/corrector.py create mode 100644 src/ace_inference/core/data_loading/__init__.py create mode 100644 src/ace_inference/core/data_loading/_xarray.py create mode 100644 src/ace_inference/core/data_loading/data_typing.py create mode 100644 src/ace_inference/core/data_loading/get_loader.py create mode 100644 src/ace_inference/core/data_loading/getters.py create mode 100644 src/ace_inference/core/data_loading/inference.py create mode 100644 src/ace_inference/core/data_loading/params.py create mode 100644 src/ace_inference/core/data_loading/requirements.py create mode 100644 src/ace_inference/core/data_loading/utils.py create mode 100644 src/ace_inference/core/device.py create mode 100644 src/ace_inference/core/dicts.py create mode 100644 src/ace_inference/core/distributed.py create mode 100644 src/ace_inference/core/ema.py create mode 100644 src/ace_inference/core/histogram.py create mode 100644 src/ace_inference/core/loss.py create mode 100644 src/ace_inference/core/metrics.py create mode 100644 src/ace_inference/core/normalizer.py create mode 100644 src/ace_inference/core/ocean.py create mode 100644 src/ace_inference/core/optimization.py create mode 100644 src/ace_inference/core/packer.py create mode 100644 src/ace_inference/core/parameter_init.py create mode 100644 src/ace_inference/core/prescriber.py create mode 100644 src/ace_inference/core/registry.py create mode 100644 src/ace_inference/core/scheduler.py create mode 100644 src/ace_inference/core/stepper.py create mode 100644 src/ace_inference/core/stepper_multistep.py create mode 100644 src/ace_inference/core/wandb.py create mode 100644 src/ace_inference/core/weight_ops.py create mode 100644 src/ace_inference/core/wildcard.py create mode 100644 src/ace_inference/core/winds.py create mode 100644 src/ace_inference/inference/__init__.py create mode 100644 src/ace_inference/inference/data_writer/__init__.py create mode 100644 src/ace_inference/inference/data_writer/histograms.py create mode 100644 src/ace_inference/inference/data_writer/main.py create mode 100644 src/ace_inference/inference/data_writer/prediction.py create mode 100644 src/ace_inference/inference/data_writer/time_coarsen.py create mode 100644 src/ace_inference/inference/data_writer/video.py create mode 100644 src/ace_inference/inference/derived_variables.py create mode 100644 src/ace_inference/inference/gcs_utils.py create mode 100755 src/ace_inference/inference/inference.py create mode 100644 src/ace_inference/inference/logging_utils.py create mode 100644 src/ace_inference/inference/loop.py create mode 100644 src/ace_inference/training/__init__.py create mode 100644 src/ace_inference/training/registry.py create mode 100644 src/ace_inference/training/train.py create mode 100644 src/ace_inference/training/train_config.py create mode 100644 src/ace_inference/training/utils/__init__.py create mode 100644 src/ace_inference/training/utils/darcy_loss.py create mode 100644 src/ace_inference/training/utils/data_loader_fv3gfs.py create mode 100644 src/ace_inference/training/utils/data_loader_multifiles.py create mode 100644 src/ace_inference/training/utils/data_loader_params.py create mode 100644 src/ace_inference/training/utils/data_requirements.py create mode 100644 src/ace_inference/training/utils/img_utils.py create mode 100644 src/configs/inference/ckpts_from_huggingface_10years.yaml create mode 100644 src/configs/inference/ckpts_from_huggingface_debug.yaml create mode 100644 src/datamodules/__init__.py create mode 100644 src/datamodules/_dataset_dimensions.py create mode 100644 src/datamodules/abstract_datamodule.py create mode 100644 src/datamodules/debug_datamodule.py create mode 100644 src/datamodules/fv3gfs_ensemble.py create mode 100644 src/dependency_versions_table.py create mode 100644 src/diffusion/__init__.py create mode 100644 src/diffusion/_base_diffusion.py create mode 100644 src/diffusion/dyffusion.py create mode 100644 src/evaluation/__init__.py create mode 100644 src/evaluation/aggregators/__init__.py create mode 100644 src/evaluation/aggregators/_abstract_aggregator.py create mode 100644 src/evaluation/aggregators/main.py create mode 100644 src/evaluation/aggregators/snapshot.py create mode 100644 src/evaluation/aggregators/time_mean.py create mode 100644 src/evaluation/aggregators/timestepwise.py create mode 100644 src/evaluation/metrics.py create mode 100644 src/evaluation/reduced_metrics.py create mode 100644 src/experiment_types/__init__.py create mode 100644 src/experiment_types/_base_experiment.py create mode 100644 src/experiment_types/forecasting_multi_horizon.py create mode 100644 src/experiment_types/interpolation.py create mode 100644 src/interface.py create mode 100644 src/losses/__init__.py create mode 100644 src/losses/losses.py create mode 100644 src/models/__init__.py create mode 100644 src/models/_base_model.py create mode 100644 src/models/modules/__init__.py create mode 100644 src/models/modules/attention.py create mode 100644 src/models/modules/convs.py create mode 100644 src/models/modules/drop_path.py create mode 100644 src/models/modules/ema.py create mode 100644 src/models/modules/misc.py create mode 100644 src/models/modules/net_norm.py create mode 100644 src/models/sfno/__init__.py create mode 100644 src/models/sfno/activations.py create mode 100644 src/models/sfno/contractions.py create mode 100644 src/models/sfno/distributed/__init__.py create mode 100644 src/models/sfno/distributed/comm.py create mode 100644 src/models/sfno/distributed/helpers.py create mode 100644 src/models/sfno/distributed/layer_norm.py create mode 100644 src/models/sfno/distributed/layers.py create mode 100644 src/models/sfno/distributed/mappings.py create mode 100644 src/models/sfno/factorizations.py create mode 100644 src/models/sfno/initialization.py create mode 100644 src/models/sfno/layers.py create mode 100644 src/models/sfno/preprocessor.py create mode 100644 src/models/sfno/s2convolutions.py create mode 100644 src/models/sfno/sfnonet.py create mode 100644 src/models/unet.py create mode 100644 src/train.py create mode 100644 src/utilities/__init__.py create mode 100644 src/utilities/checkpointing.py create mode 100644 src/utilities/config_utils.py create mode 100644 src/utilities/lr_scheduler.py create mode 100644 src/utilities/naming.py create mode 100644 src/utilities/normalization.py create mode 100644 src/utilities/packer.py create mode 100644 src/utilities/s3utils.py create mode 100644 src/utilities/utils.py create mode 100644 src/utilities/wandb_api.py create mode 100644 src/utilities/wandb_callbacks.py create mode 100644 utils/check_copies.py create mode 100644 utils/get_modified_files.py create mode 100644 utils/release.py diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml new file mode 100644 index 0000000..2d51b2e --- /dev/null +++ b/.github/workflows/ci.yaml @@ -0,0 +1,46 @@ +name: Run code quality checks + +on: + push: + branches: + - main + pull_request: + branches: + - main + +jobs: + + check_code_quality: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: | + ruff check scripts src utils --fix +# black --check tests src scripts utils +# doc-builder style src docs/source --max_len 119 --check_only --path_to_docs docs/source + + check_repository_consistency: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: "3.10" + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install .[quality] + - name: Check quality + run: | + python utils/check_copies.py + make deps_table_check_updated \ No newline at end of file diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c357f05 --- /dev/null +++ b/.gitignore @@ -0,0 +1,221 @@ +# Project specific + +results/ +data/*.nc +data/*.csv +data/*.h5 +data/*.json +data/*.pkl +data/*.npy +data/*.npz +data/*.mat +data/*.txt +data/*.zip +data/*.tar +data/*.tar.gz +data/*.tar.bz2 +*.ckpt +*.pt +stats/ +predictions/ +predictions/* +# Created by https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks +# Edit at https://www.toptal.com/developers/gitignore?templates=python,jupyternotebooks + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook + +# IPython + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# End of https://www.toptal.com/developers/gitignore/api/python,jupyternotebooks +# Pycharm +.idea +.idea/** +**/.idea + + +# Ruff +.ruff_cache + +#.DS_Store +.DS_Store +**/.DS_Store + +# logging +src/results/** +src/results/wandb/** +src/results/logs/** +src/configs/local/default.yaml +results/checkpoints/** +results/wandb/** +results/logs/** +outputs/** +wandb/** + +tmp.sh +tmp.py + +# Gifs in docs/ +docs/*.gif + +*.jpg +*.txt +*.txt~ +*.tgz + +logs/** + +notebooks/*.gif + +videos/** +scripts/_servers/* + +# all inference configs that start with '_' +src/configs/inference/_* + diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..369a58d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,15 @@ +repos: + - repo: https://github.com/psf/black + rev: 23.1.0 + hooks: + - id: black + language_version: python3 + types: [python] + stages: [commit] + args: ["--config", "pyproject.toml", "tests", "src", "scripts"] + - repo: https://github.com/charliermarsh/ruff-pre-commit + rev: 'v0.0.255' + hooks: + - id: ruff + stages: [commit] + args: [ "--config", "pyproject.toml", "tests", "src", "scripts", "--fix"] diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..2f20f48 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,91 @@ +# How to contribute to DYffusion? +[![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-2.0-4baaaa.svg)](CODE_OF_CONDUCT.md) + +Spherical DYffusion is an open source project, so all contributions and suggestions are welcome. + +You can contribute in many different ways: giving ideas, answering questions, reporting bugs, proposing enhancements, +improving the documentation, fixing bugs,... + +Many thanks in advance to every contributor. + +In order to facilitate healthy, constructive behavior in an open and inclusive community, we all respect and abide by +our [code of conduct](CODE_OF_CONDUCT.md). + +## How to work on an open Issue? +You have the list of open Issues at: https://github.com/Rose-STL-lab/spherical-dyffusion/issues + +Some of them may have the label `help wanted`: that means that any contributor is welcomed! + +If you would like to work on any of the open Issues: + +1. Make sure it is not already assigned to someone else. You have the assignee (if any) on the top of the right column of the Issue page. + +2. You can self-assign it by commenting on the Issue page with the keyword: `#self-assign`. + +3. Work on your self-assigned issue and eventually create a Pull Request. + +## How to create a Pull Request? + +1. Fork the [repository](https://github.com/Rose-STL-lab/dyffusion) by clicking on the 'Fork' button on the repository's page. This creates a copy of the code under your GitHub user account. + +2. Clone your fork to your local disk, and add the base repository as a remote: + + ```bash + git clone git@github.com:/spherical-dyffusion.git + cd spherical-dyffusion + git remote add upstream https://github.com/Rose-STL-lab/spherical-dyffusion.git + ``` + +3. Create a new branch to hold your development changes: + + ```bash + git checkout -b a-descriptive-name-for-my-changes + ``` + + **do not** work on the `main` branch. + +4. Set up a development environment by running the following command in a virtual environment: + + ```bash + pip install -e ".[dev]" + ``` + + (If `spherical_dyffusion` was already installed in the virtual environment, remove + it with `pip uninstall spherical_dyffusion` before reinstalling it in editable + mode with the `-e` flag.) + +5. Develop the features on your branch. + +6. Format your code. Run `black` and `ruff` so that your newly added files look nice with the following command: + + ```bash + make style + ``` + +7. _(Optional)_ You can also use [`pre-commit`](https://pre-commit.com/) to format your code automatically each time run `git commit`, instead of running `make style` manually. +To do this, install `pre-commit` via `pip install pre-commit` and then run `pre-commit install` in the project's root directory to set up the hooks. +Note that if any files were formatted by `pre-commit` hooks during committing, you have to run `git commit` again . + + +8. Once you're happy with your contribution, add your changed files and make a commit to record your changes locally: + + ```bash + git add -u + git commit + ``` + + It is a good idea to sync your copy of the code with the original + repository regularly. This way you can quickly account for changes: + + ```bash + git fetch upstream + git rebase upstream/main + ``` + +9. Once you are satisfied, push the changes to your fork repo using: + + ```bash + git push -u origin a-descriptive-name-for-my-changes + ``` + + Go the webpage of your fork on GitHub. Click on "Pull request" to send your to the project maintainers for review. diff --git a/LICENSE b/LICENSE index 261eeb9..131b9da 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright [yyyy] [name of copyright owner] + Copyright 2024 Salva Rühling Cachay Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -198,4 +198,4 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and - limitations under the License. + limitations under the License. \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..080d90a --- /dev/null +++ b/Makefile @@ -0,0 +1,46 @@ +.PHONY: deps_table_update modified_only_fixup quality style + +# make sure to test the local checkout in scripts and not the pre-installed one (don't use quotes!) +export PYTHONPATH = src + +check_dirs := scripts src utils + +modified_only_fixup: + $(eval modified_py_files := $(shell python3 utils/get_modified_files.py $(check_dirs))) + @if test -n "$(modified_py_files)"; then \ + echo "Checking/fixing $(modified_py_files)"; \ + black $(modified_py_files); \ + ruff check $(modified_py_files); \ + else \ + echo "No library .py files were modified"; \ + fi + +# Update src/dependency_versions_table.py + +deps_table_update: + @python3 setup.py "deps_table_update" + +deps_table_check_updated: + @md5sum src/dependency_versions_table.py > md5sum.saved + @python3 setup.py deps_table_update + @md5sum -c --quiet md5sum.saved || (printf "\nError: the version dependency table is outdated.\nPlease run 'make fixup' or 'make style' and commit the changes.\n\n" && exit 1) + @rm md5sum.saved + +# autogenerating code + +autogenerate_code: deps_table_update + +# this target runs checks on all files + +quality: + python3 -m black --check $(check_dirs) + python3 -m ruff check $(check_dirs) +# doc-builder style src docs/source --max_len 119 --check_only --path_to_docs docs/source + + +# this target runs checks on all files and potentially modifies some of them + +style: + python3 -m black $(check_dirs) + python3 -m ruff check $(check_dirs) --fix + ${MAKE} autogenerate_code diff --git a/README.md b/README.md index 525c44e..be7291e 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,146 @@ # Probabilistic Emulation of a Global Climate Model with Spherical DYffusion (NeurIPS 2024, Spotlight) -Code will be released soon. Stay tuned! +Python +PyTorch +Lightning +Config: hydra +License + +

✨Official implementation of our Spherical DYffusion paper✨

+ +[//]: # ([![Watch the video](https://img.youtube.com/vi/Hac_xGsJ1qY/hqdefault.jpg)](https://youtu.be/Hac_xGsJ1qY)) + +## | Environment Setup + +We recommend installing in a virtual environment from PyPi or Conda. Then, run: + + python3 -m pip install .[dev] + python3 -m pip install --no-deps nvidia-modulus@git+https://github.com/ai2cm/modulus.git@94f62e1ce2083640829ec12d80b00619c40a47f8 + +Alternatively, use the provided [environment/install_dependencies.sh](environment/install_dependencies.sh) script. + +Note that for some compute setups you may want to install pytorch first for proper GPU support. +For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation. +A typical, but not fully general (!), manual bash script that installs all dependencies is +[install_dependencies.sh](environment/install_dependencies.sh). Feel free to use it. + +## | Dataset + +The final +training and validation data can be downloaded from Google Cloud Storage following the instructions +of the ACE paper at https://zenodo.org/records/10791087. The data are licensed under Creative +Commons Attribution 4.0 International. + +## | Checkpoints + +Model weights are available at [https://huggingface.co/salv47/spherical-dyffusion](https://huggingface.co/salv47/spherical-dyffusion/tree/main). + +## | Running experiments + +### Inference + +Firstly, download the validation data as instructed in the [Dataset](#dataset) section. + +Secondly, use the `run_inference.py` script with a corresponding configuration file. +The configurations files used for our paper can be found in the [src/configs/inference](src/configs/inference) directory. +That is, you can run inference with the following command: + + python run_inference.py .yaml + +The available inference configurations are: +- [ckpts_from_huggingface_debug.yaml](src/configs/inference/ckpts_from_huggingface_debug.yaml): Short inference meant for debugging with checkpoints downloaded from Hugging Face. +- [ckpts_from_huggingface_10years.yaml](src/configs/inference/ckpts_from_huggingface_10years.yaml): 10-year-long inference with checkpoints downloaded from Hugging Face. + +To use these configs, **you need to correctly specify the `dataset.data_path` parameter in the configuration file to point to the validation data.** +### Training + +We use [Hydra](https://hydra.cc/) for configuration management and [PyTorch Lightning](https://www.pytorchlightning.ai/) for training. +We recommend familiarizing yourself with these tools before running training experiments. + + +### Tips + +
+ Memory Considerations and OOM Errors + +To control memory usage and avoid OOM errors, you can adjust the training batch size and evaluation batch size: + +**For training**, you can adjust the `datamodule.batch_size_per_gpu` parameter. +Note that this will automatically adjust `trainer.accumulate_grad_batches` to keep the effective batch size (set by `datamodule.batch_size`) constant (so it need to be divisible by `datamodule.batch_size_per_gpu`). + +**For evaluation** or OOMs during validation, you can adjust the `datamodule.eval_batch_size` parameter. +Note that the effective validation-time batch size is `datamodule.eval_batch_size * module.num_predictions`. Be mindful of that when choosing `eval_batch_size`. You can control how many ensemble members to run in memory +at once with `module.num_predictions_in_memory`. + +Besides those main knobs, you may turn on mixed precision training with `trainer.precision=16` to reduce memory usage and +may also adjust the `datamodule.num_workers` parameter to control the number of data loading processes. +
+ +
+ Wandb Integration + +We use [Weights & Biases](https://wandb.ai/) for logging and checkpointing. +Please set your wandb username/entity with one of the following options: +- Edit the [src/configs/local/default.yaml](src/configs/local/default.yaml) file (recommended, local for you only). +- Edit the [src/configs/logger/wandb.yaml](src/configs/logger/wandb.yaml) file. +- as a command line argument (e.g. `python run.py logger.wandb.entity=my_username`). +
+ +
+ Checkpointing + +By default, checkpoints are saved locally in the `/checkpoints` directory in the root of the repository, +which you can control with the `work_dir=` argument. + +When using the wandb logger (default), checkpoints may be saved to wandb (`logger.wandb.save_to_wandb`) or S3 storage (`logger.wandb.save_to_s3_bucket`). +Set these to `False` to disable saving them to wandb or S3. +If disabling both (only save checkpoints locally), make sure to set `logger.wandb.save_best_ckpt=False logger.wandb.save_last_ckpt=False`. +You can set these preferences in your [local config](src/configs/local/default.yaml) file +(see [src/configs/local/example_local_config.yaml](src/configs/local/example_local_config.yaml) for an example). +
+ +
+ Debugging + +For minimal data and model size, you can use the following: + + python run.py ++model.debug_mode=True ++datamodule.debug_mode=True + +Note that the model and datamodule need to support to appropriately handle the debug mode. +
+ +
+ Code Quality + +Code quality is automatically checked when pushing to the repository. +However, it is recommended that you also run the checks locally with `make quality`. + +To automatically fix some issues (as much as possible), run: + + make style +
+ +
+ hydra.errors.InstantiationException + +The ``hydra.errors.InstantiationException`` itself is not very informative, +so you need to look at the preceding exception(s) (i.e. scroll up) to see what went wrong. +
+ +
+ Local Configurations + +You can use a local config file that, defines the local data dir, working dir etc., by putting a ``default.yaml`` config +in the [src/configs/local/](src/configs/local) subdirectory. Hydra searches for & uses by default the file configs/local/default.yaml, if it exists. +You may take inspiration from the [example_local_config.yaml](src/configs/local/example_local_config.yaml) file. +
+ +## | Citation + + @inproceedings{cachay2024spherical, + title={Probablistic Emulation of a Global Climate Model with Spherical {DY}ffusion}, + author={Salva R{\"u}hling Cachay and Brian Henn and Oliver Watt-Meyer and Christopher S. Bretherton and Rose Yu}, + booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}, + year={2024}, + url={https://openreview.net/forum?id=Ib2iHIJRTh} + } diff --git a/environment/README.md b/environment/README.md new file mode 100644 index 0000000..922cd0a --- /dev/null +++ b/environment/README.md @@ -0,0 +1,30 @@ +# Environment Setup +This project is developed in Python 3.9. + +### 1. Create a virtual environment + +You can use either Python's built-in `venv` or `conda` to create a virtual environment: + +#### 1a. Python venv environment + +To create a python virtual environment, run the following commands from the root of the project: +```bash +python3 -m venv .venv +source .venv/bin/activate +``` + +#### 1b. Conda env: + +Start from a clean environment, e.g. with conda do: + + conda create -n spherical-dyffusion python=3.9 + conda activate spherical-dyffusion # activate the environment called spherical-dyffusion + +### 2. Install dependencies + +After creating your virtual environment, run + + bash environment/install-dependencies.sh + +Note that depending on your CUDA version, you may need to install PyTorch differently than in the bash file. +For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation. \ No newline at end of file diff --git a/environment/install_dependencies.sh b/environment/install_dependencies.sh new file mode 100644 index 0000000..9f6e26e --- /dev/null +++ b/environment/install_dependencies.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +python3 -m pip install --upgrade pip +# ==================== +# Install dependencies. Note: Depending on your system, you may need to install PyTorch manually. +# ==================== +# PyTorch: +python3 -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 +# General dependencies: +python3 -m pip install dacite h5py huggingface_hub matplotlib scipy einops wandb lightning xarray netCDF4 dask hydra-core cachey tensordict timm boto3 black ruff +# For SFNO: +python3 -m pip install --no-deps nvidia-modulus@git+https://github.com/ai2cm/modulus.git@94f62e1ce2083640829ec12d80b00619c40a47f8 +python3 -m pip install torch-harmonics tensorly tensorly-torch diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..48facfb --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,18 @@ +[tool.black] +line-length = 119 +target_version = ['py310'] + +[tool.ruff] +# Ignored rules: +# "E501" -> line length violation +# "F821" -> undefined named in type annotation (e.g. Literal["something"]) +lint.ignore = ["E501", "F821"] +lint.select = ["E", "F", "I", "W"] +line-length = 119 + +[tool.ruff.lint.per-file-ignores] +"src/models/sfno/*" = ["E", "F"] # sfno-net is a third-party library and we don't want to lint it for now + +[tool.ruff.lint.isort] +lines-after-imports = 2 +known-first-party = ["spherical-dyffusion"] \ No newline at end of file diff --git a/run_inference.py b/run_inference.py new file mode 100644 index 0000000..ac07a71 --- /dev/null +++ b/run_inference.py @@ -0,0 +1,14 @@ +# Usage: +# python run_inference.py +# +# Debug with: +# python run_inference.py "src/configs/inference/ckpt_from_local.yaml" +import argparse +from src.ace_inference.inference.inference import main + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("yaml_config", type=str, help="Path to the yaml config file for inference", default="src/configs/inference/ckpts_from_huggingface_debug.yaml") + + args = parser.parse_args() + main(yaml_config=args.yaml_config) diff --git a/scripts/.gitkeep b/scripts/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..8f1d5e0 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[metadata] +license_files = LICENSE + +[flake8] +exclude = docs +ignore = E203,W293,W503,F541,E402 +max-line-length = 88 diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..50f5449 --- /dev/null +++ b/setup.py @@ -0,0 +1,237 @@ +# Lint as: python3 +""" +Simple check list from AllenNLP repo: https://github.com/allenai/allennlp/blob/main/setup.py + +To create the package for pypi. + +1. Run `make pre-release` (or `make pre-patch` for a patch release) then run `make fix-copies` to fix the index of the + documentation. + + If releasing on a special branch, copy the updated README.md on the main branch for the commit you will make + for the post-release and run `make fix-copies` on the main branch as well. + +2. Unpin specific versions from setup.py that use a git install. + +3. Checkout the release branch (v-release, for example v4.19-release), and commit these changes with the + message: "Release: " and push. + +4. Wait for the tests on main to be completed and be green (otherwise revert and fix bugs) + +5. Add a tag in git to mark the release: "git tag v -m 'Adds tag v for pypi' " + Push the tag to git: git push --tags origin v-release + +6. Build both the sources and the wheel. Do not change anything in setup.py between + creating the wheel and the source distribution (obviously). + + For the wheel, run: "python setup.py bdist_wheel" in the top level directory. + (this will build a wheel for the python version you use to build it). + + For the sources, run: "python setup.py sdist" + You should now have a /dist directory with both .whl and .tar.gz source versions. + + Long story cut short, you need to run both before you can upload the distribution to the + test pypi and the actual pypi servers: + + python setup.py bdist_wheel && python setup.py sdist + +8. Check that everything looks correct by uploading the package to the pypi test server: + + twine upload dist/* -r pypitest + (pypi suggest using twine as other methods upload files via plaintext.) + You may have to specify the repository url, use the following command then: + twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ + + Check that you can install it in a virtualenv by running: + pip install -i https://testpypi.python.org/pypi dyffusion + + If you are testing from a Colab Notebook, for instance, then do: + pip install dyffusion && pip uninstall dyffusion + pip install -i https://testpypi.python.org/pypi dyffusion + + Check you can run the following commands: + python -c "python -c "from dyffusion import __version__; print(__version__)" + python -c "from dyffusion import *" + +9. Upload the final version to actual pypi: + twine upload dist/* -r pypi + +10. Prepare the release notes and publish them on github once everything is looking hunky-dory. + +11. Run `make post-release` (or, for a patch release, `make post-patch`). If you were on a branch for the release, + you need to go back to main before executing this. +""" + +import re + +# Import command from setuptools instead of distutils.core.Command for compatibility with Python>3.12 +from setuptools import Command, find_packages, setup + + +# IMPORTANT: +# 1. all dependencies should be listed here with their version requirements if any +# 2. once modified, run: `make deps_table_update` to update src/dyffusion/dependency_versions_table.py +_deps = [ + "black", + "boto3", + "cachey", + "dacite", + "dask", + "einops", + "h5py", + "hf-doc-builder", + "huggingface_hub", + "hydra-core", + "isort", + "netCDF4", + "numpy", + "omegaconf", + "pytest", + "pytorch-lightning>=2.0", + "rich", + "ruff>=0.0.241", + "regex", + "requests", + "tensordict", + "tensorly", + "tensorly-torch", + "torch>=1.8", + "torch-harmonics", + "transformers", + "urllib3", + "wandb", + "xarray", + # "xbatcher", +# nvidia-modulus@git+https://github.com/ai2cm/modulus.git@94f62e1ce2083640829ec12d80b00619c40a47f8 +] + +# this is a lookup table with items like: +# +# packaging: "packaging" +# +# some of the values are versioned whereas others aren't. +deps = {b: a for a, b in (re.findall(r"^(([^!=<>~]+)(?:[!=<>~].*)?$)", x)[0] for x in _deps)} + +# since we save this data in src/dependency_versions_table.py it can be easily accessed from +# anywhere. If you need to quickly access the data from this table in a shell, you can do so easily with: +# +# python -c 'import sys; from dyffusion.dependency_versions_table import deps; \ +# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets +# +# Just pass the desired package names to that script as it's shown with 2 packages above. +# +# If dyffusion is not yet installed and the work is done from the cloned repo remember to add `PYTHONPATH=src` to the script above +# +# You can then feed this for example to `pip`: +# +# pip install -U $(python -c 'import sys; from dyffusion.dependency_versions_table import deps; \ +# print(" ".join([ deps[x] for x in sys.argv[1:]]))' tokenizers datasets) +# + + +def deps_list(*pkgs): + return [deps[pkg] for pkg in pkgs] + + +class DepsTableUpdateCommand(Command): + """ + A custom distutils command that updates the dependency table. + usage: python setup.py deps_table_update + """ + + description = "build runtime dependency table" + user_options = [ + # format: (long option, short option, description). + ("dep-table-update", None, "updates src/dependency_versions_table.py"), + ] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + entries = "\n".join([f' "{k}": "{v}",' for k, v in deps.items()]) + content = [ + "# THIS FILE HAS BEEN AUTOGENERATED. To update:", + "# 1. modify the `_deps` dict in setup.py", + "# 2. run `make deps_table_update``", + "deps = {", + entries, + "}", + "", + ] + target = "src/dependency_versions_table.py" + print(f"updating {target}") + with open(target, "w", encoding="utf-8", newline="\n") as f: + f.write("\n".join(content)) + + +extras = {} # defaultdict(list) +extras["quality"] = deps_list("urllib3", "black", "isort", "ruff", "hf-doc-builder") +extras["docs"] = deps_list("hf-doc-builder") +extras["test"] = deps_list("pytest") +extras["run"] = deps_list("xarray", "netCDF4", "dask", "einops", "hydra-core", "wandb") +extras["torch"] = deps_list("torch", "pytorch-lightning", "tensordict", "torch-harmonics") +extras["train"] = extras["torch"] + extras["run"] +extras["optional"] = deps_list("rich") +extras["dev"] = deps_list(*[x.split("<")[0].split(">")[0] for x in _deps]) + +install_requires = [ + deps["numpy"], + deps["regex"], + deps["requests"], +] + +setup( + name="spherical_dyffusion", + version="0.0.1", # expected format is one of x.y.z.dev0, or x.y.z.rc1 or x.y.z (no to dashes, yes to dots) + description="Probabilistic Emulation of a Global Climate Model with Spherical DYffusion", + long_description=open("README.md", "r", encoding="utf-8").read(), + long_description_content_type="text/markdown", + author="Salva Rühling Cachay", + author_email="salvaruehling@gmail.com", + # url="https://github.com/Rose-STL-lab/spherical-dyffusion", + license="Apache 2.0", + package_dir={"": "src"}, + packages=find_packages("src"), + include_package_data=True, + # python_requires=">=3.8.0", + # install_requires=list(install_requires), + extras_require=extras, + classifiers=[ + "Intended Audience :: Developers", + "Intended Audience :: Education", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + keywords="machine learning climate modeling dyffusion forecasting spatiotemporal probabilistic diffusion model", + zip_safe=False, # Required for mypy to find the py.typed file + cmdclass={"deps_table_update": DepsTableUpdateCommand}, +) + +# Release checklist +# 1. Change the version in __init__.py and setup.py. +# 2. Commit these changes with the message: "Release: Release" +# 3. Add a tag in git to mark the release: "git tag RELEASE -m 'Adds tag RELEASE for pypi' " +# Push the tag to git: git push --tags origin main +# 4. Run the following commands in the top-level directory: +# python setup.py bdist_wheel +# python setup.py sdist +# 5. Upload the package to the pypi test server first: +# twine upload dist/* -r pypitest +# twine upload dist/* -r pypitest --repository-url=https://test.pypi.org/legacy/ +# 6. Check that you can install it in a virtualenv by running: +# pip install -i https://testpypi.python.org/pypi dyffusion +# dyffusion env +# dyffusion test +# 7. Upload the final version to actual pypi: +# twine upload dist/* -r pypi +# 8. Add release notes to the tag in github once everything is looking hunky-dory. +# 9. Update the version in __init__.py, setup.py to the new version "-dev" and push to master diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..f102a9c --- /dev/null +++ b/src/__init__.py @@ -0,0 +1 @@ +__version__ = "0.0.1" diff --git a/src/ace_inference/LICENSE b/src/ace_inference/LICENSE new file mode 100644 index 0000000..f49a4e1 --- /dev/null +++ b/src/ace_inference/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/src/ace_inference/README.md b/src/ace_inference/README.md new file mode 100644 index 0000000..58bec05 --- /dev/null +++ b/src/ace_inference/README.md @@ -0,0 +1,4 @@ +Adapted from [https://github.com/ai2cm/ace/tree/main](https://github.com/ai2cm/ace), which +contains code for +"ACE: A fast, skillful learned global atmospheric model for climate prediction" ([arxiv:2310.02074](https://arxiv.org/abs/2310.02074)). + diff --git a/src/ace_inference/__init__.py b/src/ace_inference/__init__.py new file mode 100644 index 0000000..3dc1f76 --- /dev/null +++ b/src/ace_inference/__init__.py @@ -0,0 +1 @@ +__version__ = "0.1.0" diff --git a/src/ace_inference/core/__init__.py b/src/ace_inference/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/core/aggregator/__init__.py b/src/ace_inference/core/aggregator/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/ace_inference/core/aggregator/__init__.py @@ -0,0 +1 @@ + diff --git a/src/ace_inference/core/aggregator/climate_data.py b/src/ace_inference/core/aggregator/climate_data.py new file mode 100644 index 0000000..7078bf1 --- /dev/null +++ b/src/ace_inference/core/aggregator/climate_data.py @@ -0,0 +1,233 @@ +import re +from types import MappingProxyType +from typing import Dict, List, Mapping, Union + +import torch + +from src.ace_inference.core import metrics +from src.ace_inference.core.constants import LATENT_HEAT_OF_VAPORIZATION +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates + + +CLIMATE_FIELD_NAME_PREFIXES = MappingProxyType( + { + "specific_total_water": ["specific_total_water_"], + "surface_pressure": ["PRESsfc", "PS"], + "tendency_of_total_water_path_due_to_advection": ["tendency_of_total_water_path_due_to_advection"], + "latent_heat_flux": ["LHTFLsfc", "LHFLX"], + "sensible_heat_flux": ["SHTFLsfc"], + "precipitation_rate": ["PRATEsfc", "surface_precipitation_rate"], + "sfc_down_sw_radiative_flux": ["DSWRFsfc"], + "sfc_up_sw_radiative_flux": ["USWRFsfc"], + "sfc_down_lw_radiative_flux": ["DLWRFsfc"], + "sfc_up_lw_radiative_flux": ["ULWRFsfc"], + } +) + + +def natural_sort(alist: List[str]) -> List[str]: + """Sort to alphabetical order but with numbers sorted + numerically, e.g. a11 comes after a2. See [1] and [2]. + + [1] https://stackoverflow.com/questions/11150239/natural-sorting + [2] https://en.wikipedia.org/wiki/Natural_sort_order + """ + + def convert(text: str) -> Union[str, int]: + if text.isdigit(): + return int(text) + else: + return text.lower() + + def alphanum_key(item: str) -> List[Union[str, int]]: + return [convert(c) for c in re.split("([0-9]+)", item)] + + return sorted(alist, key=alphanum_key) + + +class ClimateData: + """Container for climate data for accessing variables and providing + torch.Tensor views on data with multiple vertical levels.""" + + def __init__( + self, + climate_data: Mapping[str, torch.Tensor], + climate_field_name_prefixes: Mapping[str, List[str]] = CLIMATE_FIELD_NAME_PREFIXES, + ): + """ + Initializes the instance based on the climate data and prefixes. + + Args: + climate_data: Mapping from field names to tensors. + climate_field_name_prefixes: Mapping from field name prefixes (e.g. + "specific_total_water_") to standardized prefixes, e.g. "PRESsfc" → + "surface_pressure". + """ + self._data = dict(climate_data) + self._prefixes = climate_field_name_prefixes + + def _extract_levels(self, name: List[str]) -> torch.Tensor: + for prefix in name: + try: + return self._extract_prefix_levels(prefix) + except KeyError: + pass + raise KeyError(name) + + def _extract_prefix_levels(self, prefix: str) -> torch.Tensor: + names = [field_name for field_name in self._data if field_name.startswith(prefix)] + + if len(names) == 0: + raise KeyError(prefix) + + names = natural_sort(names) + return torch.stack([self._data[name] for name in names], dim=-1) + + def _get(self, name): + for prefix in self._prefixes[name]: + if prefix in self._data.keys(): + return self._get_prefix(prefix) + raise KeyError(name) + + def _get_prefix(self, prefix): + return self._data[prefix] + + def _set(self, name, value): + for prefix in self._prefixes[name]: + if prefix in self._data.keys(): + self._set_prefix(prefix, value) + return + raise KeyError(name) + + def _set_prefix(self, prefix, value): + self._data[prefix] = value + + @property + def data(self) -> Dict[str, torch.Tensor]: + """Mapping from field names to tensors.""" + return self._data + + @property + def specific_total_water(self) -> torch.Tensor: + """Returns all vertical levels of specific total water, e.g. a tensor of + shape `(..., vertical_level)`.""" + prefix = self._prefixes["specific_total_water"] + return self._extract_levels(prefix) + + @property + def surface_pressure(self) -> torch.Tensor: + return self._get("surface_pressure") + + @surface_pressure.setter + def surface_pressure(self, value: torch.Tensor): + self._set("surface_pressure", value) + + def surface_pressure_due_to_dry_air(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: + return metrics.surface_pressure_due_to_dry_air( + self.specific_total_water, + self.surface_pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ) + + def total_water_path(self, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: + return metrics.vertical_integral( + self.specific_total_water, + self.surface_pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ) + + @property + def net_surface_energy_flux_without_frozen_precip(self) -> torch.Tensor: + return metrics.net_surface_energy_flux( + self._get("sfc_down_lw_radiative_flux"), + self._get("sfc_up_lw_radiative_flux"), + self._get("sfc_down_sw_radiative_flux"), + self._get("sfc_up_sw_radiative_flux"), + self._get("latent_heat_flux"), + self._get("sensible_heat_flux"), + ) + + @property + def precipitation_rate(self) -> torch.Tensor: + """ + Precipitation rate in kg m-2 s-1. + """ + return self._get("precipitation_rate") + + @precipitation_rate.setter + def precipitation_rate(self, value: torch.Tensor): + self._set("precipitation_rate", value) + + @property + def latent_heat_flux(self) -> torch.Tensor: + """ + Latent heat flux in W m-2. + """ + return self._get("latent_heat_flux") + + @latent_heat_flux.setter + def latent_heat_flux(self, value: torch.Tensor): + self._set("latent_heat_flux", value) + + @property + def evaporation_rate(self) -> torch.Tensor: + """ + Evaporation rate in kg m-2 s-1. + """ + lhf = self._get("latent_heat_flux") # W/m^2 + # (W/m^2) / (J/kg) = (J s^-1 m^-2) / (J/kg) = kg/m^2/s + return lhf / LATENT_HEAT_OF_VAPORIZATION + + @evaporation_rate.setter + def evaporation_rate(self, value: torch.Tensor): + self._set("latent_heat_flux", value * LATENT_HEAT_OF_VAPORIZATION) + + @property + def tendency_of_total_water_path_due_to_advection(self) -> torch.Tensor: + """ + Tendency of total water path due to advection in kg m-2 s-1. + """ + return self._get("tendency_of_total_water_path_due_to_advection") + + @tendency_of_total_water_path_due_to_advection.setter + def tendency_of_total_water_path_due_to_advection(self, value: torch.Tensor): + self._set("tendency_of_total_water_path_due_to_advection", value) + + +def compute_dry_air_absolute_differences( + climate_data: ClimateData, area: torch.Tensor, sigma_coordinates: SigmaCoordinates +) -> torch.Tensor: + """ + Computes the absolute value of the dry air tendency of each time step. + + Args: + climate_data: ClimateData object. + area: Area of each grid cell as a [lat, lon] tensor, in m^2. + sigma_coordinates: The sigma coordinates of the model. + + Returns: + A tensor of shape (time,) of the absolute value of the dry air tendency + of each time step. + """ + try: + water = climate_data.specific_total_water + pressure = climate_data.surface_pressure + except KeyError: + return torch.tensor([torch.nan]) + return ( + metrics.weighted_mean( + metrics.surface_pressure_due_to_dry_air( + water, # (sample, time, y, x, level) + pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ), + area, + dim=(2, 3), + ) + .diff(dim=-1) + .abs() + .mean(dim=0) + ) diff --git a/src/ace_inference/core/aggregator/inference/__init__.py b/src/ace_inference/core/aggregator/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/core/aggregator/inference/main.py b/src/ace_inference/core/aggregator/inference/main.py new file mode 100644 index 0000000..9df2fb8 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/main.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +from typing import Dict, Iterable, List, Mapping, Optional, Protocol, Union + +import torch +import xarray as xr +from wandb import Table + +from src.ace_inference.core.aggregator.inference.reduced import MeanAggregator +from src.ace_inference.core.aggregator.inference.time_mean import TimeMeanAggregator +from src.ace_inference.core.aggregator.inference.video import VideoAggregator +from src.ace_inference.core.aggregator.inference.zonal_mean import ZonalMeanAggregator +from src.ace_inference.core.aggregator.one_step.reduced import MeanAggregator as OneStepMeanAggregator +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata +from src.ace_inference.core.device import get_device +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.wandb import WandB +from src.evaluation.aggregators.snapshot import SnapshotAggregator + + +wandb = WandB.get_instance() + + +class _Aggregator(Protocol): + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ): ... + + @torch.no_grad() + def get_logs(self, label: str): ... + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: ... + + +class InferenceAggregator: + """ + Aggregates statistics for inference. + + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__( + self, + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + n_timesteps: int, + n_ensemble_members: int = 1, + record_step_20: bool = False, + log_video: bool = False, + enable_extended_videos: bool = False, + log_zonal_mean_images: bool = False, + dist: Optional[Distributed] = None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + device: torch.device | str = None, + ): + """ + Args: + area_weights: Area weights for each grid cell. + sigma_coordinates: Data sigma coordinates + n_timesteps: Number of timesteps of inference that will be run. + record_step_20: Whether to record the mean of the 20th steps. + log_video: Whether to log videos of the state evolution. + enable_extended_videos: Whether to log videos of statistical + metrics of state evolution + log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a + time dimension. + dist: Distributed object to use for metric aggregation. + metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + """ + self._is_ensemble = n_ensemble_members > 1 + device = device if device is not None else get_device() + kwargs = dict( + area_weights=area_weights.to(device), + dist=dist, + is_ensemble=self._is_ensemble, + device=device, + ) + self._aggregators: Dict[str, _Aggregator] = { + "mean": MeanAggregator(target="denorm", n_timesteps=n_timesteps, **kwargs), + "mean_norm": MeanAggregator(target="norm", n_timesteps=n_timesteps, **kwargs), + "time_mean": TimeMeanAggregator(area_weights, dist=dist, metadata=metadata, is_ensemble=self._is_ensemble), + } + if record_step_20: + self._aggregators["mean_step_20"] = OneStepMeanAggregator(target_time=20, **kwargs) + if log_video: + self._aggregators["video"] = VideoAggregator( + n_timesteps=n_timesteps, + enable_extended_videos=enable_extended_videos, + dist=dist, + metadata=metadata, + ) + if log_zonal_mean_images and not self._is_ensemble: + self._aggregators["zonal_mean"] = ZonalMeanAggregator( + n_timesteps=n_timesteps, dist=dist, metadata=metadata + ) + if n_timesteps is not None: + potential_timesteps = [20, 500, 1400, 5000, 10_000, 14_000, 24_000, 34_000, 43_000] + for t in potential_timesteps: + if n_timesteps >= t: + self._aggregators[f"snapshot/t{t}"] = SnapshotAggregator( + is_ensemble=self._is_ensemble, target_time=t + ) + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int = 0, + ): + if len(target_data) == 0: + raise ValueError("No data in target_data") + if len(gen_data) == 0: + raise ValueError("No data in gen_data") + for aggregator in self._aggregators.values(): + try: + aggregator.record_batch( + loss=loss, + target_data=target_data, + gen_data=gen_data, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + i_time_start=i_time_start, + ) + except Exception as e: + print("---------------------> Error in aggregator", aggregator, i_time_start) + raise e + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + logs = {} + for name, aggregator in self._aggregators.items(): + try: + logs.update(aggregator.get_logs(label=name)) + except RuntimeError as e: + print(f"---------------------> Error in aggregator {name}: {e}") + pass + logs = {f"{label}/{key}": val for key, val in logs.items()} + return logs + + @torch.no_grad() + def get_inference_logs(self, label: str) -> List[Dict[str, Union[float, int]]]: + """ + Returns a list of logs to report to WandB. + + This is done because in inference, we use the wandb step + as the time step, meaning we need to re-organize the logged data + from tables into a list of dictionaries. + """ + return to_inference_logs(self.get_logs(label=label)) + + @torch.no_grad() + def get_datasets(self, aggregator_whitelist: Optional[Iterable[str]] = None) -> Dict[str, xr.Dataset]: + """ + Args: + aggregator_whitelist: aggregator names to include in the output. If + None, return all the datasets associated with all aggregators. + """ + if aggregator_whitelist is None: + aggregators = self._aggregators.keys() + else: + aggregators = aggregator_whitelist + datasets = dict() + for name in aggregators: + if name in self._aggregators.keys(): + datasets[name] = self._aggregators[name].get_dataset() + + return datasets + + +def to_inference_logs(log: Mapping[str, Union[Table, float, int]]) -> List[Dict[str, Union[float, int]]]: + # we have a dictionary which contains WandB tables + # which we will convert to a list of dictionaries, one for each + # row in the tables. Any scalar values will be reported in the last + # dictionary. + n_rows = 0 + for val in log.values(): + if isinstance(val, Table): + n_rows = max(n_rows, len(val.data)) + logs: List[Dict[str, Union[float, int]]] = [] + for i in range(n_rows): + logs.append({}) + for key, val in log.items(): + if isinstance(val, Table): + for i, row in enumerate(val.data): + for j, col in enumerate(val.columns): + key_without_table_name = key[: key.rfind("/")] + logs[i][f"{key_without_table_name}/{col}"] = row[j] + else: + logs[-1][key] = val + return logs + + +def table_to_logs(table: Table) -> List[Dict[str, Union[float, int]]]: + """ + Converts a WandB table into a list of dictionaries. + """ + logs = [] + for row in table.data: + logs.append({table.columns[i]: row[i] for i in range(len(row))}) + return logs diff --git a/src/ace_inference/core/aggregator/inference/reduced.py b/src/ace_inference/core/aggregator/inference/reduced.py new file mode 100644 index 0000000..9dce1c8 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/reduced.py @@ -0,0 +1,293 @@ +import dataclasses +from collections import defaultdict +from typing import Dict, List, Literal, Mapping, Optional, Protocol + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core import metrics +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.device import get_device +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.metrics import Dimension +from src.ace_inference.core.wandb import WandB + + +wandb = WandB.get_instance() + + +@dataclasses.dataclass +class _SeriesData: + metric_name: str + var_name: str + data: np.ndarray + + def get_wandb_key(self) -> str: + return f"{self.metric_name}/{self.var_name}" + + def get_xarray_key(self) -> str: + return f"{self.metric_name}-{self.var_name}" + + +def get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + + +class MeanMetric(Protocol): + def record(self, target: torch.Tensor, gen: torch.Tensor, i_time_start: int): + """ + Update metric for a batch of data. + """ + ... + + def get(self) -> torch.Tensor: + """ + Get the total metric value, not divided by number of recorded batches. + """ + ... + + +class AreaWeightedFunction(Protocol): + """ + A function that computes a metric on the true and predicted values, + weighted by area. + """ + + def __call__( + self, + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +class AreaWeightedSingleTargetFunction(Protocol): + """ + A function that computes a metric on a single value, weighted by area. + """ + + def __call__( + self, + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +def compute_metric_on( + source: Literal["gen", "target"], metric: AreaWeightedSingleTargetFunction +) -> AreaWeightedFunction: + """Turns a single-target metric function + (computed on only the generated or target data) into a function that takes in + both the generated and target data as arguments, as required for the APIs + which call generic metric functions. + """ + + def metric_wrapper( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: + if source == "gen": + return metric(predicted, weights=weights, dim=dim) + elif source == "target": + return metric(truth, weights=weights, dim=dim) + + return metric_wrapper + + +class AreaWeightedReducedMetric: + """ + A wrapper around an area-weighted metric function. + """ + + def __init__( + self, + area_weights: torch.Tensor, + device: torch.device, + compute_metric: AreaWeightedFunction, + n_timesteps: int, + ): + self._area_weights = area_weights + self._compute_metric = compute_metric + self._total: Optional[torch.Tensor] = None + self._n_batches = torch.zeros(n_timesteps, dtype=torch.int32, device=device) + self._device = device + self._n_timesteps = n_timesteps + + def record(self, target: torch.Tensor, gen: torch.Tensor, i_time_start: int, **kwargs): + """Add a batch of data to the metric. + + Args: + target: Target data. Should have shape [batch, time, height, width]. + gen: Generated data. Should have shape [batch, time, height, width]. + i_time_start: The index of the first timestep in the batch. + """ + new_value = self._compute_metric(target, gen, weights=self._area_weights, dim=(-2, -1), **kwargs).mean(dim=0) + if self._total is None: + self._total = torch.zeros([self._n_timesteps], dtype=new_value.dtype, device=self._device) + time_slice = slice(i_time_start, i_time_start + new_value.shape[0]) + self._total[time_slice] += new_value + self._n_batches[time_slice] += 1 + + def get(self) -> torch.Tensor: + """Returns the mean metric across recorded batches.""" + if self._total is None: + return torch.tensor(torch.nan) + return self._total / self._n_batches + + +class MeanAggregator: + def __init__( + self, + area_weights: torch.Tensor, + target: Literal["norm", "denorm"], + n_timesteps: int, + is_ensemble: bool = False, + dist: Optional[Distributed] = None, + device: torch.device = None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + self.device = get_device() if device is None else device + self._area_weights = area_weights + self._variable_metrics: Optional[Dict[str, Dict[str, MeanMetric]]] = None + self._shape_x = None + self._shape_y = None + self._target = target + self._n_timesteps = n_timesteps + self.is_ensemble = is_ensemble + self._dist = Distributed.get_instance() if dist is None else dist + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + + def _get_variable_metrics(self, gen_data: Mapping[str, torch.Tensor]): + if self._variable_metrics is None: + self._variable_metrics = defaultdict(dict) + + area_weights = self._area_weights + for key in gen_data.keys(): + metrics_zipped = [ + ("weighted_rmse", metrics.root_mean_squared_error), + ("weighted_bias", metrics.weighted_mean_bias), + ("weighted_grad_mag_percent_diff", metrics.gradient_magnitude_percent_diff), + ("weighted_mean_gen", compute_metric_on(source="gen", metric=metrics.weighted_mean)), + ("weighted_mean_target", compute_metric_on(source="target", metric=metrics.weighted_mean)), + ("weighted_std_gen", compute_metric_on(source="gen", metric=metrics.weighted_std)), + ("weighted_std_target", compute_metric_on(source="target", metric=metrics.weighted_std)), + ] + if self.is_ensemble: + metrics_zipped += [ + ("weighted_crps", metrics.weighted_crps), + ("weighted_ssr", metrics.spread_skill_ratio), + ] + + for i, (metric_name, metric) in enumerate(metrics_zipped): + self._variable_metrics[metric_name][key] = AreaWeightedReducedMetric( + area_weights=area_weights, + device=self.device, + compute_metric=metric, + n_timesteps=self._n_timesteps, + ) + + return self._variable_metrics + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int = 0, + ): + if self._target == "norm": + target_data = target_data_norm + gen_data = gen_data_norm + + if self.is_ensemble: + ensemble_mean = {name: member_preds.mean(dim=0) for name, member_preds in gen_data.items()} + else: + ensemble_mean = gen_data + + variable_metrics = self._get_variable_metrics(gen_data) + for name in gen_data.keys(): + for metric in variable_metrics: + kwargs = {} + if "ssr" in metric or "crps" in metric: + gen = gen_data[name] + elif "grad_mag" in metric: + gen = gen_data[name] + kwargs["is_ensemble_prediction"] = self.is_ensemble + else: + gen = ensemble_mean[name] + + variable_metrics[metric][name].record( + target=target_data[name], gen=gen, i_time_start=i_time_start, **kwargs + ) + + def _get_series_data(self) -> List[_SeriesData]: + """Converts internally stored variable_metrics to a list.""" + if self._variable_metrics is None: + raise ValueError("No batches have been recorded.") + data: List[_SeriesData] = [] + for metric in self._variable_metrics: + for key in self._variable_metrics[metric]: + arr = self._variable_metrics[metric][key].get().detach() + datum = _SeriesData( + metric_name=metric, + var_name=key, + data=self._dist.reduce_mean(arr).cpu().numpy(), + ) + data.append(datum) + return data + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + logs = {} + series_data: Dict[str, np.ndarray] = {datum.get_wandb_key(): datum.data for datum in self._get_series_data()} + table = data_to_table(series_data) + logs[f"{label}/series"] = table + return logs + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: + """ + Returns a dataset representation of the logs. + """ + data_vars = {} + for datum in self._get_series_data(): + metadata = self._metadata.get(datum.var_name, VariableMetadata("unknown_units", datum.var_name)) + data_vars[datum.get_xarray_key()] = xr.DataArray( + datum.data, dims=["forecast_step"], attrs=metadata._asdict() + ) + + n_forecast_steps = len(next(iter(data_vars.values()))) + coords = {"forecast_step": np.arange(n_forecast_steps)} + return xr.Dataset(data_vars=data_vars, coords=coords) + + +def data_to_table(data: Dict[str, np.ndarray]): + """ + Convert a dictionary of 1-dimensional timeseries data to a wandb Table. + """ + keys = sorted(list(data.keys())) + table = wandb.Table(columns=["forecast_step"] + keys) + for i in range(len(data[keys[0]])): + row = [i] + for key in keys: + row.append(data[key][i]) + table.add_data(*row) + return table diff --git a/src/ace_inference/core/aggregator/inference/time_mean.py b/src/ace_inference/core/aggregator/inference/time_mean.py new file mode 100644 index 0000000..d3c6435 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/time_mean.py @@ -0,0 +1,226 @@ +import dataclasses +from typing import Dict, List, Literal, Mapping, MutableMapping, Optional, Union + +import matplotlib.pyplot as plt +import torch +import xarray as xr + +from src.ace_inference.core import metrics +from src.ace_inference.core.aggregator.plotting import get_cmap_limits, plot_imshow +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.wandb import WandB + + +wandb = WandB.get_instance() + + +@dataclasses.dataclass +class _TargetGenPair: + name: str + target: torch.Tensor + gen: torch.Tensor + + def bias(self): + return self.gen - self.target + + def rmse(self, weights: torch.Tensor) -> float: + ret = float( + metrics.root_mean_squared_error( + predicted=self.gen, + truth=self.target, + weights=weights, + ) + .cpu() + .numpy() + ) + return ret + + def weighted_mean_bias(self, weights: torch.Tensor) -> float: + return float(metrics.weighted_mean_bias(predicted=self.gen, truth=self.target, weights=weights).cpu().numpy()) + + +def get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + + +class TimeMeanAggregator: + """Statistics and images on the time-mean state. + + This aggregator keeps track of the time-mean state, then computes + statistics and images on that time-mean state when logs are retrieved. + """ + + _image_captions = { + "bias_map": "{name} time-mean bias (generated - target) [{units}]", + "gen_map": "{name} time-mean generated [{units}]", + } + + def __init__( + self, + area_weights: torch.Tensor, + dist: Optional[Distributed] = None, + target: Literal["norm", "denorm"] = "denorm", + metadata: Optional[Mapping[str, VariableMetadata]] = None, + log_individual_channels: bool = True, + is_ensemble: bool = False, + ): + """ + Args: + area_weights: Area weights for each grid cell. + target: Whether to compute metrics on the normalized or denormalized data, + defaults to "denorm". + metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + log_individual_channels: Whether to log individual channels. + """ + self._area_weights = area_weights + self._is_ensemble = is_ensemble + self._target = target + self._log_individual_channels = log_individual_channels + self._dist = Distributed.get_instance() if dist is None else dist + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + # Dictionaries of tensors of shape [n_lat, n_lon] represnting time means + self._target_data: Optional[Dict[str, torch.Tensor]] = None + self._gen_data: Optional[Dict[str, torch.Tensor]] = None + self._target_data_norm = None + self._gen_data_norm = None + self._n_batches = 0 + + @staticmethod + def _add_or_initialize_time_mean( + maybe_dict: Optional[MutableMapping[str, torch.Tensor]], + new_data: Mapping[str, torch.Tensor], + ignore_initial: bool = False, + ) -> Dict[str, torch.Tensor]: + sample_dim = 0 + time_dim = 1 + if ignore_initial: + time_slice = slice(1, None) + else: + time_slice = slice(0, None) + if maybe_dict is None: + d: Dict[str, torch.Tensor] = { + name: tensor[:, time_slice].mean(dim=time_dim).mean(dim=sample_dim) + for name, tensor in new_data.items() + } + else: + d = dict(maybe_dict) + for name, tensor in new_data.items(): + d[name] += tensor[:, time_slice].mean(dim=time_dim).mean(dim=sample_dim) + return d + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int = 0, + ): + if self._target == "norm": + target_data = target_data_norm + gen_data = gen_data_norm + ignore_initial = i_time_start == 0 + self._target_data = self._add_or_initialize_time_mean(self._target_data, target_data, ignore_initial) + if self._is_ensemble: + gen_data = {k: v.mean(dim=0) for k, v in gen_data.items()} # mean over ensemble members + self._gen_data = self._add_or_initialize_time_mean(self._gen_data, gen_data, ignore_initial) + + # we can ignore time slicing and just treat segments as though they're + # different batches, because we can assume all time segments have the + # same length + self._n_batches += 1 + + def _get_target_gen_pairs(self) -> List[_TargetGenPair]: + if self._n_batches == 0 or self._gen_data is None or self._target_data is None: + raise ValueError("No data recorded.") + + ret = [] + for name in self._gen_data.keys(): + gen = self._dist.reduce_mean(self._gen_data[name] / self._n_batches) + target = self._dist.reduce_mean(self._target_data[name] / self._n_batches) + ret.append(_TargetGenPair(gen=gen, target=target, name=name)) + return ret + + @torch.no_grad() + def get_logs(self, label: str) -> Dict[str, Union[float, torch.Tensor]]: + logs = {} + preds = self._get_target_gen_pairs() + bias_map_key, gen_map_key = "bias_map", "gen_map" + rmse_all_channels = {} + for pred in preds: + bias_data = pred.bias().cpu().numpy() + vmin_bias, vmax_bias = get_cmap_limits(bias_data, diverging=True) + vmin_pred, vmax_pred = get_cmap_limits(pred.gen.cpu().numpy()) + bias_fig = plot_imshow(bias_data, vmin=vmin_bias, vmax=vmax_bias, cmap="RdBu_r") + bias_image = wandb.Image( + bias_fig, + caption=self._get_caption(bias_map_key, pred.name, vmin_bias, vmax_bias), + ) + prediction_image = wandb.Image( + plot_imshow(pred.gen.cpu().numpy()), + caption=self._get_caption(gen_map_key, pred.name, vmin_pred, vmax_pred), + ) + plt.close("all") + rmse_all_channels[pred.name] = pred.rmse(weights=self._area_weights) + if self._log_individual_channels: + logs.update( + { + f"{bias_map_key}/{pred.name}": bias_image, + f"{gen_map_key}/{pred.name}": prediction_image, + f"rmse/{pred.name}": rmse_all_channels[pred.name], + f"bias/{pred.name}": pred.weighted_mean_bias(weights=self._area_weights), + } + ) + logs.update( + { + "rmse/channel_mean": sum(rmse_all_channels.values()) / len(rmse_all_channels), + } + ) + + if len(label) != 0: + return {f"{label}/{key}": logs[key] for key in logs} + return logs + + def _get_caption(self, key: str, name: str, vmin: float, vmax: float) -> str: + if name in self._metadata: + caption_name = self._metadata[name].long_name + units = self._metadata[name].units + else: + caption_name, units = name, "unknown_units" + caption = self._image_captions[key].format(name=caption_name, units=units) + caption += f" vmin={vmin:.4g}, vmax={vmax:.4g}." + return caption + + def get_dataset(self) -> xr.Dataset: + data = {} + preds = self._get_target_gen_pairs() + dims = ("lat", "lon") + for pred in preds: + bias_metadata = self._metadata.get( + pred.name, VariableMetadata(units="unknown_units", long_name=pred.name) + )._asdict() + gen_metadata = VariableMetadata(units="", long_name=pred.name)._asdict() + data.update( + { + f"bias_map-{pred.name}": xr.DataArray(pred.bias().cpu(), dims=dims, attrs=bias_metadata), + f"gen_map-{pred.name}": xr.DataArray( + pred.gen.cpu(), + dims=dims, + attrs=gen_metadata, + ), + f"target_map-{pred.name}": xr.DataArray( + pred.target.cpu(), + dims=dims, + attrs=gen_metadata, + ), + } + ) + return xr.Dataset(data) diff --git a/src/ace_inference/core/aggregator/inference/time_mean_salva.py b/src/ace_inference/core/aggregator/inference/time_mean_salva.py new file mode 100644 index 0000000..69b3de8 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/time_mean_salva.py @@ -0,0 +1,150 @@ +from typing import Dict, Mapping, Optional + +import numpy as np +import torch +import wandb +import xarray as xr + +from src.ace_inference.core import metrics +from src.ace_inference.core.data_loading.data_typing import VariableMetadata + + +def get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + + +class TimeMeanAggregator: + """Statistics on the time-mean state. + + This aggregator keeps track of the time-mean state, then computes + statistics on that time-mean state when logs are retrieved. + """ + + _image_captions = { + "bias_map": "{name} time-mean bias (generated - target) [{units}]", + "gen_map": "{name} time-mean generated [{units}]", + } + + def __init__( + self, + area_weights: torch.Tensor, + is_ensemble: bool = False, + sigma_coordinates=None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + self._area_weights = area_weights + self._target_data: Optional[Dict[str, torch.Tensor]] = None + self._gen_data: Optional[Dict[str, torch.Tensor]] = None + self._target_data_norm = None + self._gen_data_norm = None + self._n_batches = 0 + self._is_ensemble = is_ensemble + + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + ): + def add_or_initialize_time_mean( + maybe_dict: Optional[Dict[str, torch.Tensor]], + new_data: Mapping[str, torch.Tensor], + ) -> Mapping[str, torch.Tensor]: + if maybe_dict is None: + d: Dict[str, torch.Tensor] = {name: tensor for name, tensor in new_data.items()} + else: + d = maybe_dict + for name, tensor in new_data.items(): + d[name] += tensor + return d + + self._target_data = add_or_initialize_time_mean(self._target_data, target_data) + self._gen_data = add_or_initialize_time_mean(self._gen_data, gen_data) + self._n_batches += 1 + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + if self._n_batches == 0: + raise ValueError("No data recorded.") + area_weights = self._area_weights + logs = {} + # dist = Distributed.get_instance() + for name in self._gen_data.keys(): + gen = self._gen_data[name] / self._n_batches + target = self._target_data[name] / self._n_batches + # gen = dist.reduce_mean(self._gen_data[name] / self._n_batches) + # target = dist.reduce_mean(self._target_data[name] / self._n_batches) + if self._is_ensemble: + gen_ens_mean = gen.mean(dim=0) + logs[f"rmse_member_avg/{name}"] = np.mean( + [ + metrics.root_mean_squared_error(predicted=gen[i], truth=target, weights=area_weights) + .cpu() + .numpy() + for i in range(gen.shape[0]) + ] + ) + logs[f"bias_member_avg/{name}"] = np.mean( + [ + metrics.time_and_global_mean_bias(predicted=gen[i], truth=target, weights=area_weights) + .cpu() + .numpy() + for i in range(gen.shape[0]) + ] + ) + else: + gen_ens_mean = gen + + logs[f"rmse/{name}"] = float( + metrics.root_mean_squared_error(predicted=gen_ens_mean, truth=target, weights=area_weights) + .cpu() + .numpy() + ) + + logs[f"bias/{name}"] = float( + metrics.time_and_global_mean_bias(predicted=gen_ens_mean, truth=target, weights=area_weights) + .cpu() + .numpy() + ) + logs[f"crps/{name}"] = float( + metrics.weighted_crps(predicted=gen, truth=target, weights=area_weights).cpu().numpy() + ) + return {f"{label}/{key}": logs[key] for key in logs}, {} + + def _get_image(self, key: str, name: str, data: torch.Tensor): + sample_dim = 0 + lat_dim = -2 + data = data.mean(dim=sample_dim).flip(dims=[lat_dim]).cpu() + caption = self._get_caption(key, name, data) + return wandb.Image(data, caption=caption) + + def _get_caption(self, key: str, name: str, data: torch.Tensor) -> str: + if name in self._metadata: + caption_name = self._metadata[name].long_name + units = self._metadata[name].units + else: + caption_name, units = name, "unknown_units" + caption = self._image_captions[key].format(name=caption_name, units=units) + caption += f" vmin={data.min():.4g}, vmax={data.max():.4g}." + return caption + + @torch.no_grad() + def get_dataset(self, label: str) -> xr.Dataset: + logs = self.get_logs(label=label) + logs = {key.replace("/", "-"): logs[key] for key in logs} + data_vars = {} + for key, value in logs.items(): + data_vars[key] = xr.DataArray(value) + return xr.Dataset(data_vars=data_vars) diff --git a/src/ace_inference/core/aggregator/inference/video.py b/src/ace_inference/core/aggregator/inference/video.py new file mode 100644 index 0000000..c81f3f9 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/video.py @@ -0,0 +1,448 @@ +import dataclasses +from typing import Dict, Mapping, Optional, Tuple + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.wandb import WandB + + +wandb = WandB.get_instance() + + +def _get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + raise ValueError("No data in gen_data") + + +@dataclasses.dataclass +class _ErrorData: + rmse: Dict[str, torch.Tensor] + min_err: Dict[str, torch.Tensor] + max_err: Dict[str, torch.Tensor] + + +class _ErrorVideoData: + """ + Record batches of video data and compute statistics on the error. + """ + + def __init__(self, n_timesteps: int, dist: Optional[Distributed] = None): + self._mse_data: Optional[Dict[str, torch.Tensor]] = None + self._min_err_data: Optional[Dict[str, torch.Tensor]] = None + self._max_err_data: Optional[Dict[str, torch.Tensor]] = None + self._n_timesteps = n_timesteps + self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu() + if dist is None: + dist = Distributed.get_instance() + self._dist = dist + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + i_time_start: int, + ): + """ + Record a batch of data. + + Args: + target_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + i_time_start: Index of the first timestep in the batch. + """ + if self._mse_data is None: + self._mse_data = _initialize_video_from_batch(gen_data, self._n_timesteps) + if self._min_err_data is None: + self._min_err_data = _initialize_video_from_batch(gen_data, self._n_timesteps, fill_value=np.inf) + if self._max_err_data is None: + self._max_err_data = _initialize_video_from_batch(gen_data, self._n_timesteps, fill_value=-np.inf) + + window_steps = next(iter(target_data.values())).shape[1] + time_slice = slice(i_time_start, i_time_start + window_steps) + for name, gen_tensor in gen_data.items(): + target_tensor = target_data[name] + error_tensor = (gen_tensor - target_tensor).cpu() + self._mse_data[name][time_slice, ...] += torch.var(error_tensor, dim=0) + self._min_err_data[name][time_slice, ...] = torch.minimum( + self._min_err_data[name][time_slice, ...], error_tensor.min(dim=0)[0] + ) + self._max_err_data[name][time_slice, ...] = torch.maximum( + self._max_err_data[name][time_slice, ...], error_tensor.max(dim=0)[0] + ) + + self._n_batches[time_slice] += 1 + + @torch.no_grad() + def get( + self, + ) -> _ErrorData: + if self._mse_data is None or self._min_err_data is None or self._max_err_data is None: + raise RuntimeError("No data recorded") + rmse_data = {} + min_err_data = {} + max_err_data = {} + for name, tensor in self._mse_data.items(): + mse = (tensor / self._n_batches[None, :, None, None]).mean(dim=0) + mse = self._dist.reduce_mean(mse) + rmse_data[name] = torch.sqrt(mse) + for name, tensor in self._min_err_data.items(): + min_err_data[name] = self._dist.reduce_min(tensor) + for name, tensor in self._max_err_data.items(): + max_err_data[name] = self._dist.reduce_max(tensor) + return _ErrorData(rmse_data, min_err_data, max_err_data) + + +class _MeanVideoData: + """ + Record batches of video data and compute the mean. + """ + + def __init__(self, n_timesteps: int, dist: Optional[Distributed] = None): + self._target_data: Optional[Dict[str, torch.Tensor]] = None + self._gen_data: Optional[Dict[str, torch.Tensor]] = None + self._n_timesteps = n_timesteps + self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu() + if dist is None: + dist = Distributed.get_instance() + self._dist = dist + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + i_time_start: int, + ): + """ + Record a batch of data. + + Args: + target_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + i_time_start: Index of the first timestep in the batch. + """ + if self._target_data is None: + self._target_data = _initialize_video_from_batch(target_data, self._n_timesteps) + if self._gen_data is None: + self._gen_data = _initialize_video_from_batch(gen_data, self._n_timesteps) + + window_steps = next(iter(target_data.values())).shape[1] + time_slice = slice(i_time_start, i_time_start + window_steps) + for name, tensor in target_data.items(): + self._target_data[name][time_slice, ...] += tensor.mean(dim=0).cpu() + for name, tensor in gen_data.items(): + self._gen_data[name][time_slice, ...] += tensor.mean(dim=0).cpu() + + self._n_batches[time_slice] += 1 + + @torch.no_grad() + def get(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + if self._gen_data is None or self._target_data is None: + raise RuntimeError("No data recorded") + target_data = {} + gen_data = {} + for name, tensor in self._target_data.items(): + target_data[name] = tensor / self._n_batches[:, None, None] + target_data[name] = self._dist.reduce_mean(target_data[name]) + for name, tensor in self._gen_data.items(): + gen_data[name] = tensor / self._n_batches[:, None, None] + gen_data[name] = self._dist.reduce_mean(gen_data[name]) + return gen_data, target_data + + +class _VarianceVideoData: + """ + Record batches of video data and compute the variance. + """ + + def __init__(self, n_timesteps: int, dist: Optional[Distributed] = None): + self._target_means: Optional[Dict[str, torch.Tensor]] = None + self._gen_means: Optional[Dict[str, torch.Tensor]] = None + self._target_squares: Optional[Dict[str, torch.Tensor]] = None + self._gen_squares: Optional[Dict[str, torch.Tensor]] = None + self._n_timesteps = n_timesteps + self._n_batches = torch.zeros([n_timesteps], dtype=torch.int32).cpu() + if dist is None: + dist = Distributed.get_instance() + self._dist = dist + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + i_time_start: int, + ): + """ + Record a batch of data. + + Args: + target_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + gen_data: Dict of tensors of shape (n_samples, n_timesteps, ...) + i_time_start: Index of the first timestep in the batch. + """ + if self._target_means is None: + self._target_means = _initialize_video_from_batch(target_data, self._n_timesteps) + if self._gen_means is None: + self._gen_means = _initialize_video_from_batch(gen_data, self._n_timesteps) + if self._target_squares is None: + self._target_squares = _initialize_video_from_batch(target_data, self._n_timesteps) + + if self._gen_squares is None: + self._gen_squares = _initialize_video_from_batch(gen_data, self._n_timesteps) + + window_steps = next(iter(target_data.values())).shape[1] + time_slice = slice(i_time_start, i_time_start + window_steps) + for name, tensor in target_data.items(): + self._target_means[name][time_slice, ...] += tensor.mean(dim=0).cpu() + self._target_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu() + for name, tensor in gen_data.items(): + self._gen_means[name][time_slice, ...] += tensor.mean(dim=0).cpu() + self._gen_squares[name][time_slice, ...] += (tensor**2).mean(dim=0).cpu() + self._n_batches[time_slice] += 1 + + @torch.no_grad() + def get(self) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: + if ( + self._gen_means is None + or self._target_means is None + or self._gen_squares is None + or self._target_squares is None + ): + raise RuntimeError("No data recorded") + target_data = {} + gen_data = {} + # calculate variance as E[X^2] - E[X]^2 + for name, tensor in self._target_means.items(): + mean = tensor / self._n_batches[:, None, None] + mean = self._dist.reduce_mean(mean) + square = self._target_squares[name] / self._n_batches[:, None, None] + square = self._dist.reduce_mean(square) + target_data[name] = square - mean**2 + for name, tensor in self._gen_means.items(): + mean = tensor / self._n_batches[:, None, None] + mean = self._dist.reduce_mean(mean) + square = self._gen_squares[name] / self._n_batches[:, None, None] + square = self._dist.reduce_mean(square) + gen_data[name] = square - mean**2 + return gen_data, target_data + + +def _initialize_video_from_batch(batch: Mapping[str, torch.Tensor], n_timesteps: int, fill_value: float = 0.0): + """ + Initialize a video of the same shape as the batch, but with all valeus equal + to fill_value and with n_timesteps timesteps. + """ + video = {} + for name, value in batch.items(): + shape = list(value.shape[1:]) + shape[0] = n_timesteps + video[name] = torch.zeros(shape, dtype=torch.double).cpu() + video[name][:, ...] = fill_value + return video + + +@dataclasses.dataclass +class _MaybePairedVideoData: + caption: str + gen: torch.Tensor + target: Optional[torch.Tensor] = None + + def make_video(self): + return _make_video( + caption=self.caption, + gen=self.gen, + target=self.target, + ) + + +class VideoAggregator: + """Videos of state evolution.""" + + def __init__( + self, + n_timesteps: int, + enable_extended_videos: bool, + dist: Optional[Distributed] = None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + """ + Args: + n_timesteps: Number of timesteps of inference that will be run. + enable_extended_videos: Whether to log videos of statistical + metrics of state evolution + dist: Distributed object to use for metric aggregation. + metadata: Mapping of variable names their metadata that will + used in generating logged video captions. + """ + self._mean_data = _MeanVideoData(n_timesteps=n_timesteps, dist=dist) + if enable_extended_videos: + self._error_data: Optional[_ErrorVideoData] = _ErrorVideoData(n_timesteps=n_timesteps, dist=dist) + self._variance_data: Optional[_VarianceVideoData] = _VarianceVideoData(n_timesteps=n_timesteps, dist=dist) + self._enable_extended_videos = True + else: + self._error_data = None + self._variance_data = None + self._enable_extended_videos = False + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Optional[Mapping[str, torch.Tensor]] = None, + gen_data_norm: Optional[Mapping[str, torch.Tensor]] = None, + i_time_start: int = 0, + ): + del target_data_norm, gen_data_norm # intentionally unused + self._mean_data.record_batch( + target_data=target_data, + gen_data=gen_data, + i_time_start=i_time_start, + ) + if self._error_data is not None: + self._error_data.record_batch( + target_data=target_data, + gen_data=gen_data, + i_time_start=i_time_start, + ) + if self._variance_data is not None: + self._variance_data.record_batch( + target_data=target_data, + gen_data=gen_data, + i_time_start=i_time_start, + ) + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + data = self._get_data(label=label) + videos = {} + for label, d in data.items(): + videos[label] = d.make_video() + return videos + + @torch.no_grad() + def _get_data(self, label: str) -> Mapping[str, _MaybePairedVideoData]: + """ + Returns video data as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + gen_data, target_data = self._mean_data.get() + video_data = {} + for name in gen_data: + video_data[f"{label}/{name}"] = _MaybePairedVideoData( + caption=self._get_caption(name), + gen=gen_data[name], + target=target_data[name], + ) + if self._enable_extended_videos: + video_data[f"{label}/bias/{name}"] = _MaybePairedVideoData( + caption=(f"prediction - target for {name}"), + gen=gen_data[name] - target_data[name], + ) + if self._error_data is not None: + data = self._error_data.get() + for name in data.rmse: + video_data[f"{label}/rmse/{name}"] = _MaybePairedVideoData( + caption=f"RMSE over ensemble for {name}", + gen=data.rmse[name], + ) + for name in data.min_err: + video_data[f"{label}/min_err/{name}"] = _MaybePairedVideoData( + caption=f"Min across ensemble members of min error for {name}", + gen=data.min_err[name], + ) + for name in data.max_err: + video_data[f"{label}/max_err/{name}"] = _MaybePairedVideoData( + caption=f"Max across ensemble members of max error for {name}", + gen=data.max_err[name], + ) + if self._variance_data is not None: + gen_data, target_data = self._variance_data.get() + for name in gen_data: + video_data[f"{label}/gen_var/{name}"] = _MaybePairedVideoData( + caption=(f"Variance of gen data for {name} " "as fraction of target variance"), + gen=gen_data[name] / target_data[name], + ) + return video_data + + @torch.no_grad() + def get_dataset(self) -> xr.Dataset: + """ + Return video data as an xarray Dataset. + """ + data = self._get_data(label="") + video_data = {} + for label, d in data.items(): + label = label.strip("/").replace("/", "_") # remove leading slash + if d.target is not None: + video_data[label] = xr.DataArray( + data=np.concatenate( + [d.gen.cpu().numpy()[None, :], d.target.cpu().numpy()[None, :]], + axis=0, + ), + dims=("source", "timestep", "lat", "lon"), + ) + else: + video_data[label] = xr.DataArray(data=d.gen.cpu().numpy(), dims=("timestep", "lat", "lon")) + return xr.Dataset(video_data) + + def _get_caption(self, name: str) -> str: + caption = "Autoregressive (left) prediction and (right) target for {name} [{units}]" + if name in self._metadata: + caption_name = self._metadata[name].long_name + units = self._metadata[name].units + else: + caption_name, units = name, "unknown units" + return caption.format(name=caption_name, units=units) + + +def _make_video( + caption: str, + gen: torch.Tensor, + target: Optional[torch.Tensor] = None, +): + if target is None: + video_data = np.expand_dims(gen.cpu().numpy(), axis=1) + else: + gen = np.expand_dims(gen.cpu().numpy(), axis=1) + target = np.expand_dims(target.cpu().numpy(), axis=1) + gap = np.zeros([gen.shape[0], 1, gen.shape[2], 10]) + video_data = np.concatenate([gen, gap, target], axis=-1) + if target is None: + data_min = np.nanmin(video_data) + data_max = np.nanmax(video_data) + else: + # use target data to set the color scale + data_min = np.nanmin(target) + data_max = np.nanmax(target) + # video data is brightness values on a 0-255 scale + video_data = 255 * (video_data - data_min) / (data_max - data_min) + video_data = np.minimum(video_data, 255) + video_data = np.maximum(video_data, 0) + video_data[np.isnan(video_data)] = 0 + caption += f"; vmin={data_min:.4g}, vmax={data_max:.4g}" + return wandb.Video( + np.flip(video_data, axis=-2), + caption=caption, + fps=4, + ) diff --git a/src/ace_inference/core/aggregator/inference/zonal_mean.py b/src/ace_inference/core/aggregator/inference/zonal_mean.py new file mode 100644 index 0000000..7970e73 --- /dev/null +++ b/src/ace_inference/core/aggregator/inference/zonal_mean.py @@ -0,0 +1,129 @@ +from typing import Dict, Mapping, Optional + +import torch + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.device import get_device +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.wandb import WandB + + +wandb = WandB.get_instance() + + +class ZonalMeanAggregator: + """Images of the zonal-mean state as a function of latitude and time. + + This aggregator keeps track of the generated and target zonal-mean state, + then generates zonal-mean (Hovmoller) images when logs are retrieved. + The zonal-mean images are averaged across the sample dimension. + """ + + _captions = { + "error": ( + "{name} zonal-mean error (generated - target) [{units}], " + "x-axis is time increasing to right, y-axis is latitude increasing upward" + ), + "gen": ( + "{name} zonal-mean generated [{units}], " + "x-axis is time increasing to right, y-axis is latitude increasing upward" + ), + } + + def __init__( + self, + n_timesteps: int, + dist: Optional[Distributed] = None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + """ + Args: + n_timesteps: Number of timesteps of inference that will be run. + dist: Distributed object to use for communication. + metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + """ + self._n_timesteps = n_timesteps + if dist is None: + self._dist = Distributed.get_instance() + else: + self._dist = dist + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + + self._target_data: Optional[Dict[str, torch.Tensor]] = None + self._gen_data: Optional[Dict[str, torch.Tensor]] = None + self._n_batches = torch.zeros(n_timesteps, dtype=torch.int32, device=get_device())[ + None, :, None + ] # sample, time, lat + + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int, + ): + lon_dim = 3 + if self._target_data is None: + self._target_data = self._initialize_zeros_zonal_mean_from_batch(target_data, self._n_timesteps) + if self._gen_data is None: + self._gen_data = self._initialize_zeros_zonal_mean_from_batch(gen_data, self._n_timesteps) + + window_steps = next(iter(target_data.values())).shape[1] + time_slice = slice(i_time_start, i_time_start + window_steps) + # we can average along longitude without area weighting + for name, tensor in target_data.items(): + self._target_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) + for name, tensor in gen_data.items(): + self._gen_data[name][:, time_slice, :] += tensor.mean(dim=lon_dim) + self._n_batches[:, time_slice, :] += 1 + + def get_logs(self, label: str) -> Dict[str, torch.Tensor]: + if self._gen_data is None or self._target_data is None: + raise RuntimeError("No data recorded") + sample_dim = 0 + logs = {} + for name in self._gen_data.keys(): + zonal_means = {} + gen = self._dist.reduce_mean(self._gen_data[name] / self._n_batches) + zonal_means["gen"] = gen.mean(dim=sample_dim).cpu() + error = self._dist.reduce_mean((self._gen_data[name] - self._target_data[name]) / self._n_batches) + zonal_means["error"] = error.mean(dim=sample_dim).cpu() + for key, data in zonal_means.items(): + caption = self._get_caption(key, name, data) + # images are y, x from upper left corner + # data is time, lat + # we want lat on y-axis (increasing upward) and time on x-axis + # so transpose and flip along lat axis + data = data.t().flip(dims=[0]) + wandb_image = wandb.Image(data, caption=caption) + logs[f"{label}/{key}/{name}"] = wandb_image + return logs + + def _get_caption(self, caption_key: str, varname: str, data: torch.Tensor) -> str: + if varname in self._metadata: + caption_name = self._metadata[varname].long_name + units = self._metadata[varname].units + else: + caption_name, units = varname, "unknown_units" + caption = self._captions[caption_key].format(name=caption_name, units=units) + caption += f" vmin={data.min():.4g}, vmax={data.max():.4g}." + return caption + + @staticmethod + def _initialize_zeros_zonal_mean_from_batch( + data: Mapping[str, torch.Tensor], n_timesteps: int, lat_dim: int = 2 + ) -> Dict[str, torch.Tensor]: + return { + name: torch.zeros( + (tensor.shape[0], n_timesteps, tensor.shape[lat_dim]), + dtype=tensor.dtype, + device=tensor.device, + ) + for name, tensor in data.items() + } diff --git a/src/ace_inference/core/aggregator/null.py b/src/ace_inference/core/aggregator/null.py new file mode 100644 index 0000000..dff759a --- /dev/null +++ b/src/ace_inference/core/aggregator/null.py @@ -0,0 +1,32 @@ +from typing import Mapping + +import torch + + +class NullAggregator: + """ + An aggregator that does nothing. Null object pattern. + """ + + def __init__(self): + pass + + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int = 0, + ): + pass + + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + return {} diff --git a/src/ace_inference/core/aggregator/one_step/__init__.py b/src/ace_inference/core/aggregator/one_step/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/core/aggregator/one_step/derived.py b/src/ace_inference/core/aggregator/one_step/derived.py new file mode 100644 index 0000000..79a67a6 --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/derived.py @@ -0,0 +1,132 @@ +"""Derived metrics take the global state as input and usually output a new +variable, e.g. dry air mass.""" + +from dataclasses import dataclass +from typing import Dict, Mapping, Optional, Protocol, Tuple + +import torch + +from src.ace_inference.core import metrics +from src.ace_inference.core.aggregator.climate_data import CLIMATE_FIELD_NAME_PREFIXES, ClimateData +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates +from src.ace_inference.core.device import get_device + + +@dataclass +class _TargetGenPair: + target: torch.Tensor + gen: torch.Tensor + + +class DerivedMetric(Protocol): + """Derived metrics are computed from the global state and usually output a + new variable, e.g. dry air tendencies.""" + + def record(self, target: ClimateData, gen: ClimateData) -> None: ... + + def get(self) -> _TargetGenPair: + """Returns the derived metric applied to the target and data generated + by the model.""" + ... + + +class DryAir(DerivedMetric): + """Computes absolute value of the dry air tendency of the first time step, + averaged over the batch. If the data does not contain the required fields, + then returns NaN.""" + + def __init__( + self, + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + device: torch.device, + spatial_dims=(2, 3), + ): + self._area_weights = area_weights + self._sigma_coordinates = sigma_coordinates + self._dry_air_target_total: Optional[torch.Tensor] = None + self._dry_air_gen_total: Optional[torch.Tensor] = None + self._device = device + self._spatial_dims: Tuple[int, int] = spatial_dims + + def record(self, target: ClimateData, gen: ClimateData) -> None: + def _compute_dry_air_helper(climate_data: ClimateData) -> torch.Tensor: + water = climate_data.specific_total_water + pressure = climate_data.surface_pressure + if water is None or pressure is None: + return torch.tensor(torch.nan) + return ( + metrics.weighted_mean( + metrics.surface_pressure_due_to_dry_air( + water[:, 0:2, ...], # (sample, time, y, x, level) + pressure[:, 0:2, ...], + self._sigma_coordinates.ak, + self._sigma_coordinates.bk, + ), + self._area_weights, + dim=(2, 3), + ) + .diff(dim=-1) + .abs() + .mean() + ) + + dry_air_target = _compute_dry_air_helper(target) + dry_air_gen = _compute_dry_air_helper(gen) + + # initialize + if self._dry_air_target_total is None: + self._dry_air_target_total = torch.zeros_like(dry_air_target, device=self._device) + if self._dry_air_gen_total is None: + self._dry_air_gen_total = torch.zeros_like(dry_air_gen, device=self._device) + + self._dry_air_target_total += dry_air_target + self._dry_air_gen_total += dry_air_gen + + def get(self) -> _TargetGenPair: + if self._dry_air_target_total is None or self._dry_air_gen_total is None: + raise ValueError("No batches have been recorded.") + return _TargetGenPair(target=self._dry_air_target_total, gen=self._dry_air_gen_total) + + +class DerivedMetricsAggregator: + def __init__( + self, + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + climate_field_name_prefixes: Mapping[str, str] = CLIMATE_FIELD_NAME_PREFIXES, + ): + self.area_weights = area_weights + self.sigma_coordinates = sigma_coordinates + self.climate_field_name_prefixes = climate_field_name_prefixes + device = get_device() + self._derived_metrics: Dict[str, DerivedMetric] = { + "surface_pressure_due_to_dry_air": DryAir(self.area_weights, self.sigma_coordinates, device=device) + } + self._n_batches = 0 + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ): + del target_data_norm, gen_data_norm # unused + target = ClimateData(target_data, self.climate_field_name_prefixes) + gen = ClimateData(gen_data, self.climate_field_name_prefixes) + + for metric_fn in self._derived_metrics.values(): + metric_fn.record(target, gen) + + # only increment n_batches if we actually recorded a batch + self._n_batches += 1 + + def get_logs(self, label: str): + logs = dict() + for metric_name in self._derived_metrics: + values = self._derived_metrics[metric_name].get() + logs[f"{label}/{metric_name}/target"] = values.target / self._n_batches + logs[f"{label}/{metric_name}/gen"] = values.gen / self._n_batches + return logs diff --git a/src/ace_inference/core/aggregator/one_step/main.py b/src/ace_inference/core/aggregator/one_step/main.py new file mode 100644 index 0000000..203eb4c --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/main.py @@ -0,0 +1,97 @@ +import inspect +from typing import Mapping, Optional, Protocol + +import torch + +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates, VariableMetadata + +from .reduced_salva import MeanAggregator +from .snapshot import SnapshotAggregator + + +class _Aggregator(Protocol): + def get_logs(self, label: str) -> Mapping[str, torch.Tensor]: ... + + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ) -> None: ... + + +class OneStepAggregator: + """ + Aggregates statistics for the first timestep. + + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__( + self, + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + is_ensemble: bool, + use_snapshot_aggregator: bool = True, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + self._snapshot = ( + SnapshotAggregator(is_ensemble=is_ensemble, metadata=metadata) if use_snapshot_aggregator else None + ) + self._mean = MeanAggregator(area_weights=area_weights, is_ensemble=is_ensemble) + self._aggregators: Mapping[str, _Aggregator] = { + "snapshot": self._snapshot, + "mean": self._mean, + # "derived": DerivedMetricsAggregator(area_weights, sigma_coordinates) + } + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + inputs_norm: Mapping[str, torch.Tensor] = None, + ): + if len(target_data) == 0: + raise ValueError("No data in target_data") + if len(gen_data) == 0: + raise ValueError("No data in gen_data") + for aggregator in self._aggregators.values(): + if aggregator is None: + continue + kwargs = {} + if "inputs_norm" in inspect.signature(aggregator.record_batch).parameters: + kwargs["inputs_norm"] = inputs_norm + + aggregator.record_batch( + target_data=target_data, + gen_data=gen_data, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + **kwargs, + ) + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + logs = {f"{label}/{key}": val for key, val in self._mean.get_logs(label="").items()} + if self._snapshot is not None: + logs_media = self._snapshot.get_logs(label="snapshot") + logs_media = {f"{label}/{key}": val for key, val in logs_media.items()} + else: + logs_media = {} + for agg_label, agg in self._aggregators.items(): + if agg is None or agg_label in ["mean", "snapshot"]: + continue + logs.update({f"{label}/{key}": float(val) for key, val in agg.get_logs(label=agg_label).items()}) + return logs, logs_media diff --git a/src/ace_inference/core/aggregator/one_step/reduced.py b/src/ace_inference/core/aggregator/one_step/reduced.py new file mode 100644 index 0000000..456b730 --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/reduced.py @@ -0,0 +1,156 @@ +from collections import defaultdict +from typing import Dict, Mapping, Optional + +import torch +import xarray as xr +from torch import nn + +from src.ace_inference.core import metrics +from src.ace_inference.core.distributed import Distributed + +from ..inference.reduced import compute_metric_on +from .reduced_metrics import AreaWeightedReducedMetric, ReducedMetric + + +def get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + + +class L1Loss: + def __init__(self, device: torch.device): + self._total = torch.tensor(0.0, device=device) + + def record(self, target: torch.Tensor, gen: torch.Tensor): + self._total += nn.functional.l1_loss( + gen, + target, + ) + + def get(self) -> torch.Tensor: + return self._total + + +class MeanAggregator: + """ + Aggregator for mean-reduced metrics. + + These are metrics such as means which reduce to a single float for each batch, + and then can be averaged across batches to get a single float for the + entire dataset. This is important because the aggregator uses the mean to combine + metrics across batches and processors. + """ + + def __init__( + self, + area_weights: torch.Tensor, + target_time: int = 1, + is_ensemble: bool = False, + dist: Optional[Distributed] = None, + device: torch.device = torch.device("cpu"), + ): + self._area_weights = area_weights + self._shape_x = None + self._shape_y = None + self._n_batches = 0 + self.device = device + self._loss = torch.tensor(0.0, device=self.device) + self._variable_metrics: Optional[Dict[str, Dict[str, ReducedMetric]]] = None + self._target_time = target_time + self.is_ensemble = is_ensemble + if dist is None: + self._dist = Distributed.get_instance() + else: + self._dist = dist + + def _get_variable_metrics(self, gen_data: Mapping[str, torch.Tensor]): + if self._variable_metrics is None: + self._variable_metrics = defaultdict(dict) + + area_weights = self._area_weights + for key in gen_data.keys(): + metrics_zipped = [ + ("weighted_rmse", metrics.root_mean_squared_error), + ("weighted_bias", metrics.weighted_mean_bias), + ("weighted_grad_mag_percent_diff", metrics.gradient_magnitude_percent_diff), + ("weighted_mean_gen", compute_metric_on(source="gen", metric=metrics.weighted_mean)), + ] + if self.is_ensemble: + metrics_zipped += [ + ("weighted_crps", metrics.weighted_crps), + ("weighted_ssr", metrics.spread_skill_ratio), + ] + + for i, (metric_name, metric) in enumerate(metrics_zipped): + self._variable_metrics[metric_name][key] = AreaWeightedReducedMetric( + area_weights=area_weights, + device=self.device, + compute_metric=metric, + ) + + return self._variable_metrics + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + i_time_start: int = 0, + ): + self._loss += loss + variable_metrics = self._get_variable_metrics(gen_data) + time_dim = 1 + time_dim_gen = 2 if self.is_ensemble else time_dim + time_len = gen_data[list(gen_data.keys())[0]].shape[time_dim_gen] + target_time = self._target_time - i_time_start + if target_time >= 0 and time_len > target_time: + for name in gen_data.keys(): + target = target_data[name].select(dim=time_dim, index=target_time) + gen_full = gen_data[name].select(dim=time_dim_gen, index=target_time) + if self.is_ensemble: + ensemble_mean = gen_full.mean(dim=0) + else: + ensemble_mean = gen_full + for metric in variable_metrics: + kwargs = {} + if "ssr" in metric or "crps" in metric: + gen = gen_full + elif "grad_mag" in metric: + gen = gen_full + kwargs["is_ensemble_prediction"] = self.is_ensemble + else: + gen = ensemble_mean + + variable_metrics[metric][name].record(target=target, gen=gen, **kwargs) + # only increment n_batches if we actually recorded a batch + self._n_batches += 1 + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + if self._variable_metrics is None or self._n_batches == 0: + raise ValueError("No batches have been recorded.") + logs = {f"{label}/loss": self._loss / self._n_batches} + for metric in self._variable_metrics: + for key in self._variable_metrics[metric]: + logs[f"{label}/{metric}/{key}"] = self._variable_metrics[metric][key].get() / self._n_batches + for key in sorted(logs.keys()): + logs[key] = float(self._dist.reduce_mean(logs[key].detach()).cpu().numpy()) + return logs + + @torch.no_grad() + def get_dataset(self, label: str) -> xr.Dataset: + logs = self.get_logs(label=label) + logs = {key.replace("/", "-"): logs[key] for key in logs} + data_vars = {} + for key, value in logs.items(): + data_vars[key] = xr.DataArray(value) + return xr.Dataset(data_vars=data_vars) diff --git a/src/ace_inference/core/aggregator/one_step/reduced_metrics.py b/src/ace_inference/core/aggregator/one_step/reduced_metrics.py new file mode 100644 index 0000000..f1338b0 --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/reduced_metrics.py @@ -0,0 +1,78 @@ +""" +This file contains code for computing metrics of single variables on batches of data, +and aggregating them into a single metric value. The functions here mainly exist +to turn metric functions that may have different APIs into a common API, +so that they can be iterated over and called in the same way in a loop. +""" + +from typing import Optional, Protocol + +import torch + +from src.ace_inference.core.metrics import Dimension + + +class AreaWeightedFunction(Protocol): + """ + A function that computes a metric on the true and predicted values, + weighted by area. + """ + + def __call__( + self, + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +class ReducedMetric(Protocol): + """Used to record a metric value on batches of data (potentially out-of-memory) + and then get the total metric at the end. + """ + + def record(self, target: torch.Tensor, gen: torch.Tensor): + """ + Update metric for a batch of data. + """ + ... + + def get(self) -> torch.Tensor: + """ + Get the total metric value, not divided by number of recorded batches. + """ + ... + + +class AreaWeightedReducedMetric: + """ + A wrapper around an area-weighted metric function. + """ + + def __init__( + self, + area_weights: torch.Tensor, + device: torch.device, + compute_metric: AreaWeightedFunction, + ): + self._area_weights = area_weights + self._compute_metric = compute_metric + self._total = None + self._device = device + + def record(self, target: torch.Tensor, gen: torch.Tensor, **kwargs): + """Add a batch of data to the metric. + + Args: + target: Target data. Should have shape [batch, time, height, width]. + gen: Generated data. Should have shape [batch, time, height, width]. + """ + new_value = self._compute_metric(target, gen, weights=self._area_weights, dim=(-2, -1), **kwargs).mean(dim=0) + if self._total is None: + self._total = torch.zeros_like(new_value, device=self._device) + self._total += new_value + + def get(self) -> torch.Tensor: + """Returns the metric.""" + return self._total diff --git a/src/ace_inference/core/aggregator/one_step/reduced_salva.py b/src/ace_inference/core/aggregator/one_step/reduced_salva.py new file mode 100644 index 0000000..247df34 --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/reduced_salva.py @@ -0,0 +1,136 @@ +from collections import defaultdict +from typing import Dict, Mapping, Optional + +import torch +import xarray as xr +from torch import nn + +from src.ace_inference.core import metrics + +from ..reduced_metrics import AreaWeightedReducedMetric, ReducedMetric + + +class AbstractMeanMetric: + def __init__(self, device: torch.device): + self._total = torch.tensor(0.0, device=device) + + def get(self) -> torch.Tensor: + return self._total + + +class L1Loss(AbstractMeanMetric): + def record(self, targets: torch.Tensor, preds: torch.Tensor): + self._total += nn.functional.l1_loss(preds, targets) + + +class SpreadSkillRatio(AbstractMeanMetric): + def record(self, targets: torch.Tensor, preds: torch.Tensor): + rmse = nn.functional.mse_loss(preds.mean(dim=0), targets, reduction="mean").sqrt() + # calculate spread over ensemble dim + spread = preds.var(dim=0).mean().sqrt() + self._total += spread / rmse + + +class MeanAggregator: + """ + Aggregator for mean-reduced metrics. + + These are metrics such as means which reduce to a single float for each batch, + and then can be averaged across batches to get a single float for the + entire dataset. This is important because the aggregator uses the mean to combine + metrics across batches and processors. + """ + + def __init__(self, area_weights: torch.Tensor, is_ensemble: bool): + self._area_weights = area_weights + self._n_batches = 0 + self._variable_metrics: Optional[Dict[str, Dict[str, ReducedMetric]]] = None + self.is_ensemble = is_ensemble + + def _get_variable_metrics(self, gen_data: Mapping[str, torch.Tensor]): + if self._variable_metrics is None: + self._variable_metrics = defaultdict(dict) + any_gen_data = gen_data[list(gen_data.keys())[0]] + self.device = any_gen_data.device + area_weights = self._area_weights.to(self.device) + + for var_name in gen_data.keys(): + self._variable_metrics["l1"][var_name] = L1Loss(device=self.device) + metrics_zipped = [ + ("weighted_rmse", metrics.root_mean_squared_error), + ("weighted_bias", metrics.weighted_mean_bias), + ("weighted_grad_mag_percent_diff", metrics.gradient_magnitude_percent_diff), + ] + if self.is_ensemble: + self._variable_metrics["ssr"][var_name] = SpreadSkillRatio(device=self.device) + metrics_zipped += [("weighted_crps", metrics.weighted_crps)] + + for i, (metric_name, metric) in enumerate(metrics_zipped): + self._variable_metrics[metric_name][var_name] = AreaWeightedReducedMetric( + area_weights=area_weights, device=self.device, compute_metric=metric + ) + + return self._variable_metrics + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor] = None, + gen_data_norm: Mapping[str, torch.Tensor] = None, + inputs_norm: Mapping[str, torch.Tensor] = None, + ): + variable_metrics = self._get_variable_metrics(gen_data) + if self.is_ensemble: + ensemble_mean = {name: member_preds.mean(dim=0) for name, member_preds in gen_data.items()} + else: + ensemble_mean = gen_data + + for name in gen_data.keys(): # e.g. temperature, precipitation, etc + for metric in variable_metrics: # e.g. l1, weighted_rmse, etc + kwargs = {} + # compute gradf mag differently, and potentially rmse + if "ssr" in metric or "crps" in metric: + pred = gen_data[name] + elif "grad_mag" in metric: + pred = gen_data[name] + kwargs["is_ensemble_prediction"] = self.is_ensemble + else: + pred = ensemble_mean[name] + + # time_s = time.time() + variable_metrics[metric][name].record(targets=target_data[name], preds=pred, **kwargs) + # time_taken = time.time() - time_s + # print(f"Time taken for {metric} {name} in s: {time_taken:.5f}") + self._n_batches += 1 + + @torch.no_grad() + def get_logs(self, label: str = ""): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + if self._variable_metrics is None or self._n_batches == 0: + raise ValueError("No batches have been recorded.") + logs = {} + label = label + "/" if label else "" + for metric in self._variable_metrics: + for key in self._variable_metrics[metric]: + logs[f"{label}{metric}/{key}"] = (self._variable_metrics[metric][key].get() / self._n_batches).detach() + # dist = Distributed.get_instance() + for key in sorted(logs.keys()): + logs[key] = float(logs[key].cpu()) # .numpy() + # logs[key] = float(dist.reduce_mean(logs[key]).cpu().numpy()) + return logs + + @torch.no_grad() + def get_dataset(self, label: str) -> xr.Dataset: + logs = self.get_logs(label=label) + logs = {key.replace("/", "-"): logs[key] for key in logs} + data_vars = {} + for key, value in logs.items(): + data_vars[key] = xr.DataArray(value) + return xr.Dataset(data_vars=data_vars) diff --git a/src/ace_inference/core/aggregator/one_step/snapshot.py b/src/ace_inference/core/aggregator/one_step/snapshot.py new file mode 100644 index 0000000..c4ccbc3 --- /dev/null +++ b/src/ace_inference/core/aggregator/one_step/snapshot.py @@ -0,0 +1,161 @@ +from typing import Mapping, Optional + +import numpy as np +import torch + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.wandb import WandB + + +wandb = WandB.get_instance() + + +class SnapshotAggregator: + """ + An aggregator that records the first sample of the last batch of data. + > The way it works is that it gets called once per batch, but in the end (when using get_logs) + it only returns information based on the last batch. + """ + + _captions = { + "full-field": ("{name} one step full field for last sample; " "(left) generated and (right) target [{units}]"), + "residual": ( + "{name} one step residual (prediction - previous time) for last sample; " + "(left) generated and (right) target [{units}]" + ), + "error": ("{name} one step full field error (generated - target) " "for last sample [{units}]"), + } + + def __init__( + self, + is_ensemble: bool, + target_time: Optional[int] = None, + metadata: Optional[Mapping[str, VariableMetadata]] = None, + ): + """ + Args: + metadata: Mapping of variable names their metadata that will + used in generating logged image captions. + """ + self.is_ensemble = is_ensemble + assert target_time is None or target_time > 0 + self.target_time = target_time # account for 0-indexing not needed because initial condition is included + self.target_time_in_batch = None + if metadata is None: + self._metadata: Mapping[str, VariableMetadata] = {} + else: + self._metadata = metadata + + @torch.no_grad() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + inputs_norm: Mapping[str, torch.Tensor] = None, + loss=None, + i_time_start: int = 0, + ): + data_steps = target_data_norm[list(target_data_norm.keys())[0]].shape[1] + if self.target_time is not None: + diff = self.target_time - i_time_start + # target time needs to be in the batch (between i_time_start and i_time_start + data_steps) + if diff < 0 or diff >= data_steps: + return # skip this batch, since it doesn't contain the target time + else: + self.target_time_in_batch = diff + + def to_cpu(x): + return {k: v.cpu() for k, v in x.items()} if isinstance(x, dict) else x.cpu() + + self._target_data = to_cpu(target_data) + self._gen_data = to_cpu(gen_data) + self._target_data_norm = to_cpu(target_data_norm) + self._gen_data_norm = to_cpu(gen_data_norm) + self._inputs_norm = to_cpu(inputs_norm) if inputs_norm is not None else None + if self.target_time is not None: + assert ( + self.target_time_in_batch <= data_steps + ), f"target_time={self.target_time}, time_in_batch={self.target_time_in_batch} is larger than the number of timesteps in the data={data_steps}!" + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + if self.target_time_in_batch is None and self.target_time is not None: + return {} # skip this batch, since it doesn't contain the target time + image_logs = {} + max_snapshots = 3 + for name in self._gen_data.keys(): + if name in self._gen_data_norm.keys(): + gen_data = self._gen_data_norm + target_data = self._target_data_norm + else: + gen_data = self._gen_data + target_data = self._target_data + if self.is_ensemble: + snapshots_pred = gen_data[name][:max_snapshots, 0] + else: + snapshots_pred = gen_data[name][0].unsqueeze(0) + target_for_image = target_data[name][0] # first sample in batch + small_gap = torch.zeros((target_for_image.shape[-2], 2)).to(snapshots_pred.device, dtype=torch.float) + gap = torch.zeros((target_for_image.shape[-2], 4)).to( + snapshots_pred.device, dtype=torch.float + ) # gap between images in wandb (so we can see them separately) + input_for_image = ( + self._inputs_norm[name][0] + if self._inputs_norm is not None and name in self._inputs_norm.keys() + else None + ) + # Select target time + if self.target_time is not None: + snapshots_pred = snapshots_pred[:, self.target_time_in_batch] + target_for_image = target_for_image[self.target_time_in_batch] + if input_for_image is not None: + input_for_image = input_for_image[self.target_time_in_batch] + + # Create image tensors + image_error, image_full_field, image_residual = [], [], [] + for i in range(snapshots_pred.shape[0]): + image_full_field += [snapshots_pred[i]] + image_error += [snapshots_pred[i] - target_for_image] + if input_for_image is not None: + image_residual += [snapshots_pred[i] - input_for_image] + if i == snapshots_pred.shape[0] - 1: + image_full_field += [gap, target_for_image] + if input_for_image is not None: + image_residual += [gap, target_for_image - input_for_image] + else: + image_full_field += [small_gap] + image_residual += [small_gap] + image_error += [small_gap] + + images = {} + images["error"] = torch.cat(image_error, dim=1) + images["full-field"] = torch.cat(image_full_field, dim=1) + if input_for_image is not None: + images["residual"] = torch.cat(image_residual, dim=1) + + for key, data in images.items(): + caption = self._get_caption(key, name, data) + data = np.flip(data.cpu().numpy(), axis=-2) + wandb_image = wandb.Image(data, caption=caption) + image_logs[f"image-{key}/{name}"] = wandb_image + + image_logs = {f"{label}/{key}": image_logs[key] for key in image_logs} + return image_logs + + def _get_caption(self, caption_key: str, name: str, data: torch.Tensor) -> str: + if name in self._metadata: + caption_name = self._metadata[name].long_name + units = self._metadata[name].units + else: + caption_name, units = name, "unknown_units" + caption = self._captions[caption_key].format(name=caption_name, units=units) + caption += f" vmin={data.min():.4g}, vmax={data.max():.4g}." + return caption diff --git a/src/ace_inference/core/aggregator/plotting.py b/src/ace_inference/core/aggregator/plotting.py new file mode 100644 index 0000000..198c48d --- /dev/null +++ b/src/ace_inference/core/aggregator/plotting.py @@ -0,0 +1,33 @@ +from typing import Optional, Tuple + +import matplotlib.pyplot as plt +import numpy as np + + +def get_cmap_limits(data: np.ndarray, diverging=False) -> Tuple[float, float]: + vmin = data.min() + vmax = data.max() + if diverging: + vmax = max(abs(vmin), abs(vmax)) + vmin = -vmax + return vmin, vmax + + +def plot_imshow( + data: np.ndarray, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + cmap: Optional[str] = None, + flip_lat: bool = True, +) -> plt.figure: + """Plot a 2D array using imshow, ensuring figure size is same as array size.""" + if flip_lat: + lat_dim = -2 + data = np.flip(data, axis=lat_dim) + # make figure size (in pixels) be the same as array size + figsize = np.array(data.T.shape) / plt.rcParams["figure.dpi"] + fig = plt.figure(figsize=figsize) + ax = fig.add_axes([0, 0, 1, 1]) + ax.imshow(data, cmap=cmap, vmin=vmin, vmax=vmax) + ax.set_axis_off() + return fig diff --git a/src/ace_inference/core/aggregator/reduced_metrics.py b/src/ace_inference/core/aggregator/reduced_metrics.py new file mode 100644 index 0000000..162a7e1 --- /dev/null +++ b/src/ace_inference/core/aggregator/reduced_metrics.py @@ -0,0 +1,118 @@ +""" +This file contains code for computing metrics of single variables on batches of data, +and aggregating them into a single metric value. The functions here mainly exist +to turn metric functions that may have different APIs into a common API, +so that they can be iterated over and called in the same way in a loop. +""" + +from typing import Literal, Optional, Protocol + +import torch + +from src.ace_inference.core.metrics import Dimension + + +class ReducedMetric(Protocol): + """Used to record a metric value on batches of data (potentially out-of-memory) + and then get the total metric at the end. + """ + + def record(self, target: torch.Tensor, gen: torch.Tensor): + """ + Update metric for a batch of data. + """ + ... + + def get(self) -> torch.Tensor: + """ + Get the total metric value, not divided by number of recorded batches. + """ + ... + + +class AreaWeightedFunction(Protocol): + """ + A function that computes a metric on the true and predicted values, + weighted by area. + """ + + def __call__( + self, + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +class AreaWeightedSingleTargetFunction(Protocol): + """ + A function that computes a metric on a single value, weighted by area. + """ + + def __call__( + self, + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +def compute_metric_on( + source: Literal["preds", "targets"], metric: AreaWeightedSingleTargetFunction +) -> AreaWeightedFunction: + """Turns a single-target metric function + (computed on only the generated or target data) into a function that takes in + both the generated and target data as arguments, as required for the APIs + which call generic metric functions. + """ + + def metric_wrapper( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: + if source == "preds": + return metric(predicted, weights=weights, dim=dim) + elif source == "targets": + return metric(truth, weights=weights, dim=dim) + + return metric_wrapper + + +class AreaWeightedReducedMetric: + """ + A wrapper around an area-weighted metric function. + """ + + def __init__( + self, + area_weights: torch.Tensor, + device: torch.device, + compute_metric: AreaWeightedFunction, + ): + self._area_weights = area_weights + self._compute_metric = compute_metric + self._total = None + self._device = device + + def record(self, targets: torch.Tensor, preds: torch.Tensor, batch_dim: int = 0, **kwargs): + """Add a batch of data to the metric. + + Args: + targets: Target data. Should have shape [batch, time, height, width]. + preds: Generated data. Should have shape [batch, time, height, width]. + batch_dim: The dimension of the batch axis over which to average the metric. + """ + self._area_weights = self._area_weights.to(targets.device) + new_value = self._compute_metric(targets, preds, weights=self._area_weights, dim=(-2, -1), **kwargs).mean( + dim=batch_dim + ) + if self._total is None: + self._total = torch.zeros_like(new_value, device=targets.device) + self._total += new_value # + + def get(self) -> torch.Tensor: + """Returns the metric.""" + return self._total diff --git a/src/ace_inference/core/aggregator/train.py b/src/ace_inference/core/aggregator/train.py new file mode 100644 index 0000000..5d0614f --- /dev/null +++ b/src/ace_inference/core/aggregator/train.py @@ -0,0 +1,43 @@ +from typing import Mapping + +import torch + +from src.ace_inference.core.device import get_device +from src.ace_inference.core.distributed import Distributed + + +class TrainAggregator: + """ + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__(self): + self._n_batches = 0 + self._loss = torch.tensor(0.0, device=get_device()) + + @torch.no_grad() + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ): + self._loss += loss + self._n_batches += 1 + + @torch.no_grad() + def get_logs(self, label: str): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + logs = {f"{label}/mean/loss": self._loss / self._n_batches} + dist = Distributed.get_instance() + for key in sorted(logs.keys()): + logs[key] = float(dist.reduce_mean(logs[key].detach()).cpu().numpy()) + return logs diff --git a/src/ace_inference/core/constants.py b/src/ace_inference/core/constants.py new file mode 100644 index 0000000..14041e0 --- /dev/null +++ b/src/ace_inference/core/constants.py @@ -0,0 +1,6 @@ +LATENT_HEAT_OF_VAPORIZATION = 2.5e6 # J/kg +GRAVITY = 9.80665 # m/s^2 +TIMESTEP_SECONDS = 6 * 60 * 60 # 6 hours +# following values are used by SHiELD's slab ocean model, and so we follow suit here. +SPECIFIC_HEAT_OF_WATER = 4000.0 # J/kg/K +DENSITY_OF_WATER = 1000.0 # kg/m^3 diff --git a/src/ace_inference/core/corrector.py b/src/ace_inference/core/corrector.py new file mode 100644 index 0000000..c9828b1 --- /dev/null +++ b/src/ace_inference/core/corrector.py @@ -0,0 +1,296 @@ +import dataclasses +from typing import Dict, Literal, Mapping, Optional + +import torch + +from src.ace_inference.core import metrics +from src.ace_inference.core.aggregator.climate_data import ClimateData +from src.ace_inference.core.constants import TIMESTEP_SECONDS +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates +from src.ace_inference.core.device import get_device + + +@dataclasses.dataclass +class CorrectorConfig: + """ + Configuration for the post-step state corrector. + + conserve_dry_air enforces the constraint that: + + global_dry_air = global_mean( + ps - sum_k((ak_diff + bk_diff * ps) * wat_k) + ) + + in the generated data is equal to its value in the input data. This is done + by adding a globally-constant correction to the surface pressure in each + column. As per-mass values such as mixing ratios of water are unchanged, + this can cause changes in total water or energy. Note all global means here + are area-weighted. + + zero_global_mean_moisture_advection enforces the constraint that: + + global_mean(tendency_of_total_water_path_due_to_advection) = 0 + + in the generated data. This is done by adding a globally-constant correction + to the moisture advection tendency in each column. + + moisture_budget_correction enforces closure of the moisture budget equation: + + tendency_of_total_water_path = ( + evaporation_rate - precipitation_rate + + tendency_of_total_water_path_due_to_advection + ) + + in the generated data, where tendency_of_total_water_path is the difference + between the total water path at the current timestep and the previous + timestep divided by the time difference. This is done by modifying the + precipitation, evaporation, and/or moisture advection tendency fields as + described in the moisture_budget_correction attribute. When + advection tendency is modified, this budget equation is enforced in each + column, while when only precipitation or evaporation are modified, only + the global mean of the budget equation is enforced. + + When enforcing moisture budget closure, we assume the global mean moisture + advection is zero. Therefore zero_global_mean_moisture_advection must be + True if using a moisture_budget_correction option other tha None. + + Attributes: + conserve_dry_air: If True, force the generated data to conserve dry air + by subtracting a constant offset from the surface pressure of each + column. This can cause changes in per-mass values such as total water + or energy. + zero_global_mean_moisture_advection: If True, force the generated data to + have zero global mean moisture advection by subtracting a constant + offset from the moisture advection tendency of each column. + moisture_budget_correction: If not "none", force the generated data to + conserve global or column-local moisture by modifying budget fields. + One of: + - "precipitation": multiply precipitation by a scale factor + to close the global moisture budget + - "evaporation": multiply evaporation by a scale factor + to close the global moisture budget + - "advection_and_precipitation": after applying the "precipitation" + global-mean correction above, we recompute the column-integrated + advective tendency as the budget residual, + ensuring column budget closure. + - "advection_and_evaporation": after applying the "evaporation" + global-mean correction above, we recompute the column-integrated + advective tendency as the budget residual, + ensuring column budget closure. + """ + + conserve_dry_air: bool = False + zero_global_mean_moisture_advection: bool = False + moisture_budget_correction: Optional[ + Literal[ + "precipitation", + "evaporation", + "advection_and_precipitation", + "advection_and_evaporation", + ] + ] = None + + def build(self, area: torch.Tensor, sigma_coordinates: SigmaCoordinates) -> Optional["Corrector"]: + return Corrector(config=self, area=area, sigma_coordinates=sigma_coordinates) + + +class Corrector: + def __init__( + self, + config: CorrectorConfig, + area: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + ): + self._config = config + self._area = area.to(get_device()) + self._sigma_coordinates = sigma_coordinates.to(get_device()) + + def __call__( + self, + input_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + ): + if self._config.conserve_dry_air: + gen_data = _force_conserve_dry_air( + input_data=input_data, + gen_data=gen_data, + area=self._area, + sigma_coordinates=self._sigma_coordinates, + ) + if self._config.zero_global_mean_moisture_advection: + gen_data = _force_zero_global_mean_moisture_advection( + gen_data=gen_data, + area=self._area, + ) + if self._config.moisture_budget_correction is not None: + gen_data = _force_conserve_moisture( + input_data=input_data, + gen_data=gen_data, + area=self._area, + sigma_coordinates=self._sigma_coordinates, + terms_to_modify=self._config.moisture_budget_correction, + ) + return gen_data + + +def _force_conserve_dry_air( + input_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + area: torch.Tensor, + sigma_coordinates: SigmaCoordinates, +) -> Dict[str, torch.Tensor]: + """ + Update the generated data to conserve dry air. + + This is done by adding a constant correction to the dry air pressure of + each column, and may result in changes in per-mass values such as + total water or energy. + + We first compute the target dry air pressure by computing the globally + averaged difference in dry air pressure between the input_data and gen_data, + and then add this offset to the fully-resolved gen_data dry air pressure. + We can then solve for the surface pressure corresponding to this new dry air + pressure. + + We start from the expression for dry air pressure: + + dry_air = ps - sum_k((ak_diff + bk_diff * ps) * wat_k) + + To update the dry air, we compute and update the surface pressure: + + ps = ( + dry_air + sum_k(ak_diff * wat_k) + ) / ( + 1 - sum_k(bk_diff * wat_k) + ) + """ + input = ClimateData(input_data) + if input.surface_pressure is None: + raise ValueError("surface_pressure is required to force dry air conservation") + gen = ClimateData(gen_data) + gen_dry_air = gen.surface_pressure_due_to_dry_air(sigma_coordinates) + global_gen_dry_air = metrics.weighted_mean(gen_dry_air, weights=area, dim=(-2, -1)) + global_target_gen_dry_air = metrics.weighted_mean( + input.surface_pressure_due_to_dry_air(sigma_coordinates), + weights=area, + dim=(-2, -1), + ) + error = global_gen_dry_air - global_target_gen_dry_air + new_gen_dry_air = gen_dry_air - error[..., None, None] + try: + wat = gen.specific_total_water + except KeyError: + raise ValueError("specific_total_water is required for conservation") + ak_diff = sigma_coordinates.ak.diff() + bk_diff = sigma_coordinates.bk.diff() + new_pressure = (new_gen_dry_air + (ak_diff * wat).sum(-1)) / (1 - (bk_diff * wat).sum(-1)) + gen.surface_pressure = new_pressure.to(dtype=input.surface_pressure.dtype) + return gen.data + + +def _force_zero_global_mean_moisture_advection( + gen_data: Mapping[str, torch.Tensor], + area: torch.Tensor, +) -> Dict[str, torch.Tensor]: + """ + Update the generated data so advection conserves moisture. + + Does so by adding a constant offset to the moisture advective tendency. + + Args: + gen_data: The generated data. + area: (n_lat, n_lon) array containing relative gridcell area, in any + units including unitless. + """ + gen = ClimateData(gen_data) + + mean_moisture_advection = metrics.weighted_mean( + gen.tendency_of_total_water_path_due_to_advection, + weights=area, + dim=(-2, -1), + ) + gen.tendency_of_total_water_path_due_to_advection = ( + gen.tendency_of_total_water_path_due_to_advection - mean_moisture_advection[..., None, None] + ) + return gen.data + + +def _force_conserve_moisture( + input_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + area: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + terms_to_modify: Literal[ + "precipitation", + "evaporation", + "advection_and_precipitation", + "advection_and_evaporation", + ], +) -> Dict[str, torch.Tensor]: + """ + Update the generated data to conserve moisture. + + Does so while conserving total dry air in each column. + + Assumes the global mean advective tendency of moisture is zero. This assumption + means any existing global mean advective tendency will be set to zero + if the advective tendency is re-computed. + + Args: + input_data: The input data. + gen_data: The generated data one timestep after the input data. + area: (n_lat, n_lon) array containing relative gridcell area, in any + units including unitless. + sigma_coordinates: The sigma coordinates. + terms_to_modify: Which terms to modify, in addition to modifying surface + pressure to conserve dry air mass. One of: + - "precipitation": modify precipitation only + - "evaporation": modify evaporation only + - "advection_and_precipitation": modify advection and precipitation + - "advection_and_evaporation": modify advection and evaporation + """ + input = ClimateData(input_data) + gen = ClimateData(gen_data) + + gen_total_water_path = gen.total_water_path(sigma_coordinates) + twp_total_tendency = (gen_total_water_path - input.total_water_path(sigma_coordinates)) / TIMESTEP_SECONDS + twp_tendency_global_mean = metrics.weighted_mean(twp_total_tendency, weights=area, dim=(-2, -1)) + evaporation_global_mean = metrics.weighted_mean(gen.evaporation_rate, weights=area, dim=(-2, -1)) + precipitation_global_mean = metrics.weighted_mean(gen.precipitation_rate, weights=area, dim=(-2, -1)) + if terms_to_modify.endswith("precipitation"): + # We want to achieve + # global_mean(twp_total_tendency) = ( + # global_mean(evaporation_rate) + # - global_mean(precipitation_rate) + # ) + # so we modify precipitation_rate to achieve this. Note we have + # assumed the global mean advection tendency is zero. + # First, we find the required global-mean precipitation rate + # new_global_precip_rate = ( + # global_mean(evaporation_rate) + # - global_mean(twp_total_tendency) + # ) + new_precipitation_global_mean = evaporation_global_mean - twp_tendency_global_mean + # Because scalar multiplication commutes with summation, we can + # achieve this by multiplying each gridcell's precipitation rate + # by the ratio of the new global mean to the current global mean. + # new_precip_rate = ( + # new_global_precip_rate / current_global_precip_rate + # ) * current_precip_rate + gen.precipitation_rate = ( + gen.precipitation_rate * (new_precipitation_global_mean / precipitation_global_mean)[..., None, None] + ) + elif terms_to_modify.endswith("evaporation"): + # Derived similarly as for "precipitation" case. + new_evaporation_global_mean = twp_tendency_global_mean + precipitation_global_mean + gen.evaporation_rate = ( + gen.evaporation_rate * (new_evaporation_global_mean / evaporation_global_mean)[..., None, None] + ) + if terms_to_modify.startswith("advection"): + # Having already corrected the global-mean budget, we recompute + # advection based on assumption that the columnwise + # moisture budget closes. Correcting the global mean budget first + # is important to ensure the resulting advection has zero global mean. + new_advection = twp_total_tendency - (gen.evaporation_rate - gen.precipitation_rate) + gen.tendency_of_total_water_path_due_to_advection = new_advection + return gen.data diff --git a/src/ace_inference/core/data_loading/__init__.py b/src/ace_inference/core/data_loading/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/core/data_loading/_xarray.py b/src/ace_inference/core/data_loading/_xarray.py new file mode 100644 index 0000000..cb1ece6 --- /dev/null +++ b/src/ace_inference/core/data_loading/_xarray.py @@ -0,0 +1,328 @@ +import logging +import os +from collections import namedtuple +from glob import glob +from typing import Callable, Dict, List, Mapping, Optional, Tuple + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core import metrics +from src.ace_inference.core.data_loading.params import XarrayDataParams +from src.ace_inference.core.data_loading.requirements import DataRequirements +from src.ace_inference.core.data_loading.utils import get_lons_and_lats, get_times, load_series_data +from src.ace_inference.core.device import get_device +from src.ace_inference.core.winds import lon_lat_to_xyz + +from .data_typing import ( + Dataset, + HorizontalCoordinates, + SigmaCoordinates, + VariableMetadata, +) + + +VariableNames = namedtuple( + "VariableNames", + ( + "time_dependent_names", + "time_invariant_names", + "static_derived_names", + ), +) + + +def get_sigma_coordinates(ds: xr.Dataset) -> SigmaCoordinates: + """ + Get sigma coordinates from a dataset. + + Assumes that the dataset contains variables named `ak_N` and `bk_N` where + `N` is the level number. The returned tensors are sorted by level number. + + Args: + ds: Dataset to get sigma coordinates from. + """ + ak_mapping = {int(v[3:]): torch.as_tensor(ds[v].values) for v in ds.variables if v.startswith("ak_")} + bk_mapping = {int(v[3:]): torch.as_tensor(ds[v].values) for v in ds.variables if v.startswith("bk_")} + ak_list = [ak_mapping[k] for k in sorted(ak_mapping.keys())] + bk_list = [bk_mapping[k] for k in sorted(bk_mapping.keys())] + + if len(ak_list) == 0 or len(bk_list) == 0: + raise ValueError("Dataset does not contain ak and bk sigma coordinates.") + + if len(ak_list) != len(bk_list): + raise ValueError("Expected same number of ak and bk coordinates, " f"got {len(ak_list)} and {len(bk_list)}.") + + return SigmaCoordinates( + ak=torch.as_tensor(ak_list, device=get_device()), + bk=torch.as_tensor(bk_list, device=get_device()), + ) + + +def get_cumulative_timesteps(paths: List[str]) -> np.ndarray: + """Returns a list of cumulative timesteps for each file in paths.""" + num_timesteps_per_file = [0] + for path in paths: + with xr.open_dataset(path, use_cftime=True) as ds: + num_timesteps_per_file.append(len(ds.time)) + return np.array(num_timesteps_per_file).cumsum() + + +def get_file_local_index(index: int, start_indices: np.ndarray) -> Tuple[int, int]: + """ + Return a tuple of the index of the file containing the time point at `index` + and the index of the time point within that file. + """ + file_index = np.searchsorted(start_indices, index, side="right") - 1 + time_index = index - start_indices[file_index] + return int(file_index), time_index + + +class StaticDerivedData: + names = ("x", "y", "z") + metadata = { + "x": VariableMetadata(units="", long_name="Euclidean x-coordinate"), + "y": VariableMetadata(units="", long_name="Euclidean y-coordinate"), + "z": VariableMetadata(units="", long_name="Euclidean z-coordinate"), + } + + def __init__(self, lons, lats): + """ + Args: + lons: 1D array of longitudes. + lats: 1D array of latitudes. + """ + self._lats = lats + self._lons = lons + self._x: Optional[torch.Tensor] = None + self._y: Optional[torch.Tensor] = None + self._z: Optional[torch.Tensor] = None + + def _get_xyz(self) -> Dict[str, torch.Tensor]: + if self._x is None or self._y is None or self._z is None: + lats, lons = np.broadcast_arrays(self._lats[:, None], self._lons[None, :]) + x, y, z = lon_lat_to_xyz(lons, lats) + self._x = torch.as_tensor(x) + self._y = torch.as_tensor(y) + self._z = torch.as_tensor(z) + return {"x": self._x, "y": self._y, "z": self._z} + + def __getitem__(self, name: str) -> torch.Tensor: + return self._get_xyz()[name] + + +class XarrayDataset(Dataset): + """Handles dataloading over multiple netcdf files using the xarray library. + Assumes that the netcdf filenames are time-ordered.""" + + def __init__( + self, + params: XarrayDataParams, + requirements: DataRequirements, + ): + self.params = params + self.names = requirements.names + self.path = params.data_path + self.engine = "netcdf4" if params.engine is None else params.engine + # assume that filenames include time ordering + self.full_paths = sorted(glob(os.path.join(self.path, "*.nc"))) + + if len(self.full_paths) == 0: + raise ValueError(f"No netCDF files found in '{self.path}'.") + self.full_paths *= params.n_repeats + self.n_steps = requirements.n_timesteps # one input, n_steps - 1 outputs + self._get_files_stats() + first_dataset = xr.open_dataset( + self.full_paths[0], + decode_times=False, + ) + lons, lats = get_lons_and_lats(first_dataset) + self._static_derived_data = StaticDerivedData(lons, lats) + ( + self.time_dependent_names, + self.time_invariant_names, + self.static_derived_names, + ) = self._group_variable_names_by_time_type() + self._area_weights = metrics.spherical_area_weights(lats, len(lons)) + self._sigma_coordinates = get_sigma_coordinates(first_dataset) + self._horizontal_coordinates = HorizontalCoordinates( + lat=torch.as_tensor(lats, device=get_device()), + lon=torch.as_tensor(lons, device=get_device()), + ) + + @property + def horizontal_coordinates(self) -> HorizontalCoordinates: + return self._horizontal_coordinates + + def _get_metadata(self, ds): + result = {} + for name in self.names: + if name in StaticDerivedData.names: + result[name] = StaticDerivedData.metadata[name] + elif hasattr(ds[name], "units") and hasattr(ds[name], "long_name"): + result[name] = VariableMetadata( + units=ds[name].units, + long_name=ds[name].long_name, + ) + self._metadata = result + + def _get_files_stats(self): + logging.info(f"Opening data at {os.path.join(self.path, '*.nc')}") + cum_num_timesteps = get_cumulative_timesteps(self.full_paths) + self.start_indices = cum_num_timesteps[:-1] + self.total_timesteps = cum_num_timesteps[-1] + self._n_initial_conditions = self.total_timesteps - self.n_steps + 1 + del cum_num_timesteps + + ds = self._open_file(0) + self._get_metadata(ds) + + verbose = False + if verbose: + for i in range(len(self.names)): + if self.names[i] in ds.variables: + img_shape = ds[self.names[i]].shape[-2:] + break + else: + raise ValueError(f"None of the requested variables {self.names} are present " f"in the dataset.") + logging.info(f"Found {self._n_initial_conditions} samples.") + logging.info(f"Image shape is {img_shape[0]} x {img_shape[1]}.") + # logging.info(f"Following variables are available: {list(ds.variables)}.") + + def _group_variable_names_by_time_type(self) -> VariableNames: + """Returns lists of time-dependent variable names, time-independent + variable names, and variables which are only present as an initial + condition.""" + time_dependent_names, time_invariant_names, static_derived_names = [], [], [] + # Don't use open_mfdataset here, because it will give time-invariant + # fields a time dimension. We assume that all fields are present in the + # netcdf file corresponding to the first chunk of time. + with xr.open_dataset(self.full_paths[0]) as ds: + for name in self.names: + if name in StaticDerivedData.names: + static_derived_names.append(name) + else: + dims = ds[name].dims + if "time" in dims: + time_dependent_names.append(name) + else: + time_invariant_names.append(name) + return VariableNames( + time_dependent_names, + time_invariant_names, + static_derived_names, + ) + + @property + def area_weights(self) -> torch.Tensor: + return self._area_weights + + @property + def metadata(self) -> Mapping[str, VariableMetadata]: + return self._metadata + + @property + def sigma_coordinates(self) -> SigmaCoordinates: + return self._sigma_coordinates + + def __len__(self): + return self._n_initial_conditions + + def _open_file(self, idx): + return xr.open_dataset( + self.full_paths[idx], + engine=self.engine, + use_cftime=True, + cache=False, + mask_and_scale=False, + ) + + def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], xr.DataArray]: + """Open a time-ordered subset of the files which contain the input with + global index idx and its outputs. Get a starting index in the first file + (input_local_idx) and a final index in the last file (output_local_idx), + returning the time-ordered sequence of observations from input_local_idx + to output_local_idx (inclusive). + + Args: + idx: Index of the sample to retrieve. + + Returns: + Tuple of a sample's data (a mapping from names to data, for use in + training and inference) and its corresponding time coordinates. + """ + time_slice = slice(idx, idx + self.n_steps) + return self.get_sample_by_time_slice(time_slice) + + def get_sample_by_time_slice(self, time_slice: slice) -> Tuple[Dict[str, torch.Tensor], xr.DataArray]: + input_file_idx, input_local_idx = get_file_local_index(time_slice.start, self.start_indices) + output_file_idx, output_local_idx = get_file_local_index(time_slice.stop - 1, self.start_indices) + + # get the sequence of observations + arrays: Dict[str, List[torch.Tensor]] = {} + times_segments: List[xr.DataArray] = [] + idxs = range(input_file_idx, output_file_idx + 1) + total_steps = 0 + for i, file_idx in enumerate(idxs): + ds = self._open_file(file_idx) + start = input_local_idx if i == 0 else 0 + stop = output_local_idx if i == len(idxs) - 1 else len(ds["time"]) - 1 + n_steps = stop - start + 1 + total_steps += n_steps + tensor_dict = load_series_data(start, n_steps, ds, self.time_dependent_names) + for n in self.time_dependent_names: + arrays.setdefault(n, []).append(tensor_dict[n]) + times_segments.append(get_times(ds, start, n_steps)) + ds.close() + del ds + + tensors: Dict[str, torch.Tensor] = {} + for n, tensor_list in arrays.items(): + tensors[n] = torch.cat(tensor_list) + del arrays + times: xr.DataArray = xr.concat(times_segments, dim="time") + + # load time-invariant variables from first dataset + ds = self._open_file(idxs[0]) + for name in self.time_invariant_names: + tensor = torch.as_tensor(ds[name].values) + tensors[name] = tensor.repeat((total_steps, 1, 1)) + + # load static derived variables + for name in self.static_derived_names: + tensor = self._static_derived_data[name] + tensors[name] = tensor.repeat((total_steps, 1, 1)) + + return tensors, times + + +class XarrayDatasetSalva(XarrayDataset): + def __init__( + self, + *args, + forcing_names: List[str] = tuple(), + forcing_packer: Optional[Callable] = None, + forcing_normalizer: Optional[Callable] = None, + min_idx_shift: int = 0, + split_id: Optional[str] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.forcing_names = forcing_names if len(forcing_names) > 0 else None + self.forcing_packer = forcing_packer + self.forcing_normalizer = forcing_normalizer + self.min_idx_shift = min_idx_shift + self.split_id = split_id + + def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], xr.DataArray]: + idx = idx + self.min_idx_shift + tensors, times = super().__getitem__(idx) + + if self.forcing_names is not None: + forcings = {k: tensors.pop(k) for k in list(tensors.keys()) if k in self.forcing_names} + forcings = self.forcing_packer.pack(self.forcing_normalizer.normalize(forcings)) + data = {"dynamics": tensors, "dynamical_condition": forcings} + else: + data = {"dynamics": tensors} + return data diff --git a/src/ace_inference/core/data_loading/data_typing.py b/src/ace_inference/core/data_loading/data_typing.py new file mode 100644 index 0000000..4949e25 --- /dev/null +++ b/src/ace_inference/core/data_loading/data_typing.py @@ -0,0 +1,110 @@ +import abc +import dataclasses +from collections import namedtuple +from typing import Mapping, Optional + +import numpy as np +import torch + + +VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"]) + + +@dataclasses.dataclass +class SigmaCoordinates: + """ + Defines pressure at interface levels according to the following formula: + p(k) = a(k) + b(k)*ps + + where ps is the surface pressure, a and b are the sigma coordinates. + + Attributes: + ak: a(k) coefficients as a 1-dimensional tensor + bk: b(k) coefficients as a 1-dimensional tensor + """ + + ak: torch.Tensor + bk: torch.Tensor + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return {"ak": self.ak.cpu().numpy(), "bk": self.bk.cpu().numpy()} + + def to(self, device: str) -> "SigmaCoordinates": + return SigmaCoordinates( + ak=self.ak.to(device), + bk=self.bk.to(device), + ) + + def as_dict(self) -> Mapping[str, torch.Tensor]: + return {"ak": self.ak, "bk": self.bk} + + +@dataclasses.dataclass +class HorizontalCoordinates: + """ + Defines a (latitude, longitude) grid. + + Attributes: + lat: 1-dimensional tensor of latitudes + lon: 1-dimensional tensor of longitudes + """ + + lat: torch.Tensor + lon: torch.Tensor + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return {"lat": self.lat.cpu().numpy(), "lon": self.lon.cpu().numpy()} + + +@dataclasses.dataclass +class GriddedData: + """ + Data as required for pytorch training. + + The data is assumed to be gridded, and attributes are included for + performing operations on gridded data. + + Attributes: + loader: torch DataLoader, which returns batches of type + Mapping[str, torch.Tensor] where keys indicate variable name. + Each tensor has shape + [batch_size, time_window_size, n_channels, n_lat, n_lon]. + metadata: Metadata for each variable. + area_weights: Weights for each grid cell, used for computing area-weighted + averages. Has shape [n_lat, n_lon]. + sigma_coordinates: Sigma coordinates for each grid cell, used for computing + pressure levels. + horizontal_coordinates: Lat/lon coordinates for the data. + sampler: Optional sampler for the data loader. Provided to allow support for + distributed training. + """ + + loader: torch.utils.data.DataLoader + metadata: Mapping[str, VariableMetadata] + area_weights: torch.Tensor + sigma_coordinates: SigmaCoordinates + horizontal_coordinates: HorizontalCoordinates + sampler: Optional[torch.utils.data.Sampler] = None + + @property + def coords(self) -> Mapping[str, np.ndarray]: + return { + **self.horizontal_coordinates.coords, + **self.sigma_coordinates.coords, + } + + +class Dataset(torch.utils.data.Dataset, abc.ABC): + @abc.abstractproperty + def metadata(self) -> Mapping[str, VariableMetadata]: ... + + @abc.abstractproperty + def area_weights(self) -> torch.Tensor: ... + + @abc.abstractproperty + def horizontal_coordinates(self) -> HorizontalCoordinates: ... + + @abc.abstractproperty + def sigma_coordinates(self) -> SigmaCoordinates: ... diff --git a/src/ace_inference/core/data_loading/get_loader.py b/src/ace_inference/core/data_loading/get_loader.py new file mode 100644 index 0000000..7930c7a --- /dev/null +++ b/src/ace_inference/core/data_loading/get_loader.py @@ -0,0 +1,119 @@ +from pathlib import Path +from typing import Optional + +import numpy as np +import torch.utils.data +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from src.ace_inference.core.device import using_gpu +from src.ace_inference.core.distributed import Distributed + +from ._xarray import XarrayDataset +from .data_typing import Dataset, GriddedData +from .params import DataLoaderParams +from .requirements import DataRequirements + + +def _all_same(iterable, cmp=lambda x, y: x == y): + it = iter(iterable) + try: + first = next(it) + except StopIteration: + return True + return all(cmp(first, rest) for rest in it) + + +def _get_ensemble_dataset( + params: DataLoaderParams, + requirements: DataRequirements, + window_time_slice: Optional[slice] = None, + sub_paths: Optional[list] = None, + dataset_class: Optional[Dataset] = None, + **kwargs, +) -> Dataset: + """Returns a dataset that is a concatenation of the datasets for each + ensemble member. + """ + dataset_class = dataset_class or XarrayDataset + if sub_paths is not None: + sub_paths = [sub_paths] if isinstance(sub_paths, str) else sub_paths + paths = [str(Path(params.data_path) / sub_path) for sub_path in sub_paths] + else: + # Get all subdirectories of the data path + paths = sorted([str(d) for d in Path(params.data_path).iterdir() if d.is_dir()]) + + datasets, metadatas, sigma_coords = [], [], [] + samples_per_member = params.n_samples // len(paths) if params.n_samples else None + for path in paths: + params_curr_member = DataLoaderParams( + path, params.data_type, params.batch_size, params.num_data_workers, samples_per_member + ) + dataset = dataset_class(params_curr_member, requirements, window_time_slice=window_time_slice, **kwargs) + + datasets.append(dataset) + metadatas.append(dataset.metadata) + sigma_coords.append(dataset.sigma_coordinates) + + if not _all_same(metadatas): + raise ValueError("Metadata for each ensemble member should be the same.") + + ak, bk = list(zip(*[(s.ak.cpu().numpy(), s.bk.cpu().numpy()) for s in sigma_coords])) + if not (_all_same(ak, cmp=np.allclose) and _all_same(bk, cmp=np.allclose)): + raise ValueError("Sigma coordinates for each ensemble member should be the same.") + + ensemble = torch.utils.data.ConcatDataset(datasets) + ensemble.metadata = metadatas[0] # type: ignore + ensemble.area_weights = datasets[0].area_weights # type: ignore + ensemble.sigma_coordinates = datasets[0].sigma_coordinates # type: ignore + ensemble.horizontal_coordinates = datasets[0].horizontal_coordinates # type: ignore + return ensemble + + +def get_data_loader( + params: DataLoaderParams, + train: bool, + requirements: DataRequirements, + window_time_slice: Optional[slice] = None, + dist: Optional[Distributed] = None, +) -> GriddedData: + """ + Args: + params: Parameters for the data loader. + train: Whether to use the training or validation data. + requirements: Data requirements for the model. + window_time_slice: Time slice within each window to use for the data loader, + if given the loader will only return data from this time slice. + By default it will return the full windows. + dist: Distributed object to use for distributed training. + """ + if dist is None: + dist = Distributed.get_instance() + if params.data_type == "xarray": + dataset = XarrayDataset(params, requirements=requirements, window_time_slice=window_time_slice) + elif params.data_type == "ensemble_xarray": + dataset = _get_ensemble_dataset(params, requirements, window_time_slice=window_time_slice) + else: + raise NotImplementedError(f"{params.data_type} does not have an implemented data loader") + + sampler = DistributedSampler(dataset, shuffle=train) if dist.is_distributed() else None + + dataloader = DataLoader( + dataset, + batch_size=dist.local_batch_size(int(params.batch_size)), + num_workers=params.num_data_workers, + shuffle=(sampler is None) and train, + sampler=sampler if train else None, + drop_last=True, + pin_memory=using_gpu(), + persistent_workers=params.num_data_workers > 0, + ) + + return GriddedData( + loader=dataloader, + metadata=dataset.metadata, + area_weights=dataset.area_weights, + sampler=sampler, + sigma_coordinates=dataset.sigma_coordinates, + horizontal_coordinates=dataset.horizontal_coordinates, + ) diff --git a/src/ace_inference/core/data_loading/getters.py b/src/ace_inference/core/data_loading/getters.py new file mode 100644 index 0000000..23e0ce6 --- /dev/null +++ b/src/ace_inference/core/data_loading/getters.py @@ -0,0 +1,173 @@ +import dataclasses +from pathlib import Path +from typing import Optional + +import numpy as np +import torch.utils.data +from torch.utils.data import DataLoader +from torch.utils.data.distributed import DistributedSampler + +from src.ace_inference.core.device import using_gpu +from src.ace_inference.core.distributed import Distributed + +from ._xarray import XarrayDataset +from .data_typing import Dataset, GriddedData +from .inference import InferenceDataLoaderParams, InferenceDataset +from .params import DataLoaderParams, XarrayDataParams +from .requirements import DataRequirements +from .utils import BatchData + + +def _all_same(iterable, cmp=lambda x, y: x == y): + it = iter(iterable) + try: + first = next(it) + except StopIteration: + return True + return all(cmp(first, rest) for rest in it) + + +def _subset_dataset(dataset: Dataset, subset: slice) -> Dataset: + """Returns a subset of the dataset and propagates other properties.""" + indices = range(len(dataset))[subset] + subsetted_dataset = torch.utils.data.Subset(dataset, indices) + subsetted_dataset.metadata = dataset.metadata + subsetted_dataset.area_weights = dataset.area_weights + subsetted_dataset.sigma_coordinates = dataset.sigma_coordinates + subsetted_dataset.horizontal_coordinates = dataset.horizontal_coordinates + return subsetted_dataset + + +def _get_ensemble_dataset( + params: XarrayDataParams, + requirements: DataRequirements, + subset: slice, + sub_paths: Optional[list] = None, + dataset_class: Optional[Dataset] = None, + **kwargs, +) -> Dataset: + """Returns a dataset that is a concatenation of the datasets for each + ensemble member. + """ + dataset_class = dataset_class or XarrayDataset + if sub_paths is not None: + sub_paths = [sub_paths] if isinstance(sub_paths, str) else sub_paths + paths = [str(Path(params.data_path) / sub_path) for sub_path in sub_paths] + else: + # Get all subdirectories of the data path + paths = sorted([str(d) for d in Path(params.data_path).iterdir() if d.is_dir()]) + if len(paths) == 0: + raise ValueError( + f"No directories found in {params.data_path}. " + "Check path and whether you meant to use 'ensemble_xarray' data_type." + ) + datasets, metadatas, sigma_coords = [], [], [] + for path in paths: + params_curr_member = dataclasses.replace(params, data_path=path) + dataset = dataset_class(params_curr_member, requirements, **kwargs) + dataset = _subset_dataset(dataset, subset) + + datasets.append(dataset) + metadatas.append(dataset.metadata) + sigma_coords.append(dataset.sigma_coordinates) + + if not _all_same(metadatas): + raise ValueError("Metadata for each ensemble member should be the same.") + + ak, bk = list(zip(*[(s.ak.cpu().numpy(), s.bk.cpu().numpy()) for s in sigma_coords])) + if not (_all_same(ak, cmp=np.allclose) and _all_same(bk, cmp=np.allclose)): + raise ValueError("Sigma coordinates for each ensemble member should be the same.") + + ensemble = torch.utils.data.ConcatDataset(datasets) + ensemble.metadata = metadatas[0] # type: ignore + ensemble.area_weights = datasets[0].area_weights # type: ignore + ensemble.sigma_coordinates = datasets[0].sigma_coordinates # type: ignore + ensemble.horizontal_coordinates = datasets[0].horizontal_coordinates # type: ignore + return ensemble + + +def get_dataset( + params: DataLoaderParams, requirements: DataRequirements, dataset_class: Optional[Dataset] = None, **kwargs +) -> Dataset: + dataset_class = dataset_class or XarrayDataset + if params.data_type == "xarray": + dataset = dataset_class(params.dataset, requirements, **kwargs) + dataset = _subset_dataset(dataset, params.subset.slice) + elif params.data_type == "ensemble_xarray": + dataset = _get_ensemble_dataset( + params.dataset, requirements, params.subset.slice, dataset_class=dataset_class, **kwargs + ) + else: + raise NotImplementedError(f"{params.data_type} does not have an implemented data loader") + return dataset + + +def get_data_loader( + params: DataLoaderParams, + train: bool, + requirements: DataRequirements, +) -> GriddedData: + """ + Args: + params: Parameters for the data loader. + train: Whether to use the training or validation data. + requirements: Data requirements for the model. + window_time_slice: Time slice within each window to use for the data loader, + if given the loader will only return data from this time slice. + By default it will return the full windows. + """ + dataset = get_dataset(params, requirements) + dist = Distributed.get_instance() + sampler = DistributedSampler(dataset, shuffle=train) if dist.is_distributed() else None + + dataloader = DataLoader( + dataset, + batch_size=dist.local_batch_size(int(params.batch_size)), + num_workers=params.num_data_workers, + shuffle=(sampler is None) and train, + sampler=sampler if train else None, + drop_last=True, + pin_memory=using_gpu(), + collate_fn=BatchData.from_sample_tuples, + ) + + return GriddedData( + loader=dataloader, + metadata=dataset.metadata, + area_weights=dataset.area_weights, + sampler=sampler, + sigma_coordinates=dataset.sigma_coordinates, + horizontal_coordinates=dataset.horizontal_coordinates, + ) + + +def get_inference_data( + config: InferenceDataLoaderParams, + forward_steps_in_memory: int, + requirements: DataRequirements, +) -> GriddedData: + """ + Args: + config: Parameters for the data loader. + forward_steps_in_memory: Number of forward steps to keep in memory at once. + requirements: Data requirements for the model. + + Returns: + A data loader for inference with coordinates and metadata. + """ + dataset = InferenceDataset(config, forward_steps_in_memory, requirements) + # we roll our own batching in InferenceDataset, which is why batch_size=None below + loader = DataLoader( + dataset, + batch_size=None, + num_workers=config.num_data_workers, + shuffle=False, + pin_memory=using_gpu(), + ) + return GriddedData( + loader=loader, + metadata=dataset.metadata, + area_weights=dataset.area_weights, + sigma_coordinates=dataset.sigma_coordinates, + horizontal_coordinates=dataset.horizontal_coordinates, + ) diff --git a/src/ace_inference/core/data_loading/inference.py b/src/ace_inference/core/data_loading/inference.py new file mode 100644 index 0000000..32edb17 --- /dev/null +++ b/src/ace_inference/core/data_loading/inference.py @@ -0,0 +1,175 @@ +import dataclasses +import logging + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.data_loading._xarray import XarrayDataset +from src.ace_inference.core.data_loading.data_typing import HorizontalCoordinates, SigmaCoordinates +from src.ace_inference.core.data_loading.params import XarrayDataParams +from src.ace_inference.core.data_loading.requirements import DataRequirements +from src.ace_inference.core.data_loading.utils import BatchData +from src.ace_inference.core.distributed import Distributed + + +@dataclasses.dataclass +class InferenceInitialConditionIndices: + """ + Configuration of the indices for initial conditions during inference. + """ + + n_initial_conditions: int + first: int = 0 + interval: int = 1 + + def __post_init__(self): + if self.interval < 0: + raise ValueError("interval must be positive") + + def as_indices(self) -> np.ndarray: + stop = self.n_initial_conditions * self.interval + self.first + return np.arange(self.first, stop, self.interval) + + +@dataclasses.dataclass +class InferenceDataLoaderParams: + """ + Configuration for inference data. + + This is like the `DataLoaderParams` class, but with some additional + constraints. During inference, we have only one batch, so the number of + samples directly determines the size of that batch. + + Attributes: + dataset: Parameters to define the dataset. + start_indices: Slice indicating the set of indices to consider for initial + conditions of inference series of data. Values following the initial + condition will still come from the full dataset. + num_data_workers: Number of parallel workers to use for data loading. + """ + + dataset: XarrayDataParams + start_indices: InferenceInitialConditionIndices + num_data_workers: int = 0 + + @property + def n_samples(self) -> int: + return self.start_indices.n_initial_conditions + + +class InferenceDataset(torch.utils.data.Dataset): + def __init__( + self, + params: InferenceDataLoaderParams, + forward_steps_in_memory: int, + requirements: DataRequirements, + ): + dataset = XarrayDataset(params.dataset, requirements=requirements) + self._dataset = dataset + self._sigma_coordinates = dataset.sigma_coordinates + self._metadata = dataset.metadata + self._area_weights = dataset.area_weights + self._horizontal_coordinates = dataset.horizontal_coordinates + self._forward_steps_in_memory = forward_steps_in_memory + self._total_steps = requirements.n_timesteps - 1 + if self._total_steps % self._forward_steps_in_memory != 0: + raise ValueError( + f"Total number of steps ({self._total_steps}) must be divisible by " + f"forward_steps_in_memory ({self._forward_steps_in_memory})." + ) + + # self._dataset._get_files_stats() + dataset_size = max(1, self._dataset.total_timesteps) + # print( + # f"Dataset has {dataset_size}. Total steps: {self._total_steps}. Forward steps in memory: {self._forward_steps_in_memory}." + # f"n_samples={params.n_samples}. ds._n_initial_conditions={dataset._n_initial_conditions}. ds.total_timesteps={dataset.total_timesteps}. ds.n_steps={dataset.n_steps}." + # ) + # How many times to "copy" the data? + if params.dataset.n_repeats is None: + self._n_repeats = self._total_steps // dataset_size + self._dataset_size = dataset_size + else: + self._n_repeats = params.dataset.n_repeats + self._dataset_size = None + + if self._n_repeats > 1: + print( + f"Repeating target data {self._n_repeats} times. Dataset size: {dataset_size}. Total steps: {self._total_steps}." + ) + self.n_samples = params.n_samples # public attribute + self._start_indices = params.start_indices.as_indices() + + def __getitem__(self, index) -> BatchData: + dist = Distributed.get_instance() + # 0 -> 0 1-> 100 2-> 200 3-> 0 + i_start = index * self._forward_steps_in_memory + if self._dataset_size is not None: + i_start = i_start % self._dataset_size + sample_tuples = [] + for i_sample in range(self.n_samples): + # check if sample is one this local rank should process + if i_sample % dist.world_size != dist.rank: + continue + i_window_start = i_start + self._start_indices[i_sample] + i_window_end = i_window_start + self._forward_steps_in_memory + 1 + window_time_slice = slice(i_window_start, i_window_end) + tensors, times = self._dataset.get_sample_by_time_slice(window_time_slice) + if times.shape[0] != self._forward_steps_in_memory + 1: + assert self._n_repeats > 1, f"n_repeats={self._n_repeats}" + # Fill sample with data from the beginning of the dataset + assert ( + index + 1 + ) * self._forward_steps_in_memory % self._dataset_size < i_start, ( + f"Expected {(index + 1) * self._forward_steps_in_memory % self._dataset_size} < {i_start}" + ) + diff = self._forward_steps_in_memory + 1 - times.shape[0] + window_time_slice_start = slice(0, diff) + logging.info( + f"Index {index} with i_start={i_start}. window_time_slice={window_time_slice} " + f"Filling with {diff} time steps from the beginning of the dataset. +1 i_start is" + f"{(index + 1) * self._forward_steps_in_memory % self._dataset_size}." + ) + tensors_start, times_start = self._dataset.get_sample_by_time_slice(window_time_slice_start) + for k in tensors.keys(): + if tensors[k].shape[0] == self._forward_steps_in_memory + 1: + # This occurs for static variables which are copied already + # print(f"Skipping {k} because it has the right shape. tensors[k].shape={tensors[k].shape}.") + continue + tensors[k] = torch.cat([tensors[k], tensors_start[k]], dim=0) + times = xr.concat([times, times_start], dim="time") + else: + pass + # logging.info( + # f"Index {index} with i_start={i_start}. window_time_slice={window_time_slice} " + # f"forward_steps_in_memory={self._forward_steps_in_memory}, total_steps={self._total_steps}," + # f" n_samples={self.n_samples}. start_indices={self._start_indices}. ds_size={self._dataset_size}." + # ) + sample_tuples.append((tensors, times)) + assert times.shape[0] == self._forward_steps_in_memory + 1, ( + f"Expected {self._forward_steps_in_memory + 1} time steps, " + f"got {sample_tuples[-1][1].shape[0]}. sample_tuples[-1][1].shape={sample_tuples[-1][1].shape}" + ) + result = BatchData.from_sample_tuples(sample_tuples) + assert result.times.shape[1] == self._forward_steps_in_memory + 1 + assert result.times.shape[0] == self.n_samples // dist.world_size + return result + + def __len__(self) -> int: + return self._total_steps // self._forward_steps_in_memory + + @property + def sigma_coordinates(self) -> SigmaCoordinates: + return self._sigma_coordinates + + @property + def metadata(self) -> xr.Dataset: + return self._metadata + + @property + def area_weights(self) -> xr.DataArray: + return self._area_weights + + @property + def horizontal_coordinates(self) -> HorizontalCoordinates: + return self._horizontal_coordinates diff --git a/src/ace_inference/core/data_loading/params.py b/src/ace_inference/core/data_loading/params.py new file mode 100644 index 0000000..0448b04 --- /dev/null +++ b/src/ace_inference/core/data_loading/params.py @@ -0,0 +1,77 @@ +import dataclasses +import warnings +from typing import Literal, Optional + + +@dataclasses.dataclass +class Slice: + """ + Configuration of a python `slice` built-in. + + Required because `slice` cannot be initialized directly by dacite. + """ + + start: Optional[int] = None + stop: Optional[int] = None + step: Optional[int] = None + + @property + def slice(self) -> slice: + return slice(self.start, self.stop, self.step) + + +@dataclasses.dataclass +class XarrayDataParams: + """ + Attributes: + data_path: Path to the data. + n_repeats: Number of times to repeat the dataset (in time). + engine: Backend for xarray.open_dataset. Currently supported options + are "netcdf4" (the default) and "h5netcdf". Only valid when using + XarrayDataset. + sub_paths: List of sub-paths to use as mask for globbing files (instead of using all files). + """ + + data_path: str + n_repeats: int = 1 + engine: Optional[Literal["netcdf4", "h5netcdf"]] = None + + +@dataclasses.dataclass +class DataLoaderParams: + """ + Attributes: + dataset: Parameters to define the dataset. + batch_size: Number of samples per batch. + num_data_workers: Number of parallel workers to use for data loading. + data_type: Type of data to load. + subset: Slice defining a subset of the XarrayDataset to load. For + data_type="ensemble_xarray" case this will be applied to each ensemble + member before concatenation. + """ + + dataset: XarrayDataParams + batch_size: int + num_data_workers: int + data_type: Literal["xarray", "ensemble_xarray"] + subset: Slice = dataclasses.field(default_factory=Slice) + n_samples: Optional[int] = None + + def __post_init__(self): + if self.n_samples is not None: + if self.subset.stop is not None: + raise ValueError("Both 'n_samples' and 'subset.stop' are specified. " "Only one of them can be used.") + warnings.warn( + "Specifying 'n_samples' is deprecated. Use 'subset.stop' instead.", + category=DeprecationWarning, + ) + self.subset.stop = self.n_samples + # dist = Distributed.get_instance() + # if self.batch_size % dist.world_size != 0: + # raise ValueError( + # "batch_size must be divisible by the number of parallel " + # f"workers, got {self.batch_size} and {dist.world_size}" + # ) + + if self.dataset.n_repeats != 1 and self.data_type == "ensemble_xarray": + raise ValueError("n_repeats must be 1 when using ensemble_xarray") diff --git a/src/ace_inference/core/data_loading/requirements.py b/src/ace_inference/core/data_loading/requirements.py new file mode 100644 index 0000000..f882b06 --- /dev/null +++ b/src/ace_inference/core/data_loading/requirements.py @@ -0,0 +1,11 @@ +import dataclasses +from typing import List + + +@dataclasses.dataclass +class DataRequirements: + names: List[str] + # TODO: delete these when validation no longer needs them + in_names: List[str] + out_names: List[str] + n_timesteps: int diff --git a/src/ace_inference/core/data_loading/utils.py b/src/ace_inference/core/data_loading/utils.py new file mode 100644 index 0000000..832138b --- /dev/null +++ b/src/ace_inference/core/data_loading/utils.py @@ -0,0 +1,107 @@ +import dataclasses +from typing import List, Mapping, Sequence, Tuple + +import cftime +import numpy as np +import torch +import xarray as xr +from torch.utils.data import default_collate + + +SLICE_NONE = slice(None) + + +def _load_all_variables(ds: xr.Dataset, variables: Sequence[str], time_slice: slice = SLICE_NONE) -> xr.DataArray: + """Load data from a variables into memory. + + This function leverages xarray's lazy loading to load only the time slice + (or chunk[s] for the time slice) of the variables we need. + + Consolidating the dask tasks into a single call of .compute() sped up remote + zarr loads by nearly a factor of 2. + """ + if "time" in ds.dims: + ds = ds.isel(time=time_slice) + return ds[variables].compute() + + +def load_series_data( + idx: int, + n_steps: int, + ds: xr.Dataset, + names: List[str], +): + time_slice = slice(idx, idx + n_steps) + loaded = _load_all_variables(ds, names, time_slice) + arrays = {} + for n in names: + variable = loaded[n].variable + arrays[n] = torch.as_tensor(variable.values) + # arrays[n] = as_broadcasted_tensor(variable, dims, shape) + return arrays + # Old: + # # disable dask threading to avoid warnings + # with dask.config.set(scheduler="synchronous"): + # arrays = {} + # for n in names: + # arr = ds.variables[n][time_slice, :, :] + # arrays[n] = torch.as_tensor(arr.values) + # return arrays + + +def get_lons_and_lats(ds: xr.Dataset) -> Tuple[np.ndarray, np.ndarray]: + if "grid_xt" in ds.variables: + hdims = "grid_xt", "grid_yt" + elif "lon" in ds.variables: + hdims = "lon", "lat" + elif "longitude" in ds.variables: + hdims = "longitude", "latitude" + else: + raise ValueError("Could not identify dataset's horizontal dimensions.") + lons, lats = ds[hdims[0]].values, ds[hdims[1]].values + return np.array(lons), np.array(lats) + + +def get_times(ds: xr.Dataset, start: int, n_steps: int) -> xr.DataArray: + """ + Get the time coordinate segment from the dataset, check that it's a + cftime.datetime object, and return it is a data array (not a coordinate), + so that it can be concatenated with other samples' times. + """ + time_segment = ds["time"][slice(start, start + n_steps)] + assert isinstance(time_segment[0].item(), cftime.datetime), "time must be cftime.datetime." + return time_segment.drop_vars(["time"]) + + +@dataclasses.dataclass +class BatchData: + """A container for the data and time coordinates of a batch. + + Attributes: + data: Data for each variable in each sample, concatenated along samples + to make a batch. To be used directly in training, validation, and + inference. + times: An array of times for each sample in the batch, concatenated along + samples to make a batch. To be used in writing out inference + predictions with time coordinates, not directly in ML. + + """ + + data: Mapping[str, torch.Tensor] + times: xr.DataArray + + @classmethod + def from_sample_tuples( + cls, + samples: Sequence[Tuple[Mapping[str, torch.Tensor], xr.DataArray]], + sample_dim_name: str = "sample", + ) -> "BatchData": + """ + Collate function for use with PyTorch DataLoader. Needed since samples contain + both tensor mapping and xarray time coordinates, the latter of which we do + not want to convert to tensors. + """ + sample_data, sample_times = zip(*samples) + batch_data = default_collate(sample_data) + batch_times = xr.concat(sample_times, dim=sample_dim_name) + return cls(batch_data, batch_times) diff --git a/src/ace_inference/core/device.py b/src/ace_inference/core/device.py new file mode 100644 index 0000000..91e4426 --- /dev/null +++ b/src/ace_inference/core/device.py @@ -0,0 +1,12 @@ +import torch + + +def using_gpu() -> bool: + return get_device().type == "cuda" + + +def get_device() -> torch.device: + if torch.cuda.is_available(): + return torch.device("cuda", torch.cuda.current_device()) + else: + return torch.device("cpu") diff --git a/src/ace_inference/core/dicts.py b/src/ace_inference/core/dicts.py new file mode 100644 index 0000000..cbe1135 --- /dev/null +++ b/src/ace_inference/core/dicts.py @@ -0,0 +1,41 @@ +from typing import Any, Dict + + +def to_flat_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts any nested dictionaries to a flat version with + the nested keys joined with a '.', e.g., {a: {b: 1}} -> + {a.b: 1} + """ + + new_flat = {} + for k, v in d.items(): + if isinstance(v, dict): + sub_d = to_flat_dict(v) + for sk, sv in sub_d.items(): + new_flat[".".join([k, sk])] = sv + else: + new_flat[k] = v + + return new_flat + + +def to_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Converts a flat dictionary with '.' joined keys back into + a nested dictionary, e.g., {a.b: 1} -> {a: {b: 1}} + """ + + new_config: Dict[str, Any] = {} + + for k, v in d.items(): + if "." in k: + sub_keys = k.split(".") + sub_d = new_config + for sk in sub_keys[:-1]: + sub_d = sub_d.setdefault(sk, {}) + sub_d[sub_keys[-1]] = v + else: + new_config[k] = v + + return new_config diff --git a/src/ace_inference/core/distributed.py b/src/ace_inference/core/distributed.py new file mode 100644 index 0000000..1e2661d --- /dev/null +++ b/src/ace_inference/core/distributed.py @@ -0,0 +1,107 @@ +import os +from typing import Optional + +import torch.distributed + +from src.ace_inference.core.device import using_gpu + + +singleton: Optional["Distributed"] = None + + +class Distributed: + """ + A class to represent the distributed concerns for FME training. + + This should generally be initialized first, before any pytorch objects. + This is important because it sets global variables such as the CUDA + device for the local rank, which is used when initializing pytorch objects. + + This class uses the + [Singleton pattern](https://en.wikipedia.org/wiki/Singleton_pattern) and should + be initialized through get_instance. This pattern allows easy access to global + variables without having to pass them around, and lets us put the initialization + for this global state in the same place as the routines that use it. + + Attributes: + world_size: The number of processes in the distributed training job. + rank: The rank of the current process. + """ + + @classmethod + def get_instance(cls) -> "Distributed": + """ + Get the singleton instance of the Distributed class. + """ + global singleton + if singleton is None: + singleton = cls() + return singleton + + def __init__(self): + if torch.distributed.is_available() and not torch.distributed.is_initialized(): + self._distributed = self._init_distributed() + else: + self._distributed = False + + def _init_distributed(self): + if "RANK" in os.environ: # we were executed with torchrun + if using_gpu(): + torch.distributed.init_process_group(backend="nccl", init_method="env://") + else: + torch.distributed.init_process_group(backend="gloo", init_method="env://") + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + if using_gpu(): + torch.cuda.set_device(self.rank) + distributed = True + else: + self.world_size = 1 + self.rank = 0 + distributed = False + return distributed + + def local_batch_size(self, batch_size: int) -> int: + """ + Get the local batch size for the current process. + """ + return batch_size // self.world_size + + def reduce_mean(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Reduce a tensor representing a mean across all processes. + + Whether the tensor represents a mean is important because to reduce a mean, + we must divide by the number of processes. To reduce a sum, we must not. + + Modifies the input tensor in-place as a side effect. + """ + if self._distributed: + torch.distributed.all_reduce(tensor) + return tensor / self.world_size + + def reduce_sum(self, tensor: torch.Tensor) -> torch.Tensor: + """ + Reduce a tensor representing a sum across all processes. + + Whether the tensor represents a mean is important because to reduce a mean, + we must divide by the number of processes. To reduce a sum, we must not. + + Modifies the input tensor in-place as a side effect. + """ + if self._distributed: + torch.distributed.all_reduce(tensor) + return tensor + + def is_root(self) -> bool: + """ + Returns True if this process is the root process. + """ + return self.rank == 0 + + def is_distributed(self) -> bool: + """ + Returns True if this process is running in a distributed context + with more than 1 worker. + """ + return self._distributed and self.world_size > 1 diff --git a/src/ace_inference/core/ema.py b/src/ace_inference/core/ema.py new file mode 100644 index 0000000..5afb12a --- /dev/null +++ b/src/ace_inference/core/ema.py @@ -0,0 +1,143 @@ +""" +Exponential Moving Average (EMA) module + +Copied from https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/ema.py +and modified. + +MIT License + +Copyright (c) 2022 Machine Vision and Learning Group, LMU Munich + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from typing import Iterable, Iterator, List, Protocol, Tuple + +import torch +from torch import nn + + +class HasNamedParameters(Protocol): + def named_parameters(self, recurse: bool = True) -> Iterator[Tuple[str, nn.Parameter]]: ... + + +class EMATracker: + """ + Exponential Moving Average (EMA) tracker. + + This tracks the moving average of the parameters of a model, and has methods + that can be used to temporarily replace the parameters of the model with its EMA. + """ + + def __init__(self, model: HasNamedParameters, decay: float, faster_decay_at_start=True): + """ + Create a new EMA tracker. + + Args: + model: The model whose parameters should be tracked. + decay: The decay rate of the moving average. + faster_decay_at_start: Whether to use the number of updates to determine + the decay rate. If True, the decay rate will be min(decay, (1 + + num_updates) / (10 + num_updates)). If False, the decay rate + will be decay. + """ + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self._module_name_to_ema_name = {} + self.decay = torch.tensor(decay, dtype=torch.float32) + self._faster_decay_at_start = faster_decay_at_start + self.num_updates = torch.tensor(0, dtype=torch.int) + + self._ema_params = {} + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + ema_name = name.replace(".", "") + self._module_name_to_ema_name.update({name: ema_name}) + self._ema_params[ema_name] = p.clone().detach().data + + self._stored_params: List[nn.Parameter] = [] + + def __call__(self, model: HasNamedParameters): + """ + Update the moving average of the parameters. + + Does not mutate the input, only updates the moving average. + + Args: + model: The model whose parameters should be updated. Should be a model + specified identically to the one passed when this object was + instantiated. + """ + decay = self.decay + + self.num_updates += 1 + if self._faster_decay_at_start: + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + with torch.no_grad(): + module_parameters = dict(model.named_parameters()) + + for key in module_parameters: + if module_parameters[key].requires_grad: + ema_name = self._module_name_to_ema_name[key] + self._ema_params[ema_name] = self._ema_params[ema_name].type_as(module_parameters[key]) + self._ema_params[ema_name].sub_( + (1.0 - decay) * (self._ema_params[ema_name] - module_parameters[key]) + ) + elif key in self._module_name_to_ema_name: + raise ValueError(f"Expected model parameter {key} to require gradient, " "but it does not") + + def copy_to(self, model: HasNamedParameters): + """ + Copy the averaged parameters to the model, overwriting its values. + """ + m_param = dict(model.named_parameters()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(self._ema_params[self._module_name_to_ema_name[key]].data) + else: + assert key not in self._module_name_to_ema_name + + def store(self, parameters: Iterable[nn.Parameter]): + """ + Save the current parameters for restoring later. + + Args: + parameters: The parameters to be stored for later restoration by `restore` + """ + self._stored_params = [param.clone() for param in parameters] + + def restore(self, parameters: Iterable[nn.Parameter]): + """ + Restore the parameters stored with the `store` method. + + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: The parameters to be updated with the values stored by `store` + """ + for c_param, param in zip(self._stored_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/ace_inference/core/histogram.py b/src/ace_inference/core/histogram.py new file mode 100644 index 0000000..6032023 --- /dev/null +++ b/src/ace_inference/core/histogram.py @@ -0,0 +1,99 @@ +from typing import Optional + +import numpy as np + + +EPSILON = 1.0e-6 + + +class DynamicHistogram: + """ + A histogram that dynamically bins values into a fixed number of bins + of constant width. A separate histogram is defined for each time, + and the same bins are used for all times. + + When a new value is added that goes out of range of the current bins, + bins are doubled in size until that value is within the range of the bins. + """ + + def __init__(self, n_times: int, n_bins: int = 300): + """ + Args: + n_times: Length of time dimension of each sample. + n_bins: Number of bins to use for the histogram. The "effective" + number of bins may be as small as 1/4th this number of bins, + as there may be bins greater than the max or less than the + min value with no samples in them, due to the dynamic resizing + of bins. + """ + self._n_times = n_times + self._n_bins = n_bins + self.bin_edges: Optional[np.ndarray] = None + self.counts = np.zeros((n_times, n_bins), dtype=np.int64) + self._epsilon: float = EPSILON + + def add(self, value: np.ndarray, i_time_start: int = 0): + """ + Add new values to the histogram. + + Args: + value: array of values of shape (n_times, n_values) to add to the histogram + i_time_start: index of the first time to add values to + """ + vmin = np.min(value) + vmax = np.max(value) + if vmin == vmax: + # if all values are the same, add a small amount to vmin and vmax + vmin -= self._epsilon + vmax += self._epsilon + if self.bin_edges is None: + self.bin_edges = np.linspace(vmin, vmax, self._n_bins + 1) + else: + while vmin < self.bin_edges[0]: + self._double_size_left() + while vmax > self.bin_edges[-1]: + self._double_size_right() + i_time_end = i_time_start + value.shape[0] + self.counts[i_time_start:i_time_end, :] += np.apply_along_axis( + lambda arr: np.histogram(arr, bins=self.bin_edges)[0], + axis=1, + arr=value, + ) + + def _double_size_left(self): + """ + Double the sizes of bins, extending the histogram + to the left (further negative). + """ + current_range = self.bin_edges[-1] - self.bin_edges[0] + new_range = 2 * current_range + + new_bin_edges = np.linspace( + self.bin_edges[-1] - new_range, + self.bin_edges[-1], + self._n_bins + 1, + ) + new_counts = np.zeros((self._n_times, self._n_bins), dtype=np.int64) + combined_counts = self.counts[:, ::2] + self.counts[:, 1::2] + new_counts[:, self._n_bins // 2 :] = combined_counts + self.bin_edges = new_bin_edges + self.counts = new_counts + + def _double_size_right(self): + """ + Double the sizes of bins, extending the histogram + to the right (further positive). + """ + current_range = self.bin_edges[-1] - self.bin_edges[0] + new_range = 2 * current_range + + new_bin_edges = np.linspace( + self.bin_edges[0], + self.bin_edges[0] + new_range, + self._n_bins + 1, + ) + new_counts = np.zeros((self._n_times, self._n_bins), dtype=np.int64) + combined_counts = self.counts[:, ::2] + self.counts[:, 1::2] + new_counts[:, : self._n_bins // 2] = combined_counts + self.bin_edges = new_bin_edges + self.counts = new_counts diff --git a/src/ace_inference/core/loss.py b/src/ace_inference/core/loss.py new file mode 100644 index 0000000..eb62710 --- /dev/null +++ b/src/ace_inference/core/loss.py @@ -0,0 +1,255 @@ +import dataclasses +from typing import Any, Dict, List, Literal, Mapping, Optional, Tuple + +import torch + +from src.ace_inference.core.aggregator.climate_data import ClimateData, compute_dry_air_absolute_differences +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates +from src.ace_inference.core.device import get_device + + +def get_dry_air_nonconservation( + data: Mapping[str, torch.Tensor], + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, +): + """ + Computes the time-average one-step absolute difference in surface pressure due to + changes in globally integrated dry air. + + Args: + data: A mapping from variable name to tensor of shape + [sample, time, lat, lon], in physical units. specific_total_water in kg/kg + and surface_pressure in Pa must be present. + area_weights: The area of each grid cell as a [lat, lon] tensor, in m^2. + sigma_coordinates: The sigma coordinates of the model. + """ + return compute_dry_air_absolute_differences( + ClimateData(data), area=area_weights, sigma_coordinates=sigma_coordinates + ).mean() + + +class ConservationLoss: + def __init__( + self, + config: "ConservationLossConfig", + area_weights: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + ): + """ + Args: + config: configuration options. + area_weights: The area of each grid cell as a [lat, lon] tensor, in m^2. + sigma_coordinates: The sigma coordinates of the model. + """ + self._config = config + self._area_weights = area_weights.to(get_device()) + self._sigma_coordinates = sigma_coordinates.to(get_device()) + + def __call__(self, gen_data: Mapping[str, torch.Tensor]) -> Tuple[Dict[str, torch.Tensor], torch.Tensor]: + """ + Compute loss and metrics related to conservation. + + Args: + gen_data: A mapping from variable name to tensor of shape + [sample, time, lat, lon], in physical units. + """ + conservation_metrics = {} + loss = torch.tensor(0.0, device=get_device()) + if self._config.dry_air_penalty is not None: + dry_air_loss = self._config.dry_air_penalty * get_dry_air_nonconservation( + gen_data, + area_weights=self._area_weights, + sigma_coordinates=self._sigma_coordinates, + ) + conservation_metrics["dry_air_loss"] = dry_air_loss.detach() + loss += dry_air_loss + return conservation_metrics, loss + + def get_state(self): + return { + "config": dataclasses.asdict(self._config), + "sigma_coordinates": self._sigma_coordinates, + "area_weights": self._area_weights, + } + + @classmethod + def from_state(cls, state) -> "ConservationLoss": + return cls( + config=ConservationLossConfig(**state["config"]), + sigma_coordinates=state["sigma_coordinates"], + area_weights=state["area_weights"], + ) + + +@dataclasses.dataclass +class ConservationLossConfig: + """ + Attributes: + dry_air_penalty: A constant by which to multiply one-step non-conservation + of surface pressure due to dry air in Pa as an L1 loss penalty. By + default, no such loss will be included. + """ + + dry_air_penalty: Optional[float] = None + + def build(self, area_weights: torch.Tensor, sigma_coordinates: SigmaCoordinates) -> ConservationLoss: + return ConservationLoss( + config=self, + area_weights=area_weights, + sigma_coordinates=sigma_coordinates, + ) + + +class LpLoss(torch.nn.Module): + def __init__(self, p=2): + """ + Args: + p: Lp-norm type. For example, p=1 for L1-norm, p=2 for L2-norm. + """ + super(LpLoss, self).__init__() + + if p <= 0: + raise ValueError("Lp-norm type should be positive") + + self.p = p + + def rel(self, x, y): + num_examples = x.size()[0] + + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + + return torch.mean(diff_norms / y_norms) + + def __call__(self, x, y): + return self.rel(x, y) + + +class AreaWeightedMSELoss(torch.nn.Module): + def __init__(self, area: torch.Tensor): + super(AreaWeightedMSELoss, self).__init__() + self._area_weights = area / area.mean() + + def __call__(self, x, y): + return torch.mean((x - y) ** 2 * self._area_weights) + + +class WeightedSum(torch.nn.Module): + """ + A module which applies multiple loss-function modules (taking two inputs) + to the same input and returns a tensor equal to the weighted sum of the + outputs of the modules. + """ + + def __init__(self, modules: List[torch.nn.Module], weights: List[float]): + """ + Args: + modules: A list of modules, each of which takes two tensors and + returns a scalar tensor. + weights: A list of weights to apply to the outputs of the modules. + """ + super().__init__() + if len(modules) != len(weights): + raise ValueError("modules and weights must have the same length") + self._wrapped = modules + self._weights = weights + + def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return sum(w * module(x, y) for w, module in zip(self._weights, self._wrapped)) + + +class GlobalMeanLoss(torch.nn.Module): + """ + A module which computes a loss on the global mean of each sample. + """ + + def __init__(self, area: torch.Tensor, loss: torch.nn.Module): + """ + Args: + area: A tensor of shape (n_lat, n_lon) containing the area of + each grid cell. + loss: A loss function which takes two tensors of shape + (n_samples, n_timesteps, n_channels) and returns a scalar + tensor. + """ + super().__init__() + self.global_mean = GlobalMean(area) + self.loss = loss + + def __call__(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + x = self.global_mean(x) + y = self.global_mean(y) + return self.loss(x, y) + + +class GlobalMean(torch.nn.Module): + def __init__(self, area: torch.Tensor): + """ + Args: + area: A tensor of shape (n_lat, n_lon) containing the area of + each grid cell. + """ + super().__init__() + self.area_weights = area / area.sum() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Args: + x: A tensor of shape (n_samples, n_timesteps, n_channels, n_lat, n_lon) + """ + return (x * self.area_weights[None, None, None, :, :]).sum(dim=(3, 4)) + + +@dataclasses.dataclass +class LossConfig: + """ + A dataclass containing all the information needed to build a loss function, + including the type of the loss function and the data needed to build it. + + Args: + type: the type of the loss function + kwargs: data for a loss function instance of the indicated type + global_mean_type: the type of the loss function to apply to the global + mean of each sample, by default no loss is applied + global_mean_kwargs: data for a loss function instance of the indicated + type to apply to the global mean of each sample + global_mean_weight: the weight to apply to the global mean loss + relative to the main loss + """ + + type: Literal["LpLoss", "MSE", "AreaWeightedMSE"] = "LpLoss" + kwargs: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {}) + global_mean_type: Optional[Literal["LpLoss"]] = None + global_mean_kwargs: Mapping[str, Any] = dataclasses.field(default_factory=lambda: {}) + global_mean_weight: float = 1.0 + + def __post_init__(self): + if self.type not in ("LpLoss", "MSE", "AreaWeightedMSE"): + raise NotImplementedError(self.type) + if self.global_mean_type is not None and self.global_mean_type != "LpLoss": + raise NotImplementedError(self.global_mean_type) + + def build(self, area: torch.Tensor) -> Any: + """ + Args: + area: A tensor of shape (n_lat, n_lon) containing the area of + each grid cell. + """ + area = area.to(get_device()) + if self.type == "LpLoss": + main_loss = LpLoss(**self.kwargs) + elif self.type == "MSE": + main_loss = torch.nn.MSELoss(reduction="mean") + elif self.type == "AreaWeightedMSE": + main_loss = AreaWeightedMSELoss(area) + + if self.global_mean_type is not None: + global_mean_loss = GlobalMeanLoss(area=area, loss=LpLoss(**self.global_mean_kwargs)) + final_loss = WeightedSum( + modules=[main_loss, global_mean_loss], + weights=[1.0, self.global_mean_weight], + ) + else: + final_loss = main_loss + return final_loss.to(device=get_device()) diff --git a/src/ace_inference/core/metrics.py b/src/ace_inference/core/metrics.py new file mode 100644 index 0000000..f499a61 --- /dev/null +++ b/src/ace_inference/core/metrics.py @@ -0,0 +1,367 @@ +from typing import Iterable, Optional, Union + +import numpy as np +import torch +from typing_extensions import TypeAlias + + +Dimension: TypeAlias = Union[int, Iterable[int]] +Array: TypeAlias = Union[np.ndarray, torch.Tensor] + +GRAVITY = 9.80665 # m/s^2 + + +def spherical_area_weights(lats: Array, num_lon: int) -> torch.Tensor: + """Computes area weights given the latitudes of a regular lat-lon grid. + + Args: + lats: tensor of shape (num_lat,) with the latitudes of the cell centers. + num_lon: Number of longitude points. + device: Device to place the tensor on. + + Returns: + a torch.tensor of shape (num_lat, num_lon). + """ + if isinstance(lats, np.ndarray): + lats = torch.from_numpy(lats) + weights = torch.cos(torch.deg2rad(lats)).repeat(num_lon, 1).t() + weights /= weights.sum() + return weights + + +def weighted_mean( + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + keepdim: bool = False, +) -> torch.Tensor: + """Computes the weighted mean across the specified list of dimensions. + + Args: + tensor: torch.Tensor + weights: Weights to apply to the mean. + dim: Dimensions to compute the mean over. + keepdim: Whether the output tensor has `dim` retained or not. + + Returns: + a tensor of the weighted mean averaged over the specified dimensions `dim`. + """ + if weights is None: + return tensor.mean(dim=dim, keepdim=keepdim) + return (tensor * weights).sum(dim=dim, keepdim=keepdim) / weights.expand(tensor.shape).sum( + dim=dim, keepdim=keepdim + ) + + +def weighted_std( + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), +) -> torch.Tensor: + """Computes the weighted standard deviation across the specified list of dimensions. + + Computed by first computing the weighted variance, then taking the square root. + + weighted_variance = weighted_mean((tensor - weighted_mean(tensor)) ** 2)) ** 0.5 + + Args: + tensor: torch.Tensor + weights: Weights to apply to the variance. + dim: Dimensions to compute the standard deviation over. + + Returns: + a tensor of the weighted standard deviation over the + specified dimensions `dim`. + """ + if weights is None: + weights = torch.tensor(1.0, device=tensor.device) + + mean = weighted_mean(tensor, weights=weights, dim=dim, keepdim=True) + variance = weighted_mean((tensor - mean) ** 2, weights=weights, dim=dim) + return torch.sqrt(variance) + + +def weighted_mean_bias( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), +) -> torch.Tensor: + """Computes the mean bias across the specified list of dimensions assuming + that the weights are applied to the last dimensions, e.g. the spatial dimensions. + + Args: + truth: torch.Tensor + predicted: torch.Tensor + dim: Dimensions to compute the mean over. + weights: Weights to apply to the mean. + + Returns: + a tensor of the mean biases averaged over the specified dimensions `dim`. + """ + assert truth.shape == predicted.shape, "Truth and predicted should have the same shape." + bias = predicted - truth + return weighted_mean(bias, weights=weights, dim=dim) + + +def root_mean_squared_error( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), +) -> torch.Tensor: + """ + Computes the weighted global RMSE over all variables. Namely, for each variable: + + sqrt((weights * ((xhat - x) ** 2)).mean(dims)) + + If you want to compute the RMSE over the time dimension, then pass in + `truth.mean(time_dim)` and `predicted.mean(time_dim)` and specify `dims=space_dims`. + + Args: + truth: torch.Tensor whose last dimensions are to be weighted + predicted: torch.Tensor whose last dimensions are to be weighted + weights: torch.Tensor to apply to the squared bias. + dim: Dimensions to average over. + + Returns: + a tensor of shape (variable,) of weighted RMSEs. + """ + assert truth.shape == predicted.shape, "Truth and predicted should have the same shape." + sq_bias = torch.square(predicted - truth) + return weighted_mean(sq_bias, weights=weights, dim=dim).sqrt() + + +def ensemble_spread( + ensemble: torch.Tensor, weights: Optional[torch.Tensor] = None, corr_factor: bool = True, dim: Dimension = () +) -> torch.Tensor: + """Square root of the ensemble variance. See Fortuin et al. (2013) for more details.""" + spread = weighted_mean(ensemble.var(dim=0), weights=weights, dim=dim).sqrt() + if corr_factor: + n_mems = ensemble.shape[0] + spread *= ((n_mems + 1) / n_mems) ** 0.5 + return spread + + +def spread_skill_ratio( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), +) -> torch.Tensor: + assert truth.shape == predicted.shape[1:] # ensemble ~ first axis + weighted_rmse = root_mean_squared_error(truth, predicted.mean(dim=0), weights=weights, dim=dim) + weighted_spread = ensemble_spread(predicted, weights=weights, dim=dim) + return weighted_spread / weighted_rmse + + +def weighted_crps( + truth: torch.Tensor, # TRUTH + predicted: torch.Tensor, # FORECAST + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + reduction="mean", + biased: bool = False, +) -> torch.Tensor: + """ + .. Author: Salva Rühling Cachay + + pytorch version of https://github.com/TheClimateCorporation/properscoring/blob/master/properscoring/_crps.py#L187 + + This implementation is based on the identity: + .. math:: + CRPS(F, x) = E_F|X - x| - 1/2 * E_F|X - X'| + where X and X' denote independent random variables drawn from the forecast + distribution F, and E_F denotes the expectation value under F. + + We use the fair, unbiased formulation of the ensemble CRPS, which is better for small ensembles. + Basically, we use n_members * (n_members - 1) instead of n_members**2 to average over the ensemble spread. + See Zamo & Naveau (2018; https://doi.org/10.1007/s11004-017-9709-7) for details. + + Alternative implementation: https://github.com/NVIDIA/modulus/pull/577/files + """ + assert truth.ndim == predicted.ndim - 1, f"observations.shape={truth.shape}, predictions.shape={predicted.shape}" + assert truth.shape == predicted.shape[1:] # ensemble ~ first axis + n_members = predicted.shape[0] + if n_members == 1: + return weighted_mean((predicted - truth).abs(), weights=weights, dim=dim) + + skill = (predicted - truth).abs().mean(dim=0) + # insert new axes so forecasts_diff expands with the array broadcasting + # torch.unsqueeze(predictions, 0) has shape (1, E, ...) + # torch.unsqueeze(predictions, 1) has shape (E, 1, ...) + forecasts_diff = torch.unsqueeze(predicted, 0) - torch.unsqueeze(predicted, 1) + # Forecasts_diff has shape (E, E, ...) + # Old version: score += - 0.5 * forecasts_diff.abs().mean(dim=(0, 1)) + # Using n_members * (n_members - 1) instead of n_members**2 is the fair, unbiased CRPS. Better for small ensembles. + spread = forecasts_diff.abs().sum(dim=(0, 1)) / (n_members * (n_members - 1)) + crps = skill - 0.5 * spread + # score has shape (...) (same as observations) + if reduction == "none": + return crps + assert reduction == "mean", f"Unknown reduction {reduction}" + if weights is not None: # weighted mean + crps = (crps * weights).sum(dim=dim) / weights.expand(crps.shape).sum(dim=dim) + else: + crps = crps.mean(dim=dim) + return crps + + +def gradient_magnitude(tensor: torch.Tensor, dim: Dimension = ()) -> torch.Tensor: + """Compute the magnitude of gradient across the specified dimensions.""" + gradients = torch.gradient(tensor, dim=dim) + return torch.sqrt(sum([g**2 for g in gradients])) + + +def weighted_mean_gradient_magnitude( + tensor: torch.Tensor, weights: Optional[torch.Tensor] = None, dim: Dimension = () +) -> torch.Tensor: + """Compute weighted mean of gradient magnitude across the specified dimensions.""" + return weighted_mean(gradient_magnitude(tensor, dim), weights=weights, dim=dim) + + +def gradient_magnitude_percent_diff( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + is_ensemble_prediction: bool = False, +) -> torch.Tensor: + """Compute the percent difference of the weighted mean gradient magnitude across + the specified dimensions.""" + truth_grad_mag = weighted_mean_gradient_magnitude(truth, weights, dim) + if is_ensemble_prediction: + predicted_grad_mag = 0 + for ens_i, pred in enumerate(predicted): + predicted_grad_mag += weighted_mean_gradient_magnitude(pred, weights, dim) + predicted_grad_mag /= predicted.shape[0] + else: + assert truth.shape == predicted.shape, "Truth and predicted should have the same shape." + predicted_grad_mag = weighted_mean_gradient_magnitude(predicted, weights, dim) + return 100 * (predicted_grad_mag - truth_grad_mag) / truth_grad_mag + + +def rmse_of_time_mean( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + time_dim: Dimension = 0, + spatial_dims: Dimension = (-2, -1), +) -> torch.Tensor: + """Compute the RMSE of the time-average given truth and predicted. + + Args: + truth: truth tensor + predicted: predicted tensor + weights: weights to use for computing spatial RMSE + time_dim: time dimension + spatial_dims: spatial dimensions over which RMSE is calculated + + Returns: + The RMSE between the time-mean of the two input tensors. The time and + spatial dims are reduced. + """ + truth_time_mean = truth.mean(dim=time_dim) + predicted_time_mean = predicted.mean(dim=time_dim) + ret = root_mean_squared_error(truth_time_mean, predicted_time_mean, weights=weights, dim=spatial_dims) + return ret + + +def time_and_global_mean_bias( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + time_dim: Dimension = 0, + spatial_dims: Dimension = (-2, -1), +) -> torch.Tensor: + """Compute the global- and time-mean bias given truth and predicted. + + Args: + truth: truth tensor + predicted: predicted tensor + weights: weights to use for computing the global mean + time_dim: time dimension + spatial_dims: spatial dimensions over which global mean is calculated + + Returns: + The global- and time-mean bias between the predicted and truth tensors. The + time and spatial dims are reduced. + """ + truth_time_mean = truth.mean(dim=time_dim) + predicted_time_mean = predicted.mean(dim=time_dim) + result = weighted_mean(predicted_time_mean - truth_time_mean, weights=weights, dim=spatial_dims) + return result + + +def vertical_integral( + integrand: torch.Tensor, + surface_pressure: torch.Tensor, + sigma_grid_offsets_ak: torch.Tensor, + sigma_grid_offsets_bk: torch.Tensor, +) -> torch.Tensor: + """Computes a vertical integral, namely: + + (1 / g) * ∫ x dp + + where + - g = acceleration due to gravity + - x = integrad + - p = pressure level + + Args: + integrand (lat, lon, vertical_level), (kg/kg) + surface_pressure: (lat, lon), (Pa) + sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) + sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) + + Returns: + Vertical integral of the integrand (lat, lon). + """ + ak, bk = sigma_grid_offsets_ak, sigma_grid_offsets_bk + if ak.device != surface_pressure.device: + ak = ak.to(surface_pressure.device) + if bk.device != surface_pressure.device: + bk = bk.to(surface_pressure.device) + if ak.device != integrand.device or ak.device != surface_pressure.device: + raise ValueError( + f"sigma_grid_offsets_ak.device ({ak.device}), " + f"sigma_grid_offsets_bk.device ({bk.device}), " + f"integrand.device ({integrand.device}), " + f"surface_pressure.device ({surface_pressure.device}) must be the same." + ) + pressure_thickness = ((ak + (surface_pressure.unsqueeze(-1) * bk))).diff(dim=-1) # Pa + integral = torch.sum(pressure_thickness * integrand, axis=-1) # type: ignore + return 1 / GRAVITY * integral + + +def surface_pressure_due_to_dry_air( + specific_total_water: torch.Tensor, + surface_pressure: torch.Tensor, + sigma_grid_offsets_ak: torch.Tensor, + sigma_grid_offsets_bk: torch.Tensor, +) -> torch.Tensor: + """Computes the dry air (Pa). + + Args: + specific_total_water (lat, lon, vertical_level), (kg/kg) + surface_pressure: (lat, lon), (Pa) + sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) + sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) + + Returns: + Vertically integrated dry air (lat, lon) (Pa) + """ + + num_levels = len(sigma_grid_offsets_ak) - 1 + + if num_levels != len(sigma_grid_offsets_bk) - 1 or num_levels != specific_total_water.shape[-1]: + raise ValueError(("Number of vertical levels in ak, bk, and specific_total_water must" "be the same.")) + + total_water_path = vertical_integral( + specific_total_water, + surface_pressure, + sigma_grid_offsets_ak, + sigma_grid_offsets_bk, + ) + dry_air = surface_pressure - GRAVITY * total_water_path + return dry_air diff --git a/src/ace_inference/core/normalizer.py b/src/ace_inference/core/normalizer.py new file mode 100644 index 0000000..ecd4dfe --- /dev/null +++ b/src/ace_inference/core/normalizer.py @@ -0,0 +1,126 @@ +import dataclasses +from typing import Dict, List, Mapping, Optional + +import netCDF4 +import numpy as np +import torch +import torch.jit + +from src.ace_inference.core.device import get_device + + +@dataclasses.dataclass +class NormalizationConfig: + global_means_path: Optional[str] = None + global_stds_path: Optional[str] = None + exclude_names: Optional[List[str]] = None + means: Mapping[str, float] = dataclasses.field(default_factory=dict) + stds: Mapping[str, float] = dataclasses.field(default_factory=dict) + + def __post_init__(self): + using_path = self.global_means_path is not None and self.global_stds_path is not None + using_explicit = len(self.means) > 0 and len(self.stds) > 0 + if using_path and using_explicit: + raise ValueError("Cannot use both global_means_path and global_stds_path " "and explicit means and stds.") + if not (using_path or using_explicit): + raise ValueError("Must use either global_means_path and global_stds_path " "or explicit means and stds.") + + def build(self, names: List[str]): + if self.exclude_names is not None: + names = list(set(names) - set(self.exclude_names)) + using_path = self.global_means_path is not None and self.global_stds_path is not None + if using_path: + return get_normalizer( + global_means_path=self.global_means_path, + global_stds_path=self.global_stds_path, + names=names, + ) + else: + means = {k: torch.tensor(self.means[k]) for k in names} + stds = {k: torch.tensor(self.stds[k]) for k in names} + return StandardNormalizer(means=means, stds=stds) + + +class FromStateNormalizer: + """ + An alternative to NormalizationConfig which provides a normalizer + initialized from a serializable state. + """ + + def __init__(self, state): + self.state = state + + def build(self, names: List[str]): + return StandardNormalizer.from_state(self.state) + + +class StandardNormalizer: + """ + Responsible for normalizing tensors. + """ + + def __init__( + self, + means: Dict[str, torch.Tensor], + stds: Dict[str, torch.Tensor], + ): + self.means = means + self.stds = stds + + def normalize(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return _normalize(tensors, means=self.means, stds=self.stds) + + def denormalize(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return _denormalize(tensors, means=self.means, stds=self.stds) + + def get_state(self): + """ + Returns state as a serializable data structure. + """ + return { + "means": {k: float(v.cpu().numpy()) for k, v in self.means.items()}, + "stds": {k: float(v.cpu().numpy()) for k, v in self.stds.items()}, + } + + @classmethod + def from_state(self, state) -> "StandardNormalizer": + """ + Loads state from a serializable data structure. + """ + means = {k: torch.tensor(v, device=get_device(), dtype=torch.float) for k, v in state["means"].items()} + stds = {k: torch.tensor(v, device=get_device(), dtype=torch.float) for k, v in state["stds"].items()} + return StandardNormalizer(means=means, stds=stds) + + +@torch.jit.script +def _normalize( + tensors: Dict[str, torch.Tensor], + means: Dict[str, torch.Tensor], + stds: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + return {k: (t - means[k]) / stds[k] if k in means.keys() else t for k, t in tensors.items()} + + +@torch.jit.script +def _denormalize( + tensors: Dict[str, torch.Tensor], + means: Dict[str, torch.Tensor], + stds: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + return {k: t * stds[k] + means[k] if k in means.keys() else t for k, t in tensors.items()} + + +def get_normalizer(global_means_path, global_stds_path, names: List[str]) -> StandardNormalizer: + means = load_Dict_from_netcdf(global_means_path, names) + means = {k: torch.as_tensor(v, dtype=torch.float) for k, v in means.items()} + stds = load_Dict_from_netcdf(global_stds_path, names) + stds = {k: torch.as_tensor(v, dtype=torch.float) for k, v in stds.items()} + return StandardNormalizer(means=means, stds=stds) + + +def load_Dict_from_netcdf(path, names) -> Dict[str, np.ndarray]: + ds = netCDF4.Dataset(path) + ds.set_auto_mask(False) + Dict = {c: ds.variables[c][:] for c in names} + ds.close() + return Dict diff --git a/src/ace_inference/core/ocean.py b/src/ace_inference/core/ocean.py new file mode 100644 index 0000000..e493016 --- /dev/null +++ b/src/ace_inference/core/ocean.py @@ -0,0 +1,146 @@ +import dataclasses +from typing import Dict, List, Optional + +import torch + +from src.ace_inference.core.aggregator.climate_data import ClimateData + +from .constants import DENSITY_OF_WATER, SPECIFIC_HEAT_OF_WATER, TIMESTEP_SECONDS +from .prescriber import Prescriber + + +@dataclasses.dataclass +class SlabOceanConfig: + mixed_layer_depth_name: str + q_flux_name: str + + @property + def names(self) -> List[str]: + return [self.mixed_layer_depth_name, self.q_flux_name] + + +@dataclasses.dataclass +class OceanConfig: + """Configuration for determining sea surface temperature from an ocean model. + + Args: + surface_temperature_name: Name of the sea surface temperature field. + ocean_fraction_name: Name of the ocean fraction field. + interpolate: If True, interpolate between ML-predicted surface temperature and + ocean-predicted surface temperature according to ocean_fraction. If False, + only use ocean-predicted surface temperature where ocean_fraction>=0.5. + slab: If provided, use a slab ocean model to predict surface temperature. + """ + + surface_temperature_name: str + ocean_fraction_name: str + interpolate: bool = False + slab: Optional[SlabOceanConfig] = None + + def build(self, in_names: List[str], out_names: List[str]): + if not (self.surface_temperature_name in in_names and self.surface_temperature_name in out_names): + raise ValueError( + "To use a surface ocean model, the surface temperature must be present" + f" in_names and out_names, but {self.surface_temperature_name} is not." + ) + return Ocean(config=self) + + @property + def names(self) -> List[str]: + names = [self.surface_temperature_name, self.ocean_fraction_name] + if self.slab is not None: + names.extend(self.slab.names) + return list(set(names)) + + +class Ocean: + """Overwrite sea surface temperature with that predicted from some ocean model.""" + + def __init__(self, config: OceanConfig): + """ + Args: + config: Configuration for the surface ocean model. + """ + self.surface_temperature_name = config.surface_temperature_name + self.ocean_fraction_name = config.ocean_fraction_name + self.prescriber = Prescriber( + prescribed_name=config.surface_temperature_name, + mask_name=config.ocean_fraction_name, + mask_value=1, + interpolate=config.interpolate, + ) + if config.slab is None: + self.type = "prescribed" + self._target_names = [ + self.surface_temperature_name, + self.ocean_fraction_name, + ] + else: + self.type = "slab" + self.mixed_layer_depth_name = config.slab.mixed_layer_depth_name + self.q_flux_name = config.slab.q_flux_name + self._target_names = [ + self.ocean_fraction_name, + self.mixed_layer_depth_name, + self.q_flux_name, + ] + + def __call__( + self, + target_data: Dict[str, torch.Tensor], + input_data: Dict[str, torch.Tensor], + gen_data: Dict[str, torch.Tensor], + ) -> Dict[str, torch.Tensor]: + """ + Args: + target_data: Denormalized data that includes mask and forcing data. Assumed + to correspond to the same time step as gen_data. + input_data: Denormalized input data for current step. + gen_data: Denormalized output data for current step. + + Returns: + gen_data with sea surface temperature overwritten by ocean model. + """ + if self.type == "prescribed": + next_step_temperature = target_data[self.surface_temperature_name] + elif self.type == "slab": + temperature_tendency = mixed_layer_temperature_tendency( + ClimateData(gen_data).net_surface_energy_flux_without_frozen_precip, + target_data[self.q_flux_name], + target_data[self.mixed_layer_depth_name], + ) + next_step_temperature = input_data[self.surface_temperature_name] + temperature_tendency * TIMESTEP_SECONDS + else: + raise NotImplementedError(f"Ocean type={self.type} is not implemented") + + return self.prescriber( + target_data, + gen_data, + {self.surface_temperature_name: next_step_temperature}, + ) + + @property + def target_names(self) -> List[str]: + """These are the variables required from the target data.""" + return self._target_names + + +def mixed_layer_temperature_tendency( + f_net: torch.Tensor, + q_flux: torch.Tensor, + depth: torch.Tensor, + density=DENSITY_OF_WATER, + specific_heat=SPECIFIC_HEAT_OF_WATER, +) -> torch.Tensor: + """ + Args: + f_net: Net surface energy flux in W/m^2. + q_flux: Convergence of ocean heat transport in W/m^2. + depth: Mixed layer depth in m. + density (optional): Density of water in kg/m^3. + specific_heat (optional): Specific heat of water in J/kg/K. + + Returns: + Temperature tendency of mixed layer in K/s. + """ + return (f_net + q_flux) / (density * depth * specific_heat) diff --git a/src/ace_inference/core/optimization.py b/src/ace_inference/core/optimization.py new file mode 100644 index 0000000..f358e08 --- /dev/null +++ b/src/ace_inference/core/optimization.py @@ -0,0 +1,190 @@ +import contextlib +import dataclasses +from typing import Any, Literal, Mapping, Optional + +import torch +import torch.cuda.amp as amp +from torch import nn + +from src.ace_inference.core.scheduler import SchedulerConfig + + +class Optimization: + def __init__( + self, + parameters, + optimizer_type: Literal["Adam", "FusedAdam"], + lr: float, + max_epochs: int, + scheduler: SchedulerConfig, + enable_automatic_mixed_precision: bool, + kwargs: Mapping[str, Any], + ): + if optimizer_type == "FusedAdam": + from apex import optimizers + + self.optimizer = optimizers.FusedAdam(parameters, lr=lr, **kwargs) + elif optimizer_type == "Adam": + self.optimizer = torch.optim.Adam(parameters, lr=lr, **kwargs) + else: + raise ValueError(f"Unknown optimizer type: {optimizer_type}") + + if enable_automatic_mixed_precision: + self.gscaler: Optional[amp.GradScaler] = amp.GradScaler() + else: + self.gscaler = None + self.scheduler = scheduler.build(self.optimizer, max_epochs) + + @contextlib.contextmanager + def autocast(self): + with amp.autocast(enabled=self.gscaler is not None): + yield + + @property + def learning_rate(self) -> float: + return self.optimizer.param_groups[0]["lr"] + + def set_mode(self, module: nn.Module): + """ + Sets the mode of the module to train. + """ + module.train() + + def step_scheduler(self, valid_loss: float): + """ + Step the scheduler. + + Args: + valid_loss: The validation loss. Used in schedulers which change the + learning rate based on whether the validation loss is decreasing. + """ + if self.scheduler is not None: + try: + self.scheduler.step(metrics=valid_loss) + except TypeError: + self.scheduler.step() + + def step_weights(self, loss: torch.Tensor): + if self.gscaler is not None: + self.gscaler.scale(loss).backward() + self.gscaler.step(self.optimizer) + else: + loss.backward() + self.optimizer.step() + self.optimizer.zero_grad() + + if self.gscaler is not None: + self.gscaler.update() + + def get_state(self): + """ + Returns state as a serializable data structure. + """ + state = { + "optimizer_state_dict": self.optimizer.state_dict(), + "scheduler_state_dict": self.scheduler.state_dict() if self.scheduler is not None else None, + "gscaler_state_dict": self.gscaler.state_dict() if self.gscaler is not None else None, + } + return state + + def load_state(self, state): + """ + Loads state from a serializable data structure. + """ + self.optimizer.load_state_dict(state["optimizer_state_dict"]) + if self.scheduler is not None: + self.scheduler.load_state_dict(state["scheduler_state_dict"]) + if self.gscaler is not None: + self.gscaler.load_state_dict(state["gscaler_state_dict"]) + + +@dataclasses.dataclass +class DisabledOptimizationConfig: + """ + Configuration for optimization, kept only for backwards compatibility when + loading configuration. Cannot be used to build, will raise an exception. + """ + + optimizer_type: Literal["Adam", "FusedAdam"] = "Adam" + lr: float = 0.001 + kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) + enable_automatic_mixed_precision: bool = True + scheduler: SchedulerConfig = dataclasses.field(default_factory=lambda: SchedulerConfig()) + + def build(self, parameters, max_epochs: int) -> Optimization: + raise RuntimeError("Cannot build DisabledOptimizationConfig") + + def get_state(self) -> Mapping[str, Any]: + return dataclasses.asdict(self) + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "DisabledOptimizationConfig": + return cls(**state) + + +@dataclasses.dataclass +class OptimizationConfig: + """ + Configuration for optimization. + + Attributes: + optimizer_type: The type of optimizer to use. + lr: The learning rate. + kwargs: Additional keyword arguments to pass to the optimizer. + enable_automatic_mixed_precision: Whether to use automatic mixed + precision. + scheduler: The type of scheduler to use. If none is given, no scheduler + will be used. + """ + + optimizer_type: Literal["Adam", "FusedAdam"] = "Adam" + lr: float = 0.001 + kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) + enable_automatic_mixed_precision: bool = True + scheduler: SchedulerConfig = dataclasses.field(default_factory=lambda: SchedulerConfig()) + + def build(self, parameters, max_epochs: int) -> Optimization: + return Optimization( + parameters=parameters, + optimizer_type=self.optimizer_type, + lr=self.lr, + max_epochs=max_epochs, + scheduler=self.scheduler, + enable_automatic_mixed_precision=self.enable_automatic_mixed_precision, + kwargs=self.kwargs, + ) + + def get_state(self) -> Mapping[str, Any]: + return dataclasses.asdict(self) + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "OptimizationConfig": + return cls(**state) + + +class NullOptimization: + @contextlib.contextmanager + def autocast(self): + yield + + @property + def learning_rate(self) -> float: + return float("nan") + + def step_scheduler(self, valid_loss: float): + return + + def step_weights(self, loss: torch.Tensor): + return + + def get_state(self): + return {} + + def load_state(self, state): + return + + def set_mode(self, module: nn.Module): + """ + Sets the mode of the module to eval. + """ + module.eval() diff --git a/src/ace_inference/core/packer.py b/src/ace_inference/core/packer.py new file mode 100644 index 0000000..a3e2c89 --- /dev/null +++ b/src/ace_inference/core/packer.py @@ -0,0 +1,70 @@ +from typing import Dict, List + +import torch +import torch.jit +from tensordict import TensorDict + + +class DataShapesNotUniform(ValueError): + """Indicates that a set of tensors do not all have the same shape.""" + + pass + + +class NoPacker: + def pack(self, tensors: Dict[str, torch.Tensor], axis=0) -> torch.Tensor: + return tensors + + def unpack(self, tensor: torch.Tensor, axis=0) -> Dict[str, torch.Tensor]: + return tensor + + +class Packer: + """ + Responsible for packing tensors into a single tensor. + """ + + def __init__(self, names: List[str], axis=None): + self.names = names + self.axis = axis + + def pack(self, tensors: Dict[str, torch.Tensor], axis=None) -> torch.Tensor: + """ + Packs tensors into a single tensor, concatenated along a new axis + + Args: + tensors: Dict from names to tensors. + axis: index for new concatenation axis. + """ + axis = axis if axis is not None else self.axis + return _pack(tensors, self.names, axis=axis) + + def unpack(self, tensor: torch.Tensor, axis=None) -> TensorDict: + axis = axis if axis is not None else self.axis + # packed shape is tensor.shape with axis removed + packed_shape = list(tensor.shape) + packed_shape.pop(axis) + return TensorDict(_unpack(tensor, self.names, axis=axis), batch_size=packed_shape) + + def get_state(self): + """ + Returns state as a serializable data structure. + """ + return {"names": self.names} + + @classmethod + def from_state(self, state) -> "Packer": + """ + Loads state from a serializable data structure. + """ + return Packer(state["names"]) + + +@torch.jit.script +def _pack(tensors: Dict[str, torch.Tensor], names: List[str], axis: int) -> torch.Tensor: + return torch.cat([tensors[n].unsqueeze(axis) for n in names], dim=axis) + + +@torch.jit.script +def _unpack(tensor: torch.Tensor, names: List[str], axis: int) -> Dict[str, torch.Tensor]: + return {n: tensor.select(axis, index=i) for i, n in enumerate(names)} diff --git a/src/ace_inference/core/parameter_init.py b/src/ace_inference/core/parameter_init.py new file mode 100644 index 0000000..5978328 --- /dev/null +++ b/src/ace_inference/core/parameter_init.py @@ -0,0 +1,115 @@ +import dataclasses +from typing import Any, List, Mapping, Optional + +import torch +from torch import nn + +from src.ace_inference.core.device import get_device +from src.ace_inference.core.wildcard import apply_by_wildcard, wildcard_match + +from .weight_ops import overwrite_weights, strip_leading_module + + +@dataclasses.dataclass +class FrozenParameterConfig: + """ + Configuration for freezing parameters in a model. + + Parameter names can include wildcards, e.g. "encoder.*" will select + all parameters in the encoder, while "encoder.*.bias" will select all + bias parameters in the encoder. All parameters must be specified + in either the include or exclude list, or + an exception will be raised. + + An exception is raised if a parameter is included by both lists. + + Attributes: + include: list of parameter names to freeze (set requires_grad = False) + exclude: list of parameter names to ignore + """ + + include: List[str] = dataclasses.field(default_factory=list) + exclude: List[str] = dataclasses.field(default_factory=list) + + def __post_init__(self): + for pattern in self.include: + if any(wildcard_match(pattern, exclude) for exclude in self.exclude): + raise ValueError( + f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" + ) + for pattern in self.exclude: + if any(wildcard_match(pattern, include) for include in self.include): + raise ValueError( + f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" + ) + + def apply(self, model: nn.Module): + apply_by_wildcard(model, _freeze_weight, self.include, self.exclude) + + +def _freeze_weight(module: nn.Module, name: str): + try: + module.get_parameter(name).requires_grad = False + except AttributeError: # non-parameter state + pass + + +@dataclasses.dataclass +class ParameterInitializationConfig: + """ + A class which applies custom initialization to module parameters. + + Assumes the module weights have already been randomly initialized. + + Supports overwriting the weights of the built model with weights from a + pre-trained model. If the built model has larger weights than the + pre-trained model, only the initial slice of the weights is overwritten. + + Attributes: + weight_path: path to a SingleModuleStepper checkpoint + containing weights to load + exclude_parameters: list of parameter names to exclude from the loaded + weights. Used for example to keep the random initialization for + final layer(s) of a model, and only overwrite the weights for + earlier layers. Takes values like "decoder.2.weight". + frozen_parameters: configuration for freezing parameters in the built model + """ + + weights_path: Optional[str] = None + exclude_parameters: List[str] = dataclasses.field(default_factory=list) + frozen_parameters: FrozenParameterConfig = dataclasses.field( + default_factory=lambda: FrozenParameterConfig(exclude=["*"]) + ) + + def apply(self, module: nn.Module, init_weights: bool) -> nn.Module: + """ + Apply the weight initialization to a module. + + Args: + module: a nn.Module to initialize + init_weights: whether to initialize the weight values + + Returns: + a nn.Module with initialization applied + """ + if init_weights and self.weights_path is not None: + loaded_state_dict = self.get_base_weights() + if loaded_state_dict is not None: + overwrite_weights( + loaded_state_dict, + module, + exclude_parameters=self.exclude_parameters, + ) + self.frozen_parameters.apply(module) + return module + + def get_base_weights(self) -> Optional[Mapping[str, Any]]: + """ + If a weights_path is provided, return the model base weights used for + initialization. + """ + if self.weights_path is not None: + checkpoint = torch.load(self.weights_path, map_location=get_device()) + return strip_leading_module(checkpoint["stepper"]["module"]) + else: + return None diff --git a/src/ace_inference/core/prescriber.py b/src/ace_inference/core/prescriber.py new file mode 100644 index 0000000..692ec6e --- /dev/null +++ b/src/ace_inference/core/prescriber.py @@ -0,0 +1,134 @@ +import dataclasses +from typing import Dict, List + +import torch + + +@dataclasses.dataclass +class PrescriberConfig: + """ + Configuration for overwriting predictions of 'prescribed_name' by target values. + + If interpolate is False, the data is overwritten in the region where + 'mask_name' == 'mask_value' after values are cast to integer. If interpolate + is True, the data is interpolated between the predicted value at 0 and the + target value at 1 based on the mask variable, and it is assumed the mask variable + lies in the range from 0 to 1. + + Attributes: + prescribed_name: Name of the variable to be overwritten. + mask_name: Name of the mask variable. + mask_value: Value of the mask variable in the region to be overwritten. + interpolate: Whether to interpolate linearly between the generated and target + values in the masked region, where 0 means keep the generated values and + 1 means replace completely with the target values. Requires mask_value + be set to 1. + """ + + prescribed_name: str + mask_name: str + mask_value: int + interpolate: bool = False + + def __post_init__(self): + if self.interpolate and self.mask_value != 1: + raise ValueError("Interpolation requires mask_value to be 1, but it is set to " f"{self.mask_value}.") + + def build(self, in_names: List[str], out_names: List[str]): + if not (self.prescribed_name in in_names and self.prescribed_name in out_names): + raise ValueError( + "Variables which are being prescribed in masked regions must be in" + f" in_names and out_names, but {self.prescribed_name} is not." + ) + return Prescriber( + prescribed_name=self.prescribed_name, + mask_name=self.mask_name, + mask_value=self.mask_value, + interpolate=self.interpolate, + ) + + +class Prescriber: + """ + Responsible for overwriting model predictions by target data in masked regions. + """ + + def __init__( + self, + prescribed_name: str, + mask_name: str, + mask_value: int, + interpolate: bool = False, + ): + self.prescribed_name = prescribed_name + self.mask_name = mask_name + self.mask_value = mask_value + self.interpolate = interpolate + + def __call__( + self, + data: Dict[str, torch.Tensor], + gen_norm: Dict[str, torch.Tensor], + target_norm: Dict[str, torch.Tensor], + ): + """ + Args: + data: Dictionary of data containing the mask variable. + gen_norm: Dictionary of generated data. + target_norm: Dictionary of target data. + """ + if self.interpolate: + mask = data[self.mask_name] + # 0 keeps the generated values, 1 replaces completely with the target values + prescribed_gen = mask * target_norm[self.prescribed_name] + (1 - mask) * gen_norm[self.prescribed_name] + else: + # overwrite specified generated variable in given mask region + prescribed_gen = torch.where( + torch.round(data[self.mask_name]).to(int) == self.mask_value, + target_norm[self.prescribed_name], + gen_norm[self.prescribed_name], + ) + gen_norm[self.prescribed_name] = prescribed_gen + return gen_norm + + def get_state(self): + return { + "prescribed_name": self.prescribed_name, + "mask_name": self.mask_name, + "mask_value": self.mask_value, + "interpolate": self.interpolate, + } + + def load_state(self, state): + self.prescribed_name = state["prescribed_name"] + self.mask_name = state["mask_name"] + self.mask_value = state["mask_value"] + interpolate = state.get("interpolate", False) + self.interpolate = interpolate + + @classmethod + def from_state(cls, state) -> "Prescriber": + return Prescriber( + state["prescribed_name"], + state["mask_name"], + state["mask_value"], + state.get("interpolate", False), + ) + + +class NullPrescriber: + """Dummy prescriber that does nothing.""" + + def __call__( + self, + data: Dict[str, torch.Tensor], + gen_norm: Dict[str, torch.Tensor], + target_norm: Dict[str, torch.Tensor], + ): + return gen_norm + + def get_state(self): + return {} + + def load_state(self, state): + return diff --git a/src/ace_inference/core/registry.py b/src/ace_inference/core/registry.py new file mode 100644 index 0000000..369d13f --- /dev/null +++ b/src/ace_inference/core/registry.py @@ -0,0 +1,194 @@ +import dataclasses +from typing import Any, Literal, Mapping, Optional, Protocol, Tuple, Type + +import torch_harmonics as harmonics +from torch import nn + + +class ModuleConfig(Protocol): + """ + A protocol for a class that can build a nn.Module given information about the input + and output channels and the image shape. + + This is a "Config" as in practice it is a dataclass loaded directly from yaml, + allowing us to specify details of the network architecture in a config file. + """ + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + """ + Build a nn.Module given information about the input and output channels + and the image shape. + + Args: + n_in_channels: number of input channels + n_out_channels: number of output channels + img_shape: last two dimensions of data, corresponding to lat and + lon when using FourCastNet conventions + + Returns: + a nn.Module + """ + ... + + +# this is based on the call signature of SphericalFourierNeuralOperatorNet at +# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501 +@dataclasses.dataclass +class SphericalFourierNeuralOperatorBuilder(ModuleConfig): + spectral_transform: str = "sht" + filter_type: str = "non-linear" + operator_type: str = "diagonal" + scale_factor: int = 16 + embed_dim: int = 256 + num_layers: int = 12 + num_blocks: int = 16 + hard_thresholding_fraction: float = 1.0 + normalization_layer: str = "instance_norm" + use_mlp: bool = True + activation_function: str = "gelu" + encoder_layers: int = 1 + pos_embed: bool = True + big_skip: bool = True + rank: float = 1.0 + factorization: Optional[str] = None + separable: bool = False + complex_network: bool = True + complex_activation: str = "real" + spectral_layers: int = 1 + checkpointing: int = 0 + data_grid: Literal["legendre-gauss", "equiangular"] = "legendre-gauss" + drop_rate: float = 0.0 + drop_path_rate: float = 0.0 + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ): + from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet + + sfno_net = SphericalFourierNeuralOperatorNet( + params=self, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=img_shape, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + ) + + # Patch in the grid that our data lies on rather than the one which is + # hard-coded in the modulus codebase [1]. Duplicate the code to compute + # the number of SHT modes determined by hard_thresholding_fraction. Note + # that this does not handle the distributed case which is handled by + # L518 [2] in their codebase. + + # [1] https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py # noqa: E501 + # [2] https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L518 # noqa: E501 + nlat, nlon = img_shape + modes_lat = int(nlat * self.hard_thresholding_fraction) + modes_lon = int((nlon // 2 + 1) * self.hard_thresholding_fraction) + sht = harmonics.RealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=self.data_grid).float() + isht = harmonics.InverseRealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=self.data_grid).float() + + sfno_net.trans_down = sht + sfno_net.itrans_up = isht + + return sfno_net + + +@dataclasses.dataclass +class PreBuiltBuilder(ModuleConfig): + """ + A simple module configuration which returns a pre-defined module. + + Used mainly for testing. + """ + + module: nn.Module + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + return self.module + + +NET_REGISTRY: Mapping[str, Type[ModuleConfig]] = { + # "afno": AFNONetBuilder, # using short acronym for backwards compatibility + "SphericalFourierNeuralOperatorNet": SphericalFourierNeuralOperatorBuilder, # type: ignore # noqa: E501 + "prebuilt": PreBuiltBuilder, +} + + +@dataclasses.dataclass +class ModuleSelector: + """ + A dataclass containing all the information needed to build a ModuleConfig, + including the type of the ModuleConfig and the data needed to build it. + + This is helpful as ModuleSelector can be serialized and deserialized + without any additional information, whereas to load a ModuleConfig you + would need to know the type of the ModuleConfig being loaded. + + It is also convenient because ModuleSelector is a single class that can be + used to represent any ModuleConfig, whereas ModuleConfig is a protocol + that can be implemented by many different classes. + + Attributes: + type: the type of the ModuleConfig + config: data for a ModuleConfig instance of the indicated type + """ + + type: Literal[ + "afno", + "SphericalFourierNeuralOperatorNet", + "prebuilt", + ] + config: Mapping[str, Any] + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + """ + Build a nn.Module given information about the input and output channels + and the image shape. + + Args: + n_in_channels: number of input channels + n_out_channels: number of output channels + img_shape: last two dimensions of data, corresponding to lat and + lon when using FourCastNet conventions + + Returns: + a nn.Module + """ + return NET_REGISTRY[self.type](**self.config).build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + img_shape=img_shape, + ) + + def get_state(self) -> Mapping[str, Any]: + """ + Get a dictionary containing all the information needed to build a ModuleConfig. + """ + return {"type": self.type, "config": self.config} + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "ModuleSelector": + """ + Create a ModuleSelector from a dictionary containing all the information + needed to build a ModuleConfig. + """ + return cls(**state) diff --git a/src/ace_inference/core/scheduler.py b/src/ace_inference/core/scheduler.py new file mode 100644 index 0000000..81b68ef --- /dev/null +++ b/src/ace_inference/core/scheduler.py @@ -0,0 +1,29 @@ +import dataclasses +from typing import Any, Mapping, Optional + +import torch.optim.lr_scheduler + + +@dataclasses.dataclass +class SchedulerConfig: + """ + Configuration for a scheduler to use during training. + + Attributes: + type: Name of scheduler class from torch.optim.lr_scheduler, + no scheduler is used by default. + kwargs: Keyword arguments to pass to the scheduler constructor. + """ + + type: Optional[str] = None + kwargs: Mapping[str, Any] = dataclasses.field(default_factory=dict) + + def build(self, optimizer) -> Optional[torch.optim.lr_scheduler._LRScheduler]: + """ + Build the scheduler. + """ + if self.type is None: + return None + else: + scheduler_class = getattr(torch.optim.lr_scheduler, self.type) + return scheduler_class(optimizer=optimizer, **self.kwargs) diff --git a/src/ace_inference/core/stepper.py b/src/ace_inference/core/stepper.py new file mode 100644 index 0000000..28cf3e5 --- /dev/null +++ b/src/ace_inference/core/stepper.py @@ -0,0 +1,591 @@ +import dataclasses +import warnings +from typing import ( + Any, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Protocol, + Tuple, + Union, +) + +import dacite +import torch +from torch import nn +from torch.nn.parallel import DistributedDataParallel +from tqdm import tqdm + +from src.ace_inference.core.aggregator.null import NullAggregator +from src.ace_inference.core.corrector import Corrector, CorrectorConfig +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates +from src.ace_inference.core.data_loading.requirements import DataRequirements +from src.ace_inference.core.device import get_device, using_gpu +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.loss import ConservationLoss, ConservationLossConfig, LossConfig +from src.ace_inference.core.normalizer import ( + FromStateNormalizer, + NormalizationConfig, + StandardNormalizer, +) +from src.ace_inference.core.ocean import Ocean, OceanConfig +from src.ace_inference.core.optimization import DisabledOptimizationConfig, NullOptimization, Optimization +from src.ace_inference.core.parameter_init import ParameterInitializationConfig +from src.ace_inference.core.prescriber import PrescriberConfig +from src.ace_inference.core.registry import ModuleSelector +from src.evaluation.aggregators.main import OneStepAggregator +from src.utilities.packer import Packer +from src.utilities.utils import enable_inference_dropout as enable_inference_dropout_func + + +@dataclasses.dataclass +class SingleModuleStepperConfig: + builder: ModuleSelector + in_names: List[str] + out_names: List[str] + normalization: Union[NormalizationConfig, FromStateNormalizer] + parameter_init: ParameterInitializationConfig = dataclasses.field( + default_factory=lambda: ParameterInitializationConfig() + ) + optimization: Optional[DisabledOptimizationConfig] = None + ocean: Optional[OceanConfig] = None + loss: LossConfig = dataclasses.field(default_factory=lambda: LossConfig()) + conserve_dry_air: Optional[bool] = None + corrector: CorrectorConfig = dataclasses.field(default_factory=lambda: CorrectorConfig()) + conservation_loss: ConservationLossConfig = dataclasses.field(default_factory=lambda: ConservationLossConfig()) + prescriber: Optional[PrescriberConfig] = None + enable_inference_dropout: bool = False + + def __post_init__(self): + if self.conserve_dry_air is not None: + warnings.warn( + "conserve_dry_air is deprecated, " "use corrector.conserve_dry_air instead", + category=DeprecationWarning, + ) + self.corrector.conserve_dry_air = self.conserve_dry_air + if self.prescriber is not None: + warnings.warn( + "Directly configuring prescriber is deprecated, " "use 'ocean' option instead.", + category=DeprecationWarning, + ) + if self.ocean is not None: + raise ValueError("Cannot specify both prescriber and ocean.") + self.ocean = OceanConfig( + surface_temperature_name=self.prescriber.prescribed_name, + ocean_fraction_name=self.prescriber.mask_name, + interpolate=self.prescriber.interpolate, + ) + del self.prescriber + + def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: + return DataRequirements( + names=self.all_names, + n_timesteps=n_forward_steps + 1, + ) + + def get_state(self): + return dataclasses.asdict(self) + + def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: + """ + If the model is being initialized from another model's weights for fine-tuning, + returns those weights. Otherwise, returns None. + + The list mirrors the order of `modules` in the `SingleModuleStepper` class. + """ + base_weights = self.parameter_init.get_base_weights() + if base_weights is not None: + return [base_weights] + else: + return None + + def get_stepper( + self, + img_shape: Tuple[int, int], + area: Optional[torch.Tensor], + sigma_coordinates: SigmaCoordinates, + ): + return SingleModuleStepper( + config=self, + img_shape=img_shape, + area=area, + sigma_coordinates=sigma_coordinates, + ) + + @classmethod + def from_state(cls, state) -> "SingleModuleStepperConfig": + return dacite.from_dict(data_class=cls, data=state, config=dacite.Config(strict=True)) + + @property + def all_names(self): + """Names of all variables required, including auxiliary ones.""" + extra_names = [] + if self.ocean is not None: + extra_names.extend(self.ocean.names) + all_names = list(set(self.in_names).union(self.out_names).union(extra_names)) + return all_names + + @property + def normalize_names(self): + """Names of variables which require normalization. I.e. inputs/outputs.""" + return list(set(self.in_names).union(self.out_names)) + + +@dataclasses.dataclass +class ExistingStepperConfig: + checkpoint_path: str + + def _load_checkpoint(self) -> Mapping[str, Any]: + return torch.load(self.checkpoint_path, map_location=get_device()) + + def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: + return SingleModuleStepperConfig.from_state( + self._load_checkpoint()["stepper"]["config"] + ).get_data_requirements(n_forward_steps) + + def get_base_weights(self) -> Optional[List[Mapping[str, Any]]]: + return SingleModuleStepperConfig.from_state(self._load_checkpoint()["stepper"]["config"]).get_base_weights() + + def get_stepper(self, img_shape, area, sigma_coordinates): + del img_shape # unused + return SingleModuleStepper.from_state( + self._load_checkpoint()["stepper"], + area=area, + sigma_coordinates=sigma_coordinates, + ) + + +class DummyWrapper(nn.Module): + """ + Wrapper class for a single pytorch module, which does nothing. + + Exists so we have an identical module structure to the case where we use + a DistributedDataParallel wrapper. + """ + + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, *args, **kwargs): + return self.module(*args, **kwargs) + + +@dataclasses.dataclass +class SteppedData: + metrics: Dict[str, torch.Tensor] + gen_data: Dict[str, torch.Tensor] + target_data: Dict[str, torch.Tensor] + gen_data_norm: Dict[str, torch.Tensor] + target_data_norm: Dict[str, torch.Tensor] + + def remove_initial_condition(self) -> "SteppedData": + any_key = next(iter(self.gen_data.keys())) + is_ensemble = self.gen_data[any_key].shape != self.target_data[any_key].shape + + def remove_initial_condition_from_gen(tensor: torch.Tensor) -> torch.Tensor: + return tensor[:, :, 1:] if is_ensemble else tensor[:, 1:] + + return SteppedData( + metrics=self.metrics, + gen_data={k: remove_initial_condition_from_gen(v) for k, v in self.gen_data.items()}, + target_data={k: v[:, 1:] for k, v in self.target_data.items()}, + gen_data_norm={k: remove_initial_condition_from_gen(v) for k, v in self.gen_data_norm.items()}, + target_data_norm={k: v[:, 1:] for k, v in self.target_data_norm.items()}, + ) + + def copy(self) -> "SteppedData": + """Creates new dictionaries for the data but with the same tensors.""" + return SteppedData( + metrics=self.metrics, + gen_data={k: v for k, v in self.gen_data.items()}, + target_data={k: v for k, v in self.target_data.items()}, + gen_data_norm={k: v for k, v in self.gen_data_norm.items()}, + target_data_norm={k: v for k, v in self.target_data_norm.items()}, + ) + + # Method to stack a list of stepped data objects together + @staticmethod + def stack(stepped_data_list: List["SteppedData"], dim: int) -> "SteppedData": + return SteppedData( + metrics=None, + gen_data={ + k: torch.stack([sd.gen_data[k] for sd in stepped_data_list], dim=dim) + for k in stepped_data_list[0].gen_data.keys() + }, + target_data={ + k: torch.stack([sd.target_data[k] for sd in stepped_data_list], dim=dim) + for k in stepped_data_list[0].target_data.keys() + }, + gen_data_norm={ + k: torch.stack([sd.gen_data_norm[k] for sd in stepped_data_list], dim=dim) + for k in stepped_data_list[0].gen_data_norm.keys() + }, + target_data_norm={ + k: torch.stack([sd.target_data_norm[k] for sd in stepped_data_list], dim=dim) + for k in stepped_data_list[0].target_data_norm.keys() + }, + ) + + +class SingleModuleStepper: + """ + Stepper class for a single pytorch module. + """ + + def __init__( + self, + config: SingleModuleStepperConfig, + img_shape: Tuple[int, int], + area: torch.Tensor, + sigma_coordinates: SigmaCoordinates, + init_weights: bool = True, + ): + """ + Args: + config: The configuration. + img_shape: Shape of domain as (n_lat, n_lon). + area: (n_lat, n_lon) array containing relative gridcell area, + in any units including unitless. + sigma_coordinates: The sigma coordinates. + init_weights: Whether to initialize the weights. Should pass False if + the weights are about to be overwritten by a checkpoint. + """ + dist = Distributed.get_instance() + n_in_channels = len(config.in_names) + n_out_channels = len(config.out_names) + channel_axis = -3 + self.in_packer = Packer(config.in_names, axis=channel_axis) + self.out_packer = Packer(config.out_names, axis=channel_axis) + self.normalizer = config.normalization.build(config.normalize_names) + if config.ocean is not None: + self.ocean = config.ocean.build(config.in_names, config.out_names) + else: + self.ocean = None + self.module = config.builder.build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + img_shape=img_shape, + ) + self.module = config.parameter_init.apply(self.module, init_weights=init_weights).to(get_device()) + + self._img_shape = img_shape + self._config = config + self._no_optimization = NullOptimization() + + if dist.is_distributed(): + if using_gpu(): + device_ids = [dist.local_rank] + output_device = [dist.local_rank] + else: + device_ids = None + output_device = None + self.module = DistributedDataParallel( + self.module, + device_ids=device_ids, + output_device=output_device, + ) + else: + self.module = DummyWrapper(self.module) + self._is_distributed = dist.is_distributed() + + self.area = area + self.sigma_coordinates = sigma_coordinates.to(get_device()) + self.loss_obj = config.loss.build(self.area) + self._conservation_loss = config.conservation_loss.build( + area_weights=self.area, + sigma_coordinates=self.sigma_coordinates, + ) + self._corrector = config.corrector.build(area=area, sigma_coordinates=sigma_coordinates) + + def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: + return self._config.get_data_requirements(n_forward_steps) + + @property + def modules(self) -> nn.ModuleList: + """ + Returns: + A list of modules being trained. + """ + return nn.ModuleList([self.module]) + + def run_on_batch( + self, + data: Dict[str, torch.Tensor], + optimization: Union[Optimization, NullOptimization], + n_forward_steps: int = 1, + aggregator: Optional[OneStepAggregator] = None, + ) -> SteppedData: + """ + Step the model forward on a batch of data. + + Args: + data: The batch data of shape [n_sample, n_timesteps, n_channels, n_x, n_y]. + optimization: The optimization class to use for updating the module. + Use `NullOptimization` to disable training. + n_forward_steps: The number of timesteps to run the model for. + aggregator: The data aggregator. + + Returns: + The loss, the generated data, the normalized generated data, + and the normalized batch data. + """ + if aggregator is None: + non_none_aggregator: Union[OneStepAggregator, NullAggregator] = NullAggregator() + else: + non_none_aggregator = aggregator + + device = get_device() + device_data = {name: value.to(device, dtype=torch.float) for name, value in data.items()} + return run_on_batch( + data=device_data, + module=self.module, + normalizer=self.normalizer, + in_packer=self.in_packer, + out_packer=self.out_packer, + optimization=optimization, + loss_obj=self.loss_obj, + n_forward_steps=n_forward_steps, + ocean=self.ocean, + aggregator=non_none_aggregator, + corrector=self._corrector, + conservation_loss=self._conservation_loss, + enable_inference_dropout=self.enable_inference_dropout, + ) + + def get_state(self): + """ + Returns: + The state of the stepper. + """ + return { + "module": self.module.state_dict(), + "normalizer": self.normalizer.get_state(), + "img_shape": self._img_shape, + "config": self._config.get_state(), + "area": self.area, + "sigma_coordinates": self.sigma_coordinates.as_dict(), + } + + def load_state(self, state): + """ + Load the state of the stepper. + + Args: + state: The state to load. + """ + if "module" in state: + self.module.load_state_dict(state["module"]) + + @classmethod + def from_state(cls, state, area: torch.Tensor, sigma_coordinates: SigmaCoordinates) -> "SingleModuleStepper": + """ + Load the state of the stepper. + + Args: + state: The state to load. + area: (n_lat, n_lon) array containing relative gridcell area, in any + units including unitless. + sigma_coordinates: The sigma coordinates. + + Returns: + The stepper. + """ + config = {**state["config"]} # make a copy to avoid mutating input + config["normalization"] = FromStateNormalizer(state["normalizer"]) + area = state.get("area", area) + if "sigma_coordinates" in state: + sigma_coordinates = dacite.from_dict( + data_class=SigmaCoordinates, + data=state["sigma_coordinates"], + config=dacite.Config(strict=True), + ) + if "img_shape" in state: + img_shape = state["img_shape"] + else: + # this is for backwards compatibility with old checkpoints + for v in state["data_shapes"].values(): + img_shape = v[-2:] + break + stepper = cls( + config=SingleModuleStepperConfig.from_state(config), + img_shape=img_shape, + area=area, + sigma_coordinates=sigma_coordinates, + # don't need to initialize weights, we're about to load_state + init_weights=False, + ) + stepper.load_state(state) + return stepper + + +class NameAndTimeQueryFunction(Protocol): + def __call__( + self, + names: Iterable[str], + time_index: int, + norm_mode: Literal["norm", "denorm"], + ) -> Dict[str, torch.Tensor]: ... + + +def get_name_and_time_query_fn( + data: Dict[str, torch.Tensor], data_norm: Dict[str, torch.Tensor], time_dim: int +) -> NameAndTimeQueryFunction: + """Construct a function for querying `data` by name and time and whether it + is normalized or not. (Note: that the `names` argument can contain None values + to handle NullPrescriber).""" + + norm_mode_to_data = {"norm": data_norm, "denorm": data} + + def name_and_time_query_fn(names, time_index, norm_mode): + _data = norm_mode_to_data[norm_mode] + query_results = {} + for name in names: + try: + query_results[name] = _data[name].select(dim=time_dim, index=time_index) + except IndexError as err: + raise ValueError(f'tensor "{name}" does not have values at t={time_index}') from err + return query_results + + return name_and_time_query_fn + + +def _pack_data_if_available( + packer: Packer, + data: Dict[str, torch.Tensor], + axis: int, +) -> Optional[torch.Tensor]: + try: + return packer.pack(data, axis=axis) + except ValueError: + return None + + +def run_on_batch( + data: Dict[str, torch.Tensor], + module: nn.Module, + normalizer: StandardNormalizer, + in_packer: Packer, + out_packer: Packer, + optimization: Union[Optimization, NullOptimization], + loss_obj: nn.Module, + ocean: Optional[Ocean], + aggregator: Union[OneStepAggregator, NullAggregator], + corrector: Optional[Corrector], # Optional so we can skip code when unused + conservation_loss: ConservationLoss, + n_forward_steps: int = 1, + enable_inference_dropout: bool = False, +) -> SteppedData: + """ + Run the model on a batch of data. + + The module is assumed to require packed (concatenated into a tensor with + a channel dimension) and normalized data, as provided by the given packer + and normalizer. + + Args: + data: The denormalized batch data. The second dimension of each tensor + should be the time dimension. + module: The module to run. + normalizer: The normalizer. + in_packer: The packer for the input data. + out_packer: The packer for the output data. + optimization: The optimization object. If it is NullOptimization, + then the model is not trained. + loss_obj: The loss object. + ocean: Determines sea surface temperatures. + aggregator: The data aggregator. + corrector: The post-step corrector. + conservation_loss: Computes conservation-related losses, if any. + n_forward_steps: The number of timesteps to run the model for. + + Returns: + The loss, the generated data, the normalized generated data, + and the normalized batch data. The generated data contains + the initial input data as its first timestep. + """ + channel_dim = -3 + time_dim = 1 + full_data_norm = normalizer.normalize(data) + get_input_data = get_name_and_time_query_fn(data, full_data_norm, time_dim) + + full_target_tensor_norm = _pack_data_if_available( + out_packer, + full_data_norm, + channel_dim, + ) + + loss = torch.tensor(0.0, device=get_device()) + metrics = {} + input_data_norm = get_input_data(in_packer.names, time_index=0, norm_mode="norm") + gen_data_norm = [] + optimization.set_mode(module) + if enable_inference_dropout: + enable_inference_dropout_func(module) + tqdm_bar = tqdm(range(n_forward_steps), desc="Horizon") + for step in tqdm_bar: + input_tensor_norm = in_packer.pack(input_data_norm, axis=channel_dim) + + if full_target_tensor_norm is None: + target_tensor_norm: Optional[torch.Tensor] = None + else: + target_tensor_norm = full_target_tensor_norm.select(dim=time_dim, index=step + 1) + + with optimization.autocast(): + gen_tensor_norm = module(input_tensor_norm).to(get_device(), dtype=torch.float) + gen_norm = out_packer.unpack(gen_tensor_norm, axis=channel_dim) + gen_data = normalizer.denormalize(gen_norm) + input_data = normalizer.denormalize(input_data_norm) + if corrector is not None: + gen_data = corrector(input_data, gen_data) + if ocean is not None: + target_data = get_input_data(ocean.target_names, step + 1, "denorm") + gen_data = ocean(target_data, input_data, gen_data) + gen_norm = normalizer.normalize(gen_data) + gen_tensor_norm = out_packer.pack(gen_norm, axis=channel_dim).to(get_device(), dtype=torch.float) + if target_tensor_norm is None: + step_loss = torch.tensor(torch.nan) + else: + step_loss = loss_obj(gen_tensor_norm, target_tensor_norm) + loss += step_loss + metrics[f"loss_step_{step}"] = step_loss.detach() + gen_norm = out_packer.unpack(gen_tensor_norm, axis=channel_dim) + gen_data_norm.append(gen_norm) + # update input data with generated outputs, and forcings for missing outputs + forcing_names = list(set(in_packer.names).difference(gen_norm.keys())) + forcing_data_norm = get_input_data(forcing_names, time_index=step + 1, norm_mode="norm") + input_data_norm = {**forcing_data_norm, **gen_norm} + + # prepend the initial (pre-first-timestep) output data to the generated data + initial = get_input_data(out_packer.names, time_index=0, norm_mode="norm") + gen_data_norm = [initial] + gen_data_norm + gen_data_norm_timeseries = {} + for name in out_packer.names: + gen_data_norm_timeseries[name] = torch.stack([x[name] for x in gen_data_norm], dim=time_dim) + gen_data = normalizer.denormalize(gen_data_norm_timeseries) + + conservation_metrics, conservation_loss = conservation_loss(gen_data) + metrics.update(conservation_metrics) + loss += conservation_loss + + metrics["loss"] = loss.detach() + optimization.step_weights(loss) + + aggregator.record_batch( + float(loss), + target_data=data, + gen_data=gen_data, + target_data_norm=full_data_norm, + gen_data_norm=gen_data_norm_timeseries, + ) + + return SteppedData( + metrics=metrics, + gen_data=gen_data, + target_data=data, + gen_data_norm=gen_data_norm_timeseries, + target_data_norm=full_data_norm, + ) diff --git a/src/ace_inference/core/stepper_multistep.py b/src/ace_inference/core/stepper_multistep.py new file mode 100644 index 0000000..6f90440 --- /dev/null +++ b/src/ace_inference/core/stepper_multistep.py @@ -0,0 +1,463 @@ +import dataclasses +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple, Union + +import dacite +import torch +from torch import nn +from tqdm.auto import tqdm + +from src.ace_inference.core.aggregator.null import NullAggregator +from src.ace_inference.core.device import get_device +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.normalizer import ( + NormalizationConfig, + StandardNormalizer, + # FromStateNormalizer, +) +from src.ace_inference.core.prescriber import NullPrescriber, Prescriber, PrescriberConfig +from src.ace_inference.core.stepper import SingleModuleStepper +from src.ace_inference.training.utils.darcy_loss import LpLoss +from src.ace_inference.training.utils.data_requirements import DataRequirements +from src.evaluation.aggregators.main import OneStepAggregator +from src.experiment_types.forecasting_multi_horizon import ( + AbstractMultiHorizonForecastingExperiment, + infer_class_from_ckpt, +) +from src.utilities.packer import Packer +from src.utilities.utils import to_tensordict, update_dict_with_other + +from .optimization import NullOptimization, Optimization +from .stepper import SteppedData, get_name_and_time_query_fn + + +@dataclasses.dataclass +class MultiStepStepperConfig: + in_names: List[str] + out_names: List[str] + prescriber: Optional[PrescriberConfig] = None + data_dir: Optional[str] = None + data_dir_stats: Optional[str] = None + + def get_data_requirements(self, n_forward_steps: int) -> DataRequirements: + return DataRequirements( + names=self.all_names, + in_names=self.in_names, + out_names=self.out_names, + n_timesteps=n_forward_steps + 1, + ) + + def get_stepper( + self, + shapes: Dict[str, Tuple[int, ...]], + max_epochs: int, + ): + return MultiStepStepper( + config=self, + ) + + def get_state(self): + return dataclasses.asdict(self) + + @classmethod + def from_state(cls, state) -> "MultiStepStepperConfig": + return dacite.from_dict(data_class=cls, data=state, config=dacite.Config(strict=True)) + + @property + def all_names(self): + if self.prescriber is not None: + mask_name = [self.prescriber.mask_name] + else: + mask_name = [] + all_names = list(set(self.in_names).union(self.out_names).union(mask_name)) + return all_names + + @property + def normalize_names(self): + return list(set(self.in_names).union(self.out_names)) + + +class MultiStepStepper(SingleModuleStepper): + """ + Stepper class for a single pytorch module. + """ + + channel_axis = -3 + + def __init__( + self, + config: MultiStepStepperConfig, + module: AbstractMultiHorizonForecastingExperiment, + data_shapes: Dict[str, Tuple[int, ...]], + max_epochs: int, + ): + """ + Args: + config: The configuration. + data_shapes: The shapes of the data. + max_epochs: The maximum number of epochs. Used when constructing + certain learning rate schedulers, if applicable. + """ + dist = Distributed.get_instance() + # n_in_channels = len(config.in_names) + # n_out_channels = len(config.out_names) + if "forcing_names" not in config.__dict__: + config.forcing_names = list(set(config.in_names).difference(config.out_names)) + self.init_packers(config.in_names, config.out_names, config.forcing_names) + # self.in_packer = Packer(config.in_names, axis=self.channel_axis) + # self.out_packer = Packer(config.out_names, axis=self.channel_axis) + # in_packer.names = [x for x in in_names if x not in forcing_names] + # forcings_packer = Packer(forcing_names, axis_pack=in_packer.axis_pack, axis_unpack=in_packer.axis_unpack) + # self.forcings_packer = Packer(config.forcing_names, axis=self.channel_axis) + data_dir_stats = config.data_dir_stats or config.data_dir + path_mean = Path(data_dir_stats) / "centering.nc" + path_std = Path(data_dir_stats) / "scaling.nc" + alternative_data_dirs = ["/data/climate-model/fv3gfs"] + if not path_mean.exists(): + for alt_dir in alternative_data_dirs: + path_mean = Path(alt_dir) / "centering.nc" + path_std = Path(alt_dir) / "scaling.nc" + if path_mean.exists(): + break + if not path_mean.exists(): + raise FileNotFoundError( + f"Could not find centering and scaling files in {data_dir_stats} or alternative dirs {alternative_data_dirs}" + ) + + normalization_config = NormalizationConfig(global_means_path=path_mean, global_stds_path=path_std) + self.normalizer = normalization_config.build(config.normalize_names) + + # self.normalizer = get_normalizer(path_mean, path_std, names=config.normalize_names) + if config.prescriber is not None: + self.prescriber = config.prescriber.build(config.in_names, config.out_names) + else: + self.prescriber = NullPrescriber() + self.module = module.to(get_device()) + self.data_shapes = data_shapes + self._config = config + self._max_epochs = max_epochs + self.optimization = NullOptimization() + + self._no_optimization = NullOptimization() + self._is_distributed = dist.is_distributed() + + self.loss_obj = LpLoss() + + def run_on_batch( + self, + data: Dict[str, torch.Tensor], + optimization: Union[Optimization, NullOptimization], + n_forward_steps: int = 1, + aggregator: Optional[OneStepAggregator] = None, + ) -> Tuple[ + float, + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + ]: + """ + Step the model forward on a batch of data. + + Args: + data: The batch data of shape [n_sample, n_timesteps, n_channels, n_x, n_y]. + optimization: The optimization class to use for updating the module. + Use `NullOptimization` to disable training. + n_forward_steps: The number of timesteps to run the model for. + + Returns: + The loss, the generated data, the normalized generated data, + and the normalized batch data. + """ + if aggregator is None: + non_none_aggregator: Union[OneStepAggregator, NullAggregator] = NullAggregator() + else: + non_none_aggregator = aggregator + + device = get_device() + device_data = {name: value.to(device, dtype=torch.float) for name, value in data.items()} + return run_on_batch_multistep( + data=device_data, + module=self.module, + normalizer=self.normalizer, + in_packer=self.in_packer, + out_packer=self.out_packer, + forcings_packer=self.forcings_packer, + optimization=optimization, + loss_obj=self.loss_obj, + n_forward_steps=n_forward_steps, + prescriber=self.prescriber, + aggregator=non_none_aggregator, + ) + + def load_state(self, state, load_optimizer: bool = True): + """ + Load the state of the stepper. + + Args: + state: The state to load. + load_optimizer: Whether to load the optimizer state. + """ + hparams = state["hyper_parameters"] + hparams_data = hparams["datamodule_config"] + state_dict = state["state_dict"] + # state_dict = {f"module.{k}": v for k, v in state_dict.items()} # add module. to keys if using DummyWrapper + # Reload weights + try: + self.module.load_state_dict(state_dict) + except RuntimeError as e: + raise RuntimeError( + f"Error loading state_dict." + f"\nHyperparameters: {hparams}\nData: {hparams_data}\nself.module={self.module}" + ) from e + + if load_optimizer and "optimization" in state: + self.optimization.load_state(state["optimization"]) + # in_names = hparams_data['in_names'] + hparams_data['forcing_names'] + self.init_packers(hparams_data["in_names"], hparams_data["out_names"], hparams_data["forcing_names"]) + # self.prescriber.load_state(None) + + def init_packers(self, in_names, out_names, forcing_names): + in_names = [x for x in in_names if x not in forcing_names] + self.in_packer = Packer(in_names, axis=self.channel_axis) + self.out_packer = Packer(out_names, axis=self.channel_axis) + self.forcings_packer = Packer(forcing_names, axis=self.channel_axis) + + @classmethod + def from_state(cls, state, load_optimizer: bool = True, overrides: Dict[str, Any] = None) -> "MultiStepStepper": + """ + Load the state of the stepper. + + Args: + state: The state to load. + load_optimizer: Whether to load the optimizer state. + overrides: Key -> value pairs to override the module's hyperparameters (value can be a dict). + + Returns: + The stepper. + """ + overrides = overrides or {} + module_class = infer_class_from_ckpt(ckpt_path=None, state=state) + # print(state['hyper_parameters'].keys()) + actual_hparams, diff_to_default = update_dict_with_other(state["hyper_parameters"], overrides) + module = module_class(**actual_hparams) + # Print the differences between the default and actual hyperparameters + if len(diff_to_default) > 0: + print("---------------- Overriding the following hyperparameters:") + print("|\t" + "\n\t".join(diff_to_default)) + print("----------------------------------------------------------") + # Update the wandb config and save diff_to_default as notes + import wandb + + try: + from omegaconf import OmegaConf + + # Make config a omegaconf.DictConfig + actual_hparams = OmegaConf.create(actual_hparams) + except ImportError: + pass + if wandb.run is not None: + # try: + # wandb.config.update(actual_hparams) + # except TypeError as e: + # print(f"Error updating wandb config: {e}") + # print(f"actual_hparams: {actual_hparams}") + # print(f"diff_to_default: {diff_to_default}\nSkipping... updating hparams to wandb") + wandb.run.notes = " ".join(diff_to_default) + wandb.log({"wandb.notes": wandb.run.notes}, step=0) + elif len(overrides) > 0: + print(f"No differences were found between the default and actual hyperparameters. Overrides: {overrides}") + + data_config = state["hyper_parameters"]["datamodule_config"] + # Salva's runs use separate in_names and forcing_names, for ACE data-loading we just combine them + state["hyper_parameters"]["datamodule_config"]["in_names"] = ( + data_config["in_names"] + data_config["forcing_names"] + ) + config = {} # 'builder': None, 'optimization': None} + for x in ["in_names", "out_names"]: + config[x] = list(data_config[x]) + for y in ["data_dir", "data_dir_stats"]: + config[y] = data_config[y] + + # Build prescriber back from saved config file + prescriber_config = data_config["prescriber"] + prescriber_config.pop("_target_") + config["prescriber"] = PrescriberConfig(**prescriber_config) + stepper = cls( + config=MultiStepStepperConfig.from_state(config), + module=module, + data_shapes=None, # state["data_shapes"], + max_epochs=1000, # training not supported yet + ) + stepper.load_state(state, load_optimizer=load_optimizer) + return stepper + + +def run_on_batch_multistep( + data: Dict[str, torch.Tensor], + module: AbstractMultiHorizonForecastingExperiment, + normalizer: StandardNormalizer, + in_packer: Packer, + out_packer: Packer, + forcings_packer: Packer, + optimization: Union[Optimization, NullOptimization], + loss_obj: nn.Module, + prescriber: Union[Prescriber, NullPrescriber], + aggregator: Union[OneStepAggregator, NullAggregator], + n_forward_steps: int = 1, +) -> SteppedData: + """ + Run the model on a batch of data. + + The module is assumed to require un-packed and normalized data (packing must be handled by the module), + except the forcing data, which is assumed to be packed and normalized. + + Args: + data: The denormalized batch data. The second dimension of each tensor + should be the time dimension. + module: The module to run. + normalizer: The normalizer. + in_packer: The packer for the input data. + out_packer: The packer for the output data. + optimization: The optimization object. If it is NullOptimization, + then the model is not trained. + loss_obj: The loss object. + prescriber: Overwrite an output with target value in specified region. + n_forward_steps: The number of timesteps to run the model for. + + Returns: + The loss, the generated data, the normalized generated data, + and the normalized batch data. The generated data contains + the initial input data as its first timestep. + """ + assert isinstance(prescriber, Prescriber), f"prescriber is not a Prescriber, but {type(prescriber)}" + module_actual = module.module if hasattr(module, "module") else module + horizon_training = module_actual.true_horizon + in_names = in_packer.names.copy() + # forcing_names = list(set(in_packer.names).difference(out_packer.names)) + # in_packer.names = [x for x in in_names if x not in forcing_names] + # forcings_packer = Packer(forcing_names, axis_pack=in_packer.axis_pack, axis_unpack=in_packer.axis_unpack) + # must be negative-indexed, so it works with or without a time dim + channel_dim = -3 + time_dim = 1 + example_shape = data[list(data.keys())[0]].shape + assert len(example_shape) == 4 + assert example_shape[1] == n_forward_steps + 1 + full_data_norm = normalizer.normalize(data) + get_input_data = get_name_and_time_query_fn(data, full_data_norm, time_dim) + + device = get_device() + eval_device = "cpu" + full_target_tensor_norm = out_packer.pack(full_data_norm, axis=channel_dim) + loss = torch.tensor(0.0, device=device) + metrics = {} + input_data_norm = get_input_data(in_packer.names, time_index=0, norm_mode="norm") + forcing_data_norm = get_input_data(forcings_packer.names, time_index=0, norm_mode="norm") + is_imprecise = ( + hasattr(module_actual.model.hparams, "hack_for_imprecise_interpolation") + and module_actual.model.hparams.hack_for_imprecise_interpolation + ) + gen_data_norm = [] + optimization.set_mode(module) + tqdm_bar = tqdm(range(1, n_forward_steps + 1), desc="Horizon") + for total_horizon in tqdm_bar: + # We need to map from the total horizon to the horizon for training, e.g. if train horizon = 3: + # total_horizon = 1 -> horizon_training = 1, total_horizon = 2 -> horizon_training = 2, + # total_horizon = 3 -> horizon_training = 3, total_horizon = 4 -> horizon_training = 1 + # total_horizon = 5 -> horizon_training = 2, etc. + horizon_rel = total_horizon % horizon_training + if horizon_rel == 0: + horizon_rel = horizon_training + input_tensor_norm = in_packer.pack(input_data_norm, axis=channel_dim) # Done inside module + forcing_tensor_norm = forcings_packer.pack(forcing_data_norm, axis=channel_dim) + + target_tensor_norm = full_target_tensor_norm.select(dim=time_dim, index=total_horizon) + with optimization.autocast(): + batch = { + # module_actual.main_data_key_val: to_tensordict(input_data_norm), # uncomment + "dynamics": input_tensor_norm.to(device), + # "dynamical_condition": forcing_tensor_norm.to(device), + } + if is_imprecise: + batch["static_condition"] = forcing_tensor_norm.to(device) + else: + pass + with module_actual.ema_scope(): + with module_actual.inference_dropout_scope(): + results = module_actual.get_preds_at_t_for_batch( + batch, + horizon=horizon_rel, + split="predict", + ensemble=False, + is_autoregressive=total_horizon > horizon_training, + prepare_inputs=False, # already done above + num_predictions=1, # only one prediction at a time (one ensemble member) + ) + predictions_key = f"t{horizon_rel}_preds_normed" + gen_tensor_norm = results[predictions_key] + # gen_tensor_norm = out_packer.pack(results[predictions_key], axis=channel_dim) #if unpacked inside module + step_loss = loss_obj(gen_tensor_norm, target_tensor_norm.to(device)) + loss += step_loss + metrics[f"loss_step_{total_horizon-1}"] = step_loss.detach() + + # Gen_norm will be used as input for the next AR step + gen_norm = out_packer.unpack(gen_tensor_norm, axis=channel_dim) + target_norm = out_packer.unpack(target_tensor_norm, axis=channel_dim) + data_time = {k: v.select(dim=time_dim, index=total_horizon).to(device) for k, v in data.items()} + gen_norm = prescriber(data_time, gen_norm, target_norm.to(device)) + gen_data_norm.append(gen_norm.to(eval_device)) + + if "preds_autoregressive_init_normed" not in results: + autoregressive_init_normed = gen_norm + else: + print("Using autoregressive_init_normed") + autoregressive_init_normed = results["preds_autoregressive_init_normed"] + autoregressive_init_normed = out_packer.unpack(autoregressive_init_normed, axis=channel_dim) + autoregressive_init_normed = prescriber(data_time, autoregressive_init_normed, target_norm.to(device)) + # update input data with generated outputs, and forcings for missing outputs + forcing_data_norm = get_input_data(forcings_packer.names, time_index=total_horizon, norm_mode="norm") + if is_imprecise: + autoregressive_init_normed["HGTsfc"] = input_data_norm["HGTsfc"].to(device) + # input_data_norm = {**forcing_data_norm, **gen_norm} + + # Autoregressive mode: update input data with generated outputs + input_data_norm = autoregressive_init_normed + del data_time + + optimization.step_weights(loss) + # prepend the initial (pre-first-timestep) output data to the generated data + initial = to_tensordict(get_input_data(out_packer.names, time_index=0, norm_mode="norm"), device=eval_device) + gen_data_norm = [initial] + gen_data_norm + # gen_data_norm_timeseries2 = torch.stack(gen_data_norm, dim=time_dim) + gen_data_norm_timeseries = {} + for name in out_packer.names: + gen_data_norm_timeseries[name] = torch.stack([x[name] for x in gen_data_norm], dim=time_dim) + gen_data = normalizer.denormalize(gen_data_norm_timeseries) + + # for name in out_packer.names: + # assert torch.allclose( + # gen_data_norm_timeseries[name], gen_data_norm_timeseries2[name] + # ), f'{name} not equal' + metrics["loss"] = loss.detach() + # shapes = set([v.shape for v in data.values()]) + # if len(shapes) > 1: + # d_to_s = {k: v.shape for k, v in data.items()} + # raise ValueError(f"Shapes of data tensors are not the same: {shapes}. example_shape={example_shape}" + # f"Data to shape: {d_to_s}") + + data = to_tensordict(data, device=eval_device) # Not needed on GPU + full_data_norm = to_tensordict(full_data_norm, device=eval_device) + aggregator.record_batch( + float(loss), + target_data=data, + gen_data=gen_data, + target_data_norm=full_data_norm, + gen_data_norm=gen_data_norm_timeseries, + ) + in_packer.names = in_names + return SteppedData( + metrics=metrics, + gen_data=gen_data, + target_data=data, + gen_data_norm=gen_data_norm_timeseries, + target_data_norm=full_data_norm, + ) diff --git a/src/ace_inference/core/wandb.py b/src/ace_inference/core/wandb.py new file mode 100644 index 0000000..be00efb --- /dev/null +++ b/src/ace_inference/core/wandb.py @@ -0,0 +1,189 @@ +import glob +import os +import shutil +import time +from typing import Any, Mapping, Optional + +import numpy as np +import wandb + +from src.ace_inference.core.distributed import Distributed + + +singleton: Optional["WandB"] = None + + +class DirectInitializationError(RuntimeError): + pass + + +class Histogram(wandb.Histogram): + def __init__( + self, + *args, + direct_access=True, + **kwargs, + ): + if direct_access: + raise DirectInitializationError( + "must initialize from `wandb = WandB.get_instance()`, " + "not directly from `import src.ace_inference.core.wandb`" + ) + super().__init__(*args, **kwargs) + + +Histogram.__doc__ = wandb.Histogram.__doc__ +Histogram.__init__.__doc__ = wandb.Histogram.__init__.__doc__ + + +class Table(wandb.Table): + def __init__( + self, + *args, + direct_access=True, + **kwargs, + ): + if direct_access: + raise DirectInitializationError( + "must initialize from `wandb = WandB.get_instance()`, " + "not directly from `import src.ace_inference.core.wandb`" + ) + super().__init__(*args, **kwargs) + + +Table.__doc__ = wandb.Table.__doc__ +Table.__init__.__doc__ = wandb.Table.__init__.__doc__ + + +class Video(wandb.Video): + def __init__( + self, + *args, + direct_access=True, + **kwargs, + ): + if direct_access: + raise DirectInitializationError( + "must initialize from `wandb = WandB.get_instance()`, " + "not directly from `import src.ace_inference.core.wandb`" + ) + super().__init__(*args, **kwargs) + + +Video.__doc__ = wandb.Video.__doc__ +Video.__init__.__doc__ = wandb.Video.__init__.__doc__ + + +class Image(wandb.Image): + def __init__( + self, + *args, + direct_access=True, + **kwargs, + ): + if direct_access: + raise DirectInitializationError( + "must initialize from `wandb = WandB.get_instance()`, " + "not directly from `import src.ace_inference.core.wandb`" + ) + super().__init__(*args, **kwargs) + + +Image.__doc__ = wandb.Image.__doc__ +Image.__init__.__doc__ = wandb.Image.__init__.__doc__ + + +class WandB: + """ + A singleton class to interface with Weights and Biases (WandB). + """ + + @classmethod + def get_instance(cls) -> "WandB": + """ + Get the singleton instance of the WandB class. + """ + global singleton + if singleton is None: + singleton = cls() + return singleton + + def __init__(self): + self._enabled = False + self._configured = False + + def configure(self, log_to_wandb: bool): + dist = Distributed.get_instance() + self._enabled = log_to_wandb and dist.is_root() + self._configured = True + + def init(self, **kwargs): + """kwargs are passed to wandb.init""" + if not self._configured: + raise RuntimeError("must call WandB.configure before WandB init can be called") + if self._enabled: + wandb.init(**kwargs) + + def watch(self, modules): + if self._enabled: + wandb.watch(modules) + + def log(self, data: Mapping[str, Any], step=None, sleep=None): + if self._enabled: + wandb.log(data, step=step) + if sleep is not None: + time.sleep(sleep) + + def Image(self, data_or_path, *args, **kwargs) -> Image: + if isinstance(data_or_path, np.ndarray): + data_or_path = scale_image(data_or_path) + + return Image(data_or_path, *args, direct_access=False, **kwargs) + + def clean_wandb_dir(self, experiment_dir: str): + # this is necessary because wandb does not remove run media directories + # after a run is synced; see https://github.com/wandb/wandb/issues/3564 + if self._enabled: + wandb.run.finish() # necessary to ensure the run directory is synced + wandb_dir = os.path.join(experiment_dir, "wandb") + remove_media_dirs(wandb_dir) + + def Video(self, *args, **kwargs) -> Video: + return Video(*args, direct_access=False, **kwargs) + + def Table(self, *args, **kwargs) -> Table: + return Table(*args, direct_access=False, **kwargs) + + def Histogram(self, *args, **kwargs) -> Histogram: + return Histogram(*args, direct_access=False, **kwargs) + + @property + def enabled(self) -> bool: + return self._enabled + + +def scale_image( + image_data: np.ndarray, +) -> np.ndarray: + """ + Given an array of scalar data, rescale the data to the range [0, 255]. + """ + data_min = np.nanmin(image_data) + data_max = np.nanmax(image_data) + # video data is brightness values on a 0-255 scale + image_data = 255 * (image_data - data_min) / (data_max - data_min) + image_data = np.minimum(image_data, 255) + image_data = np.maximum(image_data, 0) + image_data[np.isnan(image_data)] = 0 + return image_data + + +def remove_media_dirs(wandb_dir: str, media_dir_pattern: str = "run-*-*/files/media"): + """ + Remove the media directories in the wandb run directories. + """ + glob_pattern = os.path.join(wandb_dir, media_dir_pattern) + media_dirs = glob.glob(glob_pattern) + for media_dir in media_dirs: + if os.path.isdir(media_dir): + shutil.rmtree(media_dir) diff --git a/src/ace_inference/core/weight_ops.py b/src/ace_inference/core/weight_ops.py new file mode 100644 index 0000000..4e0ff23 --- /dev/null +++ b/src/ace_inference/core/weight_ops.py @@ -0,0 +1,166 @@ +import dataclasses +from typing import Any, List, Mapping, Optional + +import torch +from torch import nn + +from .wildcard import apply_by_wildcard, wildcard_match + + +@dataclasses.dataclass +class CopyWeightsConfig: + """ + Configuration for copying weights from a base model to a target model. + + Used during training to overwrite weights after every batch of data, + to have the effect of "freezing" the overwritten weights. When the + target parameters have longer dimensions than the base model, only + the initial slice is overwritten. + + All parameters must be covered by either the include or exclude list, + but not both. + + Args: + include: list of wildcard patterns to overwrite + exclude: list of wildcard patterns to exclude from overwriting + """ + + include: List[str] = dataclasses.field(default_factory=list) + exclude: List[str] = dataclasses.field(default_factory=list) + + def __post_init__(self): + for pattern in self.include: + if any(wildcard_match(pattern, exclude) for exclude in self.exclude): + raise ValueError( + f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" + ) + for pattern in self.exclude: + if any(wildcard_match(pattern, include) for include in self.include): + raise ValueError( + f"Parameter {pattern} is included in both include " f"{self.include} and exclude {self.exclude}" + ) + + @torch.no_grad() + def apply(self, weights: List[Mapping[str, Any]], modules: List[nn.Module]): + """ + Apply base weights to modules according to the include/exclude lists + of this instance. + + In order to "freeze" the weights during training, this must be called after + each time the weights are updated in the training loop. + + Args: + weights: list of base weights to apply + modules: list of modules to apply the weights to + """ + if len(modules) > 1: + # We can support multiple modules by having this configuration take a list + # of include/exclude for each module. Not implemented right now because it + # is not needed, and would make the configuration more confusing for the + # single-module case (especially when it's only ever single-module). + raise NotImplementedError("only one module currently supported") + if len(modules) != len(weights): + raise ValueError("number of modules and weights must match") + for module, weight in zip(modules, weights): + + def func(module, name): + overwrite_weight_initial_slice(module, name, weight[name]) + + apply_by_wildcard(module, func, self.include, self.exclude) + return module + + +def strip_leading_module(state_dict: Mapping[str, Any]) -> Mapping[str, Any]: + """ + Remove the leading "module." from the keys of a state dict. + + This is necessary because SingleModuleStepper wraps the module in either + a DistributedDataParallel layer or DummyWrapper layer, which adds a leading + "module." to the keys of the state dict. + """ + return {k[len("module.") :] if k.startswith("module.") else k: v for k, v in state_dict.items()} + + +def overwrite_weights( + from_state: Mapping[str, Any], + to_module: torch.nn.Module, + exclude_parameters: Optional[List[str]] = None, +): + """ + Overwrite the weights in to_module with the weights in from_state. + + When an axis is larger in to_module than in from_state, only the initial + slice is overwritten. For example, if the from module has a parameter `a` + of shape [10, 10], and the to module has a parameter `a` of shape [20, 10], + then only the first 10 rows of `a` will be overwritten. + + If an axis is larger in from_state than in to_module, an exception is raised. + + Args: + from_state: module state dict containing weights to be copied + to_module: module whose weights will be overwritten + exclude_parameters: list of parameter names to exclude from the loaded + weights. Wildcards can be used, e.g. "decoder.*.weight". + """ + if exclude_parameters is None: + exclude_parameters = [] + from_names = set(from_state.keys()) + to_names = set(to_module.state_dict().keys()) + if not from_names.issubset(to_names): + missing_parameters = from_names - to_names + raise ValueError(f"Dest module is missing parameters {missing_parameters}, " "which is not allowed") + for name in from_names: + if any(wildcard_match(pattern, name) for pattern in exclude_parameters): + continue + from_param = from_state[name] + try: + overwrite_weight_initial_slice(to_module, name, from_param) + except AttributeError: # if state is not a parameter + pass + + +def overwrite_weight_initial_slice(module, name, from_param): + """ + Overwrite the initial slice of a parameter in module with from_param. + + When an axis is larger in the module's param than in from_param, + only the initial slice is overwritten. For example, if the from module + has a parameter `a` of shape [10, 10], and the to module has a parameter + `a` of shape [20, 10], then only the first 10 rows of `a` will be overwritten. + + If an axis is larger in from_param, an exception is raised. + + Args: + module: module whose parameter will be overwritten + name: name of the parameter to be overwritten + from_param: parameter to overwrite with + """ + to_param = module.get_parameter(name) + if len(from_param.shape) != len(to_param.shape): + raise ValueError( + f"Dest parameter {name} has " + f"{len(to_param.shape.shape)} " + "dimensions which needs to be equal to the loaded " + f"parameter dimension {len(from_param.shape)}" + ) + for from_size, to_size in zip(from_param.shape, to_param.shape): + if from_size > to_size: + raise ValueError( + f"Dest parameter has size {to_size} along one of its " + "dimensions which needs to be greater than loaded " + f"parameter size {from_size}" + ) + slices = tuple(slice(0, size) for size in from_param.shape) + with torch.no_grad(): + new_param_data = to_param.data.clone() + new_param_data[slices] = from_param.data + _set_nested_parameter(module, name, new_param_data) + + +def _set_nested_parameter(module, param_name, new_param): + *path, name = param_name.split(".") + for p in path: + module = getattr(module, p) + if not isinstance(new_param, nn.Parameter): + new_param = nn.Parameter(new_param) + setattr(module, name, new_param) diff --git a/src/ace_inference/core/wildcard.py b/src/ace_inference/core/wildcard.py new file mode 100644 index 0000000..b556e5d --- /dev/null +++ b/src/ace_inference/core/wildcard.py @@ -0,0 +1,40 @@ +import re +from typing import Callable, List + +from torch import nn + + +def wildcard_match(pattern: str, name: str) -> bool: + """ + Check if a name matches a wildcard pattern. + + A wildcard pattern can include "*" to match any number of characters. + """ + # use regex + pattern = pattern.replace(".", r"\.") + pattern = pattern.replace("*", ".*") + pattern = f"^{pattern}$" + return bool(re.match(pattern, name)) + + +def apply_by_wildcard( + model: nn.Module, + func: Callable[[nn.Module, str], None], + include: List[str], + exclude: List[str], +): + missing_parameters = [] + for name in model.state_dict().keys(): + if any(wildcard_match(pattern, name) for pattern in include): + if any(wildcard_match(pattern, name) for pattern in exclude): + raise ValueError(f"Parameter {name} is included in both include " f"{include} and exclude {exclude}") + func(model, name) + elif not any(wildcard_match(pattern, name) for pattern in exclude): + missing_parameters.append(name) + if len(missing_parameters) > 0: + raise ValueError( + f"Model has parameters {missing_parameters} which are not " + f"specified in either include {include} " + f"or exclude {exclude}" + ) + return model diff --git a/src/ace_inference/core/winds.py b/src/ace_inference/core/winds.py new file mode 100644 index 0000000..78a55d4 --- /dev/null +++ b/src/ace_inference/core/winds.py @@ -0,0 +1,170 @@ +from typing import Tuple + +import numpy as np + + +def u_v_to_x_y_z_wind( + u: np.ndarray, v: np.ndarray, lat: np.ndarray, lon: np.ndarray +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Converts u and v wind components to x, y, z wind components. + + The x-axis is defined as the vector from the center of the earth to the + intersection of the equator and the prime meridian. The y-axis is defined + as the vector from the center of the earth to the intersection of the + equator and the 90 degree east meridian. The z-axis is defined as the + vector from the center of the earth to the north pole. + + Args: + u: u wind component + v: v wind component + lat: latitude, in degrees + lon: longitude, in degrees + + Returns: + wx: x wind component + wy: y wind component + wz: z wind component + """ + + # for a graphical proof of the equations used here, see + # https://github.com/ai2cm/full-model/pull/355#issuecomment-1729773301 + + # Convert to radians + lat = np.deg2rad(lat) + lon = np.deg2rad(lon) + + # Horizontal winds + # + # Contribution from u + # + # The u component of the wind is aligned with the longitude lines, + # so all of its contributions are to the x and y components of the wind. + # At the prime meridian (which lies on the x axis), the u component points + # parallel to the y-axis, so it contributes only to the y component of the + # wind. As we move eastward, the u component points more and more in the + # negative x direction, so it contributes more and more to the x component + # of the wind. At the 90 degree east meridian, the u component points + # parallel to the x-axis (in the negative direction). + # + # This influence on the x component is captured by multiplying the u component + # by the negative sine of the longitude, which is zero at the prime meridian + # and then becomes negative until reaching -1 at 90 degrees east. + # + # The influence on the y component is captured by multiplying the u component + # by the cosine of the longitude, which is 1 at the prime meridian and then + # decreases until reaching 0 at 90 degrees east. + # + # Contribution from v + # + # The v component of the wind is aligned with the latitude lines, + # with no contribution to the horizontal at the equator and full contribution + # at the poles. This is captured by multiplying the v component by the sine + # of the latitude, which is 0 at the equator and 1 at the poles. + # + # The direction of the horizontal contribution is the vector pointing inwards + # towards the axis of rotation of the Earth. At the prime meridian, this + # is the negative x direction. At the 90 degree east meridian, this is the + # negative y direction. + # + # This influence on the x component is captured by multiplying the v component + # by the negative cosine of the longitude, which is 1 at the prime meridian + # and then decreases until reaching 0 at 90 degrees east. + # + # The influence on the y component is captured by multiplying the v component + # by the negative sine of the longitude, which is 0 at the prime meridian + # and then decreases until reaching -1 at 90 degrees east. + # + # An exact derivation proving this effect is captured by sine and cosine + # can be done graphically. Generally for these kinds of problems on a sphere + # or circle it's always sine and cosine, and the question is which one and + # whether it's positive or negative. + + # Wind in the x-direction: + wx = -u * np.sin(lon) - v * np.sin(lat) * np.cos(lon) + + # Wind in the y-direction: + wy = u * np.cos(lon) - v * np.sin(lat) * np.sin(lon) + + # Wind in the z-direction: + + # As the u-component is along latitude lines, and latitude lines are + # perpendicular to Earth's axis of rotation, u does not appear in wz. + # + # The v-component of the wind is entirely aligned with Earth's axis of rotation + # at the equator, and is entirely perpendicular at the poles, an effect that + # is captured by multiplying the v component by the cosine of the latitude. + wz = v * np.cos(lat) + + return wx, wy, wz + + +def normalize_vector(*vector_components: np.ndarray) -> np.ndarray: + """ + Normalize a vector. + + The vector is assumed to be represented in an orthonormal basis, where + each component is orthogonal to the others. + + Args: + vector_components: components of the vector (e.g. x-, y-, and z-components) + + Returns: + normalized vector, as a numpy array where each component has been + concatenated along a new first dimension + """ + scale = np.divide( + 1.0, + np.sum(np.asarray([item**2.0 for item in vector_components]), axis=0) ** 0.5, + ) + return np.asarray([item * scale for item in vector_components]) + + +def lon_lat_to_xyz(lon, lat): + """ + Convert (lon, lat) to (x, y, z). + + Args: + lon: 2d array of longitudes, in degrees + lat: 2d array of latitudes, in degrees + + Returns: + x: 2d array of x values + y: 2d array of y values + z: 2d array of z values + """ + lat = np.deg2rad(lat) + lon = np.deg2rad(lon) + x = np.cos(lat) * np.cos(lon) + y = np.cos(lat) * np.sin(lon) + z = np.sin(lat) + x, y, z = normalize_vector(x, y, z) + return x, y, z + + +def xyz_to_lon_lat(x, y, z): + """ + Convert (x, y, z) to (lon, lat). + + Args: + x: 2d array of x values + y: 2d array of y values + z: 2d array of z values + + Returns: + lon: 2d array of longitudes, in degrees + lat: 2d array of latitudes, in degrees + """ + x, y, z = normalize_vector(x, y, z) + # double transpose to index last dimension, regardless of number of dimensions + lon = np.zeros_like(x) + nonzero_lon = np.abs(x) + np.abs(y) >= 1.0e-10 + lon[nonzero_lon] = np.arctan2(y[nonzero_lon], x[nonzero_lon]) + negative_lon = lon < 0.0 + while np.any(negative_lon): + lon[negative_lon] += 2 * np.pi + negative_lon = lon < 0.0 + lat = np.arcsin(z) + lat = np.rad2deg(lat) + lon = np.rad2deg(lon) + return lon, lat diff --git a/src/ace_inference/inference/__init__.py b/src/ace_inference/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/inference/data_writer/__init__.py b/src/ace_inference/inference/data_writer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/inference/data_writer/histograms.py b/src/ace_inference/inference/data_writer/histograms.py new file mode 100644 index 0000000..b6a6cdc --- /dev/null +++ b/src/ace_inference/inference/data_writer/histograms.py @@ -0,0 +1,148 @@ +from pathlib import Path +from typing import Dict, Mapping, Optional + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata +from src.ace_inference.core.histogram import DynamicHistogram + + +class _HistogramAggregator: + def __init__(self, n_times: int): + self._prediction_histograms: Optional[Mapping[str, DynamicHistogram]] = None + self._target_histograms: Optional[Mapping[str, DynamicHistogram]] = None + self._n_times = n_times + + def record_batch( + self, + target_data: Dict[str, torch.Tensor], + prediction_data: Dict[str, torch.Tensor], + i_time_start: int, + ): + if self._target_histograms is None: + self._target_histograms = { + var_name: DynamicHistogram(n_times=self._n_times) for var_name in target_data.keys() + } + if self._prediction_histograms is None: + self._prediction_histograms = { + var_name: DynamicHistogram(n_times=self._n_times) for var_name in prediction_data.keys() + } + for var_name, histogram in self._prediction_histograms.items(): + # go from [n_samples, n_timesteps, n_lat, n_lon] to + # [n_timesteps, n_samples, n_lat, n_lon] + # and then reshape to [n_timesteps, n_hist_samples] + n_times = prediction_data[var_name].shape[1] + data = prediction_data[var_name].cpu().numpy().transpose(1, 0, 2, 3).reshape(n_times, -1) + histogram.add(data, i_time_start=i_time_start) + for var_name, histogram in self._target_histograms.items(): + # go from [n_samples, n_timesteps, n_lat, n_lon] to + # [n_timesteps, n_samples, n_height, n_width] + # and then reshape to [n_timesteps, n_hist_samples] + n_times = target_data[var_name].shape[1] + data = target_data[var_name].cpu().numpy().transpose(1, 0, 2, 3).reshape(n_times, -1) + histogram.add(data, i_time_start=i_time_start) + + def get_dataset(self) -> xr.Dataset: + if self._target_histograms is None or self._prediction_histograms is None: + raise RuntimeError("No data has been recorded.") + target_dataset = self._get_single_dataset(self._target_histograms) + prediction_dataset = self._get_single_dataset(self._prediction_histograms) + for missing_target_name in set(prediction_dataset.data_vars) - set(target_dataset.data_vars): + if not missing_target_name.endswith("_bin_edges"): + target_dataset[missing_target_name] = xr.DataArray( + np.zeros_like(prediction_dataset[missing_target_name]), + dims=("time", "bin"), + ) + target_dataset[f"{missing_target_name}_bin_edges"] = prediction_dataset[ + f"{missing_target_name}_bin_edges" + ] + for missing_prediction_name in set(target_dataset.data_vars) - set(prediction_dataset.data_vars): + if not missing_prediction_name.endswith("_bin_edges"): + prediction_dataset[missing_prediction_name] = xr.DataArray( + np.zeros_like(target_dataset[missing_prediction_name]), + dims=("time", "bin"), + ) + prediction_dataset[f"{missing_prediction_name}_bin_edges"] = target_dataset[ + f"{missing_prediction_name}_bin_edges" + ] + ds = xr.concat([target_dataset, prediction_dataset], dim="source") + ds["source"] = ["target", "prediction"] + return ds + + @staticmethod + def _get_single_dataset(histograms: Mapping[str, DynamicHistogram]) -> xr.Dataset: + data = {} + for var_name, histogram in histograms.items(): + data[var_name] = xr.DataArray( + histogram.counts, + dims=("time", "bin"), + ) + data[f"{var_name}_bin_edges"] = xr.DataArray( + histogram.bin_edges, + dims=("bin_edges",), + ) + return xr.Dataset(data) + + +class HistogramDataWriter: + """ + Write [time, bin] histogram data for each variable to a netCDF file. + """ + + def __init__( + self, + path: str, + n_timesteps: int, + metadata: Mapping[str, VariableMetadata], + ): + """ + Args: + path: Path to write netCDF file(s). + n_timesteps: Number of timesteps to write to the file. + metadata: Metadata for each variable to be written to the file. + """ + self.path = path + self._metrics_filename = str(Path(path) / "histograms.nc") + self.metadata = metadata + self._histogram = _HistogramAggregator(n_times=n_timesteps) + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray = None, + ): + """ + Append a batch of data to the file. + + Args: + target: Target data. + prediction: Prediction data. + start_timestep: Timestep at which to start writing. + start_sample: Sample at which to start writing. + batch_times: Time coordinates for each sample in the batch. + """ + del start_sample, batch_times + self._histogram.record_batch( + target_data=target, + prediction_data=prediction, + i_time_start=start_timestep, + ) + + def flush(self): + """ + Flush the data to disk. + """ + metric_dataset = self._histogram.get_dataset() + for name in self.metadata: + metric_dataset[f"{name}_bin_edges"].attrs["units"] = self.metadata[name].units + for name in metric_dataset.data_vars: + if not name.endswith("_bin_edges"): + metric_dataset[f"{name}_bin_edges"].attrs["long_name"] = f"{name} bin edges" + metric_dataset[name].attrs["units"] = "count" + metric_dataset[name].attrs["long_name"] = f"{name} histogram" + metric_dataset.to_netcdf(self._metrics_filename) diff --git a/src/ace_inference/inference/data_writer/main.py b/src/ace_inference/inference/data_writer/main.py new file mode 100644 index 0000000..5183bd6 --- /dev/null +++ b/src/ace_inference/inference/data_writer/main.py @@ -0,0 +1,187 @@ +import dataclasses +from typing import Dict, List, Mapping, Optional, Sequence, Union + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata + +from .histograms import HistogramDataWriter +from .prediction import PredictionDataWriter +from .time_coarsen import TimeCoarsen, TimeCoarsenConfig +from .video import VideoDataWriter + + +Subwriter = Union[PredictionDataWriter, VideoDataWriter, HistogramDataWriter, TimeCoarsen] + + +@dataclasses.dataclass +class DataWriterConfig: + """ + Configuration for inference data writers. + + Args: + log_extended_video_netcdfs: Whether to enable writing of netCDF files + containing video metrics. + save_prediction_files: Whether to enable writing of netCDF files + containing the predictions. + save_raw_prediction_names: Names of variables to save in the predictions + netcdf file. + time_coarsen: Configuration for time coarsening of written outputs. + """ + + log_extended_video_netcdfs: bool = False + save_prediction_files: bool = True + save_raw_prediction_names: Optional[Sequence[str]] = None + time_coarsen: Optional[TimeCoarsenConfig] = None + + def __post_init__(self): + if not self.save_prediction_files and self.save_raw_prediction_names is not None: + raise ValueError("save_raw_prediction_names provided but save_prediction_files is False") + + def build( + self, + experiment_dir: str, + n_samples: int, + n_timesteps: int, + metadata: Mapping[str, VariableMetadata], + coords: Mapping[str, np.ndarray], + n_ensemble_members: int = 1, + ) -> "DataWriter": + return DataWriter( + path=experiment_dir, + n_samples=n_samples, + n_timesteps=n_timesteps, + metadata=metadata, + coords=coords, + enable_prediction_netcdfs=self.save_prediction_files, + enable_video_netcdfs=self.log_extended_video_netcdfs, + time_coarsen=self.time_coarsen, + n_ensemble_members=n_ensemble_members, + ) + + +class DataWriter: + def __init__( + self, + path: str, + n_samples: int, + n_timesteps: int, + metadata: Mapping[str, VariableMetadata], + coords: Mapping[str, np.ndarray], + enable_prediction_netcdfs: bool, + enable_video_netcdfs: bool, + time_coarsen: Optional[TimeCoarsenConfig] = None, + n_ensemble_members: int = 1, + ): + """ + Args: + path: Path to write netCDF file(s). + n_samples: Number of samples to write to the file. + n_timesteps: Number of timesteps to write to the file. + metadata: Metadata for each variable to be written to the file. + coords: Coordinate data to be written to the file. + enable_prediction_netcdfs: Whether to enable writing of netCDF files + containing the predictions. + enable_video_netcdfs: Whether to enable writing of netCDF files + containing video metrics. + save_names: Names of variables to save in the predictions netcdf file. + time_coarsen: Configuration for time coarsening of written outputs. + """ + self._writers: List[Subwriter] = [] + + if time_coarsen is not None: + n_timesteps = time_coarsen.n_coarsened_timesteps(n_timesteps) + + def _time_coarsen_builder(data_writer: Subwriter) -> Subwriter: + if time_coarsen is not None: + return time_coarsen.build(data_writer) + else: + return data_writer + + if enable_prediction_netcdfs: + self._writers.append( + _time_coarsen_builder( + PredictionDataWriter( + path=path, + n_samples=n_samples, + metadata=metadata, + coords=coords, + ) + ) + ) + if enable_video_netcdfs: + self._writers.append( + _time_coarsen_builder( + VideoDataWriter( + path=path, + n_timesteps=n_timesteps, + metadata=metadata, + coords=coords, + ) + ) + ) + if n_ensemble_members == 1: + self._writers.append( + _time_coarsen_builder( + HistogramDataWriter( + path=path, + n_timesteps=n_timesteps, + metadata=metadata, + ) + ) + ) + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray = None, + ): + """ + Append a batch of data to the file. + + Args: + target: Target data. + prediction: Prediction data. + start_timestep: Timestep at which to start writing. + start_sample: Sample at which to start writing. + batch_times: Time coordinates for each sample in the batch. + """ + for writer in self._writers: + writer.append_batch( + target=target, + prediction=prediction, + start_timestep=start_timestep, + start_sample=start_sample, + batch_times=batch_times, + ) + + def flush(self): + """ + Flush the data to disk. + """ + for writer in self._writers: + writer.flush() + + +class NullDataWriter: + """ + Null pattern for DataWriter, which does nothing. + """ + + def __init__(self): + pass + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray = None, + ): + pass diff --git a/src/ace_inference/inference/data_writer/prediction.py b/src/ace_inference/inference/data_writer/prediction.py new file mode 100644 index 0000000..0fcc5c5 --- /dev/null +++ b/src/ace_inference/inference/data_writer/prediction.py @@ -0,0 +1,131 @@ +from pathlib import Path +from typing import Dict, Mapping, Optional + +import numpy as np +import torch +import xarray as xr +from netCDF4 import Dataset + +from src.ace_inference.core.data_loading.data_typing import VariableMetadata + + +class PredictionDataWriter: + """ + Write raw prediction data to a netCDF file. + """ + + def __init__( + self, + path: str, + n_samples: int, + metadata: Mapping[str, VariableMetadata], + coords: Mapping[str, np.ndarray], + ): + """ + Args: + filename: Path to write netCDF file(s). + n_samples: Number of samples to write to the file. + n_timesteps: Number of timesteps to write to the file. + metadata: Metadata for each variable to be written to the file. + coords: Coordinate data to be written to the file. + """ + self.path = path + filename = str(Path(path) / "autoregressive_predictions.nc") + self.metadata = metadata + self.coords = coords + self.dataset = Dataset(filename, "w", format="NETCDF4") + self.dataset.createDimension("source", 2) + self.dataset.createDimension("timestep", None) # unlimited dimension + self.dataset.createDimension("sample", n_samples) + self.dataset.createVariable("source", "str", ("source",)) + self.dataset.variables["source"][:] = np.array(["target", "prediction"]) + self._n_lat: Optional[int] = None + self._n_lon: Optional[int] = None + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray = None, + ): + """ + Append a batch of data to the file. + + Args: + target: Target data. + prediction: Prediction data. + start_timestep: Timestep at which to start writing. + start_sample: Sample at which to start writing. + """ + if self._n_lat is None: + self._n_lat = target[next(iter(target.keys()))].shape[-2] + self.dataset.createDimension("lat", self._n_lat) + if "lat" in self.coords: + self.dataset.createVariable("lat", "f4", ("lat",)) + self.dataset.variables["lat"][:] = self.coords["lat"] + if self._n_lon is None: + self._n_lon = target[next(iter(target.keys()))].shape[-1] + self.dataset.createDimension("lon", self._n_lon) + if "lon" in self.coords: + self.dataset.createVariable("lon", "f4", ("lon",)) + self.dataset.variables["lon"][:] = self.coords["lon"] + + dims = ("source", "sample", "timestep", "lat", "lon") + for variable_name in set(target.keys()).union(prediction.keys()): + # define the variable if it doesn't exist + if variable_name not in self.dataset.variables: + self.dataset.createVariable( + variable_name, + "f4", + dims, + fill_value=np.nan, + ) + if variable_name in self.metadata: + self.dataset.variables[variable_name].units = self.metadata[variable_name].units + self.dataset.variables[variable_name].long_name = self.metadata[variable_name].long_name + + # Target and prediction may not have the same variables. + # The netCDF contains a "source" dimension for all variables + # and will have NaN for missing data. + if variable_name in target: + target_numpy = target[variable_name].cpu().numpy() + else: + target_numpy = np.full(shape=target[next(iter(target.keys()))].shape, fill_value=np.nan) + if variable_name in prediction: + prediction_numpy = prediction[variable_name].cpu().numpy() + else: + prediction_numpy = np.full( + shape=prediction[next(iter(prediction.keys()))].shape, + fill_value=np.nan, + ) + + # Broadcast the corresponding dimension to match with the + # 'source' dimension of the variable in the netCDF file + target_numpy = np.expand_dims(target_numpy, dims.index("source")) + prediction_numpy = np.expand_dims(prediction_numpy, dims.index("source")) + + n_samples_total = self.dataset.variables[variable_name].shape[1] + if start_sample + target_numpy.shape[1] > n_samples_total: + raise ValueError( + f"Batch size {target_numpy.shape[1]} starting at sample " + f"{start_sample} " + "is too large to fit in the netCDF file with sample " + f"dimension of length {n_samples_total}." + ) + # Append the data to the variables + self.dataset.variables[variable_name][ + :, + start_sample : start_sample + target_numpy.shape[1], + start_timestep : start_timestep + target_numpy.shape[2], + :, + ] = np.concatenate([target_numpy, prediction_numpy], axis=0) + + self.dataset.sync() # Flush the data to disk + + def flush(self): + """ + Flush the data to disk. + """ + self.dataset.sync() diff --git a/src/ace_inference/inference/data_writer/time_coarsen.py b/src/ace_inference/inference/data_writer/time_coarsen.py new file mode 100644 index 0000000..9ed7bb9 --- /dev/null +++ b/src/ace_inference/inference/data_writer/time_coarsen.py @@ -0,0 +1,141 @@ +import dataclasses +from typing import Dict, Protocol, Tuple + +import torch +import xarray as xr + + +TIME_DIM_NAME = "time" +TIME_DIM = 1 # sample, time, lat, lon + + +class _DataWriter(Protocol): + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray, + ): + pass + + def flush(self): + pass + + +@dataclasses.dataclass +class TimeCoarsenConfig: + """ + Config for inference data time coarsening. + + Args: + coarsen_factor: Factor by which to coarsen in time, an integer 1 or greater. The + resulting time labels will be coarsened to the mean of the original labels. + """ + + def __post_init__(self): + if self.coarsen_factor < 1: + raise ValueError(f"coarsen_factor must be 1 or greater, got {self.coarsen_factor}") + + coarsen_factor: int + + def build(self, data_writer: _DataWriter) -> "TimeCoarsen": + return TimeCoarsen( + data_writer=data_writer, + coarsen_factor=self.coarsen_factor, + ) + + def n_coarsened_timesteps(self, n_timesteps: int) -> int: + """Assumes initial condition is in n_timesteps, and is not coarsened""" + return ((n_timesteps - 1) // self.coarsen_factor) + 1 + + +class TimeCoarsen: + """Wraps a data writer and coarsens its arguments in time before passing them on.""" + + def __init__( + self, + data_writer: _DataWriter, + coarsen_factor: int, + ): + self._data_writer: _DataWriter = data_writer + self._coarsen_factor: int = coarsen_factor + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray, + ): + if start_timestep == 0: + # record the initial condition without coarsening + target_inital = tensor_dict_time_select(target, time_slice=slice(None, 1)) + prediction_initial = tensor_dict_time_select(prediction, time_slice=slice(None, 1)) + batch_times_initial = batch_times.isel({TIME_DIM_NAME: slice(None, 1)}) + self._data_writer.append_batch( + target_inital, + prediction_initial, + start_timestep, + start_sample, + batch_times_initial, + ) + # then coarsen the rest of the batch + target = tensor_dict_time_select(target, time_slice=slice(1, None)) + prediction = tensor_dict_time_select(prediction, time_slice=slice(1, None)) + batch_times = batch_times.isel({TIME_DIM_NAME: slice(1, None)}) + start_timestep = 1 + ( + target_coarsened, + prediction_coarsened, + start_timestep, + batch_times_coarsened, + ) = self.coarsen_batch(target, prediction, start_timestep, batch_times) + self._data_writer.append_batch( + target_coarsened, + prediction_coarsened, + start_timestep, + start_sample, + batch_times_coarsened, + ) + + def coarsen_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + batch_times: xr.DataArray, + ) -> Tuple[ + Dict[str, torch.Tensor], + Dict[str, torch.Tensor], + int, + xr.DataArray, + ]: + target_coarsened = self._coarsen_tensor_dict(target) + prediction_coarsened = self._coarsen_tensor_dict(prediction) + start_timestep = ((start_timestep - 1) // self._coarsen_factor) + 1 + batch_times_coarsened = batch_times.coarsen({TIME_DIM_NAME: self._coarsen_factor}).mean() + return ( + target_coarsened, + prediction_coarsened, + start_timestep, + batch_times_coarsened, + ) + + def _coarsen_tensor_dict(self, tensor_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + """Coarsen each tensor along a given axis by a given factor.""" + coarsened_tensor_dict = {} + for name, tensor in tensor_dict.items(): + coarsened_tensor_dict[name] = tensor.unfold( + dimension=TIME_DIM, size=self._coarsen_factor, step=self._coarsen_factor + ).mean(dim=-1) + return coarsened_tensor_dict + + def flush(self): + self._data_writer.flush() + + +def tensor_dict_time_select(tensor_dict: Dict[str, torch.Tensor], time_slice: slice): + return {name: tensor[:, time_slice] for name, tensor in tensor_dict.items()} diff --git a/src/ace_inference/inference/data_writer/video.py b/src/ace_inference/inference/data_writer/video.py new file mode 100644 index 0000000..b94bf64 --- /dev/null +++ b/src/ace_inference/inference/data_writer/video.py @@ -0,0 +1,73 @@ +from pathlib import Path +from typing import Dict, Mapping + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.aggregator.inference.video import VideoAggregator +from src.ace_inference.core.data_loading.data_typing import VariableMetadata + + +class VideoDataWriter: + """ + Write [time, lat, lon] metric data to a netCDF file. + """ + + def __init__( + self, + path: str, + n_timesteps: int, + metadata: Mapping[str, VariableMetadata], + coords: Mapping[str, np.ndarray], + ): + """ + Args: + filename: Path to write netCDF file(s). + n_samples: Number of samples to write to the file. + n_timesteps: Number of timesteps to write to the file. + metadata: Metadata for each variable to be written to the file. + coords: Coordinate data to be written to the file. + """ + self.path = path + self._metrics_filename = str(Path(path) / "reduced_autoregressive_predictions.nc") + self.metadata = metadata + self.coords = coords + self._video = VideoAggregator(n_timesteps=n_timesteps, enable_extended_videos=True) + + def append_batch( + self, + target: Dict[str, torch.Tensor], + prediction: Dict[str, torch.Tensor], + start_timestep: int, + start_sample: int, + batch_times: xr.DataArray = None, + ): + """ + Append a batch of data to the file. + + Args: + target: Target data. + prediction: Prediction data. + start_timestep: Timestep at which to start writing. + start_sample: Sample at which to start writing. + """ + self._video.record_batch( + loss=np.nan, + target_data=target, + gen_data=prediction, + i_time_start=start_timestep, + ) + + def flush(self): + """ + Flush the data to disk. + """ + metric_dataset = self._video.get_dataset() + coords = {} + if "lat" in self.coords: + coords["lat"] = self.coords["lat"] + if "lon" in self.coords: + coords["lon"] = self.coords["lon"] + metric_dataset = metric_dataset.assign_coords(coords) + metric_dataset.to_netcdf(self._metrics_filename) diff --git a/src/ace_inference/inference/derived_variables.py b/src/ace_inference/inference/derived_variables.py new file mode 100644 index 0000000..5b23c01 --- /dev/null +++ b/src/ace_inference/inference/derived_variables.py @@ -0,0 +1,132 @@ +import dataclasses +import logging +from typing import Callable, Dict, MutableMapping + +import torch +from toolz import curry + +from src.ace_inference.core import metrics +from src.ace_inference.core.aggregator.climate_data import ClimateData +from src.ace_inference.core.constants import TIMESTEP_SECONDS +from src.ace_inference.core.data_loading.data_typing import SigmaCoordinates +from src.ace_inference.core.stepper import SteppedData + + +@dataclasses.dataclass +class DerivedVariableRegistryEntry: + func: Callable[[ClimateData, SigmaCoordinates], torch.Tensor] + + +_DERIVED_VARIABLE_REGISTRY: MutableMapping[str, DerivedVariableRegistryEntry] = {} + + +@curry +def register( + func: Callable[[ClimateData, SigmaCoordinates], torch.Tensor], +): + """Decorator for registering a function that computes a derived variable.""" + label = func.__name__ + if label in _DERIVED_VARIABLE_REGISTRY: + raise ValueError(f"Function {label} has already been added to registry.") + _DERIVED_VARIABLE_REGISTRY[label] = DerivedVariableRegistryEntry(func=func) + return func + + +@register() +def surface_pressure_due_to_dry_air(data: ClimateData, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: + return metrics.surface_pressure_due_to_dry_air( + data.specific_total_water, + data.surface_pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ) + + +@register() +def total_water_path(data: ClimateData, sigma_coordinates: SigmaCoordinates) -> torch.Tensor: + return metrics.vertical_integral( + data.specific_total_water, + data.surface_pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ) + + +@register() +def total_water_path_budget_residual(data: ClimateData, sigma_coordinates: SigmaCoordinates): + total_water_path = metrics.vertical_integral( + data.specific_total_water, + data.surface_pressure, + sigma_coordinates.ak, + sigma_coordinates.bk, + ) + twp_total_tendency = (total_water_path[:, 1:] - total_water_path[:, :-1]) / (TIMESTEP_SECONDS) + twp_budget_residual = torch.zeros_like(total_water_path) + # no budget residual on initial step + twp_budget_residual[:, 1:] = twp_total_tendency - ( + data.evaporation_rate[:, 1:] + - data.precipitation_rate[:, 1:] + + data.tendency_of_total_water_path_due_to_advection[:, 1:] + ) + return twp_budget_residual + + +def _compute_derived_variable( + data: Dict[str, torch.Tensor], + sigma_coordinates: SigmaCoordinates, + label: str, + derived_variable: DerivedVariableRegistryEntry, +) -> Dict[str, torch.Tensor]: + """Computes a derived variable and adds it to the given data. + + If the required input data is not available, a warning will be logged and + no change will be made to the data. + + Args: + data: dictionary of data add the derived variable to. + sigma_coordinates: the vertical coordinate. + label: the name of the derived variable. + derived_variable: class indicating required names and function to compute. + + Returns: + A new SteppedData instance with the derived variable added. + + Note: + Derived variables are only computed for the denormalized data in stepped. + """ + if label in data: + raise ValueError( + f"Variable {label} already exists. It is not permitted " + "to overwrite existing variables with derived variables." + ) + new_data = data.copy() + climate_data = ClimateData(data) + try: + output = derived_variable.func(climate_data, sigma_coordinates) + except KeyError as key_error: + logging.warning(f"Could not compute {label} because {key_error} is missing") + else: # if no exception was raised + new_data[label] = output + return new_data + + +def compute_derived_quantities( + data: Dict[str, torch.Tensor], + sigma_coordinates: SigmaCoordinates, + registry: MutableMapping[str, DerivedVariableRegistryEntry] = _DERIVED_VARIABLE_REGISTRY, +) -> Dict[str, torch.Tensor]: + """Computes all derived quantities from the given data.""" + + for label, derived_variable in registry.items(): + data = _compute_derived_variable(data, sigma_coordinates, label, derived_variable) + return data + + +def compute_stepped_derived_quantities( + stepped: SteppedData, + sigma_coordinates: SigmaCoordinates, + registry: MutableMapping[str, DerivedVariableRegistryEntry] = _DERIVED_VARIABLE_REGISTRY, +) -> SteppedData: + stepped.gen_data = compute_derived_quantities(stepped.gen_data, sigma_coordinates, registry) + stepped.target_data = compute_derived_quantities(stepped.target_data, sigma_coordinates, registry) + return stepped diff --git a/src/ace_inference/inference/gcs_utils.py b/src/ace_inference/inference/gcs_utils.py new file mode 100644 index 0000000..fb905dc --- /dev/null +++ b/src/ace_inference/inference/gcs_utils.py @@ -0,0 +1,20 @@ +import logging +import os +import subprocess + +from src.ace_inference.core.distributed import Distributed + + +KEYFILE_VAR = "GOOGLE_APPLICATION_CREDENTIALS" + + +def authenticate(keyfile_var: str = KEYFILE_VAR): + dist = Distributed.get_instance() + if dist.is_root(): + try: + keyfile = os.environ[keyfile_var] + except KeyError: + logging.info("No keyfile found in environment, skipping gcloud authentication.") + else: + output = subprocess.check_output(["gcloud", "auth", "activate-service-account", "--key-file", keyfile]) + logging.info(output.decode("utf-8")) diff --git a/src/ace_inference/inference/inference.py b/src/ace_inference/inference/inference.py new file mode 100755 index 0000000..fbf61ed --- /dev/null +++ b/src/ace_inference/inference/inference.py @@ -0,0 +1,333 @@ +import argparse +import dataclasses +import os +import time +import warnings +from pathlib import Path +from typing import Any, Dict, Optional, Sequence, Union + +import dacite +import torch +import tqdm.auto as tqdm +import wandb +import yaml + +from src.ace_inference.core.aggregator.inference.main import InferenceAggregator +from src.ace_inference.core.aggregator.null import NullAggregator +from src.ace_inference.core.data_loading.data_typing import GriddedData +from src.ace_inference.core.data_loading.getters import get_inference_data +from src.ace_inference.core.data_loading.inference import InferenceDataLoaderParams +from src.ace_inference.core.device import get_device +from src.ace_inference.core.dicts import to_flat_dict +from src.ace_inference.core.stepper import SingleModuleStepper +from src.ace_inference.core.stepper_multistep import MultiStepStepper +from src.ace_inference.core.wandb import WandB +from src.ace_inference.inference import gcs_utils, logging_utils +from src.ace_inference.inference.data_writer.main import DataWriter, DataWriterConfig +from src.ace_inference.inference.loop import run_dataset_inference, run_inference +from src.utilities.checkpointing import local_path_to_absolute_and_download_if_needed +from src.utilities.utils import get_logger +from src.utilities.wandb_api import restore_model_from_wandb_cloud + + +logging = get_logger(__name__) +device = get_device() + + +def load_stepper( + checkpoint_file: str, overrides: Dict[str, Any] = None, area=None +) -> Union[SingleModuleStepper, MultiStepStepper]: + checkpoint_file = local_path_to_absolute_and_download_if_needed(checkpoint_file) + checkpoint = torch.load(checkpoint_file, map_location=device, weights_only=False) + # checkpoint['hyper_parameters'].pop('prediction_mode', None) + # checkpoint["hyper_parameters"]["diffusion_config"]["_target_"] = checkpoint["hyper_parameters"]["diffusion_config"]["_target_"].replace("DYffusionMultiHorizonWithPretrainedInterpolator", "DYffusion") + # checkpoint["hyper_parameters"]["diffusion_config"].pop("is_parametric", None) + # checkpoint["hyper_parameters"]["diffusion_config"].pop("prediction_mode", None) + # checkpoint["hyper_parameters"]["diffusion_config"].pop("use_mean_of_parametric_predictions", None) + # checkpoint["hyper_parameters"]["model_config"]["dropout_filter"] = 0 + # print(f"{checkpoint['hyper_parameters']}") + # torch.save(checkpoint, checkpoint_file) + # torch.save(checkpoint, checkpoint_file.replace('.ckpt', '_cleaned.ckpt')) + epoch = checkpoint["epoch"] + if wandb.run is None: + pass # wandb is not being used + elif "wandb" in checkpoint.keys(): + # wandb.run.group = checkpoint['wandb']['group']: NOT POSSIBLE + wandb.run._run_obj.run_group = checkpoint["wandb"]["group"] + wandb.run.name = checkpoint["wandb"]["name"] + f"-{epoch}epoch" + run_current = wandb.Api().run(wandb.run.path) + run_current.group = checkpoint["wandb"]["group"] + run_current.update() + else: + wandb.run.name = f"{wandb.run.name}-{epoch}epoch" + + ckpt_time = {k: checkpoint[k] for k in ["epoch", "step", "global_step"] if k in checkpoint.keys()} + if wandb.run is not None: + wandb.log(ckpt_time, step=0) + wandb.run.summary.update(ckpt_time) + + # Check if it is Spherical DYffusion model + if ("FV3GFS" in checkpoint_file and "seed" in checkpoint_file) or ".ckpt" in checkpoint_file: + stepper = MultiStepStepper.from_state(checkpoint, load_optimizer=False, overrides=overrides) + else: + assert overrides is None, "Overrides not supported for non-DYffusion models. Please set it to None." + stepper = SingleModuleStepper.from_state(checkpoint["stepper"], area) + return stepper + + +@dataclasses.dataclass +class InferenceConfig: + """ + Configuration for running inference. + + Attributes: + experiment_dir: Directory to save results to. + n_forward_steps: Number of steps to run the model forward for. Must be divisble + by forward_steps_in_memory. + checkpoint_path: Path to stepper checkpoint to load. + logging: configuration for logging. + validation_loader: Configuration for validation data. + prediction_loader: Configuration for prediction data to evaluate. If given, + model evaluation will not run, and instead predictions will be evaluated. + Model checkpoint will still be used to determine inputs and outputs. + log_video: Whether to log videos of the state evolution. + log_extended_video: Whether to log wandb videos of the predictions with + statistical metrics, only done if log_video is True. + log_extended_video_netcdfs: Whether to log videos of the predictions with + statistical metrics as netcdf files. + log_zonal_mean_images: Whether to log zonal-mean images (hovmollers) with a + time dimension. + save_prediction_files: Whether to save the predictions as a netcdf file. + save_raw_prediction_names: Names of variables to save in the predictions + netcdf file. Ignored if save_prediction_files is False. + forward_steps_in_memory: Number of forward steps to complete in memory + at a time, will load one more step for initial condition. + data_writer: Configuration for data writers. + overrides: Overrides for the re-loaded module. E.g. change the sampling behavior. Should be a dict or dict of dicts + """ + + experiment_dir: str + n_forward_steps: int + checkpoint_path: str + logging: logging_utils.LoggingConfig + validation_loader: InferenceDataLoaderParams + prediction_loader: Optional[InferenceDataLoaderParams] = None + n_ensemble_members: int = 1 + wandb_run_path: Optional[str] = None + log_video: bool = True + log_extended_video: bool = False + log_extended_video_netcdfs: Optional[bool] = None + log_zonal_mean_images: bool = True + save_prediction_files: Optional[bool] = None + save_raw_prediction_names: Optional[Sequence[str]] = None + forward_steps_in_memory: int = 1 + overrides: Optional[Dict[str, Any]] = None + data_writer: DataWriterConfig = dataclasses.field(default_factory=lambda: DataWriterConfig()) + compute_metrics: bool = True + + def __post_init__(self): + if self.n_forward_steps % self.forward_steps_in_memory != 0: + raise ValueError( + "n_forward_steps must be divisible by steps_in_memory, " + f"got {self.n_forward_steps} and {self.forward_steps_in_memory}" + ) + deprecated_writer_attrs = { + k: getattr(self, k) + for k in [ + "log_extended_video_netcdfs", + "save_prediction_files", + "save_raw_prediction_names", + ] + if getattr(self, k) is not None + } + for k, v in deprecated_writer_attrs.items(): + warnings.warn( + f"Inference configuration attribute `{k}` is deprecated. " + f"Using its value `{v}`, but please use attribute `data_writer` " + "instead." + ) + setattr(self.data_writer, k, v) + if (self.data_writer.time_coarsen is not None) and ( + self.forward_steps_in_memory % self.data_writer.time_coarsen.coarsen_factor != 0 + ): + raise ValueError( + "forward_steps_in_memory must be divisible by " + f"time_coarsen.coarsen_factor. Got {self.forward_steps_in_memory} " + f"and {self.data_writer.time_coarsen.coarsen_factor}." + ) + + def configure_logging(self, log_filename: str): + self.logging.configure_logging(self.experiment_dir, log_filename) + + def configure_wandb(self, env_vars: Optional[dict] = None, **kwargs): + config = to_flat_dict(dataclasses.asdict(self)) + if "environment" in config: + logging.warning("Not recording env vars since 'environment' is in config.") + elif env_vars is not None: + config["environment"] = env_vars + self.logging.configure_wandb(config=config, resume=False, **kwargs) + + def clean_wandb(self): + self.logging.clean_wandb(self.experiment_dir) + + def load_stepper(self, **kwargs) -> Union[SingleModuleStepper, MultiStepStepper]: + # logging.info(f"Loading trained model checkpoint from {self.checkpoint_path}") + checkpoint_file = self.checkpoint_path + if checkpoint_file.startswith("hf:"): + checkpoint_file = local_path_to_absolute_and_download_if_needed(self.checkpoint_path) + elif os.path.exists(checkpoint_file) and self.wandb_run_path is None: + print(f"Loading checkpoint from local path {checkpoint_file}") + pass + elif self.wandb_run_path is not None: + checkpoint_file = restore_model_from_wandb_cloud(self.wandb_run_path, ckpt_filename=checkpoint_file) + logging.info(f"Restored model ckpt ``{checkpoint_file}`` from wandb run path {self.wandb_run_path}.") + else: + from pathlib import Path + + # List directory contents of every directory in the path, starting from the end + # until we find a directory that exists + path = Path(checkpoint_file) + while not path.exists(): + path = path.parent + print(f"Found {path} to exist. Ls: {os.listdir(path)}") + return load_stepper(checkpoint_file, overrides=self.overrides, **kwargs) + + def get_data_writer(self, data: GriddedData) -> DataWriter: + return self.data_writer.build( + experiment_dir=self.experiment_dir, + n_samples=self.validation_loader.n_samples, + n_timesteps=self.n_forward_steps + 1, + metadata=data.metadata, + coords=data.coords, + n_ensemble_members=self.n_ensemble_members, + ) + + +def main( + yaml_config: str, +): + with open(yaml_config, "r") as f: + data = yaml.safe_load(f) + config = dacite.from_dict( + data_class=InferenceConfig, + data=data, + config=dacite.Config(strict=True), + ) + + if os.path.exists(config.experiment_dir): + # Append a timestamp to the experiment directory to avoid overwriting + config.experiment_dir += f"-{time.strftime('%Y-%m-%d-%H-%M-%S')}" + assert not os.path.exists(config.experiment_dir), f"Experiment directory {config.experiment_dir} already exists." + if not os.path.isdir(config.experiment_dir): + os.makedirs(config.experiment_dir) + with open(os.path.join(config.experiment_dir, "config.yaml"), "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + config.configure_logging(log_filename="inference_out.log") + env_vars = logging_utils.retrieve_env_vars() + config.configure_wandb(env_vars=env_vars) + gcs_utils.authenticate() + + torch.backends.cudnn.benchmark = True + + logging_utils.log_versions() + _ = WandB.get_instance() # wandb = WandB.get_instance() + + start_time = time.time() + stepper = config.load_stepper() + logging.info("Loading inference data") + data_requirements = stepper.get_data_requirements(n_forward_steps=config.n_forward_steps) + + data = get_inference_data( + config.validation_loader, + config.forward_steps_in_memory, + data_requirements, + ) + + eval_device = device # 'cuda' #'cpu' if config.n_ensemble_members > 1 else 'cuda' + aggregator = ( + InferenceAggregator( + data.area_weights.to(device), + sigma_coordinates=data.sigma_coordinates, + record_step_20=config.n_forward_steps >= 20, + log_video=config.log_video, + enable_extended_videos=config.log_extended_video, + log_zonal_mean_images=config.log_zonal_mean_images, + n_timesteps=config.n_forward_steps + 1, + metadata=data.metadata, + n_ensemble_members=config.n_ensemble_members, + device=eval_device, + ) + if config.compute_metrics + else NullAggregator() + ) + writer = config.get_data_writer(data) if config.compute_metrics else None + + logging.info("Starting inference") + if config.prediction_loader is not None: + prediction_data = get_inference_data( + config.prediction_loader, + config.forward_steps_in_memory, + data_requirements, + ) + + timers = run_dataset_inference( + aggregator=aggregator, + normalizer=stepper.normalizer, + prediction_data=prediction_data, + target_data=data, + n_forward_steps=config.n_forward_steps, + forward_steps_in_memory=config.forward_steps_in_memory, + writer=writer, + ) + else: + timers = run_inference( + aggregator=aggregator, + writer=writer, + stepper=stepper, + data=data, + n_forward_steps=config.n_forward_steps, + forward_steps_in_memory=config.forward_steps_in_memory, + n_ensemble_members=config.n_ensemble_members, + eval_device=eval_device, + ) + + duration = time.time() - start_time + total_steps = config.n_forward_steps * config.validation_loader.n_samples + total_steps_per_second = total_steps / duration + logging.info(f"Inference duration: {duration:.2f} seconds") + logging.info(f"Total steps per second: {total_steps_per_second:.2f} steps/second") + + step_logs = aggregator.get_inference_logs(label="inference") + tqdm_bar = tqdm.tqdm(step_logs, desc="Logging inference results to wandb") + wandb = WandB.get_instance() + duration_logs = { + "duration_seconds": duration, + "time/inference": duration, + "total_steps_per_second": total_steps_per_second, + } + wandb.log({**timers, **duration_logs}, step=0) + for i, log in enumerate(tqdm_bar): + log["timestep"] = i + wandb.log(log, step=i) + # wandb.log cannot be called more than "a few times per second" + time.sleep(0.005) + writer.flush() + + logging.info("Writing reduced metrics to disk in netcdf format.") + aggregators_to_save = ["time_mean"] # , "zonal_mean"] + for name, ds in aggregator.get_datasets(aggregators_to_save).items(): + coords = {k: v for k, v in data.coords.items() if k in ds.dims} + ds = ds.assign_coords(coords) + ds.to_netcdf(Path(config.experiment_dir) / f"{name}_diagnostics.nc") + + # config.clean_wandb() + return step_logs + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("yaml_config", type=str) + + args = parser.parse_args() + + main(yaml_config=args.yaml_config) diff --git a/src/ace_inference/inference/logging_utils.py b/src/ace_inference/inference/logging_utils.py new file mode 100644 index 0000000..3b12814 --- /dev/null +++ b/src/ace_inference/inference/logging_utils.py @@ -0,0 +1,150 @@ +import contextlib +import dataclasses +import logging +import os +from typing import Any, Dict, Mapping, Optional, Union + +from src.ace_inference.core.distributed import Distributed +from src.ace_inference.core.wandb import WandB + + +ENV_VAR_NAMES = ( + "BEAKER_EXPERIMENT_ID", + "SLURM_JOB_ID", + "SLURM_JOB_USER", + "FME_TRAIN_DIR", + "FME_VALID_DIR", + "FME_STATS_DIR", + "FME_CHECKPOINT_DIR", + "FME_OUTPUT_DIR", + "FME_IMAGE", +) + + +@dataclasses.dataclass +class LoggingConfig: + """ + Configuration for logging. + + Attributes: + project: name of the project in Weights & Biases + entity: name of the entity in Weights & Biases + log_to_screen: whether to log to the screen + log_to_file: whether to log to a file + log_to_wandb: whether to log to Weights & Biases + log_format: format of the log messages + """ + + project: str = "Spherical-DYffusion" + entity: Optional[str] = None # defaults to the user + log_to_screen: bool = True + log_to_file: bool = True + log_to_wandb: bool = True + log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + level: Union[str, int] = logging.INFO + + def __post_init__(self): + self._dist = Distributed.get_instance() + + def configure_logging(self, experiment_dir: str, log_filename: str): + """ + Configure the global `logging` module based on this LoggingConfig. + """ + if self.log_to_screen and self._dist.is_root(): + logging.basicConfig(format=self.log_format, level=self.level) + elif self._dist.is_root(): + logging.basicConfig(level=logging.WARNING) + else: # we are not root + logging.basicConfig(level=logging.ERROR) + logger = logging.getLogger() + if self.log_to_file and self._dist.is_root(): + if not os.path.exists(experiment_dir): + raise ValueError(f"experiment directory {experiment_dir} does not exist, " "cannot log files to it") + log_path = os.path.join(experiment_dir, log_filename) + fh = logging.FileHandler(log_path) + fh.setLevel(self.level) + fh.setFormatter(logging.Formatter(self.log_format)) + logger.addHandler(fh) + + def configure_wandb( + self, + config: Mapping[str, Any], + env_vars: Optional[Mapping[str, Any]] = None, + **kwargs, + ): + config_copy = {**config} + if "environment" in config_copy: + logging.warning( + "Not recording environmental variables since 'environment' key is " "already present in config." + ) + elif env_vars is not None: + config_copy["environment"] = env_vars + # must ensure wandb.configure is called before wandb.init + wandb = WandB.get_instance() + wandb.configure(log_to_wandb=self.log_to_wandb) + wandb.init( + config=config_copy, + project=self.project, + entity=self.entity, + dir=config["experiment_dir"], + **kwargs, + ) + + def clean_wandb(self, experiment_dir: str): + wandb = WandB.get_instance() + wandb.clean_wandb_dir(experiment_dir=experiment_dir) + + +def log_versions(): + import torch + + logging.info("--------------- Versions ---------------") + logging.info("Torch: " + str(torch.__version__)) + logging.info("----------------------------------------") + + +def retrieve_env_vars(names=ENV_VAR_NAMES) -> Dict[str, str]: + """Return a dictionary of specific environmental variables.""" + output = {} + for name in names: + try: + value = os.environ[name] + except KeyError: + pass # logging.warning(f"Environmental variable {name} not found.") + else: + output[name] = value + logging.info(f"Environmental variable {name}={value}.") + return output + + +def log_beaker_url(beaker_id=None): + """Log the Beaker ID and URL for the current experiment. + + beaker_id: The Beaker ID of the experiment. If None, uses the env variable + `BEAKER_EXPERIMENT_ID`. + + Returns the Beaker URL. + """ + if beaker_id is None: + try: + beaker_id = os.environ["BEAKER_EXPERIMENT_ID"] + except KeyError: + logging.warning("Beaker Experiment ID not found.") + return None + + beaker_url = f"https://beaker.org/ex/{beaker_id}" + logging.info(f"Beaker ID: {beaker_id}") + logging.info(f"Beaker URL: {beaker_url}") + return beaker_url + + +@contextlib.contextmanager +def log_level(level): + """Temporarily set the log level of the global logger.""" + logger = logging.getLogger() # presently, data loading uses the root logger + old_level = logger.getEffectiveLevel() + try: + logger.setLevel(level) + yield + finally: + logger.setLevel(old_level) diff --git a/src/ace_inference/inference/loop.py b/src/ace_inference/inference/loop.py new file mode 100644 index 0000000..44896e5 --- /dev/null +++ b/src/ace_inference/inference/loop.py @@ -0,0 +1,326 @@ +from __future__ import annotations + +import logging +import time +from collections import defaultdict +from typing import Any, Dict, Mapping, Optional, Union + +import numpy as np +import torch +import xarray as xr + +from src.ace_inference.core.aggregator.inference.main import InferenceAggregator +from src.ace_inference.core.aggregator.null import NullAggregator +from src.ace_inference.core.data_loading.data_typing import GriddedData +from src.ace_inference.core.device import get_device +from src.ace_inference.core.optimization import NullOptimization +from src.ace_inference.core.stepper import SingleModuleStepper, SteppedData +from src.ace_inference.inference.data_writer.main import DataWriter, NullDataWriter +from src.ace_inference.inference.derived_variables import ( + compute_derived_quantities, + # compute_stepped_derived_quantities, +) +from src.utilities.normalization import StandardNormalizer + + +class WindowStitcher: + """ + Handles stitching together the windows of data from the inference loop. + + For example, handles passing in windows to data writers which combine + them together into a continuous series, and handles storing prognostic + variables from the end of a window to use as the initial condition for + the next window. + """ + + def __init__( + self, + n_forward_steps: int, + writer: Union[DataWriter, NullDataWriter], + is_ensemble: bool = False, + ): + self.i_time = 0 + self.n_forward_steps = n_forward_steps + self.writer = writer + self.is_ensemble = is_ensemble + # tensors have shape [n_sample, n_lat, n_lon] with no time axis + self._initial_condition: Optional[Mapping[str, torch.Tensor]] = None + + def append( + self, + data: Dict[str, torch.tensor], + gen_data: Dict[str, torch.tensor], + batch_times: xr.DataArray, + ) -> None: + """ + Appends a time segment of data to the ensemble batch. + + Args: + data: The reference data for the current time segment, tensors + should have shape [n_sample, n_time, n_lat, n_lon] + gen_data: The generated data for the current time segment, tensors + should have shape [n_sample, n_time, n_lat, n_lon] + batch_times: Time coordinates for each sample in the batch. + """ + tensor_shape = next(data.values().__iter__()).shape + self.writer.append_batch( + target=data, + prediction=gen_data, + start_timestep=self.i_time, + start_sample=0, + batch_times=batch_times, + ) + self.i_time += tensor_shape[1] + if self.i_time < self.n_forward_steps: # only store if needed + # store the end of the time window as + # initial condition for the next segment. + self._initial_condition = {key: value[:, -1] for key, value in data.items()} + self.ensemble_keys = list(gen_data.keys()) + for key, value in gen_data.items(): + self._initial_condition[key] = value[..., -1, :, :].detach().cpu() # 3rd last dimension is time + + for key, value in self._initial_condition.items(): + self._initial_condition[key] = value.detach().cpu() + + def apply_initial_condition( + self, + data: Mapping[str, torch.Tensor], + ensemble_member: int = None, + ): + """ + Applies the last recorded state of the batch as the initial condition for + the next segment of the timeseries. + + Args: + data: The data to apply the initial condition to, tensors should have + shape [n_sample, n_time, n_lat, n_lon] and the first value along + the time axis will be replaced with the last value from the + previous segment. + """ + if self.i_time > self.n_forward_steps: + raise ValueError( + "Cannot apply initial condition after " + "the last segment has been appended, currently at " + f"time index {self.i_time} " + f"with {self.n_forward_steps} max forward steps." + ) + if ensemble_member is not None: + assert self.is_ensemble, "Cannot apply initial condition for ensemble member > 0 if not ensemble" + if self.is_ensemble: + assert ensemble_member is not None, "Must specify ensemble member to apply initial condition for ensemble" + + if self._initial_condition is not None: + for key, value in data.items(): + ic = self._initial_condition[key].to(value.device) + if self.is_ensemble and key in self.ensemble_keys: + ic = ic[ensemble_member, ...] + value[:, 0] = ic + + +def _inference_internal_loop( + stepped: SteppedData, + i_time: int, + aggregator: InferenceAggregator, + stitcher: WindowStitcher, + batch_times: xr.DataArray, +): + """Do operations that need to be done on each time step of the inference loop. + + This function exists to de-duplicate code between run_inference and + run_data_inference.""" + + # for non-initial windows, we want to record only the new data + # and discard the initial sample of the window + if i_time > 0: + stepped = stepped.remove_initial_condition() + batch_times = batch_times.isel(time=slice(1, None)) + i_time_aggregator = i_time + 1 + else: + i_time_aggregator = i_time + # record raw data for the batch, and store the final state + # for the next segment + stitcher.append(stepped.target_data, stepped.gen_data, batch_times) + # record metrics + aggregator.record_batch( + loss=float(stepped.metrics["loss"]), + target_data=stepped.target_data, + gen_data=stepped.gen_data, + target_data_norm=stepped.target_data_norm, + gen_data_norm=stepped.gen_data_norm, + i_time_start=i_time_aggregator, + ) + + +def _to_device(data: Mapping[str, torch.Tensor], device: torch.device) -> Dict[str, Any]: + return {key: value.to(device) for key, value in data.items()} + + +def run_inference( + aggregator: InferenceAggregator, + stepper: SingleModuleStepper, + data: GriddedData, + n_forward_steps: int, + forward_steps_in_memory: int, + n_ensemble_members: int, + eval_device: torch.device | str, + writer: Optional[Union[DataWriter, NullDataWriter]] = None, +) -> Dict[str, float]: + if writer is None: + writer = NullDataWriter() + stitcher = WindowStitcher(n_forward_steps, writer, is_ensemble=n_ensemble_members > 1) + + not_compute_metrics = isinstance(aggregator, NullAggregator) + with torch.no_grad(): + stepper.module.eval() + # We have data batches with long windows, where all data for a + # given batch does not fit into memory at once, so we window it in time + # and run the model on each window in turn. + # + # We process each time window and keep track of the + # final state. We then use this as the initial condition + # for the next time window. + device = get_device() + logging.info(f"Running inference on {n_forward_steps} steps, with {n_ensemble_members} ensemble members") + timers: Dict[str, float] = defaultdict(float) + current_time = time.time() + for i, window_batch_data in enumerate(data.loader): + timers["data_loading"] += time.time() - current_time + current_time = time.time() + i_time = i * forward_steps_in_memory + logging.info( + f"Inference: starting window spanning {i_time}" + f" to {i_time + forward_steps_in_memory} steps, " + f"out of total {n_forward_steps}." + ) + window_data = _to_device(window_batch_data.data, device) + + target_data = compute_derived_quantities(window_data, data.sigma_coordinates) + metrics, gen_data, gen_data_norm = defaultdict(list), [], [] + for ens_mem in range(n_ensemble_members): + print(f"Ensemble member {ens_mem}") + stitcher.apply_initial_condition( + window_data, ensemble_member=ens_mem if n_ensemble_members > 1 else None + ) + stepped = stepper.run_on_batch( + window_data, + NullOptimization(), + n_forward_steps=forward_steps_in_memory, + ) + + if not_compute_metrics: + gen_data.append({key: value.detach() for key, value in stepped.gen_data.items()}) + gen_data_norm.append({key: value.detach() for key, value in stepped.gen_data_norm.items()}) + else: + for k, v in stepped.metrics.items(): + metrics[k].append(float(v.detach().cpu())) + gen_data.append({key: value.detach().cpu() for key, value in stepped.gen_data.items()}) + gen_data_norm.append({key: value.detach().cpu() for key, value in stepped.gen_data_norm.items()}) + + if n_ensemble_members == 1: + stepped = stepped + else: + # Stack the ensemble members into a single tensor (first dimension) + ensemble_dim = 0 + if not_compute_metrics: + metrics = None + else: + metrics = {key: np.mean(value) for key, value in metrics.items()} + gen_data = { + key: torch.stack([value[key] for value in gen_data], dim=ensemble_dim) + for key in gen_data[0].keys() + } + gen_data_norm = { + key: torch.stack([value[key] for value in gen_data_norm], dim=ensemble_dim) + for key in gen_data_norm[0].keys() + } + stepped = SteppedData( + metrics=metrics, + target_data=stepped.target_data, + gen_data=gen_data, + target_data_norm=stepped.target_data_norm, + gen_data_norm=gen_data_norm, + ) + + stepped.target_data = target_data + stepped.gen_data = compute_derived_quantities(stepped.gen_data, data.sigma_coordinates) + stepped.gen_data = _to_device(stepped.gen_data, device) + stepped.gen_data_norm = _to_device(stepped.gen_data_norm, device) + stepped.target_data_norm = _to_device(stepped.target_data_norm, device) + timers["run_on_batch"] += time.time() - current_time + current_time = time.time() + _inference_internal_loop( + stepped, + i_time, + aggregator, + stitcher, + window_batch_data.times, + ) + del stepped + timers["writer_and_aggregator"] += time.time() - current_time + current_time = time.time() + + for name, duration in timers.items(): + print(f"{name} duration: {duration:.2f}s") + return timers + + +def remove_initial_condition(data: Mapping[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return {key: value[:, 1:] for key, value in data.items()} + + +def run_dataset_inference( + aggregator: InferenceAggregator, + normalizer: StandardNormalizer, + prediction_data: GriddedData, + target_data: GriddedData, + n_forward_steps: int, + forward_steps_in_memory: int, + writer: Optional[Union[DataWriter, NullDataWriter]] = None, +) -> Dict[str, float]: + if writer is None: + writer = NullDataWriter() + stitcher = WindowStitcher(n_forward_steps, writer) + + device = get_device() + # We have data batches with long windows, where all data for a + # given batch does not fit into memory at once, so we window it in time + # and run the model on each window in turn. + # + # We process each time window and keep track of the + # final state. We then use this as the initial condition + # for the next time window. + timers: Dict[str, float] = defaultdict(float) + current_time = time.time() + for i, (pred, target) in enumerate(zip(prediction_data.loader, target_data.loader)): + timers["data_loading"] += time.time() - current_time + current_time = time.time() + i_time = i * forward_steps_in_memory + logging.info( + f"Inference: starting window spanning {i_time}" + f" to {i_time + forward_steps_in_memory} steps," + f" out of total {n_forward_steps}." + ) + pred_window_data = _to_device(pred.data, device) + target_window_data = _to_device(target.data, device) + stepped = SteppedData( + {"loss": torch.tensor(float("nan"))}, + pred_window_data, + target_window_data, + normalizer.normalize(pred_window_data), + normalizer.normalize(target_window_data), + ) + stepped = compute_stepped_derived_quantities(stepped, target_data.sigma_coordinates) + timers["run_on_batch"] += time.time() - current_time + current_time = time.time() + _inference_internal_loop( + stepped, + i_time, + aggregator, + stitcher, + target.times, + ) + timers["writer_and_aggregator"] += time.time() - current_time + current_time = time.time() + for name, duration in timers.items(): + logging.info(f"{name} duration: {duration:.2f}s") + return timers diff --git a/src/ace_inference/training/__init__.py b/src/ace_inference/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/training/registry.py b/src/ace_inference/training/registry.py new file mode 100644 index 0000000..0ff31a9 --- /dev/null +++ b/src/ace_inference/training/registry.py @@ -0,0 +1,197 @@ +import dataclasses +from typing import Any, Literal, Mapping, Optional, Protocol, Tuple, Type + +import torch_harmonics as harmonics + +# this package is installed in models/FourCastNet +# from fourcastnet.networks.afnonet import AFNONetBuilder +from torch import nn + + +class ModuleConfig(Protocol): + """ + A protocol for a class that can build a nn.Module given information about the input + and output channels and the image shape. + + This is a "Config" as in practice it is a dataclass loaded directly from yaml, + allowing us to specify details of the network architecture in a config file. + """ + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + """ + Build a nn.Module given information about the input and output channels + and the image shape. + + Args: + n_in_channels: number of input channels + n_out_channels: number of output channels + img_shape: last two dimensions of data, corresponding to lat and + lon when using FourCastNet conventions + + Returns: + a nn.Module + """ + ... + + +# this is based on the call signature of SphericalFourierNeuralOperatorNet at +# https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L292 # noqa: E501 +@dataclasses.dataclass +class SphericalFourierNeuralOperatorBuilder(ModuleConfig): + spectral_transform: str = "sht" + filter_type: str = "non-linear" + operator_type: str = "diagonal" + scale_factor: int = 16 + embed_dim: int = 256 + num_layers: int = 12 + num_blocks: int = 16 + hard_thresholding_fraction: float = 1.0 + normalization_layer: str = "instance_norm" + use_mlp: bool = True + activation_function: str = "gelu" + encoder_layers: int = 1 + pos_embed: bool = True + big_skip: bool = True + rank: float = 1.0 + factorization: Optional[str] = None + separable: bool = False + complex_network: bool = True + complex_activation: str = "real" + spectral_layers: int = 1 + checkpointing: int = 0 + data_grid: Literal["legendre-gauss", "equiangular"] = "legendre-gauss" + drop_rate: float = 0.0 + drop_path_rate: float = 0.0 + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ): + from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet + + sfno_net = SphericalFourierNeuralOperatorNet( + params=self, + in_chans=n_in_channels, + out_chans=n_out_channels, + img_shape=img_shape, + drop_rate=self.drop_rate, + drop_path_rate=self.drop_path_rate, + ) + + # Patch in the grid that our data lies on rather than the one which is + # hard-coded in the modulus codebase [1]. Duplicate the code to compute + # the number of SHT modes determined by hard_thresholding_fraction. Note + # that this does not handle the distributed case which is handled by + # L518 [2] in their codebase. + + # [1] https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py # noqa: E501 + # [2] https://github.com/NVIDIA/modulus/blob/b8e27c5c4ebc409e53adaba9832138743ede2785/modulus/models/sfno/sfnonet.py#L518 # noqa: E501 + nlat, nlon = img_shape + modes_lat = int(nlat * self.hard_thresholding_fraction) + modes_lon = int((nlon // 2 + 1) * self.hard_thresholding_fraction) + sht = harmonics.RealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=self.data_grid).float() + isht = harmonics.InverseRealSHT(nlat, nlon, lmax=modes_lat, mmax=modes_lon, grid=self.data_grid).float() + + sfno_net.trans_down = sht + sfno_net.itrans_up = isht + + return sfno_net + + +@dataclasses.dataclass +class PreBuiltBuilder(ModuleConfig): + """ + A simple module configuration which returns a pre-defined module. + + Used mainly for testing. + """ + + module: nn.Module + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + return self.module + + +NET_REGISTRY: Mapping[str, Type[ModuleConfig]] = { + # "afno": AFNONetBuilder, # using short acronym for backwards compatibility + "SphericalFourierNeuralOperatorNet": SphericalFourierNeuralOperatorBuilder, # type: ignore # noqa: E501 + "prebuilt": PreBuiltBuilder, +} + + +@dataclasses.dataclass +class ModuleSelector: + """ + A dataclass containing all the information needed to build a ModuleConfig, + including the type of the ModuleConfig and the data needed to build it. + + This is helpful as ModuleSelector can be serialized and deserialized + without any additional information, whereas to load a ModuleConfig you + would need to know the type of the ModuleConfig being loaded. + + It is also convenient because ModuleSelector is a single class that can be + used to represent any ModuleConfig, whereas ModuleConfig is a protocol + that can be implemented by many different classes. + + Attributes: + type: the type of the ModuleConfig + config: data for a ModuleConfig instance of the indicated type + """ + + type: Literal[ + "afno", + "SphericalFourierNeuralOperatorNet", + "prebuilt", + ] + config: Mapping[str, Any] + + def build( + self, + n_in_channels: int, + n_out_channels: int, + img_shape: Tuple[int, int], + ) -> nn.Module: + """ + Build a nn.Module given information about the input and output channels + and the image shape. + + Args: + n_in_channels: number of input channels + n_out_channels: number of output channels + img_shape: last two dimensions of data, corresponding to lat and + lon when using FourCastNet conventions + + Returns: + a nn.Module + """ + return NET_REGISTRY[self.type](**self.config).build( + n_in_channels=n_in_channels, + n_out_channels=n_out_channels, + img_shape=img_shape, + ) + + def get_state(self) -> Mapping[str, Any]: + """ + Get a dictionary containing all the information needed to build a ModuleConfig. + """ + return {"type": self.type, "config": self.config} + + @classmethod + def from_state(cls, state: Mapping[str, Any]) -> "ModuleSelector": + """ + Create a ModuleSelector from a dictionary containing all the information + needed to build a ModuleConfig. + """ + return cls(**state) diff --git a/src/ace_inference/training/train.py b/src/ace_inference/training/train.py new file mode 100644 index 0000000..834798e --- /dev/null +++ b/src/ace_inference/training/train.py @@ -0,0 +1,418 @@ +# BSD 3-Clause License +# +# Copyright (c) 2022, FourCastNet authors +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# The code was authored by the following people: +# +# Jaideep Pathak - NVIDIA Corporation +# Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory +# Peter Harrington - NERSC, Lawrence Berkeley National Laboratory +# Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory +# Ashesh Chattopadhyay - Rice University +# Morteza Mardani - NVIDIA Corporation +# Thorsten Kurth - NVIDIA Corporation +# David Hall - NVIDIA Corporation +# Zongyi Li - California Institute of Technology, NVIDIA Corporation +# Kamyar Azizzadenesheli - Purdue University +# Pedram Hassanzadeh - Rice University +# Karthik Kashinath - NVIDIA Corporation +# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation + +import argparse +import contextlib +import dataclasses +import logging +import os +import time +from typing import Optional + +import dacite +import fme +import torch +import yaml +from fme.core.aggregator import InferenceAggregator, OneStepAggregator, TrainAggregator +from fme.core.aggregator.null import NullAggregator +from fme.core.data_loading.getters import get_data_loader, get_inference_data +from fme.core.distributed import Distributed +from fme.core.optimization import NullOptimization +from fme.core.wandb import WandB +from fme.training.inference import run_inference +from fme.training.inference.derived_variables import ( + compute_stepped_derived_quantities, +) +from fme.training.train_config import TrainConfig +from fme.training.utils import gcs_utils, logging_utils + + +class Trainer: + def count_parameters(self): + parameters = 0 + for module in self.stepper.modules: + for parameter in module.parameters(): + if parameter.requires_grad: + parameters += parameter.numel() + return parameters + + def __init__(self, config: TrainConfig): + self.dist = Distributed.get_instance() + if self.dist.is_root(): + if not os.path.isdir(config.experiment_dir): + os.makedirs(config.experiment_dir) + if not os.path.isdir(config.checkpoint_dir): + os.makedirs(config.checkpoint_dir) + self.config = config + + data_requirements = config.stepper.get_data_requirements(n_forward_steps=self.config.n_forward_steps) + logging.info("rank %d, begin data loader init" % self.dist.rank) + self.train_data = get_data_loader( + config.train_loader, + requirements=data_requirements, + train=True, + ) + self.valid_data = get_data_loader( + config.validation_loader, + requirements=data_requirements, + train=False, + ) + logging.info("rank %d, data loader initialized" % self.dist.rank) + for gridded_data, name in zip((self.train_data, self.valid_data), ("train", "valid")): + n_samples = len(gridded_data.loader.dataset) + n_batches = len(gridded_data.loader) + logging.info(f"{name} data: {n_samples} samples, {n_batches} batches") + + self.num_batches_seen = 0 + self.startEpoch = 0 + + self._model_epoch = self.startEpoch + self.num_batches_seen = 0 + + for batch in self.train_data.loader: + shapes = {k: v.shape for k, v in batch.data.items()} + for value in shapes.values(): + img_shape = value[-2:] + break + break + logging.info("Starting model initialization") + self.stepper = config.stepper.get_stepper( + img_shape=img_shape, + area=self.train_data.area_weights, + sigma_coordinates=self.train_data.sigma_coordinates, + ) + self.optimization = config.optimization.build(self.stepper.module.parameters(), config.max_epochs) + self._base_weights = self.config.stepper.get_base_weights() + self._copy_after_batch = config.copy_weights_after_batch + self._no_optimization = NullOptimization() + + if config.resuming: + logging.info("Loading checkpoint %s" % config.latest_checkpoint_path) + self.restore_checkpoint(config.latest_checkpoint_path) + + wandb = WandB.get_instance() + wandb.watch(self.stepper.modules) + + logging.info(f"Number of trainable model parameters: {self.count_parameters()}") + inference_data_requirements = dataclasses.replace(data_requirements) + inference_data_requirements.n_timesteps = config.inference.n_forward_steps + 1 + + self._inference_data = get_inference_data( + config.inference.loader, + config.inference.forward_steps_in_memory, + inference_data_requirements, + ) + + self._ema = self.config.ema.build(self.stepper.modules) + + def switch_off_grad(self, model): + for param in model.parameters(): + param.requires_grad = False + + def train(self): + logging.info("Starting Training Loop...") + + best_valid_loss = torch.inf + best_inference_error = torch.inf + self._model_epoch = self.startEpoch + inference_epochs = list(range(0, self.config.max_epochs))[self.config.inference.epochs.slice] + if self.config.segment_epochs is None: + segment_max_epochs = self.config.max_epochs + else: + segment_max_epochs = min(self.startEpoch + self.config.segment_epochs, self.config.max_epochs) + # "epoch" describes the loop, self._model_epoch describes model weights + # needed so we can describe the loop even after weights are updated + for epoch in range(self.startEpoch, segment_max_epochs): + logging.info(f"Epoch: {epoch+1}") + if self.train_data.sampler is not None: + self.train_data.sampler.set_epoch(epoch) + + start_time = time.time() + logging.info(f"Starting training step on epoch {epoch + 1}") + train_logs = self.train_one_epoch() + train_end = time.time() + logging.info(f"Starting validation step on epoch {epoch + 1}") + valid_logs = self.validate_one_epoch() + valid_end = time.time() + if epoch in inference_epochs: + logging.info(f"Starting inference step on epoch {epoch + 1}") + inference_logs = self.inference_one_epoch() + inference_end: Optional[float] = time.time() + else: + inference_logs = {} + inference_end = None + + train_loss = train_logs["train/mean/loss"] + valid_loss = valid_logs["val/mean/loss"] + inference_error = inference_logs.get("inference/time_mean_norm/rmse/channel_mean", None) + # need to get the learning rate before stepping the scheduler + lr = self.optimization.learning_rate + self.optimization.step_scheduler(valid_loss) + + if self.dist.is_root(): + if self.config.save_checkpoint: + # checkpoint at the end of every epoch + self.save_checkpoint(self.config.latest_checkpoint_path) + if self.config.epoch_checkpoint_enabled(self._model_epoch): + self.save_checkpoint(self.config.epoch_checkpoint_path(self._model_epoch)) + if self.config.validate_using_ema: + best_checkpoint_context = self._ema_context + else: + best_checkpoint_context = contextlib.nullcontext + with best_checkpoint_context(): + if valid_loss <= best_valid_loss: + self.save_checkpoint(self.config.best_checkpoint_path) + best_valid_loss = valid_loss + if inference_error is not None and (inference_error <= best_inference_error): + self.save_checkpoint(self.config.best_inference_checkpoint_path) + best_inference_error = inference_error + with self._ema_context(): + self.save_checkpoint(self.config.ema_checkpoint_path) + + time_elapsed = time.time() - start_time + logging.info(f"Time taken for epoch {epoch + 1} is {time_elapsed} sec") + logging.info(f"Train loss: {train_loss}. Valid loss: {valid_loss}") + + logging.info("Logging to wandb") + all_logs = { + **train_logs, + **valid_logs, + **inference_logs, + **{ + "lr": lr, + "epoch": epoch, + "epoch_train_seconds": train_end - start_time, + "epoch_validation_seconds": valid_end - train_end, + "epoch_total_seconds": time_elapsed, + }, + } + if inference_end is not None: + all_logs["epoch_inference_seconds"] = inference_end - valid_end + wandb = WandB.get_instance() + wandb.log(all_logs, step=self.num_batches_seen) + if segment_max_epochs == self.config.max_epochs: + self.config.clean_wandb() + + def train_one_epoch(self): + """Train for one epoch and return logs from TrainAggregator.""" + wandb = WandB.get_instance() + aggregator = TrainAggregator() + if self.num_batches_seen == 0: + # Before training, log the loss on the first batch. + with torch.no_grad(): + batch = next(iter(self.train_data.loader)) + stepped = self.stepper.run_on_batch( + batch.data, + optimization=self._no_optimization, + n_forward_steps=self.config.n_forward_steps, + ) + + if self.config.log_train_every_n_batches > 0: + with torch.no_grad(): + metrics = { + f"batch_{name}": self.dist.reduce_mean(metric) + for name, metric in sorted(stepped.metrics.items()) + } + wandb.log(metrics, step=self.num_batches_seen) + for batch in self.train_data.loader: + stepped = self.stepper.run_on_batch( + batch.data, + self.optimization, + n_forward_steps=self.config.n_forward_steps, + aggregator=aggregator, + ) + if self._base_weights is not None: + self._copy_after_batch.apply(weights=self._base_weights, modules=self.stepper.modules) + self._ema(model=self.stepper.modules) + self.num_batches_seen += 1 + if ( + self.config.log_train_every_n_batches > 0 + and self.num_batches_seen % self.config.log_train_every_n_batches == 0 + ): + with torch.no_grad(): + metrics = { + f"batch_{name}": self.dist.reduce_mean(metric) + for name, metric in sorted(stepped.metrics.items()) + } + wandb.log(metrics, step=self.num_batches_seen) + self._model_epoch += 1 + + return aggregator.get_logs(label="train") + + @contextlib.contextmanager + def _validation_context(self): + """ + The context for running validation. + + In this context, the stepper uses the EMA model if + `self.config.validate_using_ema` is True. + """ + if self.config.validate_using_ema: + with self._ema_context(): + yield + else: + yield + + @contextlib.contextmanager + def _ema_context(self): + """ + A context where the stepper uses the EMA model. + """ + self._ema.store(parameters=self.stepper.modules.parameters()) + self._ema.copy_to(model=self.stepper.modules) + try: + yield + finally: + self._ema.restore(parameters=self.stepper.modules.parameters()) + + def validate_one_epoch(self): + aggregator = OneStepAggregator( + self.train_data.area_weights.to(fme.get_device()), + self.train_data.sigma_coordinates, + self.train_data.metadata, + ) + + with torch.no_grad(), self._validation_context(): + for batch in self.valid_data.loader: + stepped = self.stepper.run_on_batch( + batch.data, + optimization=NullOptimization(), + n_forward_steps=self.config.n_forward_steps, + aggregator=NullAggregator(), + ) + stepped = compute_stepped_derived_quantities(stepped, self.valid_data.sigma_coordinates) + aggregator.record_batch( + loss=stepped.metrics["loss"], + target_data=stepped.target_data, + gen_data=stepped.gen_data, + target_data_norm=stepped.target_data_norm, + gen_data_norm=stepped.gen_data_norm, + ) + return aggregator.get_logs(label="val") + + def inference_one_epoch(self): + record_step_20 = self.config.inference.n_forward_steps >= 20 + aggregator = InferenceAggregator( + self.train_data.area_weights.to(fme.get_device()), + self.train_data.sigma_coordinates, + record_step_20=record_step_20, + log_video=False, + log_zonal_mean_images=True, + n_timesteps=self.config.inference.n_forward_steps + 1, + enable_extended_videos=False, + metadata=self.train_data.metadata, + ) + with torch.no_grad(), self._validation_context(): + run_inference( + aggregator=aggregator, + stepper=self.stepper, + data=self._inference_data, + n_forward_steps=self.config.inference.n_forward_steps, + forward_steps_in_memory=self.config.inference.forward_steps_in_memory, + ) + logs = aggregator.get_logs(label="inference") + return logs + + def save_checkpoint(self, checkpoint_path): + torch.save( + { + "num_batches_seen": self.num_batches_seen, + "epoch": self._model_epoch, + "stepper": self.stepper.get_state(), + "optimization": self.optimization.get_state(), + }, + checkpoint_path, + ) + + def restore_checkpoint(self, checkpoint_path): + _restore_checkpoint(self, checkpoint_path) + + +def _restore_checkpoint(trainer: Trainer, checkpoint_path): + # separated into a function only to make it easier to mock + checkpoint = torch.load(checkpoint_path, map_location=fme.get_device()) + # restore checkpoint is used for finetuning as well as resuming. + # If finetuning (i.e., not resuming), restore checkpoint + # does not load optimizer state, instead uses config specified lr. + trainer.stepper.load_state(checkpoint["stepper"]) + trainer.optimization.load_state(checkpoint["optimization"]) + trainer.num_batches_seen = checkpoint["num_batches_seen"] + trainer.startEpoch = checkpoint["epoch"] + + +def main(yaml_config: str): + dist = Distributed.get_instance() + if fme.using_gpu(): + torch.backends.cudnn.benchmark = True + with open(yaml_config, "r") as f: + data = yaml.safe_load(f) + train_config: TrainConfig = dacite.from_dict( + data_class=TrainConfig, + data=data, + config=dacite.Config(strict=True), + ) + + if not os.path.isdir(train_config.experiment_dir): + os.makedirs(train_config.experiment_dir) + with open(os.path.join(train_config.experiment_dir, "config.yaml"), "w") as f: + yaml.dump(data, f, default_flow_style=False, sort_keys=False) + train_config.configure_logging(log_filename="out.log") + env_vars = logging_utils.retrieve_env_vars() + gcs_utils.authenticate() + logging_utils.log_versions() + beaker_url = logging_utils.log_beaker_url() + train_config.configure_wandb(env_vars=env_vars, resume=True, notes=beaker_url) + trainer = Trainer(train_config) + trainer.train() + logging.info("DONE ---- rank %d" % dist.rank) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--yaml_config", required=True, type=str) + + args = parser.parse_args() + main(yaml_config=args.yaml_config) diff --git a/src/ace_inference/training/train_config.py b/src/ace_inference/training/train_config.py new file mode 100644 index 0000000..34fe6a4 --- /dev/null +++ b/src/ace_inference/training/train_config.py @@ -0,0 +1,253 @@ +import dataclasses +import logging +import os +import warnings +from typing import Any, Mapping, Optional, Union + +from fme.core import SingleModuleStepperConfig +from fme.core.data_loading.inference import InferenceDataLoaderParams +from fme.core.data_loading.params import DataLoaderParams, Slice +from fme.core.dicts import to_flat_dict +from fme.core.distributed import Distributed +from fme.core.ema import EMATracker +from fme.core.optimization import OptimizationConfig +from fme.core.stepper import ExistingStepperConfig +from fme.core.wandb import WandB +from fme.core.weight_ops import CopyWeightsConfig + + +@dataclasses.dataclass +class LoggingConfig: + project: str = "training" + entity: str = "ai2cm" + log_to_screen: bool = True + log_to_file: bool = True + log_to_wandb: bool = True + log_format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + def __post_init__(self): + self._dist = Distributed.get_instance() + + def configure_logging(self, experiment_dir: str, log_filename: str): + """ + Configure the global `logging` module based on this LoggingConfig. + """ + if self.log_to_screen and self._dist.is_root(): + logging.basicConfig(format=self.log_format, level=logging.INFO) + elif self._dist.is_root(): + logging.basicConfig(level=logging.WARNING) + else: # we are not root + logging.basicConfig(level=logging.ERROR) + logger = logging.getLogger() + if self.log_to_file and self._dist.is_root(): + if not os.path.exists(experiment_dir): + raise ValueError(f"experiment directory {experiment_dir} does not exist, " "cannot log files to it") + log_path = os.path.join(experiment_dir, log_filename) + fh = logging.FileHandler(log_path) + fh.setLevel(logging.INFO) + fh.setFormatter(logging.Formatter(self.log_format)) + logger.addHandler(fh) + + def configure_wandb(self, config: Mapping[str, Any], **kwargs): + # must ensure wandb.configure is called before wandb.init + wandb = WandB.get_instance() + wandb.configure(log_to_wandb=self.log_to_wandb) + wandb.init( + config=config, + project=self.project, + entity=self.entity, + dir=config["experiment_dir"], + **kwargs, + ) + + def clean_wandb(self, experiment_dir: str): + wandb = WandB.get_instance() + wandb.clean_wandb_dir(experiment_dir=experiment_dir) + + +@dataclasses.dataclass +class InlineInferenceConfig: + """ + Attributes: + loader: configuration for the data loader used during inference + n_forward_steps: number of forward steps to take + forward_steps_in_memory: number of forward steps to take before + re-reading data from disk + epochs: epochs on which to run inference, where the first epoch is + defined as epoch 0 (unlike in logs which show epochs as starting + from 1). By default runs inference every epoch. + """ + + loader: InferenceDataLoaderParams + n_forward_steps: int = 2 + forward_steps_in_memory: int = 2 + epochs: Optional[Slice] = None + parallel: Optional[bool] = None + + def __post_init__(self): + if self.epochs is None: + self.epochs = Slice(start=0, stop=None, step=1) + if self.n_forward_steps % self.forward_steps_in_memory != 0: + raise ValueError( + "n_forward_steps must be divisible by steps_in_memory, " + f"got {self.n_forward_steps} and {self.forward_steps_in_memory}" + ) + dist = Distributed.get_instance() + if self.loader.n_samples % dist.world_size != 0: + raise ValueError( + "batch_size must be divisible by the number of parallel " + f"workers, got {self.batch_size} and {dist.world_size}" + ) + if self.parallel is not None: + if self.parallel: + warnings.warn( + ( + "The 'parallel' argument is deprecated and will be ignored. " + "Inline inference is now always performed in parallel. " + "There's no need to specify this argument in future uses " + "of this function." + ), + category=DeprecationWarning, + ) + elif not self.parallel: + raise ValueError("parallel=False is no longer supported") + + +@dataclasses.dataclass +class EMAConfig: + """ + Configuration for exponential moving average of model weights. + + Attributes: + decay: decay rate for the moving average + """ + + decay: float = 0.9999 + + def build(self, model): + return EMATracker(model, decay=self.decay, faster_decay_at_start=True) + + +@dataclasses.dataclass +class TrainConfig: + """ + Configuration for training a model. + + Attributes: + train_loader: configuration for the training data loader + validation_loader: configuration for the validation data loader + stepper: configuration for the stepper + optimization: configuration for the optimization + logging: configuration for logging + max_epochs: total number of epochs to train for + save_checkpoint: whether to save checkpoints + experiment_dir: directory where checkpoints and logs are saved + inference: configuration for inline inference + n_forward_steps: number of forward steps to take gradient over + copy_weights_after_batch: Configuration for copying weights from the + base model to the training model after each batch. This is used + to achieve an effect of freezing model parameters that can freeze + a subset of each weight that comes from a smaller base weight. + This is less efficient than true parameter freezing, but layer + freezing is all-or-nothing for each parameter. By default, no + weights are copied. + checkpoint_save_epochs: how often to save epoch-based checkpoints, + if save_checkpoint is True. If None, checkpoints are only saved + for the most recent epoch and the best epoch. + log_train_every_n_batches: how often to log batch_loss during training + segment_epochs: (optional) exit after training for at most this many epochs + in current job, without exceeding `max_epochs`. Use this if training + must be run in segments, e.g. due to wall clock limit. + """ + + train_loader: DataLoaderParams + validation_loader: DataLoaderParams + stepper: Union[SingleModuleStepperConfig, ExistingStepperConfig] + optimization: OptimizationConfig + logging: LoggingConfig + max_epochs: int + save_checkpoint: bool + experiment_dir: str + inference: InlineInferenceConfig + n_forward_steps: int + copy_weights_after_batch: CopyWeightsConfig = dataclasses.field( + default_factory=lambda: CopyWeightsConfig(exclude=["*"]) + ) + ema: EMAConfig = dataclasses.field(default_factory=lambda: EMAConfig()) + validate_using_ema: bool = False + checkpoint_save_epochs: Optional[Slice] = None + log_train_every_n_batches: int = 100 + segment_epochs: Optional[int] = None + checkpoint_every_n_epochs: Optional[int] = None + + def __post_init__(self): + if self.checkpoint_every_n_epochs is not None: + warnings.warn( + "checkpoint_every_n_epochs is deprecated, use" "checkpoint_save_epochs instead.", + category=DeprecationWarning, + ) + self.checkpoint_save_epochs = Slice( + start=0, + stop=self.max_epochs, + step=self.checkpoint_every_n_epochs, + ) + + @property + def checkpoint_dir(self) -> str: + """ + The directory where checkpoints are saved. + """ + return os.path.join(self.experiment_dir, "training_checkpoints") + + @property + def latest_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "ckpt.tar") + + @property + def best_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "best_ckpt.tar") + + @property + def best_inference_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "best_inference_ckpt.tar") + + @property + def ema_checkpoint_path(self) -> str: + return os.path.join(self.checkpoint_dir, "ema_ckpt.tar") + + def epoch_checkpoint_path(self, epoch: int) -> str: + return os.path.join(self.checkpoint_dir, f"ckpt_{epoch:04d}.tar") + + def epoch_checkpoint_enabled(self, epoch: int) -> bool: + return epoch_checkpoint_enabled(epoch, self.max_epochs, self.checkpoint_save_epochs) + + @property + def resuming(self) -> bool: + checkpoint_file_exists = os.path.isfile(self.latest_checkpoint_path) + resuming = True if checkpoint_file_exists else False + return resuming + + def configure_logging(self, log_filename: str): + self.logging.configure_logging(self.experiment_dir, log_filename) + + def configure_wandb(self, env_vars: Optional[Mapping[str, str]] = None, **kwargs): + config = to_flat_dict(dataclasses.asdict(self)) + if "environment" in config: + logging.warning("Not recording env vars since 'environment' is in config.") + elif env_vars is not None: + config["environment"] = env_vars + self.logging.configure_wandb(config=config, **kwargs) + + def log(self): + logging.info("------------------ Configuration ------------------") + logging.info(str(self)) + logging.info("---------------------------------------------------") + + def clean_wandb(self): + self.logging.clean_wandb(experiment_dir=self.experiment_dir) + + +def epoch_checkpoint_enabled(epoch: int, max_epochs: int, save_epochs: Optional[Slice]) -> bool: + if save_epochs is None: + return False + return epoch in range(max_epochs)[save_epochs.slice] diff --git a/src/ace_inference/training/utils/__init__.py b/src/ace_inference/training/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/ace_inference/training/utils/darcy_loss.py b/src/ace_inference/training/utils/darcy_loss.py new file mode 100644 index 0000000..88af270 --- /dev/null +++ b/src/ace_inference/training/utils/darcy_loss.py @@ -0,0 +1,350 @@ +# MIT License +# +# Copyright (c) 2020 Zongyi Li +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import h5py +import numpy as np +import scipy.io +import torch +import torch.nn as nn + + +################################################# +# +# Utilities +# +################################################# +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + +# reading data +class MatReader(object): + def __init__(self, file_path, to_torch=True, to_cuda=False, to_float=True): + super(MatReader, self).__init__() + + self.to_torch = to_torch + self.to_cuda = to_cuda + self.to_float = to_float + + self.file_path = file_path + + self.data = None + self.old_mat = None + self._load_file() + + def _load_file(self): + try: + self.data = scipy.io.loadmat(self.file_path) + self.old_mat = True + except: # noqa: E722 + self.data = h5py.File(self.file_path) + self.old_mat = False + + def load_file(self, file_path): + self.file_path = file_path + self._load_file() + + def read_field(self, field): + x = self.data[field] + + if not self.old_mat: + x = x[()] + x = np.transpose(x, axes=range(len(x.shape) - 1, -1, -1)) + + if self.to_float: + x = x.astype(np.float32) + + if self.to_torch: + x = torch.from_numpy(x) + + if self.to_cuda: + x = x.cuda() + + return x + + def set_cuda(self, to_cuda): + self.to_cuda = to_cuda + + def set_torch(self, to_torch): + self.to_torch = to_torch + + def set_float(self, to_float): + self.to_float = to_float + + +# normalization, pointwise gaussian +class UnitGaussianNormalizer(object): + def __init__(self, x, eps=0.00001): + super(UnitGaussianNormalizer, self).__init__() + + # x could be in shape of ntrain*n or ntrain*T*n or ntrain*n*T + self.mean = torch.mean(x, 0) + self.std = torch.std(x, 0) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x.float() + + def decode(self, x, sample_idx=None): + if sample_idx is None: + std = self.std + self.eps # n + mean = self.mean + else: + if len(self.mean.shape) == len(sample_idx[0].shape): + std = self.std[sample_idx] + self.eps # batch*n + mean = self.mean[sample_idx] + if len(self.mean.shape) > len(sample_idx[0].shape): + std = self.std[:, sample_idx] + self.eps # T*batch*n + mean = self.mean[:, sample_idx] + + # x is in shape of batch*n or T*batch*n + x = (x * std) + mean + return x.float() + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + +# normalization, Gaussian +class GaussianNormalizer(object): + def __init__(self, x, eps=0.00001): + super(GaussianNormalizer, self).__init__() + + self.mean = torch.mean(x) + self.std = torch.std(x) + self.eps = eps + + def encode(self, x): + x = (x - self.mean) / (self.std + self.eps) + return x + + def decode(self, x, sample_idx=None): + x = (x * (self.std + self.eps)) + self.mean + return x + + def cuda(self): + self.mean = self.mean.cuda() + self.std = self.std.cuda() + + def cpu(self): + self.mean = self.mean.cpu() + self.std = self.std.cpu() + + +# normalization, scaling by range +class RangeNormalizer(object): + def __init__(self, x, low=0.0, high=1.0): + super(RangeNormalizer, self).__init__() + mymin = torch.min(x, 0)[0].view(-1) + mymax = torch.max(x, 0)[0].view(-1) + + self.a = (high - low) / (mymax - mymin) + self.b = -self.a * mymax + high + + def encode(self, x): + s = x.size() + x = x.view(s[0], -1) + x = self.a * x + self.b + x = x.view(s) + return x + + def decode(self, x): + s = x.size() + x = x.view(s[0], -1) + x = (x - self.b) / self.a + x = x.view(s) + return x + + +# loss function with rel/abs Lp loss +class LpLoss(object): + def __init__(self, d=2, p=2, size_average=True, reduction=True): + super(LpLoss, self).__init__() + + # Dimension and Lp-norm type are postive + assert d > 0 and p > 0 + + self.d = d + self.p = p + self.reduction = reduction + self.size_average = size_average + + def abs(self, x, y): + num_examples = x.size()[0] + + # Assume uniform mesh + h = 1.0 / (x.size()[1] - 1.0) + + all_norms = (h ** (self.d / self.p)) * torch.norm( + x.view(num_examples, -1) - y.view(num_examples, -1), self.p, 1 + ) + + if self.reduction: + if self.size_average: + return torch.mean(all_norms) + else: + return torch.sum(all_norms) + + return all_norms + + def rel(self, x, y): + num_examples = x.size()[0] + + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + + if self.reduction: + if self.size_average: + return torch.mean(diff_norms / y_norms) + else: + return torch.sum(diff_norms / y_norms) + + return diff_norms / y_norms + + def __call__(self, x, y): + return self.rel(x, y) + + +# Sobolev norm (HS norm) +# where we also compare the numerical derivatives between the output and target +class HsLoss(object): + def __init__(self, d=2, p=2, k=1, a=None, group=False, size_average=True, reduction=True): + super(HsLoss, self).__init__() + + # Dimension and Lp-norm type are postive + assert d > 0 and p > 0 + + self.d = d + self.p = p + self.k = k + self.balanced = group + self.reduction = reduction + self.size_average = size_average + + if a is None: + a = [ + 1, + ] * k + self.a = a + + def rel(self, x, y): + num_examples = x.size()[0] + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + if self.reduction: + if self.size_average: + return torch.mean(diff_norms / y_norms) + else: + return torch.sum(diff_norms / y_norms) + return diff_norms / y_norms + + def __call__(self, x, y, a=None): + nx = x.size()[1] + ny = x.size()[2] + k = self.k + balanced = self.balanced + a = self.a + x = x.view(x.shape[0], nx, ny, -1) + y = y.view(y.shape[0], nx, ny, -1) + + k_x = ( + torch.cat( + ( + torch.arange(start=0, end=nx // 2, step=1), + torch.arange(start=-nx // 2, end=0, step=1), + ), + 0, + ) + .reshape(nx, 1) + .repeat(1, ny) + ) + k_y = ( + torch.cat( + ( + torch.arange(start=0, end=ny // 2, step=1), + torch.arange(start=-ny // 2, end=0, step=1), + ), + 0, + ) + .reshape(1, ny) + .repeat(nx, 1) + ) + k_x = torch.abs(k_x).reshape(1, nx, ny, 1).to(x.device) + k_y = torch.abs(k_y).reshape(1, nx, ny, 1).to(x.device) + + x = torch.fft.fftn(x, dim=[1, 2]) + y = torch.fft.fftn(y, dim=[1, 2]) + + if not balanced: + weight = 1 + if k >= 1: + weight += a[0] ** 2 * (k_x**2 + k_y**2) + if k >= 2: + weight += a[1] ** 2 * (k_x**4 + 2 * k_x**2 * k_y**2 + k_y**4) + weight = torch.sqrt(weight) + loss = self.rel(x * weight, y * weight) + else: + loss = self.rel(x, y) + if k >= 1: + weight = a[0] * torch.sqrt(k_x**2 + k_y**2) + loss += self.rel(x * weight, y * weight) + if k >= 2: + weight = a[1] * torch.sqrt(k_x**4 + 2 * k_x**2 * k_y**2 + k_y**4) + loss += self.rel(x * weight, y * weight) + loss = loss / (k + 1) + + return loss + + +# A simple feedforward neural network +class DenseNet(torch.nn.Module): + def __init__(self, layers, nonlinearity, out_nonlinearity=None, normalize=False): + super(DenseNet, self).__init__() + + self.n_layers = len(layers) - 1 + + assert self.n_layers >= 1 + + self.layers = nn.ModuleList() + + for j in range(self.n_layers): + self.layers.append(nn.Linear(layers[j], layers[j + 1])) + + if j != self.n_layers - 1: + if normalize: + self.layers.append(nn.BatchNorm1d(layers[j + 1])) + + self.layers.append(nonlinearity()) + + if out_nonlinearity is not None: + self.layers.append(out_nonlinearity()) + + def forward(self, x): + for _, layer in enumerate(self.layers): + x = layer(x) + + return x diff --git a/src/ace_inference/training/utils/data_loader_fv3gfs.py b/src/ace_inference/training/utils/data_loader_fv3gfs.py new file mode 100644 index 0000000..66d06ff --- /dev/null +++ b/src/ace_inference/training/utils/data_loader_fv3gfs.py @@ -0,0 +1,252 @@ +import logging +import os +from typing import Callable, List, Optional + +import netCDF4 +import numpy as np +import torch +from torch.utils.data import Dataset + + +def load_series_data_sequential(idx: int, ds: netCDF4.MFDataset, horizon: int, names: List[str]): + # flip the lat dimension so that it is increasing + arrays = { + n: torch.as_tensor(np.flip(ds.variables[n][idx : idx + horizon + 1, :, :], axis=-2).copy()) for n in names + } + return {"dynamics": arrays} + + +def load_series_data_direct(idx: int, ds: netCDF4.MFDataset, horizon: int, names: List[str]): + # flip the lat dimension so that it is increasing + arrays = { + n: torch.as_tensor( + np.flip( + np.stack([ds.variables[n][idx, :, :], ds.variables[n][idx + horizon, :, :]], axis=0), axis=-2 + ).copy() + ) + for n in names + } + return {"data": arrays} + + +def load_series_data_multistep_randomized( + idx: int, + ds: netCDF4.MFDataset, + horizon: int, + names: List[str], + random_timestep: Optional[int] = None, + is_forcing: bool = False, +): + random_timestep = random_timestep or np.random.randint(1, horizon + 1) # in [1, horizon] + arrays = { + n: torch.as_tensor( + np.flip( + np.stack( + [ + ds.variables[n][idx, :, :], # first step/initial conditions + ds.variables[n][idx + random_timestep, :, :], # random step + ], + axis=0, + ), + axis=-2, + ).copy() + ) + for n in names + } + return {"data": arrays, "random_timestep": torch.as_tensor(random_timestep, dtype=torch.long)} + + +def load_series_data_multistep_interpolation( + idx: int, + ds: netCDF4.MFDataset, + horizon: int, + names: List[str], + random_timestep: Optional[int] = None, + is_forcing: bool = False, +): + # Note that for interpolation, the condition/forcings willl only be used from the interpolation step + random_timestep = random_timestep or np.random.randint(1, horizon) # in [1, horizon - 1] + + def get_time_data(name): + if False: # is_forcing: + return ds.variables[name][idx + random_timestep, :, :] + else: + return np.stack( + [ + ds.variables[name][idx, :, :], # first step/initial conditions + ds.variables[name][idx + random_timestep, :, :], # random step + ds.variables[name][idx + horizon, :, :], # last step + ], + axis=0, + ) + + arrays = {n: torch.as_tensor(np.flip(get_time_data(n), axis=-2).copy()) for n in names} + return {"data": arrays, "random_timestep": torch.as_tensor(random_timestep)} + + +class FV3GFSDataset(Dataset): + def __init__( + self, + path: str, + in_names: List[str], + out_names: List[str], + all_names: List[str], + forcing_names: List[str], + horizon: int, + multistep_strategy: Optional[str] = None, + n_samples: Optional[int] = None, + min_idx_shift: int = 0, + forcing_packer: Optional[Callable] = None, + forcing_normalizer: Optional[Callable] = None, + split_id: Optional[str] = None, + ): + assert n_samples is None or n_samples > 0, f"Invalid n_samples {n_samples}" + assert min_idx_shift >= 0, f"Invalid min_idx_shift {min_idx_shift}" + self.names = all_names + self.in_names = in_names + self.out_names = out_names + self.in_or_out_names = list(set(all_names) - set(forcing_names)) + self.forcing_names = forcing_names if len(forcing_names) > 0 else None + self.forcing_packer = forcing_packer + if self.forcing_packer is not None: + assert self.forcing_packer.axis is not None, f"Forcing packer {self.forcing_packer} must have axis set" + self.forcing_normalizer = forcing_normalizer + self.horizon = horizon + self.n_in_channels = len(self.in_names) + self.n_out_channels = len(self.out_names) + self.multistep_strategy = multistep_strategy + self.path = path + self.full_path = os.path.join(path, "*.nc") + self.split_id = split_id + + self._get_files_stats() + self.min_idx_shift = min_idx_shift # Used to shift the indices to avoid overlap between val & test + if n_samples is not None: + self.n_samples_total = n_samples # Hardcodes max number of samples + + if multistep_strategy != "sequential": + assert horizon > 0, f"Invalid horizon {horizon} for multistep strategy {multistep_strategy}" + + if multistep_strategy == "sequential": + self.load_series_data = load_series_data_sequential + elif multistep_strategy == "random": + self.load_series_data = load_series_data_multistep_randomized + elif multistep_strategy == "interpolation": + self.load_series_data = load_series_data_multistep_interpolation + elif multistep_strategy in [None, "direct"]: + self.load_series_data = load_series_data_direct + else: + raise ValueError(f"Unknown multistep strategy {multistep_strategy}") + + if multistep_strategy == "sequential": + self.main_data_key = "dynamics" + else: + self.main_data_key = "data" + + self.shared_kwargs = dict(horizon=horizon, ds=self.ds) + # print(f'Initialized dataset with {len(self)} samples', horizon, split_id) + + def _get_files_stats(self): + expected_vars = [ + "DLWRFsfc", + "DSWRFsfc", + "DSWRFtoa", + "LHTFLsfc", + "PRATEsfc", + "PRESsfc", + "SHTFLsfc", + "ULWRFsfc", + "ULWRFtoa", + "USWRFsfc", + "USWRFtoa", + "air_temperature_0", + "air_temperature_1", + "air_temperature_2", + "air_temperature_3", + "air_temperature_4", + "air_temperature_5", + "air_temperature_6", + "air_temperature_7", + "eastward_wind_0", + "eastward_wind_1", + "eastward_wind_2", + "eastward_wind_3", + "eastward_wind_4", + "eastward_wind_5", + "eastward_wind_6", + "eastward_wind_7", + "grid_xt", + "grid_yt", + "land_sea_mask", + "northward_wind_0", + "northward_wind_1", + "northward_wind_2", + "northward_wind_3", + "northward_wind_4", + "northward_wind_5", + "northward_wind_6", + "northward_wind_7", + "pressure_thickness_of_atmospheric_layer_0", + "pressure_thickness_of_atmospheric_layer_1", + "pressure_thickness_of_atmospheric_layer_2", + "pressure_thickness_of_atmospheric_layer_3", + "pressure_thickness_of_atmospheric_layer_4", + "pressure_thickness_of_atmospheric_layer_5", + "pressure_thickness_of_atmospheric_layer_6", + "pressure_thickness_of_atmospheric_layer_7", + "specific_total_water_0", + "specific_total_water_1", + "specific_total_water_2", + "specific_total_water_3", + "specific_total_water_4", + "specific_total_water_5", + "specific_total_water_6", + "specific_total_water_7", + "surface_temperature", + "tendency_of_total_water_path", + "tendency_of_total_water_path_due_to_advection", + "time", + "total_water_path", + ] + logging.info(f"Opening data at {self.full_path}") + self.ds = netCDF4.MFDataset(self.full_path) + self.ds.set_auto_mask(False) + # minus one since don't have an output for the last step + self.n_samples_total = len(self.ds.variables["time"][:]) - self.horizon + # provided ERA5 dataloader gets the "wrong" x/y convention (x is lat, y is lon) + # so we follow that convention here for consistency + if "grid_xt" in self.ds.variables: + self.img_shape_x = len(self.ds.variables["grid_yt"][:]) + self.img_shape_y = len(self.ds.variables["grid_xt"][:]) + else: + self.img_shape_x = len(self.ds.variables["lat"][:]) + self.img_shape_y = len(self.ds.variables["lon"][:]) + logging.info(f"Found {self.n_samples_total} samples.") + logging.info(f"Image shape is {self.img_shape_x} x {self.img_shape_y}.") + missing_vars = set(expected_vars) - set(self.ds.variables) + if len(missing_vars) > 0: + raise ValueError(f"Missing variables: {missing_vars}") + elif len(set(self.ds.variables) - set(expected_vars)) > 0: + logging.warning(f"Found unexpected variables: {set(self.ds.variables) - set(expected_vars)}") + # logging.info(f"Following variables are available: {list(self.ds.variables)}.") + + def __len__(self): + return self.n_samples_total + + def __getitem__(self, idx): + idx = idx + self.min_idx_shift # Shift indices to avoid overlap between val & test + data = self.load_series_data(idx=idx, names=self.in_or_out_names, **self.shared_kwargs) + # data_shape = data[list(data.keys())[0]].shape + # print(f'Loaded data with shape {data_shape}') + # data = TensorDict(data, batch_size=data_shape) + if self.forcing_names is not None: + if self.multistep_strategy in ["random", "interpolation"]: + fkwargs = {"random_timestep": data["random_timestep"], "is_forcing": True} + else: + fkwargs = {} + forcings = self.load_series_data(idx=idx, names=self.forcing_names, **self.shared_kwargs, **fkwargs)[ + self.main_data_key + ] + forcings = self.forcing_packer.pack(self.forcing_normalizer.normalize(forcings)) + data["condition"] = forcings + return data diff --git a/src/ace_inference/training/utils/data_loader_multifiles.py b/src/ace_inference/training/utils/data_loader_multifiles.py new file mode 100644 index 0000000..bc384b8 --- /dev/null +++ b/src/ace_inference/training/utils/data_loader_multifiles.py @@ -0,0 +1,174 @@ +# BSD 3-Clause License +# +# Copyright (c) 2022, FourCastNet authors +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# The code was authored by the following people: +# +# Jaideep Pathak - NVIDIA Corporation +# Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory +# Peter Harrington - NERSC, Lawrence Berkeley National Laboratory +# Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory +# Ashesh Chattopadhyay - Rice University +# Morteza Mardani - NVIDIA Corporation +# Thorsten Kurth - NVIDIA Corporation +# David Hall - NVIDIA Corporation +# Zongyi Li - California Institute of Technology, NVIDIA Corporation +# Kamyar Azizzadenesheli - Purdue University +# Pedram Hassanzadeh - Rice University +# Karthik Kashinath - NVIDIA Corporation +# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation + +import logging +import os +from collections import namedtuple +from typing import List, Mapping + +import netCDF4 +import numpy as np +import torch +from fme.core.device import using_gpu +from fme.core.distributed import Distributed +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.distributed import DistributedSampler + +# import cv2 +from .data_loader_params import DataLoaderParams +from .data_requirements import DataRequirements + + +def get_data_loader( + params: DataLoaderParams, + split: str, + requirements: DataRequirements, +): + assert split in ["train", "validation", "test"], f"Invalid split: {split}" + is_train = split == "train" + dist = Distributed.get_instance() + # TODO: move this default to the DataLoaderParams init + if params.data_type is None: + params.data_type = "ERA5" + if params.data_type == "ERA5": + raise NotImplementedError("ERA5 data loader is not implemented. ") + elif params.data_type in ["FV3GFS", "E3SMV2"]: + data_path = os.path.join(params.data_path, split) + dataset = FV3GFSDataset(params, data_path, requirements=requirements) + if params.num_data_workers > 0: + # netCDF4 __getitem__ fails with + # "RuntimeError: Resource temporarily unavailable" + # if num_data_workers > 0 + # TODO: move this logic to the DataLoaderParams initialization + logging.warning( + f"If data_type=={params.data_type}, must use num_data_workers=0. " + "Got num_data_workers=" + f"{params.num_data_workers}, but it is being set to 0." + ) + params.num_data_workers = 0 + else: + raise NotImplementedError(f"{params.data_type} does not have an implemented data loader") + + sampler = DistributedSampler(dataset, shuffle=is_train) if dist.is_distributed() else None + batch_size = params.batch_size if is_train else params.batch_size_eval + dataloader = DataLoader( + dataset, + batch_size=dist.local_batch_size(int(batch_size)), + num_workers=params.num_data_workers, + shuffle=(sampler is None) and is_train, + sampler=sampler if is_train else None, + drop_last=True, + pin_memory=using_gpu(), + ) + + if is_train: + return dataloader, dataset, sampler + else: + return dataloader, dataset + + +# Old dataset +def load_series_data(idx: int, n_steps: int, ds: netCDF4.MFDataset, names: List[str]): + # flip the lat dimension so that it is increasing + arrays = {n: torch.as_tensor(np.flip(ds.variables[n][idx : idx + n_steps, :, :], axis=-2).copy()) for n in names} + return arrays + + +VariableMetadata = namedtuple("VariableMetadata", ["units", "long_name"]) + + +class FV3GFSDataset(Dataset): + def __init__(self, params: DataLoaderParams, data_path, requirements: DataRequirements): + print("FV3GFSDataset init") + self.params = params + self.in_names = requirements.in_names + self.out_names = requirements.out_names + self.names = requirements.names + self.n_in_channels = len(self.in_names) + self.n_out_channels = len(self.out_names) + self.path = data_path + print("self.path", self.path) + self.full_path = os.path.join(self.path, "*.nc") + self.n_steps = requirements.n_timesteps # one input, one output timestep + print("self.n_steps", self.n_steps) + self._get_files_stats() + if params.n_samples is not None: + self.n_samples_total = params.n_samples + + def _get_files_stats(self): + logging.info(f"Opening data at {self.full_path}") + self.ds = netCDF4.MFDataset(self.full_path) + self.ds.set_auto_mask(False) + # minus one since don't have an output for the last step + self.n_samples_total = len(self.ds.variables["time"][:]) - self.n_steps + 1 + # provided ERA5 dataloader gets the "wrong" x/y convention (x is lat, y is lon) + # so we follow that convention here for consistency + if "grid_xt" in self.ds.variables: + self.img_shape_x = len(self.ds.variables["grid_yt"][:]) + self.img_shape_y = len(self.ds.variables["grid_xt"][:]) + else: + self.img_shape_x = len(self.ds.variables["lat"][:]) + self.img_shape_y = len(self.ds.variables["lon"][:]) + logging.info(f"Found {self.n_samples_total} samples.") + logging.info(f"Image shape is {self.img_shape_x} x {self.img_shape_y}.") + logging.info(f"Following variables are available: {list(self.ds.variables)}.") + + @property + def metadata(self) -> Mapping[str, VariableMetadata]: + result = {} + for name in self.names: + if hasattr(self.ds.variables[name], "units") and hasattr(self.ds.variables[name], "long_name"): + result[name] = VariableMetadata( + units=self.ds.variables[name].units, + long_name=self.ds.variables[name].long_name, + ) + return result + + def __len__(self): + return self.n_samples_total + + def __getitem__(self, idx): + return load_series_data(idx=idx, n_steps=self.n_steps, ds=self.ds, names=self.names) diff --git a/src/ace_inference/training/utils/data_loader_params.py b/src/ace_inference/training/utils/data_loader_params.py new file mode 100644 index 0000000..983afe2 --- /dev/null +++ b/src/ace_inference/training/utils/data_loader_params.py @@ -0,0 +1,40 @@ +import dataclasses +from typing import Optional + + +@dataclasses.dataclass +class DataLoaderParams: + """ + Attributes: + data_path: Path to the data. + data_type: Type of data to load. + horizon: Number of steps to predict into the future. + batch_size: Batch size for training. + batch_size_eval: Batch size for evaluation/validation. + num_data_workers: Number of parallel data workers. + multistep_strategy: Strategy for loading multistep data. Options are: + - "random": Randomly select a step within the horizon to predict. + - "sequential": Return all steps within the horizon. + - None: Return only the last step of the horizon. + n_samples: Number of samples to load, starting at the beginning of the data. + If None, load all samples. + """ + + data_path: str + data_type: str + horizon: int + batch_size: int + batch_size_eval: int + num_data_workers: int + multistep_strategy: Optional[str] = None + n_samples: Optional[int] = None + + def __post_init__(self): + assert self.horizon > 0, f"horizon ({self.horizon}) must be positive" + if self.n_samples is not None and self.batch_size > self.n_samples: + raise ValueError( + f"batch_size ({self.batch_size}) must be less than or equal to " + f"n_samples ({self.n_samples}) or no batches would be produced" + ) + if self.multistep_strategy == "null": + self.multistep_strategy = None diff --git a/src/ace_inference/training/utils/data_requirements.py b/src/ace_inference/training/utils/data_requirements.py new file mode 100644 index 0000000..f882b06 --- /dev/null +++ b/src/ace_inference/training/utils/data_requirements.py @@ -0,0 +1,11 @@ +import dataclasses +from typing import List + + +@dataclasses.dataclass +class DataRequirements: + names: List[str] + # TODO: delete these when validation no longer needs them + in_names: List[str] + out_names: List[str] + n_timesteps: int diff --git a/src/ace_inference/training/utils/img_utils.py b/src/ace_inference/training/utils/img_utils.py new file mode 100644 index 0000000..5603eaf --- /dev/null +++ b/src/ace_inference/training/utils/img_utils.py @@ -0,0 +1,66 @@ +# BSD 3-Clause License +# +# Copyright (c) 2022, FourCastNet authors +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +# The code was authored by the following people: +# +# Jaideep Pathak - NVIDIA Corporation +# Shashank Subramanian - NERSC, Lawrence Berkeley National Laboratory +# Peter Harrington - NERSC, Lawrence Berkeley National Laboratory +# Sanjeev Raja - NERSC, Lawrence Berkeley National Laboratory +# Ashesh Chattopadhyay - Rice University +# Morteza Mardani - NVIDIA Corporation +# Thorsten Kurth - NVIDIA Corporation +# David Hall - NVIDIA Corporation +# Zongyi Li - California Institute of Technology, NVIDIA Corporation +# Kamyar Azizzadenesheli - Purdue University +# Pedram Hassanzadeh - Rice University +# Karthik Kashinath - NVIDIA Corporation +# Animashree Anandkumar - California Institute of Technology, NVIDIA Corporation + +import torch.nn as nn +import torch.nn.functional as F + + +class PeriodicPad2d(nn.Module): + """ + pad longitudinal (left-right) circular + and pad latitude (top-bottom) with zeros + """ + + def __init__(self, pad_width): + super(PeriodicPad2d, self).__init__() + self.pad_width = pad_width + + def forward(self, x): + # pad left and right circular + out = F.pad(x, (self.pad_width, self.pad_width, 0, 0), mode="circular") + # pad top and bottom zeros + out = F.pad(out, (0, 0, self.pad_width, self.pad_width), mode="constant", value=0) + return out diff --git a/src/configs/inference/ckpts_from_huggingface_10years.yaml b/src/configs/inference/ckpts_from_huggingface_10years.yaml new file mode 100644 index 0000000..3eaf454 --- /dev/null +++ b/src/configs/inference/ckpts_from_huggingface_10years.yaml @@ -0,0 +1,51 @@ +# Compared to the non-debug version, this file simply runs inference for fewer steps. +experiment_dir: results/spherical-dyffusion +n_forward_steps: 14600 +forward_steps_in_memory: 100 +validation_loader: + # IMPORTANT: Set the correct data path for the validation dataset + dataset: + # =============== Edit this path =============== + data_path: "/data/climate-model/fv3gfs/2023-09-07-vertically-resolved-1deg-fme-ensemble-dataset-netcdfs/validation/ic_0011" + # ============================================= + n_repeats: 1 # Use 10 for 100 year rollout (10 x 10 years) + start_indices: + first: 0 + n_initial_conditions: 1 + num_data_workers: 8 +# The following specifies which (Spherical) DYffusion checkpoint to use. It can be +# a) A local path to a checkpoint file +# b) A huggingface model id (with the prefix "hf:") +checkpoint_path: "hf:salv47/spherical-dyffusion/forecaster-sfno-best-inference_avg_crps.ckpt" +# Override inference parameters and interpolator configuration +overrides: + diffusion_config: + hack_for_imprecise_interpolation: True + # The following two paths are used to load the interpolator model and config + # a) If pre-pending "hf:" to the path, the checkpoint will be downloaded from huggingface. + # b) Otherwise, it will be loaded from the local filesystem. + # c) Set it to null, to try downloading from wandb based on the interpolator_run_id & entity and project below. + interpolator_local_checkpoint_path: "hf:salv47/spherical-dyffusion/interpolator-sfno-best-val_avg_crps.ckpt" + interpolator_local_config_path: "hf:salv47/spherical-dyffusion/interpolator_sfno_paper_v0_hydra_config.yaml" + # As said, alternatively, just mention the entity and project to download the checkpoint from wandb. + interpolator_wandb_kwargs: + entity: null + project: "Spherical-DYffusion" # potentially replace with the correct project name +# More possible overrides: +# interpolator_use_ema: True +# use_cold_sampling_for_last_step: False +# use_cold_sampling_for_intermediate_steps: False +# use_cold_sampling_for_init_of_ar_step: True + +# Logging configuration +logging: + project: "Spherical-DYffusion-inference" # Where to log inference results to + entity: null # Replace with + log_to_screen: true + log_to_wandb: true + log_to_file: true +log_video: false +log_zonal_mean_images: false +data_writer: + # Set below to true to save predictions to xarray files (e.g. for further analysis) + save_prediction_files: false \ No newline at end of file diff --git a/src/configs/inference/ckpts_from_huggingface_debug.yaml b/src/configs/inference/ckpts_from_huggingface_debug.yaml new file mode 100644 index 0000000..5a88a75 --- /dev/null +++ b/src/configs/inference/ckpts_from_huggingface_debug.yaml @@ -0,0 +1,51 @@ +# Compared to the non-debug version, this file simply runs inference for fewer steps. +experiment_dir: results/spherical-dyffusion +n_forward_steps: 100 +forward_steps_in_memory: 50 +validation_loader: + # IMPORTANT: Set the correct data path for the validation dataset + dataset: + # =============== Edit this path =============== + data_path: "/data/climate-model/fv3gfs/2023-09-07-vertically-resolved-1deg-fme-ensemble-dataset-netcdfs/validation/ic_0011" + # ============================================= + n_repeats: 1 # Use 10 for 100 year rollout (10 x 10 years) + start_indices: + first: 0 + n_initial_conditions: 1 + num_data_workers: 8 +# The following specifies which (Spherical) DYffusion checkpoint to use. It can be +# a) A local path to a checkpoint file +# b) A huggingface model id (with the prefix "hf:") +checkpoint_path: "hf:salv47/spherical-dyffusion/forecaster-sfno-best-inference_avg_crps.ckpt" +# Override inference parameters and interpolator configuration +overrides: + diffusion_config: + hack_for_imprecise_interpolation: True + # The following two paths are used to load the interpolator model and config + # a) If pre-pending "hf:" to the path, the checkpoint will be downloaded from huggingface. + # b) Otherwise, it will be loaded from the local filesystem. + # c) Set it to null, to try downloading from wandb based on the interpolator_run_id & entity and project below. + interpolator_local_checkpoint_path: "hf:salv47/spherical-dyffusion/interpolator-sfno-best-val_avg_crps.ckpt" + interpolator_local_config_path: "hf:salv47/spherical-dyffusion/interpolator_sfno_paper_v0_hydra_config.yaml" + # As said, alternatively, just mention the entity and project to download the checkpoint from wandb. + interpolator_wandb_kwargs: + entity: null + project: "Spherical-DYffusion" # potentially replace with the correct project name +# More possible overrides: +# interpolator_use_ema: True +# use_cold_sampling_for_last_step: False +# use_cold_sampling_for_intermediate_steps: False +# use_cold_sampling_for_init_of_ar_step: True + +# Logging configuration +logging: + project: "Spherical-DYffusion-inference" # Where to log inference results to + entity: null # Replace with + log_to_screen: true + log_to_wandb: true + log_to_file: true +log_video: false +log_zonal_mean_images: false +data_writer: + # Set below to true to save predictions to xarray files (e.g. for further analysis) + save_prediction_files: false \ No newline at end of file diff --git a/src/datamodules/__init__.py b/src/datamodules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/datamodules/_dataset_dimensions.py b/src/datamodules/_dataset_dimensions.py new file mode 100644 index 0000000..0ce91df --- /dev/null +++ b/src/datamodules/_dataset_dimensions.py @@ -0,0 +1,27 @@ +from omegaconf import DictConfig + + +def get_dims_of_dataset(datamodule_config: DictConfig): + """Returns the number of features for the given dataset.""" + target = datamodule_config.get("_target_", datamodule_config.get("name")) + conditional_dim = 0 + spatial_dims_out = None + if "fv3gfs" in target: + input_dim = len(datamodule_config.in_names) + output_dim = len(datamodule_config.out_names) + spatial_dims = (180, 360) + conditional_dim = len(datamodule_config.forcing_names) if datamodule_config.forcing_names is not None else 0 + + elif "debug_datamodule" in target: + input_dim = output_dim = datamodule_config.channels + spatial_dims = (datamodule_config.height, datamodule_config.width) + + else: + raise ValueError(f"Unknown dataset: {target}") + return { + "input": input_dim, + "output": output_dim, + "spatial_in": spatial_dims, + "spatial_out": spatial_dims_out if spatial_dims_out is not None else spatial_dims, + "conditional": conditional_dim, + } diff --git a/src/datamodules/abstract_datamodule.py b/src/datamodules/abstract_datamodule.py new file mode 100644 index 0000000..a85c246 --- /dev/null +++ b/src/datamodules/abstract_datamodule.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +import multiprocessing +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import pytorch_lightning as pl +import torch +import xarray as xr +from omegaconf import DictConfig +from pytorch_lightning.utilities.types import EVAL_DATALOADERS +from tensordict import TensorDict +from torch import Tensor +from torch.utils.data import DataLoader, Dataset + +from src.evaluation.aggregators._abstract_aggregator import AbstractAggregator +from src.utilities.utils import get_logger, raise_error_if_invalid_value + + +log = get_logger(__name__) + + +class BaseDataModule(pl.LightningDataModule): + """ + ---------------------------------------------------------------------------------------------------------- + A DataModule implements 5 key methods: + - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) + - setup (things to do on every accelerator in distributed mode) + - train_dataloader (the training dataloader) + - val_dataloader (the validation dataloader(s)) + - test_dataloader (the test dataloader(s)) + + This allows you to share a full dataset without explaining how to download, + split, transform and process the data + + Read the docs: + https://pytorch-lightning.readthedocs.io/en/latest/extensions/datamodules.html + """ + + _data_train: Dataset + _data_val: Union[Dataset, Sequence[Dataset]] + _data_test: Dataset + _data_predict: Dataset + + def __init__( + self, + data_dir: str, + model_config: DictConfig = None, + batch_size: int = 2, + eval_batch_size: int = 64, + num_workers: int = -1, + pin_memory: bool = True, + persistent_workers: bool = False, + prefetch_factor: Optional[int] = None, + multiprocessing_context: Optional[str] = None, + drop_last: bool = False, + shuffle_train_data: bool = True, + debug_mode: bool = False, + verbose: bool = True, + seed_data: int = 43, + batch_size_per_gpu=None, # should be none and be handled outside (directly set in batch_size) + ): + """ + Args: + data_dir (str): A path to the data folder that contains the input and output files. + batch_size (int): Batch size for the training dataloader + eval_batch_size (int): Batch size for the test and validation dataloader's + num_workers (int): Dataloader arg for higher efficiency (usually set to # of CPU cores). + Default: Set to -1 to use all available cores. + pin_memory (bool): Dataloader arg for higher efficiency. Default: True + drop_last (bool): Only for training data loading: Drop the last incomplete batch + when the dataset size is not divisible by the batch size. Default: False + shuffle_train_data (bool): Only for training data loading: Shuffle the training data. + Default: True + verbose (bool): Print the dataset sizes. Default: True + """ + super().__init__() + # The following makes all args available as, e.g., self.hparams.batch_size + self.save_hyperparameters(ignore=["model_config", "verbose"]) + self.model_config = model_config + self.test_batch_size = eval_batch_size # just for testing + assert ( + batch_size_per_gpu is None or batch_size_per_gpu == batch_size + ), f"batch_size_per_gpu should be None, but got {batch_size_per_gpu}. (batch_size={batch_size})" + self._data_train = self._data_val = self._data_test = self._data_predict = None + self._experiment_class_name = None + self._check_args() + + def _check_args(self): + """Check if the arguments are valid.""" + if self.hparams.debug_mode is True: + self.hparams.num_workers = 0 + self.hparams.batch_size = 8 + self.hparams.eval_batch_size = 8 + + if self.hparams.num_workers == 0: + if self.hparams.persistent_workers is True: + log.warning( + "persistent_workers can only be set to True if num_workers > 0. " + "Setting persistent_workers to False." + ) + self.hparams.persistent_workers = False + + @property + def sigma_data(self) -> float: + raise NotImplementedError("Please specify the standard deviation of the training data in the subclass.") + + @property + def experiment_class_name(self) -> str: + if self._experiment_class_name is None: + if self.trainer is not None and hasattr(self.trainer, "lightning_module"): + self._experiment_class_name = self.trainer.lightning_module.__class__.__name__ + return self._experiment_class_name or "unknown" + + def _concat_variables_into_channel_dim(self, data: xr.Dataset, variables: List[str], filename=None) -> np.ndarray: + """Concatenate xarray variables into numpy channel dimension (last).""" + data_all = [] + for var in variables: + # Get the variable from the dataset (as numpy array, by selecting .values) + var_data = data[var].values + # add feature dimension (channel) + var_data = np.expand_dims(var_data, axis=-1) + # add to list of all variables + data_all.append(var_data) + + # Concatenate all the variables into a single array along the last (channel/feature) dimension + dataset = np.concatenate(data_all, axis=-1) + assert dataset.shape[-1] == len(variables), "Number of variables does not match number of channels." + return dataset + + # @rank_zero_only + def print_data_sizes(self, stage: str = None): + """Print the sizes of the data.""" + + if stage in ["fit", None]: + val_size = [len(dv) for dv in self._data_val] if isinstance(self._data_val, list) else len(self._data_val) + log.info(f"Dataset sizes train: {len(self._data_train)}, val: {val_size}") + elif stage == "validate": + val_size = [len(dv) for dv in self._data_val] if isinstance(self._data_val, list) else len(self._data_val) + log.info(f"Dataset validation size: {val_size}") + elif stage in ["test", None]: + log.info(f"Dataset test size: {len(self._data_test)}") + elif stage == "predict": + log.info(f"Dataset predict size: {len(self._data_predict)}") + + @abstractmethod + def setup(self, stage: Optional[str] = None): + """Load data. Set internal variables: self._data_train, self._data_val, self._data_test.""" + raise_error_if_invalid_value(stage, ["fit", "validate", "test", "predict", None], "stage") + + if stage == "fit" or stage is None: + self._data_train = ... # get_tensor_dataset_from_numpy(X_train, Y_train, dataset_id='train') + if stage in ["fit", "validate", None]: + self._data_val = ... # get_tensor_dataset_from_numpy(X_val, Y_val, dataset_id='val') + if stage in ["test", None]: + self._data_test = ... # get_tensor_dataset_from_numpy(X_test, Y_test, dataset_id='test') + if stage in ["predict"]: + self._data_predict = ... + raise NotImplementedError("This class is not implemented yet.") + + @abstractmethod + def get_horizon(self, split: str, dataloader_idx: int = 0) -> int: + """Return the horizon for the given split.""" + return self.hparams.get("horizon", 1) + + def get_horizon_range(self, split: str, dataloader_idx: int = 0) -> List[int]: + """Return the horizon range for the given split.""" + return list(np.arange(1, self.get_horizon(split, dataloader_idx) + 1)) + + @property + def valid_time_range_for_backbone_model(self) -> List[int]: + return self.get_horizon_range("fit") + + def get_epoch_aggregators( + self, + split: str, + is_ensemble: bool, + dataloader_idx: int = 0, + experiment_type: str = None, + device: torch.device = None, + verbose: bool = True, + ) -> Dict[str, AbstractAggregator]: + """Return the epoch aggregators for the given split.""" + return {} + + @property + def num_workers(self) -> int: + if self.hparams.num_workers == -1: + return multiprocessing.cpu_count() + return int(self.hparams.num_workers) + + def _shared_dataloader_kwargs(self) -> dict: + shared_kwargs = dict( + num_workers=self.num_workers, + pin_memory=self.hparams.pin_memory, + persistent_workers=self.hparams.persistent_workers, + ) + if self.hparams.prefetch_factor is not None: + shared_kwargs["prefetch_factor"] = self.hparams.prefetch_factor + if self.hparams.multiprocessing_context is not None: + shared_kwargs["multiprocessing_context"] = self.hparams.multiprocessing_context + return shared_kwargs + + def train_dataloader(self): + return ( + DataLoader( + dataset=self._data_train, + batch_size=self.hparams.batch_size, + drop_last=self.hparams.drop_last, # drop last incomplete batch (only for training) + shuffle=self.hparams.shuffle_train_data, + **self._shared_dataloader_kwargs(), + ) + if self._data_train is not None + else None + ) + + def _shared_eval_dataloader_kwargs(self) -> dict: + return dict(**self._shared_dataloader_kwargs(), shuffle=False) + + def val_dataloader(self): + if self._data_val is None: + return None + elif isinstance(self._data_val, List): + return [ + DataLoader( + dataset=ds_val, + batch_size=self.hparams.eval_batch_size, + **self._shared_eval_dataloader_kwargs(), + ) + for ds_val in self._data_val + ] + else: + return [ + DataLoader( + dataset=self._data_val, + batch_size=self.hparams.eval_batch_size, + **self._shared_eval_dataloader_kwargs(), + ) + ] + + def test_dataloader(self) -> DataLoader: + return ( + DataLoader( + dataset=self._data_test, + batch_size=self.test_batch_size, + **self._shared_eval_dataloader_kwargs(), + ) + if self._data_test is not None + else None + ) + + def predict_dataloader(self) -> EVAL_DATALOADERS: + return ( + DataLoader( + dataset=self._data_predict, + batch_size=self.hparams.eval_batch_size, + **self._shared_eval_dataloader_kwargs(), + ) + if self._data_predict is not None + else None + ) + + def boundary_conditions( + self, + preds: Union[Tensor, TensorDict], + targets: Union[Tensor, TensorDict], + data: Any = None, + metadata: Any = None, + time: float = None, + ) -> Union[Tensor, TensorDict]: + """Return predictions that satisfy the boundary conditions for a given item (batch element).""" + return preds + + def get_boundary_condition_kwargs(self, batch: Any, batch_idx: int, split: str) -> dict: + return dict(t0=0.0, dt=1.0) + + @property + def validation_set_names(self) -> List[str] | None: + """Use for using specific prefix for logging validation metrics.""" + return None diff --git a/src/datamodules/debug_datamodule.py b/src/datamodules/debug_datamodule.py new file mode 100644 index 0000000..b75b98d --- /dev/null +++ b/src/datamodules/debug_datamodule.py @@ -0,0 +1,114 @@ +from __future__ import annotations + +from typing import Dict, Optional + +import torch +from torch.utils.data import Dataset + +from src.datamodules.abstract_datamodule import BaseDataModule +from src.evaluation.one_step.main import OneStepAggregator +from src.utilities.utils import ( + get_logger, +) + + +log = get_logger(__name__) + + +class DebugDataModule(BaseDataModule): + def __init__( + self, + length: int = 100, + channels: int = 2, + height: int = 10, + width: int = 10, + window: int = 1, + horizon: int = 1, + **kwargs, + ): + super().__init__(**kwargs) + self.save_hyperparameters() + + def setup(self, stage: Optional[str] = None): + train_len = int(0.8 * self.hparams.length) + val_len = int(0.1 * self.hparams.length) + test_len = self.hparams.length - train_len - val_len + ds_kwargs = dict( + channels=self.hparams.channels, + height=self.hparams.height, + width=self.hparams.width, + window=self.hparams.window, + horizon=self.hparams.horizon, + ) + self._data_train = DebugDataset(length=train_len, **ds_kwargs) + self._data_val = DebugDataset(length=val_len, **ds_kwargs) + self._data_test = DebugDataset(length=test_len, **ds_kwargs) + + @property + def sigma_data(self) -> float: + return 1.0 + + def get_epoch_aggregators( + self, + split: str, + is_ensemble: bool, + dataloader_idx: int = 0, + experiment_type: str = None, + device: torch.device = None, + verbose: bool = True, + ) -> Dict[str, OneStepAggregator]: + getattr(self, f"_data_{split}") + + split_horizon = self.get_horizon(split, dataloader_idx) + if "interpolation" in experiment_type.lower(): + horizon_range = range(1, split_horizon) + else: + horizon_range = range(1, split_horizon + 1) + + aggr_kwargs = dict(is_ensemble=is_ensemble) + one_step_kwargs = { + **aggr_kwargs, + "record_rmse": True, + "record_normed": False, + "use_snapshot_aggregator": False, + } + aggregators_all = {} + for h in horizon_range: + aggregators_all[f"t{h}"] = OneStepAggregator(name=f"t{h}", verbose=verbose and (h == 1), **one_step_kwargs) + + return aggregators_all + + +class DebugDataset(Dataset): + def __init__( + self, + length: int = 100, + channels: int = 2, + height: int = 10, + width: int = 10, + window: int = 1, + horizon: int = 1, + ): + self.length = length - horizon + self.channels = channels + self.height = height + self.width = width + self.window = window + self.horizon = horizon + # self.data = torch.randn(self.length, self.channels, self.height, self.width) + + def __len__(self): + return self.length + + def __getitem__(self, idx: int): + return dict( + dynamics=torch.randn(self.window + self.horizon, self.channels, self.height, self.width), + dynamical_condition=torch.randn(self.window + self.horizon, 3, self.height, self.width), + static_condition=torch.randn(2, self.height, self.width), + ) + # return dict(dynamics=self.data[idx:idx + self.horizon + self.window]) + # return self.data[idx:idx + self.horizon + self.window] + + @property + def loss_weights_tensor(self) -> Optional[torch.Tensor]: + return torch.randn(self.height, self.width) diff --git a/src/datamodules/fv3gfs_ensemble.py b/src/datamodules/fv3gfs_ensemble.py new file mode 100644 index 0000000..e7ca70d --- /dev/null +++ b/src/datamodules/fv3gfs_ensemble.py @@ -0,0 +1,280 @@ +from __future__ import annotations + +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Union + +import hydra +import torch +from omegaconf import OmegaConf +from tensordict import TensorDict +from torch import Tensor + +from src.ace_inference.core.data_loading._xarray import XarrayDatasetSalva +from src.ace_inference.core.data_loading.getters import get_dataset +from src.ace_inference.core.data_loading.params import DataLoaderParams, XarrayDataParams +from src.ace_inference.core.prescriber import Prescriber +from src.ace_inference.training.utils.data_loader_fv3gfs import FV3GFSDataset +from src.datamodules.abstract_datamodule import BaseDataModule +from src.evaluation.aggregators.main import OneStepAggregator +from src.evaluation.aggregators.time_mean import TimeMeanAggregator +from src.utilities.normalization import get_normalizer +from src.utilities.packer import Packer +from src.utilities.utils import get_logger, raise_error_if_invalid_type, to_torch_and_device + + +log = get_logger(__name__) + + +class FV3GFSEnsembleDataModule(BaseDataModule): + def __init__( + self, + data_dir: str, + in_names: List[str], + out_names: List[str], + forcing_names: List[str], + auxiliary_names: List[str] = None, + window: int = 1, + horizon: int = 1, + prediction_horizon: int = None, # None means use horizon + prediction_horizon_long: int = None, # None means use horizon + prescriber: Optional[Prescriber] = None, + multistep_strategy: Optional[str] = None, + data_dir_stats: Optional[str] = None, + max_train_samples: Optional[int] = None, + max_val_samples: Optional[int] = None, + training_sub_paths: Optional[List[str]] = None, + **kwargs, + ): + raise_error_if_invalid_type(data_dir, possible_types=[str], name="data_dir") + if not os.path.isdir(data_dir): + raise ValueError(f"Data dir={data_dir} not found.") + super().__init__(data_dir=data_dir, **kwargs) + self.save_hyperparameters(ignore=["prescriber"]) + + forcing_names = forcing_names or [] + auxiliary_names = auxiliary_names or [] + data_dir_stats = data_dir_stats or data_dir + path_mean = Path(data_dir_stats) / "centering.nc" + path_std = Path(data_dir_stats) / "scaling.nc" + if not path_mean.exists() or not path_std.exists(): + raise FileNotFoundError(f"Could not find normalization files at ``{path_mean}`` and/or ``{path_std}``") + self.train_dir = Path(data_dir) / "train" + self.validation_dir = Path(data_dir) / "validation" / "ic_0011" + + non_forcing_names = [n for n in self.all_names if n not in forcing_names] + self.normalizer = get_normalizer(path_mean, path_std, names=non_forcing_names) + channel_axis = -3 + self.in_packer = Packer(in_names, axis=channel_axis) + self.out_packer = Packer(out_names, axis=channel_axis) + + self.forcing_normalizer = get_normalizer(path_mean, path_std, names=forcing_names) + self.forcing_packer = Packer(forcing_names, axis=channel_axis) + if prescriber is not None: + if not isinstance(prescriber, Prescriber): + prescriber = hydra.utils.instantiate(prescriber) + log.info(f"Prescribing ``{prescriber.prescribed_name}`` using mask ``{prescriber.mask_name}``") + self.prescriber = prescriber + + if self.hparams.debug_mode: + log.info("Running in debug mode") + self.hparams.training_sub_paths = ["ic_0001"] + self.hparams.max_train_samples = 80 + self.hparams.max_val_samples = 10 + + def _check_args(self): + h = self.hparams.horizon + w = self.hparams.window + assert isinstance(h, list) or h >= 0, f"horizon must be >= 0 or a list, but is {h}" + assert w == 1, f"window must be 1, but is {w}" + + @property + def all_names(self): + forcing_names = self.hparams.forcing_names or [] + aux_names = self.hparams.auxiliary_names or [] + if self.hparams.prescriber is not None: + aux_names = set(aux_names).union([self.hparams.prescriber.mask_name]) + + all_names = list( + set(self.hparams.in_names).union(self.hparams.out_names).union(forcing_names).union(aux_names) + ) + return all_names + + def _create_ds(self, split: str, dataloader_idx: Optional[int] = None, **kwargs) -> Optional[FV3GFSDataset]: + kwargs = kwargs.copy() + kwargs["split_id"] = split + horizon = self.get_horizon(split, dataloader_idx) + n_valid_samples = self.hparams.max_val_samples + n_samples = None + if split == "train": + if self.hparams.max_train_samples is not None: + log.info(f"Limiting training samples to {self.hparams.max_train_samples}") + n_samples = self.hparams.max_train_samples + if self.hparams.training_sub_paths is not None: + sub_paths = self.hparams.training_sub_paths + log.info(f"Limiting training sub-paths to {sub_paths}") + kwargs["sub_paths"] = sub_paths + + elif split == "val" and n_valid_samples is not None: + log.info(f"Limiting validation samples to {n_valid_samples}") + n_samples = n_valid_samples if dataloader_idx in [0, None] else 8 + + elif split == "test" and n_valid_samples is not None: + log.info(f"Limiting test samples to val samples {n_valid_samples}:{n_valid_samples*3}") + n_samples = n_valid_samples * 2 + kwargs["min_idx_shift"] = n_valid_samples # preclude test samples + + elif split == "predict": + n_samples = 1 + + requirements = {k: getattr(self.hparams, k) for k in ["in_names", "out_names"]} + requirements.update( + { + "names": self.all_names, + "n_timesteps": self.hparams.window + horizon, + } + ) + requirements = OmegaConf.create(requirements) + + params = DataLoaderParams( + dataset=XarrayDataParams( + data_path=self.train_dir if split == "train" else self.validation_dir, + n_repeats=1, + engine=None, + ), + batch_size=0, + num_data_workers=self.hparams.num_workers, + data_type="ensemble_xarray" if split == "train" else "xarray", + n_samples=n_samples, + ) + + kwargs_final = dict( + # window_time_slice=kwargs.get("window_time_slice", None), + forcing_names=self.hparams.forcing_names, + forcing_packer=self.forcing_packer, + forcing_normalizer=self.forcing_normalizer, + ) + ds = get_dataset(params, requirements, dataset_class=XarrayDatasetSalva, **kwargs, **kwargs_final) + return ds + + def setup(self, stage: Optional[str] = None): + shared_dset_kwargs = dict() + if stage in (None, "fit"): + self._data_train = self._create_ds(split="train", **shared_dset_kwargs) + + if stage in (None, "fit", "validate"): + self._data_val = [self._create_ds(split="val", **shared_dset_kwargs)] + if self.hparams.horizon > 0 and self.hparams.prediction_horizon_long is not None: + # Add a validation dataloader for running inference + log.info( + f"Adding a validation dataset for inference with horizon {self.hparams.prediction_horizon_long}" + ) + self._data_val += [self._create_ds(split="val", dataloader_idx=1, **shared_dset_kwargs)] + + if stage in (None, "test"): + self._data_test = self._create_ds(split="test", **shared_dset_kwargs) + if stage == "predict": + self._data_predict = self._create_ds(split="predict", **shared_dset_kwargs) + + # Print sizes of the datasets (how many examples) + self.print_data_sizes(stage) + + def boundary_conditions( + self, + preds: Union[Tensor, TensorDict], + targets: Union[Tensor, TensorDict], + data: Any = None, + metadata: Any = None, + time: float = None, + ) -> Union[Tensor, TensorDict]: + """Return predictions that satisfy the boundary conditions for a given item (batch element).""" + if self.prescriber is None: + return super().boundary_conditions(preds, targets, data, time) + else: + return self.prescriber(gen_norm=preds, target_norm=targets, data=data) + + @property + def validation_set_names(self) -> List[str]: + return ["val", "inference"] if len(self._data_val) > 1 else ["val"] + + def get_horizon(self, split: str, dataloader_idx: int = 0) -> int: + if split in ["val", "validate"] and dataloader_idx == 1: + return self.hparams.prediction_horizon_long + assert dataloader_idx in [0, None], f"Invalid dataloader_idx: {dataloader_idx}" + if split in ["predict", "test"]: + return self.hparams.prediction_horizon_long or self.hparams.horizon + elif split in ["val", "validate"]: + return self.hparams.prediction_horizon or self.hparams.horizon + else: + assert split in ["train", "fit"], f"Invalid split: {split}" + return self.hparams.horizon + + def get_epoch_aggregators( + self, + split: str, + is_ensemble: bool, + dataloader_idx: int = 0, + experiment_type: str = None, + device: torch.device = None, + verbose: bool = True, + ) -> Dict[str, OneStepAggregator]: + assert dataloader_idx in [0, 1], f"Invalid dataloader_idx: {dataloader_idx}" + split_ds = getattr(self, f"_data_{split}") + if split == "val" and isinstance(split_ds, list): + split_ds = split_ds[0] # just need it for the area weights + is_inference_val = split == "val" and dataloader_idx == 1 + use_full_rollout = is_inference_val or split == "test" + if "interpolation" in experiment_type.lower(): + horizon_range = range(1, self.hparams.horizon) + else: + split_horizon = self.get_horizon(split, dataloader_idx) + horizon_range = range(1, split_horizon + 1) + + area_weights = to_torch_and_device(split_ds.area_weights, device) + aggr_kwargs = dict(area_weights=area_weights, is_ensemble=is_ensemble) + aggregators = {} + if use_full_rollout or "interpolation" in experiment_type.lower(): + save_snapshots = True + else: + save_snapshots = False + # we want to save at most 10 snapshots, including 1st, 10th, 20th, and last + max_h = horizon_range[-1] + if "interpolation" in experiment_type.lower(): + snapshot_horizons = [1, max_h // 2] + elif max_h <= 10: + snapshot_horizons = [1, max_h] + elif max_h <= 50: + snapshot_horizons = [1, 5, 12, 20, 32, 40, max_h] + elif max_h <= 100: + snapshot_horizons = [1, 12, 20, 40, 60, 80, max_h] + elif max_h <= 200: + snapshot_horizons = [1, 12, 20, 40, 80, 120, max_h] + elif max_h <= 460: + snapshot_horizons = [1, 12, 20, 120, 240, 420, max_h] + elif max_h <= 500: + snapshot_horizons = [1, 12, 20, 120, 240, 420, max_h] + elif max_h <= 1460: + snapshot_horizons = [1, 12, 20, 120, 500, 1460] + elif max_h == 14600: + snapshot_horizons = [40, 120, 240, 360, 420, 500, 1000, 2000, 4000, 8000, 12000, 14600] + else: + snapshot_horizons = [] + snaps_vars = [ + "air_temperature_7_normed", + "specific_total_water_7", + "specific_total_water_7_normed", + "air_temperature_0", + ] + for h in horizon_range: + aggregators[f"t{h}"] = OneStepAggregator( + use_snapshot_aggregator=save_snapshots and h in snapshot_horizons, + snapshot_var_names=snaps_vars, + verbose=verbose and (h == 1), + record_normed=True, + record_abs_values=True, + name=f"t{h}", + **aggr_kwargs, + ) + if use_full_rollout: + aggregators["time_mean"] = TimeMeanAggregator(**aggr_kwargs, name="time_mean") + return aggregators diff --git a/src/dependency_versions_table.py b/src/dependency_versions_table.py new file mode 100644 index 0000000..679617d --- /dev/null +++ b/src/dependency_versions_table.py @@ -0,0 +1,34 @@ +# THIS FILE HAS BEEN AUTOGENERATED. To update: +# 1. modify the `_deps` dict in setup.py +# 2. run `make deps_table_update`` +deps = { + "black": "black", + "boto3": "boto3", + "cachey": "cachey", + "dacite": "dacite", + "dask": "dask", + "einops": "einops", + "h5py": "h5py", + "hf-doc-builder": "hf-doc-builder", + "huggingface_hub": "huggingface_hub", + "hydra-core": "hydra-core", + "isort": "isort", + "netCDF4": "netCDF4", + "numpy": "numpy", + "omegaconf": "omegaconf", + "pytest": "pytest", + "pytorch-lightning": "pytorch-lightning>=2.0", + "rich": "rich", + "ruff": "ruff>=0.0.241", + "regex": "regex", + "requests": "requests", + "tensordict": "tensordict", + "tensorly": "tensorly", + "tensorly-torch": "tensorly-torch", + "torch": "torch>=1.8", + "torch-harmonics": "torch-harmonics", + "transformers": "transformers", + "urllib3": "urllib3", + "wandb": "wandb", + "xarray": "xarray", +} diff --git a/src/diffusion/__init__.py b/src/diffusion/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/diffusion/_base_diffusion.py b/src/diffusion/_base_diffusion.py new file mode 100644 index 0000000..dd99db2 --- /dev/null +++ b/src/diffusion/_base_diffusion.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import inspect +from abc import abstractmethod +from typing import Any + +from src.models._base_model import BaseModel + + +class BaseDiffusion(BaseModel): + def __init__( + self, + model: BaseModel, + timesteps: int, + sampling_timesteps: int = None, + sampling_schedule=None, + **kwargs, + ): + signature = inspect.signature(BaseModel.__init__).parameters + base_kwargs = {k: model.hparams.get(k) for k in signature if k in model.hparams} + base_kwargs.update(kwargs) # override base_kwargs with kwargs + super().__init__(**base_kwargs) + if model is None: + raise ValueError( + "Arg ``model`` is missing..." " Please provide a backbone model for the diffusion model (e.g. a Unet)" + ) + self.save_hyperparameters(ignore=["model"]) + # self.sampling_timesteps = default(sampling_timesteps, timesteps) + self.model = model + + self.spatial_shape_in = model.spatial_shape_in + self.spatial_shape_out = model.spatial_shape_out + self.num_input_channels = model.num_input_channels + self.num_output_channels = model.num_output_channels + self.num_conditional_channels = model.num_conditional_channels + self.num_timesteps = int(timesteps) + + # if hasattr(model, 'example_input_array'): + # self.example_input_array = model.example_input_array + self.model.criterion = None + + @property + def short_description(self) -> str: + name = super().short_description + name += f" (timesteps={self.num_timesteps})" + return name + + def sample(self, condition=None, num_samples=1, **kwargs): + # sample from the model + raise NotImplementedError() + + def predict_forward(self, *inputs, condition=None, metadata: Any = None, **kwargs): + assert len(inputs) == 1, "Only one input tensor is allowed for the forward pass" + inputs = inputs[0] + if inputs is not None and condition is not None: + raise ValueError("Only one of the inputs or condition should be provided. Need to refactor the code.") + elif condition is not None: + raise NotImplementedError("Condition is not implemented yet.") + else: # if inputs is not None: + inital_condition = inputs + + _ = kwargs.pop("lookback", None) # remove the lookback argument + return self.sample(inital_condition, **kwargs) + + @abstractmethod + def p_losses(self, *args, **kwargs): + """Compute the loss for the given targets and condition. + + Args: + targets (Tensor): Target data tensor of shape :math:`(B, C_{out}, *)` + condition (Tensor): Condition data tensor of shape :math:`(B, C_{in}, *)` + t (Tensor): Timestep of shape :math:`(B,)` + """ + raise NotImplementedError(f"Method ``p_losses`` is not implemented for {self.__class__.__name__}!") + + def forward(self, *args, **kwargs): + return self.p_losses(*args, **kwargs) + + def get_loss(self, *args, **kwargs): + raise NotImplementedError(f"Plese implement the ``get_loss`` method for {self.__class__.__name__}!") diff --git a/src/diffusion/dyffusion.py b/src/diffusion/dyffusion.py new file mode 100644 index 0000000..641711f --- /dev/null +++ b/src/diffusion/dyffusion.py @@ -0,0 +1,738 @@ +from __future__ import annotations + +import math +from abc import abstractmethod +from contextlib import ExitStack +from typing import Any, Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +from torch import Tensor, nn +from tqdm.auto import tqdm + +from src.diffusion._base_diffusion import BaseDiffusion +from src.experiment_types.interpolation import InterpolationExperiment +from src.interface import get_checkpoint_from_path_or_wandb +from src.utilities.utils import freeze_model, raise_error_if_invalid_value + + +class BaseDYffusion(BaseDiffusion): + # enable_interpolator_dropout: whether to enable dropout in the interpolator + def __init__( + self, + forward_conditioning: str = "data", + dynamic_cond_from_t: str = "h", # 'h', '0', or 't' + schedule: str = "before_t1_only", + additional_interpolation_steps: int = 0, + additional_interpolation_steps_factor: int = 0, + interpolate_before_t1: bool = True, + sampling_type: str = "cold", # 'cold' or 'naive' + sampling_schedule: Union[List[float], str] = None, + use_cold_sampling_for_intermediate_steps: bool = True, + use_cold_sampling_for_last_step: bool = True, + use_cold_sampling_for_init_of_ar_step: Optional[bool] = None, + time_encoding: str = "discrete", + refine_intermediate_predictions: bool = False, + prediction_timesteps: Optional[Sequence[float]] = None, + enable_interpolator_dropout: Union[bool, str] = True, + interpolator_use_ema: bool = False, + log_every_t: Union[str, int] = None, + reconstruction2_detach_x_last=None, + hack_for_imprecise_interpolation: bool = False, + # hack_for_imprecise_interpolation can be used when accidentally using one input-only variable + # in the list of inputs vs. the list of forcings, which leads to using them twice for the interpolator model, + # which becomes a problem when using DYffusion, since the forecaster does not predict those input-only variables + *args, + **kwargs, + ): + use_cold_sampling_for_init_of_ar_step = ( + use_cold_sampling_for_init_of_ar_step + if use_cold_sampling_for_init_of_ar_step is not None + else use_cold_sampling_for_last_step + ) + super().__init__(*args, **kwargs, sampling_schedule=sampling_schedule) + sampling_schedule = None if sampling_schedule == "None" else sampling_schedule + self.save_hyperparameters(ignore=["model"]) + self.num_timesteps = self.hparams.timesteps + self.use_cold_sampling_for_init_of_ar_step = use_cold_sampling_for_init_of_ar_step + + fcond_options = ["data", "none", "data+noise-v1", "data+noise-v2"] + raise_error_if_invalid_value(forward_conditioning, fcond_options, "forward_conditioning") + + # Add additional interpolation steps to the diffusion steps + # we substract 2 because we don't want to use the interpolator in timesteps outside [1, num_timesteps-1] + horizon = self.num_timesteps # = self.interpolator_horizon + assert horizon > 1, f"horizon must be > 1, but got {horizon}. Please use datamodule.horizon with > 1" + if schedule == "linear": + assert ( + additional_interpolation_steps == 0 + ), "additional_interpolation_steps must be 0 when using linear schedule" + self.additional_interpolation_steps_fac = additional_interpolation_steps_factor + if interpolate_before_t1: + interpolated_steps = horizon - 1 + self.di_to_ti_add = 0 + else: + interpolated_steps = horizon - 2 + self.di_to_ti_add = additional_interpolation_steps_factor + + self.additional_diffusion_steps = additional_interpolation_steps_factor * interpolated_steps + elif schedule == "before_t1_only": + assert ( + additional_interpolation_steps_factor == 0 + ), "additional_interpolation_steps_factor must be 0 when using before_t1_only schedule" + assert interpolate_before_t1, "interpolate_before_t1 must be True when using before_t1_only schedule" + self.additional_diffusion_steps = additional_interpolation_steps + elif schedule == "before_t1_then_linear": + assert ( + interpolate_before_t1 + ), "interpolate_before_t1 must be True when using before_t1_then_linear schedule" + self.additional_interpolation_steps_fac = additional_interpolation_steps_factor + self.additional_diffusion_steps_pre_t1 = additional_interpolation_steps + self.additional_diffusion_steps = ( + additional_interpolation_steps + additional_interpolation_steps_factor * (horizon - 2) + ) + else: + raise ValueError(f"Invalid schedule: {schedule}") + + self.num_timesteps += self.additional_diffusion_steps + d_to_i_step = {d: self.diffusion_step_to_interpolation_step(d) for d in range(1, self.num_timesteps)} + self.dynamical_steps = {d: i_n for d, i_n in d_to_i_step.items() if float(i_n).is_integer()} + self.i_to_diffusion_step = {i_n: d for d, i_n in d_to_i_step.items()} + self.artificial_interpolation_steps = {d: i_n for d, i_n in d_to_i_step.items() if not float(i_n).is_integer()} + # check that float tensors and floats return the same value + for d, i_n in d_to_i_step.items(): + i_n2 = float(self.diffusion_step_to_interpolation_step(torch.tensor(d, dtype=torch.float))) + assert math.isclose( + i_n, i_n2, abs_tol=4e-6 + ), f"float and tensor return different values for diffusion_step_to_interpolation_step({d}): {i_n} != {i_n2}" + # note that self.dynamical_steps does not include t=0, which is always dynamical (but not an output!) + if additional_interpolation_steps_factor > 0 or additional_interpolation_steps > 0: + self.log_text.info( + f"Added {self.additional_diffusion_steps} steps.. total diffusion num_timesteps={self.num_timesteps}. \n" + # f'Mapping diffusion -> interpolation steps: {d_to_i_step}. \n' + f"Diffusion -> Dynamical timesteps: {self.dynamical_steps}." + ) + self.enable_interpolator_dropout = enable_interpolator_dropout + raise_error_if_invalid_value( + enable_interpolator_dropout, + [True, False, "always", "except_dynamical_steps"], + "enable_interpolator_dropout", + ) + if self.hparams.interpolator_use_ema: + self.log_text.info("Using EMA for the interpolator.") + if refine_intermediate_predictions: + self.log_text.info("Enabling refinement of intermediate predictions.") + + # which diffusion steps to take during sampling + self.full_sampling_schedule = list(range(0, self.num_timesteps)) + self.sampling_schedule = sampling_schedule or self.full_sampling_schedule + + @property + def diffusion_steps(self) -> List[int]: + return list(range(0, self.num_timesteps)) + + def diffusion_step_to_interpolation_step(self, diffusion_step: Union[int, Tensor]) -> Union[float, Tensor]: + """ + Convert a diffusion step to an interpolation step + Args: + diffusion_step: the diffusion step (in [0, num_timesteps-1]) + Returns: + the interpolation step + """ + # assert correct range + if torch.is_tensor(diffusion_step): + assert (0 <= diffusion_step).all() and ( + diffusion_step <= self.num_timesteps - 1 + ).all(), f"diffusion_step must be in [0, num_timesteps-1]=[0, {self.num_timesteps - 1}], but got {diffusion_step}" + else: + assert ( + 0 <= diffusion_step <= self.num_timesteps - 1 + ), f"diffusion_step must be in [0, num_timesteps-1]=[0, {self.num_timesteps - 1}], but got {diffusion_step}" + if self.hparams.schedule == "linear": + # self.di_to_ti_add is 0 or 1 + # Self.additional_interpolation_steps_fac is 0 by default (no additional interpolation steps) + i_n = (diffusion_step + self.di_to_ti_add) / (self.additional_interpolation_steps_fac + 1) + elif self.hparams.schedule == "before_t1_only": + # map d_N to h-1, d_N-1 to h-2, ..., d_n to 1, and d_n-1..d_1 uniformly to [0, 1) + # e.g. if h=5, then d_5 -> 4, d_4 -> 3, d_3 -> 2, d_2 -> 1, d_1 -> 0.5 + # or d_6 -> 4, d_5 -> 3, d_4 -> 2, d_3 -> 1, d_2 -> 0.66, d_1 -> 0.33 + # or d_7 -> 4, d_6 -> 3, d_5 -> 2, d_4 -> 1, d_3 -> 0.75, d_2 -> 0.5, d_1 -> 0.25 + if torch.is_tensor(diffusion_step): + i_n = torch.where( + diffusion_step >= self.additional_diffusion_steps + 1, + (diffusion_step - self.additional_diffusion_steps).float(), + diffusion_step / (self.additional_diffusion_steps + 1), + ) + elif diffusion_step >= self.additional_diffusion_steps + 1: + i_n = diffusion_step - self.additional_diffusion_steps + else: + i_n = diffusion_step / (self.additional_diffusion_steps + 1) + elif self.hparams.schedule == "before_t1_then_linear": + if torch.is_tensor(diffusion_step): + i_n = torch.where( + diffusion_step >= self.additional_diffusion_steps_pre_t1 + 1, + 1 + + (diffusion_step - self.additional_diffusion_steps_pre_t1 - 1) + / (self.additional_interpolation_steps_fac + 1), + diffusion_step / (self.additional_diffusion_steps_pre_t1 + 1), + ) + elif diffusion_step >= self.additional_diffusion_steps_pre_t1 + 1: + i_n = 1 + (diffusion_step - self.additional_diffusion_steps_pre_t1 - 1) / ( + self.additional_interpolation_steps_fac + 1 + ) + else: + i_n = diffusion_step / (self.additional_diffusion_steps_pre_t1 + 1) + else: + raise ValueError(f"schedule=``{self.hparams.schedule}`` not supported.") + + return i_n + + def q_sample( + self, + x0, + x_end, + t: Optional[Tensor], + interpolation_time: Optional[Tensor] = None, + batch_mask: Optional[Tensor] = None, + is_artificial_step: bool = True, + **kwargs, + ) -> Tensor: + # q_sample = using model in interpolation mode + # just remember that x_end here refers to t=0 (the initial conditions) + # and x_0 (terminology of diffusion models) refers to t=T, i.e. the last timestep + ipol_handles = [self.interpolator] if hasattr(self, "interpolator") else [self] + if hasattr(self, "interpolator_artificial_steps") and self.interpolator_artificial_steps is not None: + ipol_handles.append(self.interpolator_artificial_steps) + + assert t is None or interpolation_time is None, "Either t or interpolation_time must be None." + t = interpolation_time if t is None else self.diffusion_step_to_interpolation_step(t) # .float() + + # Handle dynamical cond based on logic in src.experimental_types.interpolation.InterpolationExperiment + dynamical_cond = kwargs.pop("dynamical_condition", None) + if dynamical_cond is not None: + kwargs["condition"] = ipol_handles[0].get_dynamical_condition(dynamical_cond, t) + + # Tensorfy t if it is a float/int + if not torch.is_tensor(t): + t = torch.full((x0.shape[0],), t, dtype=torch.float32, device=self.device) + + # Apply mask if necessary on batch dimension + if batch_mask is not None: + x0 = x0[batch_mask] + x_end = x_end[batch_mask] + t = t[batch_mask] + kwargs = {k: v[batch_mask] if isinstance(v, Tensor) else v for k, v in kwargs.items()} + + do_enable = ( + self.training + or self.enable_interpolator_dropout in [True, "always"] + or (self.enable_interpolator_dropout == "except_dynamical_steps" and is_artificial_step) + ) + + with ExitStack() as stack: + # inference_dropout_scope of all handles (enable and disable) is managed by the ExitStack + for ipol in ipol_handles: + stack.enter_context(ipol.inference_dropout_scope(condition=do_enable)) + if self.hparams.interpolator_use_ema: + stack.enter_context(ipol.ema_scope(condition=True)) + + x_ti = self._interpolate(initial_condition=x_end, x_last=x0, t=t, **kwargs) + return x_ti + + @abstractmethod + def _interpolate( + self, + initial_condition: Tensor, + x_last: Tensor, + t: Tensor, + num_predictions: int = 1, + **kwargs, + ): + """This is an internal method. Please use q_sample to access it.""" + raise NotImplementedError(f"``_interpolate`` must be implemented in {self.__class__.__name__}") + + def get_condition( + self, + initial_condition_cond: Optional[Tensor], + x_last: Optional[Tensor], + prediction_type: str, + condition: Optional[Tensor] = None, + shape: Sequence[int] = None, + ) -> Tensor: + if initial_condition_cond is not None and condition is not None: + return torch.cat([initial_condition_cond, condition], dim=1) + elif initial_condition_cond is not None and condition is not None: + return torch.cat([initial_condition_cond, condition], dim=1) + elif initial_condition_cond is not None: + return initial_condition_cond + elif condition is not None: + return condition + else: + return None + + @property + def valid_time_range_for_backbone_model(self) -> List[int]: + diff_steps = list(range(0, self.num_timesteps)) + if self.hparams.time_encoding == "discrete": + valid_time = diff_steps + elif self.hparams.time_encoding == "continuous": + valid_time = list(np.array(diff_steps) / self.num_timesteps) + elif self.hparams.time_encoding == "dynamics": + valid_time = [self.diffusion_step_to_interpolation_step(d) for d in diff_steps] + else: + raise ValueError(f"Invalid time_encoding: {self.hparams.time_encoding}") + return valid_time + + def _predict_last_dynamics(self, condition: Tensor, x_t: Tensor, t: Tensor, **kwargs): + if self.hparams.time_encoding == "discrete": + time = t + elif self.hparams.time_encoding == "continuous": + time = t / self.num_timesteps + elif self.hparams.time_encoding == "dynamics": + time = self.diffusion_step_to_interpolation_step(t) + else: + raise ValueError(f"Invalid time_encoding: {self.hparams.time_encoding}") + + x_last_pred = self.model.predict_forward(x_t, time=time, condition=condition, **kwargs) + return x_last_pred + + def predict_x_last( + self, + initial_condition: Tensor, + x_t: Tensor, + t: Tensor, + **kwargs, + ): + """Predict x_{t+h} given x_t""" + if not torch.is_tensor(t): + assert 0 <= t <= self.num_timesteps - 1, f"Invalid timestep: {t}. {self.num_timesteps=}" + t = torch.full((initial_condition.shape[0],), t, dtype=torch.float32, device=self.device) + else: + assert (0 <= t).all() and (t <= self.num_timesteps - 1).all(), f"Invalid timestep: {t}" + cond_type = self.hparams.forward_conditioning + if cond_type == "data": + forward_inputs = initial_condition + elif cond_type == "none": + forward_inputs = None + elif "data+noise" in cond_type: + # simply use factor t/T to scale the condition and factor (1-t/T) to scale the noise + # this is the same as using a linear combination of the condition and noise + tfactor = t / (self.num_timesteps - 1) # shape: (b,) + tfactor = tfactor.view( + initial_condition.shape[0], *[1] * (initial_condition.ndim - 1) + ) # shape: (b, 1, 1, 1) + if cond_type == "data+noise-v1": + # add noise to the data in a linear combination, s.t. the noise is more important at the beginning (t=0) + # and less important at the end (t=T) + forward_inputs = tfactor * initial_condition + (1 - tfactor) * torch.randn_like(initial_condition) + elif cond_type == "data+noise-v2": + forward_inputs = (1 - tfactor) * initial_condition + tfactor * torch.randn_like(initial_condition) + else: + raise ValueError(f"Invalid forward conditioning type: {cond_type}") + + dynamic_cond = kwargs.pop("dynamical_condition", None) # a window (=1) + horizon (=T) tensor + if dynamic_cond is not None: + assert ( + dynamic_cond.shape[1] == self.num_timesteps + 1 + ), f"{dynamic_cond.shape}[1] != {self.num_timesteps+1}" + if self.hparams.dynamic_cond_from_t == "0": + dynamic_cond = self.slice_time(dynamic_cond, 0) # take from initial conditions timestep + elif self.hparams.dynamic_cond_from_t == "h": + dynamic_cond = self.slice_time(dynamic_cond, -1) # take from last timestep (to predict) + elif self.hparams.dynamic_cond_from_t == "t": + dynamic_cond = self.slice_time(dynamic_cond, t) # take from input timestep + else: + raise ValueError(f"Invalid dynamic_cond_from_t: {self.hparams.dynamic_cond_from_t}") + + forward_inputs = self.get_condition( + initial_condition_cond=forward_inputs, + x_last=None, + prediction_type="forward", + shape=initial_condition.shape, + condition=dynamic_cond, + ) + x_last_pred = self._predict_last_dynamics(x_t=x_t, condition=forward_inputs, t=t, **kwargs) + return x_last_pred + + def slice_time(self, x: Tensor, t: Union[int, Tensor]) -> Tensor: + if torch.is_tensor(t): + b = x.shape[0] + return x[torch.arange(b), t] + return x[:, t] + + @property + def sampling_schedule(self) -> List[Union[int, float]]: + return self._sampling_schedule + + @sampling_schedule.setter + def sampling_schedule(self, schedule: Union[str, List[Union[int, float]]]): + """Set the sampling schedule. At the very minimum, the sampling schedule will go through all dynamical steps. + Notation: + - N: number of diffusion steps + - h: number of dynamical steps + - h_0: first dynamical step + + Options for diffusion sampling schedule trajectories ('': ): + - 'only_dynamics': the diffusion steps corresponding to dynamical steps (this is the minimum) + - 'only_dynamics_plus_discreteINT': add INT discrete non-dynamical steps, uniformly drawn between 0 and h_0 + - 'only_dynamics_plusINT': add INT non-dynamical steps (possibly continuous), uniformly drawn between 0 and h_0 + - 'everyINT': only use every INT-th diffusion step (e.g. 'every2' for every second diffusion step) + - 'firstINT': only use the first INT diffusion steps + - 'firstFLOAT': only use the first FLOAT*N diffusion steps + + """ + schedule_name = schedule + if isinstance(schedule_name, str): + base_schedule = [0] + list(self.dynamical_steps.keys()) # already included: + [self.num_timesteps - 1] + artificial_interpolation_steps = list(self.artificial_interpolation_steps.keys()) + if "only_dynamics" in schedule_name: + schedule = [] # only sample from base_schedule (added below) + + if "only_dynamics_plus" in schedule_name: + # parse schedule 'only_dynamics_plusN' to get N + plus_n = int(schedule_name.replace("only_dynamics_plus", "").replace("_discrete", "")) + # Add N additional steps to the front of the schedule + schedule = list(np.linspace(0, base_schedule[1], plus_n + 1, endpoint=False)) + if "_discrete" in schedule_name: # floor the values + schedule = [int(np.floor(s)) for s in schedule] + else: + assert "only_dynamics" == schedule_name, f"Invalid sampling schedule: {schedule}" + + elif schedule_name.startswith("every"): + # parse schedule 'everyNth' to get N + every_nth = schedule.replace("every", "").replace("th", "").replace("nd", "").replace("rd", "") + every_nth = int(every_nth) + assert 1 <= every_nth <= self.num_timesteps, f"Invalid sampling schedule: {schedule}" + schedule = artificial_interpolation_steps[::every_nth] + + elif schedule.startswith("first"): + # parse schedule 'firstN' to get N + first_n = float(schedule.replace("first", "").replace("v2", "")) + if first_n < 1: + assert 0 < first_n < 1, f"Invalid sampling schedule: {schedule}, must end with number/float > 0" + first_n = int(np.ceil(first_n * len(artificial_interpolation_steps))) + schedule = artificial_interpolation_steps[:first_n] + self.log_text.info(f"Using sampling schedule: {schedule_name} -> (first {first_n} steps)") + else: + assert first_n.is_integer(), f"If first_n >= 1, it must be an integer, but got {first_n}" + assert 1 <= first_n <= self.num_timesteps, f"Invalid sampling schedule: {schedule}" + first_n = int(first_n) + # Simple schedule: sample using first N steps + schedule = artificial_interpolation_steps[:first_n] + else: + raise ValueError(f"Invalid sampling schedule: ``{schedule}``. ") + + # Add dynamic steps to the schedule + schedule += base_schedule + # need to sort in ascending order and remove duplicates + schedule = list(sorted(set(schedule))) + + assert ( + 1 <= schedule[-1] <= self.num_timesteps + ), f"Invalid sampling schedule: {schedule}, must end with number/float <= {self.num_timesteps}" + if schedule[0] != 0: + self.log_text.warning( + f"Sampling schedule {schedule_name} must start at 0. Adding 0 to the beginning of it." + ) + schedule = [0] + schedule + + last = schedule[-1] + if last != self.num_timesteps - 1: + self.log_text.warning("------" * 20) + self.log_text.warning( + f"Are you sure you don't want to sample at the last timestep? (current last timestep: {last})" + ) + self.log_text.warning("------" * 20) + + # check that schedule is monotonically increasing + for i in range(1, len(schedule)): + assert schedule[i] > schedule[i - 1], f"Invalid sampling schedule not monotonically increasing: {schedule}" + + if all(float(s).is_integer() for s in schedule): + schedule = [int(s) for s in schedule] + else: + self.log_text.info(f"Sampling schedule {schedule_name} uses diffusion steps it has not been trained on!") + self._sampling_schedule = schedule + + def sample_loop( + self, + initial_condition, + log_every_t: Optional[Union[str, int]] = None, + num_predictions: int = None, + verbose=True, + **kwargs, + ): + log_every_t = log_every_t or self.hparams.log_every_t + log_every_t = log_every_t if log_every_t != "auto" else 1 + sampling_schedule = self.sampling_schedule + + assert len(initial_condition.shape) == 4, f"condition.shape: {initial_condition.shape} (should be 4D)" + intermediates, xhat_th, dynamics_pred_step = dict(), None, 0 + last_i_n_plus_one = sampling_schedule[-1] + 1 + s_and_snext = zip( + sampling_schedule, + sampling_schedule[1:] + [last_i_n_plus_one], + sampling_schedule[2:] + [last_i_n_plus_one, last_i_n_plus_one + 1], + ) + progress_bar = tqdm(s_and_snext, desc="Sampling", total=len(sampling_schedule), leave=False) + x_s = initial_condition + for s, s_next, s_nnext in progress_bar: + is_first_step = s == 0 + is_last_step = s == self.num_timesteps - 1 + + # Forecast x_{t+h} using x_{s} as input + xhat_th = self.predict_x_last(initial_condition=initial_condition, x_t=x_s, t=s, **kwargs) + + # Are we predicting dynamical time step or an artificial interpolation step? + time_i_n = self.diffusion_step_to_interpolation_step(s_next) if not is_last_step else np.inf + is_dynamics_pred = float(time_i_n).is_integer() or is_last_step + q_sample_kwargs = dict( + x0=xhat_th, + x_end=initial_condition, + is_artificial_step=not is_dynamics_pred, + num_predictions=num_predictions if is_first_step else 1, + ) + if s_next <= self.num_timesteps - 1: + # D(x_s, s-1) + x_interpolated_s_next = self.q_sample(**q_sample_kwargs, t=s_next, **kwargs) + else: + assert is_last_step, f"Invalid s_next: {s_next} (should be <= {self.num_timesteps - 1})" + x_interpolated_s_next = xhat_th # for the last step, we use the final x0_hat prediction + if self.hparams.hack_for_imprecise_interpolation: + x_interpolated_s_next = torch.cat([initial_condition[:, :1], x_interpolated_s_next], dim=1) + + if self.hparams.sampling_type == "cold": + if not self.hparams.use_cold_sampling_for_last_step and is_last_step: + if self.hparams.use_cold_sampling_for_init_of_ar_step: + x_interpolated_s = self.q_sample(**q_sample_kwargs, t=s, **kwargs) + ar_init = x_s + xhat_th - x_interpolated_s + if self.hparams.hack_for_imprecise_interpolation: + ar_init = ar_init[:, 1:] + intermediates["preds_autoregressive_init"] = ar_init + x_s = xhat_th + else: + # D(x_s, s) + x_interpolated_s = self.q_sample(**q_sample_kwargs, t=s, **kwargs) if s > 0 else x_s + # for s = 0, we have x_s_degraded = x_s, so we just directly return x_s_degraded_next + d_i1 = x_interpolated_s_next - x_interpolated_s + if self.hparams.sampling_type == "cold" or s_nnext > self.num_timesteps - 1: + x_s = x_s + d_i1 + elif self.hparams.sampling_type == "naive": + x_s = x_interpolated_s_next + else: + raise ValueError(f"unknown sampling type {self.hparams.sampling_type}") + + dynamics_pred_step = int(time_i_n) if s < self.num_timesteps - 1 else dynamics_pred_step + 1 + if is_dynamics_pred: + if self.hparams.use_cold_sampling_for_intermediate_steps or is_last_step: + preds_t = x_s + else: + assert not self.hparams.use_cold_sampling_for_intermediate_steps and not is_last_step + preds_t = x_interpolated_s_next + if self.hparams.hack_for_imprecise_interpolation: + preds_t = preds_t[:, 1:] + intermediates[f"t{dynamics_pred_step}_preds"] = preds_t # preds + if log_every_t is not None: + intermediates[f"t{dynamics_pred_step}_preds2"] = x_interpolated_s_next + + s1, s2 = s, s # s + 1, next_step # s, next_step + if log_every_t is not None: + intermediates[f"x_{s2}_dmodel"] = x_s # preds + intermediates[f"intermediate_{s1}_x0hat"] = xhat_th + intermediates[f"xipol_{s2}_dmodel"] = x_interpolated_s_next + if self.hparams.sampling_type == "cold": + intermediates[f"xipol_{s1}_dmodel2"] = x_interpolated_s + + if self.hparams.refine_intermediate_predictions: + # Use last prediction of x0 for final prediction of intermediate steps (not the last timestep!) + q_sample_kwargs["x0"] = xhat_th + q_sample_kwargs["is_artificial_step"] = False + dynamical_steps = self.hparams.prediction_timesteps or list(self.dynamical_steps.values()) + dynamical_steps = [i for i in dynamical_steps if i < self.num_timesteps] + for i_n in dynamical_steps: + i_n_for_str = int(i_n) if float(i_n).is_integer() else i_n + assert ( + not float(i_n).is_integer() or f"t{i_n_for_str}_preds" in intermediates + ), f"t{i_n_for_str}_preds not in intermediates" + intermediates[f"t{i_n_for_str}_preds"] = self.q_sample( + **q_sample_kwargs, + t=None, + interpolation_time=i_n, + **kwargs, + ) + if self.hparams.hack_for_imprecise_interpolation: + intermediates[f"t{i_n_for_str}_preds"] = intermediates[f"t{i_n_for_str}_preds"][:, 1:] + if last_i_n_plus_one < self.num_timesteps: + return x_s, intermediates + return xhat_th, intermediates + + @torch.inference_mode() + def sample(self, initial_condition, num_samples=1, **kwargs): + x_0, intermediates = self.sample_loop(initial_condition, **kwargs) + return intermediates + + def predict_forward(self, *inputs, metadata: Any = None, **kwargs): + assert len(inputs) == 1, "Only one input tensor is allowed for the forward pass" + inital_condition = inputs[0] + return self.sample(inital_condition, **kwargs) + + +# -------------------------------------------------------------------------------- +# DYffusion with a pretrained interpolator +# -------------------------------------------------------------------------------- + + +class DYffusion(BaseDYffusion): + """ + DYffusion model with a pretrained interpolator + Args: + interpolator: the interpolator model + lambda_reconstruction: the weight of the reconstruction loss + lambda_reconstruction2: the weight of the reconstruction loss (using the predicted xt_last as feedback) + """ + + def __init__( + self, + interpolator: Optional[nn.Module] = None, + interpolator_run_id: Optional[str] = None, + interpolator_checkpoint_path: Optional[str] = None, + lambda_reconstruction: float = 1.0, + lambda_reconstruction2: float = 0.0, + reconstruction2_detach_x_last: bool = False, + interpolator_local_checkpoint_path: Optional[Union[str, bool]] = True, # if true, search in local path + interpolator_overrides: Optional[List[str]] = None, # a dot list, e.g. ["model.hidden_dims=128"] + interpolator_local_config_path: Optional[str] = None, # If set, load the config from this path + interpolator_wandb_ckpt_filename: Optional[str] = None, + interpolator_wandb_kwargs: Optional[Dict[str, Any]] = None, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.save_hyperparameters(ignore=["interpolator", "model"]) + self.name = self.name or "DYffusion (2stage)" + # Load interpolator and its weights + interpolator_wandb_kwargs = interpolator_wandb_kwargs or {} + interpolator_wandb_kwargs["epoch"] = interpolator_wandb_kwargs.get("epoch", "best") + if interpolator_wandb_ckpt_filename is not None: + assert interpolator_wandb_kwargs.get("ckpt_filename") is None, "ckpt_filename already set" + interpolator_wandb_kwargs["ckpt_filename"] = interpolator_wandb_ckpt_filename + interpolator_overrides = list(interpolator_overrides) if interpolator_overrides is not None else [] + interpolator_overrides.append("model.verbose=False") + self.interpolator: InterpolationExperiment = get_checkpoint_from_path_or_wandb( + interpolator, + model_checkpoint_path=interpolator_local_checkpoint_path, + config_path=interpolator_local_config_path, + wandb_run_id=interpolator_run_id, + wandb_kwargs=interpolator_wandb_kwargs, + model_overrides=interpolator_overrides, + ) + # freeze the interpolator (and set to eval mode) + freeze_model(self.interpolator) + + self.interpolator_window = self.interpolator.window + self.interpolator_horizon = self.interpolator.true_horizon + last_d_to_i_tstep = self.diffusion_step_to_interpolation_step(self.num_timesteps - 1) + if self.interpolator_horizon != last_d_to_i_tstep + 1: + # maybe: automatically set the num_timesteps to the interpolator_horizon + raise ValueError( + f"interpolator horizon {self.interpolator_horizon} must be equal to the " + f"last interpolation step+1=i_N=i_{self.num_timesteps - 1}={last_d_to_i_tstep + 1}" + ) + + def _interpolate( + self, + initial_condition: Tensor, + x_last: Tensor, + t: Tensor, + num_predictions: int = 1, + **kwargs, + ): + # interpolator networks uses time in [1, horizon-1] + assert (0 < t).all() and ( + t < self.interpolator_horizon + ).all(), f"interpolate time must be in (0, {self.interpolator_horizon}), got {t}" + # select condition data to be consistent with the interpolator training data + if self.hparams.hack_for_imprecise_interpolation: + x_last = torch.cat([initial_condition[:, :1], x_last], dim=1) + interpolator_inputs = torch.cat([initial_condition, x_last], dim=1) + interpolator_outputs = self.interpolator.predict_packed(interpolator_inputs, time=t, **kwargs) + interpolator_outputs = interpolator_outputs["preds"] + if self.hparams.hack_for_imprecise_interpolation: + interpolator_outputs = torch.cat([initial_condition[:, :1], interpolator_outputs], dim=1) + return interpolator_outputs + + def p_losses(self, input_dynamics: Tensor, xt_last: Tensor, **kwargs): + r""" + + Args: + input_dynamics: the initial condition data (time = 0) + xt_last: the start/target data (time = horizon) + **kwargs: may include additional args for models, e.g. static conditions + """ + _ = kwargs.pop("verbose", None) # remove verbose from kwargs (unused) + criterion = self.criterion["preds"] + batch_size = input_dynamics.shape[0] + t = torch.randint(0, self.num_timesteps, (batch_size,), device=self.device, dtype=torch.long) + + # x_t is what multi-horizon exp passes as targets, and xt_last is the last timestep of the data dynamics + # check that the time step is valid (between 0 and horizon-1) + # assert torch.all(t >= 0) and torch.all(t <= self.num_timesteps-1), f'invalid time step {t}' + lam1 = self.hparams.lambda_reconstruction + lam2 = self.hparams.lambda_reconstruction2 + + # Create the inputs for the forecasting model + # 1. For t=0, simply use the initial conditions + x_t = input_dynamics.clone() + + # 2. For t>0, we need to interpolate the data using the interpolator + t_nonzero = t > 0 + if t_nonzero.any(): + # sample one interpolation prediction + x_interpolated = self.q_sample( + x_end=input_dynamics, + x0=xt_last, + t=t, + batch_mask=t_nonzero, + num_predictions=1, + **kwargs, + ) + # Now, simply concatenate the inital_conditions for t=0 with the interpolated data for t>0 + x_t[t_nonzero] = x_interpolated.to(x_t.dtype) + # assert torch.all(x_t[t == 0] == condition[t == 0]) + + # Train the forward predictions (i.e. predict xt_last from xt_t) + xt_last_pred = self.predict_x_last(initial_condition=input_dynamics, x_t=x_t, t=t, **kwargs) + loss_forward = criterion(xt_last_pred, xt_last) + + # Train the forward predictions II by emulating one more step of the diffusion process + t2 = t + 1 # t2 is the next time step, between 1 and T + tnot_last = t2 <= self.num_timesteps - 1 # tnot_last is True for t < T + if lam2 > 0 and tnot_last.any(): + # train the predictions using x0 = xlast = forward_pred(condition, t=0) + # x_last_denoised2 = self.predict_x_last(condition=condition, x_t=condition, t=torch.zeros_like(t)) + # simulate the diffusion process for a single step, where the x_last=forward_pred(condition, t) prediction + # is used to get the interpolated x_t+1 = interpolate(condition, x_last, t+1) + x_interpolated2 = self.q_sample( + x_end=input_dynamics, + x0=xt_last_pred.detach() if self.hparams.reconstruction2_detach_x_last else xt_last_pred, + t=t2, + batch_mask=tnot_last, + num_predictions=1, + **kwargs, + ) + x_last_pred2 = self.predict_x_last( + initial_condition=input_dynamics[tnot_last], x_t=x_interpolated2, t=t2[tnot_last], **kwargs + ) + loss_forward2 = criterion(x_last_pred2, xt_last[tnot_last]) + else: + loss_forward2 = 0.0 + + loss = lam1 * loss_forward + lam2 * loss_forward2 + + log_prefix = "train" if self.training else "val" + loss_dict = { + "loss": loss, + f"{log_prefix}/loss_forward": loss_forward, + f"{log_prefix}/loss_forward2": loss_forward2, + } + return loss_dict diff --git a/src/evaluation/__init__.py b/src/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/evaluation/aggregators/__init__.py b/src/evaluation/aggregators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/evaluation/aggregators/_abstract_aggregator.py b/src/evaluation/aggregators/_abstract_aggregator.py new file mode 100644 index 0000000..4cc2294 --- /dev/null +++ b/src/evaluation/aggregators/_abstract_aggregator.py @@ -0,0 +1,68 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Tuple + +import torch +from tensordict import TensorDictBase + +from src.utilities.utils import ellipsis_torch_dict_boolean_tensor, get_logger, to_tensordict + + +class AbstractAggregator(ABC): + def __init__( + self, + is_ensemble: bool = False, + area_weights: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + name: str | None = None, + verbose: bool = True, + ): + self.log_text = get_logger(name=self.__class__.__name__) + + self.mask = mask + if mask is not None: + self.log_text.info(f"{name}: Using mask for evaluation of shape {mask.shape}") if verbose else None + if area_weights is not None: + area_weights = area_weights[mask] + + if area_weights is not None and verbose: + self.log_text.info(f"{name}: Using area weights for evaluation of shape {area_weights.shape}") + self._area_weights = area_weights + self._is_ensemble = is_ensemble + self.name = name + + @abstractmethod + def _record_batch(self, **kwargs) -> None: ... + + def record_batch(self, predictions_mask: Optional[torch.Tensor] = None, **kwargs) -> None: + assert predictions_mask is None, f"Deprecated predictions_mask {predictions_mask}" + if self.mask is not None: + # Apply mask to all tensors + for key, data in kwargs.items(): + # print(f"{key} Shape before ellipsis_torch_dict_boolean_tensor: {data.shape}") + if torch.is_tensor(data): + kwargs[key] = data[..., self.mask] + elif isinstance(data, TensorDictBase): + kwargs[key] = to_tensordict( + {k: ellipsis_torch_dict_boolean_tensor(v, self.mask) for k, v in data.items()}, + find_batch_size_max=True, + ) + else: + raise ValueError(f"Unsupported data type {type(data)}") + # print(f"{key} Shape after ellipsis_torch_dict_boolean_tensor: {kwargs[key].shape}") + + return self._record_batch(**kwargs) + + @torch.inference_mode() + def get_logs(self, prefix: str, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: + prefix = "" if prefix is None else prefix + if self.name is not None and self.name not in prefix: + prefix = f"{prefix}/{self.name}".replace("//", "/").rstrip("/").lstrip("/") + logs_values, logs_media = self._get_logs(**kwargs) + logs_values = {f"{prefix}/{key}": value for key, value in logs_values.items()} + logs_media = {f"{prefix}/{key}": value for key, value in logs_media.items()} + return logs_values, logs_media + + @abstractmethod + def _get_logs(self, epoch: Optional[int] = None) -> Tuple[Dict[str, Any], Dict[str, Any]]: ... diff --git a/src/evaluation/aggregators/main.py b/src/evaluation/aggregators/main.py new file mode 100644 index 0000000..901a31b --- /dev/null +++ b/src/evaluation/aggregators/main.py @@ -0,0 +1,153 @@ +from __future__ import annotations + +from abc import ABC +from typing import Any, Callable, Dict, Mapping, Protocol, Tuple + +import torch + +from src.evaluation.aggregators._abstract_aggregator import AbstractAggregator +from src.evaluation.aggregators.snapshot import SnapshotAggregator +from src.evaluation.aggregators.timestepwise import MeanAggregator +from src.utilities.utils import get_logger + + +log = get_logger(__name__) + + +class _Aggregator(Protocol): + def get_logs(self, label: str) -> Mapping[str, torch.Tensor]: ... + + def record_batch( + self, + loss: float, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ) -> None: ... + + +class ListAggregator(AbstractAggregator, ABC): + def __init__( + self, + aggregators: list[AbstractAggregator], + **kwargs, + ): + super().__init__(**kwargs) + assert self.name is None, f"ListAggregator {self.name} should not have a name" + assert self._area_weights is None, f"ListAggregator {self.name} should not have area weights" + + self._aggregators = aggregators + for i, aggregator in enumerate(self._aggregators): + assert isinstance(aggregator, AbstractAggregator), f"Aggregator {i} is not an AbstractAggregator" + assert aggregator.name is not None, f"Aggregator {i}: {aggregator} has no name" + + def record_batch(self, **kwargs) -> None: + for aggregator in self._aggregators: + aggregator.record_batch(**kwargs) + + def _record_batch(self, **kwargs) -> None: + raise NotImplementedError("ListAggregator should not be called directly") + + def _get_logs(self, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]: + logs_values = {} + logs_media = {} + for aggregator in self._aggregators: + logs_values_i, logs_media_i = aggregator.get_logs(prefix=None, **kwargs) + logs_values.update(logs_values_i) + logs_media.update(logs_media_i) + return logs_values, logs_media + + +class OneStepAggregator(AbstractAggregator): + """ + Aggregates statistics for the timestep pairs. + + To use, call `record_batch` on the results of each batch, then call + `get_logs` to get a dictionary of statistics when you're done. + """ + + def __init__( + self, + use_snapshot_aggregator: bool = True, + record_normed: bool = False, + record_rmse: bool = True, + record_abs_values: bool = False, # logs absolutes mean and std of preds and targets + snapshot_var_names: list[str] = None, + every_nth_epoch_snapshot: int = 8, + snapshots_preprocess_fn: Callable = None, + **kwargs, + ): + super().__init__(**kwargs) + if use_snapshot_aggregator: + self._snapshot = SnapshotAggregator( + is_ensemble=self._is_ensemble, + var_names=snapshot_var_names, + every_nth_epoch=every_nth_epoch_snapshot, + preprocess_fn=snapshots_preprocess_fn, + ) + else: + self._snapshot = None + + self._mean = MeanAggregator( + area_weights=self._area_weights, + is_ensemble=self._is_ensemble, + record_normed=record_normed, + record_rmse=record_rmse, + record_abs_values=record_abs_values, + ) + self._aggregators: Mapping[str, _Aggregator] = { + "snapshot": self._snapshot, + "mean": self._mean, + } + + @torch.inference_mode() + def _record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + ): + if len(target_data) == 0: + raise ValueError("No data in target_data") + if len(gen_data) == 0: + raise ValueError("No data in gen_data") + + for k, aggregator in self._aggregators.items(): + if aggregator is None: + continue + aggregator.record_batch( + target_data=target_data, + gen_data=gen_data, + target_data_norm=target_data_norm, + gen_data_norm=gen_data_norm, + ) + + @torch.inference_mode() + def _get_logs(self, **kwargs) -> Tuple[Dict[str, float], Dict[str, float]]: + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + """ + try: + logs = self._mean.get_logs(label="", **kwargs) + except ValueError as e: + raise ValueError( + f"Aggregator ``{self.name}`` has problems with mean sub-aggregator.\n" + f"Did you forget to record any batches?" + ) from e + + if self._snapshot is not None: + logs_media = self._snapshot.get_logs(**kwargs) + # logs_media = {f"snapshot/{key}": val for key, val in logs_media.items()} + else: + logs_media = {} + for agg_label, agg in self._aggregators.items(): + if agg is None or agg_label in ["mean", "snapshot"]: + continue + logs.update(agg.get_logs(label=agg_label, **kwargs)) + # logs.update({f"{label}/{key}": float(val) for key, val in agg.get_logs(label=agg_label).items()}) + return logs, logs_media diff --git a/src/evaluation/aggregators/snapshot.py b/src/evaluation/aggregators/snapshot.py new file mode 100644 index 0000000..c3382b8 --- /dev/null +++ b/src/evaluation/aggregators/snapshot.py @@ -0,0 +1,208 @@ +from typing import List, Mapping, Optional + +import matplotlib.pyplot as plt +import torch + + +try: + import seaborn as sns + + # Apply Seaborn styles for enhanced aesthetics + sns.set( + context="talk", style="white", palette="colorblind", font="serif", font_scale=1, rc={"lines.linewidth": 2.5} + ) +except ImportError: + pass + + +class SnapshotAggregator: + """ + An aggregator that records the first sample of the last batch of data. + > The way it works is that it gets called once per batch, but in the end (when using get_logs) + it only returns information based on the last batch. + """ + + _captions = { + "full-field": "{name} one step full field for last samples; (left) generated and (right) target.", # noqa: E501 + "residual": "{name} one step residual for last samples; (left) generated and (right) target.", # noqa: E501 + "error": "{name} one step error (generated - target) for last sample.", + } + + def __init__( + self, + is_ensemble: bool, + target_time: Optional[int] = None, + var_names: Optional[List[str]] = None, + every_nth_epoch: int = 1, + preprocess_fn=None, + ): + self.is_ensemble = is_ensemble + assert target_time is None or target_time > 0 + self.target_time = target_time # account for 0-indexing not needed because initial condition is included + self.target_time_in_batch = None + self.var_names = var_names + self.every_nth_epoch = every_nth_epoch + self.preprocess_fn = preprocess_fn if preprocess_fn is not None else lambda x: x + + @torch.inference_mode() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor], + gen_data_norm: Mapping[str, torch.Tensor], + loss=None, + i_time_start: int = 0, + ): + data_steps = target_data_norm[list(target_data_norm.keys())[0]].shape[1] + if self.target_time is not None: + diff = self.target_time - i_time_start + # target time needs to be in the batch (between i_time_start and i_time_start + data_steps) + if diff < 0 or diff >= data_steps: + return # skip this batch, since it doesn't contain the target time + else: + self.target_time_in_batch = diff + + def to_cpu(x): + return {k: v.cpu() for k, v in x.items()} if isinstance(x, dict) else x.cpu() + + self._target_data = to_cpu(target_data) + self._gen_data = to_cpu(gen_data) + self._target_data_norm = to_cpu(target_data_norm) + self._gen_data_norm = to_cpu(gen_data_norm) + if self.target_time is not None: + assert ( + self.target_time_in_batch <= data_steps + ), f"target_time={self.target_time}, time_in_batch={self.target_time_in_batch} is larger than the number of timesteps in the data={data_steps}!" + + @torch.inference_mode() + def get_logs(self, label: str = "", epoch: int = None): + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + epoch: Current epoch number. + """ + if self.every_nth_epoch > 1 and epoch >= 3 and epoch % self.every_nth_epoch != 0: + return {} + if self.target_time_in_batch is None and self.target_time is not None: + return {} # skip this batch, since it doesn't contain the target time + image_logs = {} + max_snapshots = 2 # 3 + names = self.var_names if self.var_names is not None else self._gen_data_norm.keys() + for name in names: + name_label = name + if "normed" in name: + gen_data = self._gen_data_norm + target_data = self._target_data_norm + name = name.replace("_normed", "") + else: + gen_data = self._gen_data + target_data = self._target_data + + if self.is_ensemble: + snapshots_pred = gen_data[name][:max_snapshots, 0] + else: + snapshots_pred = gen_data[name][0].unsqueeze(0) + target_for_image = target_data[name][0] # first sample in batch + input_for_image = None + # Select target time + if self.target_time is not None: + snapshots_pred = snapshots_pred[:, self.target_time_in_batch] + target_for_image = target_for_image[self.target_time_in_batch] + if input_for_image is not None: + input_for_image = input_for_image[self.target_time_in_batch] + + n_ens_members = snapshots_pred.shape[0] + figsize1 = ((n_ens_members + 1) * 5, 5) + figsize2 = (n_ens_members * 5, 5) + fig_full_field, ax_full_field = plt.subplots( + 1, n_ens_members + 1, figsize=figsize1, sharex=True, sharey=True + ) + fig_error, ax_error = plt.subplots(1, n_ens_members, figsize=figsize2, sharex=True, sharey=True) + ax_error = [ax_error] if n_ens_members == 1 else ax_error + # Compute vmin and vmax + vmin = min(snapshots_pred.min(), target_for_image.min()) + vmax = max(snapshots_pred.max(), target_for_image.max()) + # Plot full field and compute errors. Plot with colorbar using same vmin and vmax (different for error vs full field) + errors = [snapshots_pred[i] - target_for_image for i in range(n_ens_members)] + vmin_error = min([error.min() for error in errors]) + vmax_error = max([error.max() for error in errors]) + if abs(vmin_error) > abs(vmax_error): + vmax_error = -vmin_error + else: + vmin_error = -vmax_error # make sure 0 is in the middle of the colorbar + # Preprocess (e.g. flip) the images so that they are plotted correctly + snapshots_pred = self.preprocess_fn(snapshots_pred.cpu().numpy()) + target_for_image = self.preprocess_fn(target_for_image.cpu().numpy()) + errors = [self.preprocess_fn(error.cpu().numpy()) for error in errors] + + for i in range(n_ens_members): + # Plot full field with colorbar + pcm_ff = ax_full_field[i].imshow(snapshots_pred[i], vmin=vmin, vmax=vmax) + ax_ff_title = f"Generated {i}" if n_ens_members > 1 else "Generated" + ax_full_field[i].set_title(ax_ff_title) + # Plot error with red blue colorbar + pcm_err = ax_error[i].imshow(errors[i], vmin=vmin_error, vmax=vmax_error, cmap="seismic") + ax_error_title = rf"$\hat{{y}}_{i} - y$" if n_ens_members > 1 else r"$\hat{y} - y$" + ax_error[i].set_title(ax_error_title) + + pcm_ff = ax_full_field[-1].imshow(target_for_image, vmin=vmin, vmax=vmax) + ax_full_field[-1].set_title("Target") + # Create colorbar's beneath the images horizontally. To make it less thick: shrink=0.8, pad=0.03, location="bottom", aspect=40, fraction=0.05 + cbar_kwargs = dict(location="bottom", shrink=0.8, pad=0.03, fraction=0.08) + # Defaults are: shrink=1, pad=0.05, fraction=0.15 + fig_full_field.colorbar(pcm_ff, ax=ax_full_field, **cbar_kwargs) + fig_error.colorbar(pcm_err, ax=ax_error, **cbar_kwargs) + # Set titles + # fig_full_field.suptitle(f"{name_label} full field; (left) generated and (right) target.", y=title_y) + # fig_error.suptitle(f"{name_label} error (generated - target).", y=title_y) + # fig_error.suptitle("generated - target", y=title_y) + # Disable ticks + for ax in ax_full_field: + ax.axis("off") + for ax in ax_error: + ax.axis("off") + + # fig_full_field.tight_layout() + # fig_error.tight_layout() + image_logs[f"image-full-field/{name_label}"] = fig_full_field + image_logs[f"image-error/{name_label}"] = fig_error + + # small_gap = torch.zeros((target_for_image.shape[-2], 2)).to(snapshots_pred.device, dtype=torch.float) + # gap = torch.zeros((target_for_image.shape[-2], 4)).to( + # snapshots_pred.device, dtype=torch.float + # ) # gap between images in wandb (so we can see them separately) + # # Create image tensors + # image_error, image_full_field, image_residual = [], [], [] + # for i in range(snapshots_pred.shape[0]): + # image_full_field += [snapshots_pred[i]] + # image_error += [snapshots_pred[i] - target_for_image] + # if input_for_image is not None: + # image_residual += [snapshots_pred[i] - input_for_image] + # if i == snapshots_pred.shape[0] - 1: + # image_full_field += [gap, target_for_image] + # if input_for_image is not None: + # image_residual += [gap, target_for_image - input_for_image] + # else: + # image_full_field += [small_gap] + # image_residual += [small_gap] + # image_error += [small_gap] + + # images = {} + # images["error"] = torch.cat(image_error, dim=1) + # images["full-field"] = torch.cat(image_full_field, dim=1) + # if input_for_image is not None: + # images["residual"] = torch.cat(image_residual, dim=1) + + # for key, data in images.items(): + # caption = self._captions[key].format(name=name) + # caption += f" vmin={data.min():.4g}, vmax={data.max():.4g}." + # data = np.flip(data.cpu().numpy(), axis=-2) + # wandb_image = wandb.Image(data, caption=caption) + # image_logs[f"image-{key}/{name}"] = wandb_image + + label = label + "/" if label else "" + image_logs = {f"{label}{key}": image_logs[key] for key in image_logs} + return image_logs diff --git a/src/evaluation/aggregators/time_mean.py b/src/evaluation/aggregators/time_mean.py new file mode 100644 index 0000000..dad216d --- /dev/null +++ b/src/evaluation/aggregators/time_mean.py @@ -0,0 +1,116 @@ +from typing import Dict, Mapping, Optional, Tuple + +import numpy as np +import torch +import xarray as xr + +from src.evaluation import metrics +from src.evaluation.aggregators._abstract_aggregator import AbstractAggregator +from src.utilities.utils import add + + +def get_gen_shape(gen_data: Mapping[str, torch.Tensor]): + for name in gen_data: + return gen_data[name].shape + + +class TimeMeanAggregator(AbstractAggregator): + """Statistics on the time-mean state. + + This aggregator keeps track of the time-mean state, then computes + statistics on that time-mean state when logs are retrieved. + """ + + _image_captions = { + "bias_map": "{name} time-mean bias (generated - target)", + "gen_map": "{name} time-mean generated", + } + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self._target_data: Optional[Dict[str, torch.Tensor]] = None + self._gen_data: Optional[Dict[str, torch.Tensor]] = None + self._target_data_norm = None + self._gen_data_norm = None + self._n_batches = 0 + + @torch.inference_mode() + def _record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + ): + def add_or_initialize_time_mean( + maybe_dict: Optional[Dict[str, torch.Tensor]], + new_data: Mapping[str, torch.Tensor], + ) -> Mapping[str, torch.Tensor]: + if maybe_dict is None: + d: Dict[str, torch.Tensor] = {name: tensor for name, tensor in new_data.items()} + else: + d = add(maybe_dict, new_data) + return d + + self._target_data = add_or_initialize_time_mean(self._target_data, target_data) + self._gen_data = add_or_initialize_time_mean(self._gen_data, gen_data) + self._n_batches += 1 + + @torch.inference_mode() + def _get_logs(self, **kwargs) -> Tuple[Dict[str, float], Dict[str, float]]: + """ + Returns logs as can be reported to WandB. + """ + if self._n_batches == 0: + raise ValueError("No data recorded.") + area_weights = self._area_weights + logs = {} + # dist = Distributed.get_instance() + for name in self._gen_data.keys(): + gen = self._gen_data[name] / self._n_batches + target = self._target_data[name] / self._n_batches + # gen = dist.reduce_mean(self._gen_data[name] / self._n_batches) + # target = dist.reduce_mean(self._target_data[name] / self._n_batches) + if self._is_ensemble: + gen_ens_mean = gen.mean(dim=0) + logs[f"rmse_member_avg/{name}"] = np.mean( + [ + metrics.root_mean_squared_error(predicted=gen[i], truth=target, weights=area_weights) + .cpu() + .numpy() + for i in range(gen.shape[0]) + ] + ) + logs[f"bias_member_avg/{name}"] = np.mean( + [ + metrics.time_and_global_mean_bias(predicted=gen[i], truth=target, weights=area_weights) + .cpu() + .numpy() + for i in range(gen.shape[0]) + ] + ) + else: + gen_ens_mean = gen + + logs[f"rmse/{name}"] = float( + metrics.root_mean_squared_error(predicted=gen_ens_mean, truth=target, weights=area_weights) + .cpu() + .numpy() + ) + + logs[f"bias/{name}"] = float( + metrics.time_and_global_mean_bias(predicted=gen_ens_mean, truth=target, weights=area_weights) + .cpu() + .numpy() + ) + logs[f"crps/{name}"] = float( + metrics.crps_ensemble(predicted=gen, truth=target, weights=area_weights).cpu().numpy() + ) + return logs, {} + + @torch.inference_mode() + def get_dataset(self, **kwargs) -> xr.Dataset: + logs = self.get_logs(**kwargs) + logs = {key.replace("/", "-"): logs[key] for key in logs} + data_vars = {} + for key, value in logs.items(): + data_vars[key] = xr.DataArray(value) + return xr.Dataset(data_vars=data_vars) diff --git a/src/evaluation/aggregators/timestepwise.py b/src/evaluation/aggregators/timestepwise.py new file mode 100644 index 0000000..583e4c8 --- /dev/null +++ b/src/evaluation/aggregators/timestepwise.py @@ -0,0 +1,214 @@ +from collections import defaultdict +from typing import Dict, Mapping, Optional + +import torch +import xarray as xr +from torch import nn + +from src.evaluation import metrics +from src.evaluation.reduced_metrics import AreaWeightedReducedMetric, ReducedMetric + + +class AbstractMeanMetric: + def __init__(self, device: torch.device): + self._total = torch.tensor(0.0, device=device) + + def get(self) -> torch.Tensor: + return self._total + + +class L1Loss(AbstractMeanMetric): + # Note: NOT area weighted + def record(self, targets: torch.Tensor, preds: torch.Tensor): + self._total += nn.functional.l1_loss(preds, targets) + + +class MeanAggregator: + """ + Aggregator for mean-reduced metrics. + + These are metrics such as means which reduce to a single float for each batch, + and then can be averaged across batches to get a single float for the + entire dataset. This is important because the aggregator uses the mean to combine + metrics across batches and processors. + """ + + def __init__( + self, + area_weights: torch.Tensor, + is_ensemble: bool, + record_normed: bool = False, + record_rmse: bool = True, + record_abs_values: bool = False, + ): + self._area_weights = area_weights + self._n_batches = 0 + self._variable_metrics: Optional[Dict[str, Dict[str, ReducedMetric]]] = None + self.is_ensemble = is_ensemble + self.record_normed = record_normed + self.record_rmse = record_rmse + self.record_abs_values = record_abs_values + if area_weights is None: + self._area_weights_dims = (-3, -2, -1) # None # ( -3, -2, -1)) + elif len(area_weights.shape) == 2: + self._area_weights_dims = (-2, -1) + elif len(area_weights.shape) == 1: + self._area_weights_dims = (-1,) + else: + raise ValueError(f"Area weights must be 1D or 2D tensor, got {area_weights.shape}") + + def _get_variable_metrics(self, gen_data: Mapping[str, torch.Tensor]): + if self._variable_metrics is None: + self._variable_metrics = defaultdict(dict) + if torch.is_tensor(gen_data): + self.device = gen_data.device + gen_data_keys = [""] + else: + self.device = gen_data[list(gen_data.keys())[0]].device # any key will do + gen_data_keys = list(gen_data.keys()) + if self._area_weights is not None: + area_weights = self._area_weights.to(self.device) + else: + area_weights = None + + metric_names = ["l1", "rmse", "bias", "grad_mag_percent_diff"] + if self.is_ensemble: + metric_names += ["ssr", "crps"] + if self.record_normed: + metric_names += [f"{metric}_normed" for metric in metric_names if metric != "l1"] + for i, var_name in enumerate(gen_data_keys): + try: + self._variable_metrics["l1"][var_name] = L1Loss(device=self.device) + except KeyError as e: + if i > 0: + raise e + self._variable_metrics = dict() + for metric in metric_names: + self._variable_metrics[metric] = dict() + self._variable_metrics["l1"][var_name] = L1Loss(device=self.device) + + if self.record_rmse: + mse_metric = ("rmse", metrics.root_mean_squared_error) + else: + mse_metric = ("mse", metrics.mean_squared_error) + metrics_zipped = [ + mse_metric, + ("bias", metrics.weighted_mean_bias), + ("grad_mag_percent_diff", metrics.gradient_magnitude_percent_diff), + ] + if self.record_abs_values: + metrics_zipped += [ + ("mean_gen", metrics.compute_metric_on(source="gen", metric=metrics.weighted_mean)), + ("mean_target", metrics.compute_metric_on(source="target", metric=metrics.weighted_mean)), + ("std_gen", metrics.compute_metric_on(source="gen", metric=metrics.weighted_std)), + ("std_target", metrics.compute_metric_on(source="target", metric=metrics.weighted_std)), + ] + if self.is_ensemble: + metrics_zipped += [("crps", metrics.crps_ensemble)] + metrics_zipped += [("ssr", metrics.spread_skill_ratio)] + + for i, (metric_name, metric) in enumerate(metrics_zipped): + self._variable_metrics[metric_name][var_name] = AreaWeightedReducedMetric( + area_weights=area_weights, + device=self.device, + compute_metric=metric, + dim=self._area_weights_dims, + ) + + if self.record_normed: + for var_name in gen_data_keys: + for i, (metric_name, metric) in enumerate(metrics_zipped): + self._variable_metrics[f"{metric_name}_normed"][var_name] = AreaWeightedReducedMetric( + area_weights=area_weights, + device=self.device, + compute_metric=metric, + dim=self._area_weights_dims, + ) + + return self._variable_metrics + + @torch.inference_mode() + def record_batch( + self, + target_data: Mapping[str, torch.Tensor], + gen_data: Mapping[str, torch.Tensor], + target_data_norm: Mapping[str, torch.Tensor] = None, + gen_data_norm: Mapping[str, torch.Tensor] = None, + ): + variable_metrics = self._get_variable_metrics(gen_data) + is_tensor = torch.is_tensor(gen_data) + if is_tensor: # add dummy key + gen_data = {"": gen_data} + target_data = {"": target_data} + gen_data_norm = {"": gen_data_norm} + target_data_norm = {"": target_data_norm} + + record_normed_list = [True, False] if self.record_normed else [False] + for is_normed in record_normed_list: + if is_normed: + preds_data = gen_data_norm + truth_data = target_data_norm + var_metrics_here = {metric: v for metric, v in variable_metrics.items() if "normed" in metric} + else: + preds_data = gen_data + truth_data = target_data + var_metrics_here = {metric: v for metric, v in variable_metrics.items() if "normed" not in metric} + + for metric in var_metrics_here.keys(): # e.g. l1, weighted_rmse, etc + if "grad_mag" in metric: + kwargs = {"is_ensemble_prediction": self.is_ensemble} + else: + kwargs = {} + + for var_name, var_preds in preds_data.items(): # e.g. temperature, precipitation, etc + if "ssr" in metric or "crps" in metric or "grad_mag" in metric: + preds = var_preds + else: + preds = var_preds.mean(dim=0) if self.is_ensemble else var_preds + + # time_s = time.time() + try: + variable_metrics[metric][var_name].record(targets=truth_data[var_name], preds=preds, **kwargs) + except AssertionError as e: + raise AssertionError(f"Error with {metric=}. {var_name=}, {self.is_ensemble=}") from e + # time.time() - time_s + # print(f"Time taken for {metric} {name} in s: {time_taken:.5f}") + + self._n_batches += 1 + + @torch.inference_mode() + def get_logs(self, label: str = "", epoch: Optional[int] = None) -> Dict[str, float]: + """ + Returns logs as can be reported to WandB. + + Args: + label: Label to prepend to all log keys. + epoch: Current epoch number. + """ + if self._variable_metrics is None or self._n_batches == 0: + raise ValueError(f"No batches have been recorded. n_batches={self._n_batches}") + logs = {} + label = label + "/" if label else "" + for i, metric in enumerate(self._variable_metrics): + for variable, metric_value in self._variable_metrics[metric].items(): + metric_value = metric_value.get() + if metric_value is None: + raise ValueError( + f"{metric=} hasn't been computed for {variable=}. ({label=}, {self._n_batches=}, {i=})" + ) + log_key = f"{label}{metric}/{variable}".rstrip("/") + logs[log_key] = float((metric_value / self._n_batches).detach().item()) + + # for key in sorted(logs.keys()): + # logs[key] = float(logs[key].cpu()) # .numpy() + + return logs + + @torch.inference_mode() + def get_dataset(self, label: str) -> xr.Dataset: + logs = self.get_logs(label=label) + logs = {key.replace("/", "-"): logs[key] for key in logs} + data_vars = {} + for key, value in logs.items(): + data_vars[key] = xr.DataArray(value) + return xr.Dataset(data_vars=data_vars) diff --git a/src/evaluation/metrics.py b/src/evaluation/metrics.py new file mode 100644 index 0000000..e25a44f --- /dev/null +++ b/src/evaluation/metrics.py @@ -0,0 +1,456 @@ +from typing import Iterable, Literal, Optional, Protocol, Union + +import numpy as np +import torch +from torch import Tensor +from typing_extensions import TypeAlias + + +Dimension: TypeAlias = Union[int, Iterable[int]] +Array: TypeAlias = Union[np.ndarray, torch.Tensor] + +GRAVITY = 9.80665 # m/s^2 + + +def spherical_area_weights(lats: Array, num_lon: int, device=None) -> torch.Tensor: + """Computes area weights given the latitudes of a regular lat-lon grid. + + Args: + lats: tensor of shape (num_lat,) with the latitudes of the cell centers. + num_lon: Number of longitude points. + device: Device to place the tensor on. + + Returns a torch.tensor of shape (num_lat, num_lon). + """ + if isinstance(lats, np.ndarray): + lats = torch.from_numpy(lats) + weights = torch.cos(torch.deg2rad(lats)).repeat(num_lon, 1).t() + weights /= weights.sum() + return weights + + +def weighted_mean( + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + keepdim: bool = False, +) -> torch.Tensor: + """Computes the weighted mean across the specified list of dimensions. + + Args: + tensor: torch.Tensor + weights: Weights to apply to the mean. + dim: Dimensions to compute the mean over. + keepdim: Whether the output tensor has `dim` retained or not. + + Returns: + a tensor of the weighted mean averaged over the specified dimensions `dim`. + """ + if weights is None: + return tensor.mean(dim=dim, keepdim=keepdim) + try: + return (tensor * weights).sum(dim=dim, keepdim=keepdim) / weights.expand(tensor.shape).sum( + dim=dim, keepdim=keepdim + ) + except RuntimeError as e: + raise RuntimeError( + f"Error computing weighted mean. tensor.shape={tensor.shape}, weights.shape={weights.shape}, dim={dim}" + ) from e + + +def weighted_std( + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), +) -> torch.Tensor: + """Computes the weighted standard deviation across the specified list of dimensions. + + Computed by first computing the weighted variance, then taking the square root. + + weighted_variance = weighted_mean((tensor - weighted_mean(tensor)) ** 2)) ** 0.5 + + Args: + tensor: torch.Tensor + weights: Weights to apply to the variance. + dim: Dimensions to compute the standard deviation over. + + Returns: + a tensor of the weighted standard deviation over the + specified dimensions `dim`. + """ + if weights is None: + weights = torch.tensor(1.0, device=tensor.device) + + mean = weighted_mean(tensor, weights=weights, dim=dim, keepdim=True) + variance = weighted_mean((tensor - mean) ** 2, weights=weights, dim=dim) + return torch.sqrt(variance) + + +def weighted_mean_bias( + truth: Tensor, + predicted: Tensor, + weights: Optional[Tensor] = None, + dim: Dimension = (), +) -> Tensor: + """Computes the mean bias across the specified list of dimensions assuming + that the weights are applied to the last dimensions, e.g. the spatial dimensions. + + Args: + truth: Tensor + predicted: Tensor + dim: Dimensions to compute the mean over. + weights: Weights to apply to the mean. + + Returns a tensor of the mean biases averaged over the specified dimensions `dim`. + """ + assert truth.shape == predicted.shape, "Truth and predicted should have the same shape." + bias = predicted - truth + return weighted_mean(bias, weights=weights, dim=dim) + + +def mean_squared_error( + truth: Tensor, + predicted: Tensor, + weights: Optional[Tensor] = None, + dim: Dimension = (), +) -> Tensor: + """ + Computes the weighted global MSE over all variables. Namely, for each variable: + + sqrt((weights * ((xhat - x) ** 2)).mean(dims)) + + If you want to compute the MSE over the time dimension, then pass in + `truth.mean(time_dim)` and `predicted.mean(time_dim)` and specify `dims=space_dims`. + + Args: + truth: Tensor whose last dimensions are to be weighted + predicted: Tensor whose last dimensions are to be weighted + weights: Tensor to apply to the squared bias. + dim: Dimensions to average over. + + Returns a tensor of shape (variable,) of weighted RMSEs. + """ + assert ( + truth.shape == predicted.shape + ), f"Truth and predicted should have the same shape. But got {truth.shape} and {predicted.shape}." + sq_bias = torch.square(predicted - truth) + return weighted_mean(sq_bias, weights=weights, dim=dim) + + +def root_mean_squared_error( + truth: Tensor, + predicted: Tensor, + weights: Optional[Tensor] = None, + dim: Dimension = (), +) -> Tensor: + """ + Computes the weighted global RMSE over all variables. Namely, for each variable: + + sqrt((weights * ((xhat - x) ** 2)).mean(dims)) + + If you want to compute the RMSE over the time dimension, then pass in + `truth.mean(time_dim)` and `predicted.mean(time_dim)` and specify `dims=space_dims`. + + Args: + truth: Tensor whose last dimensions are to be weighted + predicted: Tensor whose last dimensions are to be weighted + weights: Tensor to apply to the squared bias. + dim: Dimensions to average over. + + Returns a tensor of shape (variable,) of weighted RMSEs. + """ + mse = mean_squared_error(truth, predicted, weights=weights, dim=dim) + return torch.sqrt(mse) + + +def ensemble_spread(predicted: Tensor, weights: Optional[Tensor] = None, dim: Dimension = ()) -> Tensor: + """Compute the spread of the ensemble members. + This is calculated as the square root of the average ensemble variance, + which is different from the standard deviation of the ensemble. + See Fortuin et al. 2013 for more details why the square root of the average ensemble variance is adequate. + Args: + predicted (torch.Tensor): The predictions of the ensemble, of shape (n_member, n_samples, *) + """ + mean_ensemble_variance = weighted_mean(predicted.var(dim=0), weights=weights, dim=dim) + return torch.sqrt(mean_ensemble_variance) + + +def spread_skill_ratio( + truth: Tensor, predicted: Tensor, weights: Optional[Tensor] = None, dim: Dimension = () +) -> Tensor: + """Compute the spread-skill ratio (SSR) of an ensemble of predictions. + The SSR is defined as the ratio of the ensemble spread to the ensemble-mean RMSE. + Args: + predicted (torch.Tensor): The predictions of the ensemble, of shape (n_member, n_samples, *) + truth (torch.Tensor): The targets, of shape (n_samples, *) + weights (torch.Tensor, optional): The weights to apply to the spread. Defaults to None. + dim (Dimension, optional): The dimensions over which to compute the spread. Defaults to (). + """ + assert len(truth.shape) == len(predicted.shape) - 1, f"{truth.shape=} and {predicted.shape=}" + n_mems = predicted.shape[0] + spread = ensemble_spread(predicted, weights=weights, dim=dim) + # calculate skill as ensemble_mean RMSE + rmse = root_mean_squared_error(truth, predicted.mean(dim=0), weights=weights, dim=dim) + # Add correction factor sqrt((M+1)/M); see https://doi.org/10.1175/JHM-D-14-0008.1), important for small ensemble sizes + spread *= ((n_mems + 1) / n_mems) ** 0.5 + return spread / rmse + + +def crps_ensemble( + truth: Tensor, # TRUTH + predicted: Tensor, # FORECAST + weights: Tensor = None, + dim: Union[int, Iterable[int]] = (), + reduction="mean", +) -> Tensor: + """ + .. Author: Salva Rühling Cachay + + pytorch adaptation of https://github.com/TheClimateCorporation/properscoring/blob/master/properscoring/_crps.py#L187 + but implementing the fair, unbiased CRPS as in Zamo & Naveau (2018; https://doi.org/10.1007/s11004-017-9709-7) + + This implementation is based on the identity: + .. math:: + CRPS(F, x) = E_F|X - x| - 1/2 * E_F|X - X'| + where X and X' denote independent random variables drawn from the forecast + distribution F, and E_F denotes the expectation value under F. + + We use the fair, unbiased formulation of the ensemble CRPS, which is particularly important for small ensembles. + Anecdotically, the unbiased CRPS leads to slightly smaller (i.e. "better") values than the biased version. + Basically, we use n_members * (n_members - 1) instead of n_members**2 to average over the ensemble spread. + See Zamo & Naveau (2018; https://doi.org/10.1007/s11004-017-9709-7) for details. + + Alternative implementation: https://github.com/NVIDIA/modulus/pull/577/files + """ + assert truth.ndim == predicted.ndim - 1, f"{truth.shape=}, {predicted.shape=}" + assert truth.shape == predicted.shape[1:] # ensemble ~ first axis + n_members = predicted.shape[0] + skill = (predicted - truth).abs().mean(dim=0) + # insert new axes so forecasts_diff expands with the array broadcasting + # torch.unsqueeze(predictions, 0) has shape (1, E, ...) + # torch.unsqueeze(predictions, 1) has shape (E, 1, ...) + forecasts_diff = torch.unsqueeze(predicted, 0) - torch.unsqueeze(predicted, 1) + # Forecasts_diff has shape (E, E, ...) + # Old version: score += - 0.5 * forecasts_diff.abs().mean(dim=(0, 1)) + # Using n_members * (n_members - 1) instead of n_members**2 is the fair, unbiased CRPS. Better for small ensembles. + spread = forecasts_diff.abs().sum(dim=(0, 1)) / (n_members * (n_members - 1)) + crps = skill - 0.5 * spread + # score has shape (...) (same as observations) + if reduction == "none": + return crps + assert reduction == "mean", f"Unknown reduction {reduction}" + if weights is not None: # weighted mean + crps = (crps * weights).sum(dim=dim) / weights.expand(crps.shape).sum(dim=dim) + else: + crps = crps.mean(dim=dim) + return crps + + +def gradient_magnitude(tensor: Tensor, dim: Dimension = ()) -> Tensor: + """Compute the magnitude of gradient across the specified dimensions.""" + no_singleton_dims = tuple(d for d in dim if tensor.shape[d] > 1) + gradients = torch.gradient( + tensor.squeeze(), dim=no_singleton_dims + ) # squeeze to remove singleton dimensions, which cause errors (edge_order) + grad_magnitude = torch.sqrt(sum([g**2 for g in gradients])) + grad_magnitude = grad_magnitude.reshape(tensor.shape) # restore original shape + return grad_magnitude + + +def weighted_mean_gradient_magnitude(tensor: Tensor, weights: Optional[Tensor] = None, dim: Dimension = ()) -> Tensor: + """Compute weighted mean of gradient magnitude across the specified dimensions.""" + return weighted_mean(gradient_magnitude(tensor, dim), weights=weights, dim=dim) + + +def gradient_magnitude_percent_diff( + truth: Tensor, + predicted: Tensor, + weights: Optional[Tensor] = None, + dim: Dimension = (), + is_ensemble_prediction: bool = False, +) -> Tensor: + """Compute the percent difference of the weighted mean gradient magnitude across + the specified dimensions.""" + truth_grad_mag = weighted_mean_gradient_magnitude(truth, weights, dim) + if is_ensemble_prediction: + predicted_grad_mag = 0 + for ens_i, pred in enumerate(predicted): + predicted_grad_mag += weighted_mean_gradient_magnitude(pred, weights, dim) + predicted_grad_mag /= predicted.shape[0] + else: + assert truth.shape == predicted.shape, "Truth and predicted should have the same shape." + predicted_grad_mag = weighted_mean_gradient_magnitude(predicted, weights, dim) + return 100 * (predicted_grad_mag - truth_grad_mag) / truth_grad_mag + + +def rmse_of_time_mean( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + time_dim: Dimension = 0, + spatial_dims: Dimension = (-2, -1), +) -> torch.Tensor: + """Compute the RMSE of the time-average given truth and predicted. + + Args: + truth: truth tensor + predicted: predicted tensor + weights: weights to use for computing spatial RMSE + time_dim: time dimension + spatial_dims: spatial dimensions over which RMSE is calculated + + Returns: + The RMSE between the time-mean of the two input tensors. The time and + spatial dims are reduced. + """ + truth_time_mean = truth.mean(dim=time_dim) + predicted_time_mean = predicted.mean(dim=time_dim) + ret = root_mean_squared_error(truth_time_mean, predicted_time_mean, weights=weights, dim=spatial_dims) + return ret + + +def time_and_global_mean_bias( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + time_dim: Dimension = 0, + spatial_dims: Dimension = (-2, -1), +) -> torch.Tensor: + """Compute the global- and time-mean bias given truth and predicted. + + Args: + truth: truth tensor + predicted: predicted tensor + weights: weights to use for computing the global mean + time_dim: time dimension + spatial_dims: spatial dimensions over which global mean is calculated + + Returns: + The global- and time-mean bias between the predicted and truth tensors. The + time and spatial dims are reduced. + """ + truth_time_mean = truth.mean(dim=time_dim) + predicted_time_mean = predicted.mean(dim=time_dim) + result = weighted_mean(predicted_time_mean - truth_time_mean, weights=weights, dim=spatial_dims) + return result + + +class AreaWeightedFunction(Protocol): + """ + A function that computes a metric on the true and predicted values, + weighted by area. + """ + + def __call__( + self, + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +class AreaWeightedSingleTargetFunction(Protocol): + """ + A function that computes a metric on a single value, weighted by area. + """ + + def __call__( + self, + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +def compute_metric_on( + source: Literal["gen", "target"], metric: AreaWeightedSingleTargetFunction +) -> AreaWeightedFunction: + """Turns a single-target metric function + (computed on only the generated or target data) into a function that takes in + both the generated and target data as arguments, as required for the APIs + which call generic metric functions. + """ + + def metric_wrapper( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: + if source == "gen": + return metric(predicted, weights=weights, dim=dim) + elif source == "target": + return metric(truth, weights=weights, dim=dim) + + return metric_wrapper + + +def vertical_integral( + integrand: torch.Tensor, + surface_pressure: torch.Tensor, + sigma_grid_offsets_ak: torch.Tensor, + sigma_grid_offsets_bk: torch.Tensor, +) -> torch.Tensor: + """Computes a vertical integral, namely: + + (1 / g) * ∫ x dp + + where + - g = acceleration due to gravity + - x = integrad + - p = pressure level + + Args: + integrand (lat, lon, vertical_level), (kg/kg) + surface_pressure: (lat, lon), (Pa) + sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) + sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) + + Returns: + Vertical integral of the integrand (lat, lon). + """ + ak, bk = sigma_grid_offsets_ak, sigma_grid_offsets_bk + if ak.device != integrand.device or ak.device != surface_pressure.device: + raise ValueError( + f"sigma_grid_offsets_ak.device ({ak.device}), " + f"sigma_grid_offsets_bk.device ({bk.device}), " + f"integrand.device ({integrand.device}), " + f"surface_pressure.device ({surface_pressure.device}) must be the same." + ) + pressure_thickness = ((ak + (surface_pressure.unsqueeze(-1) * bk))).diff(dim=-1) # Pa + integral = torch.sum(pressure_thickness * integrand, axis=-1) # type: ignore + return 1 / GRAVITY * integral + + +def surface_pressure_due_to_dry_air( + specific_total_water: torch.Tensor, + surface_pressure: torch.Tensor, + sigma_grid_offsets_ak: torch.Tensor, + sigma_grid_offsets_bk: torch.Tensor, +) -> torch.Tensor: + """Computes the dry air (Pa). + + Args: + specific_total_water (lat, lon, vertical_level), (kg/kg) + surface_pressure: (lat, lon), (Pa) + sigma_grid_offsets_ak: Sorted sigma grid offsets ak, (vertical_level + 1,) + sigma_grid_offsets_bk: Sorted sigma grid offsets bk, (vertical_level + 1,) + + Returns: + Vertically integrated dry air (lat, lon) (Pa) + """ + + num_levels = len(sigma_grid_offsets_ak) - 1 + + if num_levels != len(sigma_grid_offsets_bk) - 1 or num_levels != specific_total_water.shape[-1]: + raise ValueError(("Number of vertical levels in ak, bk, and specific_total_water must" "be the same.")) + + total_water_path = vertical_integral( + specific_total_water, + surface_pressure, + sigma_grid_offsets_ak, + sigma_grid_offsets_bk, + ) + dry_air = surface_pressure - GRAVITY * total_water_path + return dry_air diff --git a/src/evaluation/reduced_metrics.py b/src/evaluation/reduced_metrics.py new file mode 100644 index 0000000..e6d9346 --- /dev/null +++ b/src/evaluation/reduced_metrics.py @@ -0,0 +1,122 @@ +""" +This file contains code for computing metrics of single variables on batches of data, +and aggregating them into a single metric value. The functions here mainly exist +to turn metric functions that may have different APIs into a common API, +so that they can be iterated over and called in the same way in a loop. +""" + +from typing import Literal, Optional, Protocol + +import torch + +from src.evaluation.metrics import Dimension + + +class ReducedMetric(Protocol): + """Used to record a metric value on batches of data (potentially out-of-memory) + and then get the total metric at the end. + """ + + def record(self, target: torch.Tensor, gen: torch.Tensor): + """ + Update metric for a batch of data. + """ + ... + + def get(self) -> torch.Tensor: + """ + Get the total metric value, not divided by number of recorded batches. + """ + ... + + +class AreaWeightedFunction(Protocol): + """ + A function that computes a metric on the true and predicted values, + weighted by area. + """ + + def __call__( + self, + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +class AreaWeightedSingleTargetFunction(Protocol): + """ + A function that computes a metric on a single value, weighted by area. + """ + + def __call__( + self, + tensor: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: ... + + +def compute_metric_on( + source: Literal["preds", "targets"], metric: AreaWeightedSingleTargetFunction +) -> AreaWeightedFunction: + """Turns a single-target metric function + (computed on only the generated or target data) into a function that takes in + both the generated and target data as arguments, as required for the APIs + which call generic metric functions. + """ + + def metric_wrapper( + truth: torch.Tensor, + predicted: torch.Tensor, + weights: Optional[torch.Tensor] = None, + dim: Dimension = (), + ) -> torch.Tensor: + if source == "preds": + return metric(predicted, weights=weights, dim=dim) + elif source == "targets": + return metric(truth, weights=weights, dim=dim) + + return metric_wrapper + + +class AreaWeightedReducedMetric: + """ + A wrapper around an area-weighted metric function. + """ + + def __init__( + self, + area_weights: Optional[torch.Tensor], + device: torch.device, + compute_metric: AreaWeightedFunction, + dim: Dimension = (-2, -1), + ): + self._area_weights = area_weights.to(device) if area_weights is not None else None + self._compute_metric = compute_metric + self._total = None + self._device = device + self._dim = dim + + def record(self, targets: torch.Tensor, preds: torch.Tensor, batch_dim: int = 0, **kwargs): + """Add a batch of data to the metric. + + Args: + targets: Target data. Should have shape [batch, time, height, width]. + preds: Generated data. Should have shape [batch, time, height, width]. + batch_dim: The dimension of the batch axis over which to average the metric. + """ + # dim=(-2, -1) means average over the two spatial dimensions + # dim=batch_dim works usually too, but some data may have other non-spatial dimensions + new_value = self._compute_metric( + truth=targets, predicted=preds, weights=self._area_weights, dim=self._dim, **kwargs + ).mean(dim=None) + # assert new_value.dim() == 0, f"Expected scalar value, got {new_value}" + if self._total is None: + self._total = torch.zeros_like(new_value, device=targets.device) + self._total += new_value + + def get(self) -> torch.Tensor: + """Returns the metric.""" + return self._total diff --git a/src/experiment_types/__init__.py b/src/experiment_types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/experiment_types/_base_experiment.py b/src/experiment_types/_base_experiment.py new file mode 100644 index 0000000..458ca22 --- /dev/null +++ b/src/experiment_types/_base_experiment.py @@ -0,0 +1,1275 @@ +from __future__ import annotations + +import inspect +import logging +import re +import time +from collections import defaultdict +from contextlib import contextmanager +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Union + +import hydra +import numpy as np +import torch +import wandb +from omegaconf import DictConfig +from pytorch_lightning import LightningModule +from tensordict import TensorDict, TensorDictBase +from torch import Tensor +from torch.optim.lr_scheduler import LambdaLR + +from src.datamodules._dataset_dimensions import get_dims_of_dataset +from src.datamodules.abstract_datamodule import BaseDataModule +from src.models._base_model import BaseModel +from src.models.modules.ema import LitEma +from src.utilities.lr_scheduler import get_scheduler +from src.utilities.utils import ( + AlreadyLoggedError, + concatenate_array_dicts, + get_logger, + print_gpu_memory_usage, + raise_error_if_invalid_value, + rrearrange, + to_DictConfig, + to_tensordict, + torch_to_numpy, +) + + +class BaseExperiment(LightningModule): + r"""This is a template base class, that should be inherited by any stand-alone ML model. + Methods that need to be implemented by your concrete ML model (just as if you would define a :class:`torch.nn.Module`): + - :func:`__init__` + - :func:`forward` + + The other methods may be overridden as needed. + It is recommended to define the attribute + >>> self.example_input_array = torch.randn() # batch dimension can be anything, e.g. 7 + + + .. note:: + Please use the function :func:`predict` at inference time for a given input tensor, as it postprocesses the + raw predictions from the function :func:`raw_predict` (or model.forward or model())! + + Args: + optimizer: DictConfig with the optimizer configuration (e.g. for AdamW) + scheduler: DictConfig with the scheduler configuration (e.g. for CosineAnnealingLR) + monitor (str): The name of the metric to monitor, e.g. 'val/mse' + mode (str): The mode of the monitor. Default: 'min' (lower is better) + use_ema (bool): Whether to use an exponential moving average (EMA) of the model weights during inference. + ema_decay (float): The decay of the EMA. Default: 0.9999 (only used if use_ema=True) + enable_inference_dropout (bool): Whether to enable dropout during inference. Default: False + name (str): optional string with a name for the model + num_predictions (int): The number of predictions to make for each input sample + prediction_inputs_noise (float): The amount of noise to add to the inputs before predicting + log_every_step_up_to (int): Logging is performed at every step up to this number. Default: 1000. + After that, logging interval corresponds to the lightning Trainer's log_every_n_steps parameter (default: 50) + verbose (bool): Whether to print/log or not + + Read the docs regarding LightningModule for more information: + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html + """ + + CHANNEL_DIM = -3 # assumes 2 spatial dimensions for everything + + def __init__( + self, + model_config: DictConfig, + datamodule_config: DictConfig, + diffusion_config: Optional[DictConfig] = None, + optimizer: Optional[DictConfig] = None, + scheduler: Optional[DictConfig] = None, + monitor: Optional[str] = None, + mode: str = "min", + use_ema: bool = False, + ema_decay: float = 0.9999, + enable_inference_dropout: bool = False, + reset_optimizer: bool = False, + torch_compile: str = None, + num_predictions: int = 1, + num_predictions_in_memory: int = None, + logging_infix: str = "", + prediction_inputs_noise: float = 0.0, + save_predictions_filename: Optional[str] = None, + save_prediction_batches: int = 0, + log_every_step_up_to: int = 1000, + seed: int = None, + name: str = "", + work_dir: str = "", + verbose: bool = True, + ): + super().__init__() + # The following saves all the args that are passed to the constructor to self.hparams + # e.g. access them with self.hparams.monitor + self.save_hyperparameters(ignore=["model_config", "datamodule_config", "diffusion_config", "verbose"]) + # Get a logger + self.log_text = get_logger(name=self.__class__.__name__ if name == "" else name) + self.name = name + self._datamodule = None + self.verbose = verbose + self.logging_infix = logging_infix + if not self.verbose: # turn off info level logging + self.log_text.setLevel(logging.WARN) + + self.model_config = model_config + self.datamodule_config = datamodule_config + self.diffusion_config = diffusion_config + self.num_predictions = num_predictions + self.num_predictions_in_mem = num_predictions_in_memory or num_predictions + assert self.num_predictions_in_mem <= num_predictions, "num_predictions_in_memory must be <= num_predictions" + self.num_prediction_loops = num_predictions // self.num_predictions_in_mem + self.is_diffusion_model = diffusion_config is not None and diffusion_config.get("_target_", None) is not None + self.dims = get_dims_of_dataset(self.datamodule_config) + self._instantiate_auxiliary_modules() + self.model = self.instantiate_model() + + # Compile torch model if needed + raise_error_if_invalid_value(torch_compile, [False, None, "model", "module"], name="torch_compile") + if torch_compile == "model": + self.log_text.info("Compiling the model (but not the LightningModule)...") + self.model = torch.compile(self.model) + + # Initialize the EMA model, if needed + self.use_ema = use_ema + self.update_ema = use_ema + if self.update_ema: + self.model_ema = LitEma(self.model_handle_for_ema, decay=ema_decay) + self.log_text.info(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + if not self.use_ema: + self.log_text.info("Not using EMA.") + + if self.model is not None: + self.model.ema_scope = self.ema_scope + + if enable_inference_dropout: + self.log_text.info("Enabling dropout during inference!") + + # Timing variables to track the training/epoch/validation time + self._start_validation_epoch_time = self._start_test_epoch_time = self._start_epoch_time = None + self.training_step_outputs = [] + self._validation_step_outputs, self._predict_step_outputs = [], [] + self._test_step_outputs = defaultdict(list) + + # Epoch and global step defaults. When only doing inference, the current_epoch of lightning may be 0, so you can set it manually. + self._default_epoch = self._default_global_step = 0 + + # Check that the args/hparams are valid + self._check_args() + + if self.use_ensemble_predictions("val"): + self.log_text.info(f"Using a {num_predictions}-member ensemble for validation.") + + # Example input array, if set + if hasattr(self.model, "example_input_array"): + self.example_input_array = self.model.example_input_array + + if save_predictions_filename is not None: + assert ( + save_prediction_batches == "all" or save_prediction_batches > 0 + ), "save_prediction_batches must be > 0 if save_predictions_filename is set." + + @property + def model_handle_for_ema(self) -> torch.nn.Module: + """Return the model handle that is used for the EMA. By default, this is the model itself. + But it can be overridden in subclasses, e.g. for GANs, where the EMA is only applied to the generator.""" + return self.model + + @property + def current_epoch(self) -> int: + """The current epoch in the ``Trainer``, or 0 if not attached.""" + if self._trainer and self.trainer.current_epoch != 0: + return self.trainer.current_epoch + return self._default_epoch + + @property + def global_step(self) -> int: + """Total training batches seen across all epochs. + + If no Trainer is attached, this propery is 0. + + """ + if self._trainer and self.trainer.global_step != 0: + return self.trainer.global_step + + return self._default_global_step + + # --------------------------------- Interface with model + def actual_spatial_shapes(self, spatial_shape_in: Tuple[int, int], spatial_shape_out: Tuple[int, int]) -> Tuple: + return spatial_shape_in, spatial_shape_out + + def actual_num_input_channels(self, num_input_channels: int) -> int: + return num_input_channels + + def actual_num_output_channels(self, num_output_channels: int) -> int: + return num_output_channels + + @property + def num_conditional_channels(self) -> int: + """The number of channels that are used for conditioning as auxiliary inputs.""" + nc = self.dims.get("conditional", 0) + if self.is_diffusion_model: + d_class = self.diffusion_config.get("_target_").lower() + is_standard_diffusion = "dyffusion" not in d_class + if is_standard_diffusion: + nc += self.window * self.dims["input"] # we use the data from the past window frames as conditioning + else: + fwd_cond = self.diffusion_config.get("forward_conditioning", "").lower() + if fwd_cond == "": + pass # no forward conditioning, i.e. don't add anything + elif fwd_cond == "data|noise": + nc += 2 * self.window * self.dims["input"] + elif fwd_cond in ["none", None]: + pass + else: + nc += self.window * self.dims["input"] + return nc + + @property + def window(self) -> int: + return self.datamodule_config.get("window", 1) + + @property + def horizon(self) -> int: + return self.datamodule_config.get("horizon", 1) + + @property + def inputs_noise(self): + # internally_probabilistic = isinstance(self.model, (GaussianDiffusion, DDPM)) + # return 0 if internally_probabilistic else self.hparams.prediction_inputs_noise + return self.hparams.prediction_inputs_noise + + @property + def datamodule(self) -> BaseDataModule: + if self._datamodule is None: # alt: set in ``on_fit_start`` method + if self._trainer is None: + return None + self._datamodule = self.trainer.datamodule + # Make sure that normalizer means and stds are on same device as model + if hasattr(self._datamodule, "normalizer"): + self.log_text.info(f"Moving normalizer means and stds to same device as model: device={self.device}") + self._datamodule.normalizer.to(self.device) + return self._datamodule + + def _instantiate_auxiliary_modules(self): + """Instantiate auxiliary modules that need to exist before the model is instantiated. + This is necessary because it is not possible to instantiate modules before calling super().__init__(). + """ + pass + + def extra_model_kwargs(self) -> dict: + """Return extra kwargs for the model instantiation.""" + return {} + + def instantiate_model(self, *args, **kwargs) -> BaseModel: + r"""Instantiate the model, e.g. by calling the constructor of the class :class:`BaseModel` or a subclass thereof.""" + spatial_shape_in, spatial_shape_out = self.actual_spatial_shapes( + self.dims["spatial_in"], self.dims["spatial_out"] + ) + in_channels = self.actual_num_input_channels(self.dims["input"]) + out_channels = self.actual_num_output_channels(self.dims["output"]) + cond_channels = self.num_conditional_channels + assert isinstance(in_channels, (int, dict)), f"Expected int, got {type(in_channels)} for in_channels." + assert isinstance(out_channels, (int, dict)), f"Expected int, got {type(out_channels)} for out_channels." + kwargs["datamodule_config"] = self.datamodule_config + model = hydra.utils.instantiate( + self.model_config, + num_input_channels=in_channels, + num_output_channels=out_channels, + num_output_channels_raw=self.dims["output"], + num_conditional_channels=cond_channels, + spatial_shape_in=spatial_shape_in, + spatial_shape_out=spatial_shape_out, + _recursive_=False, + **kwargs, + **self.extra_model_kwargs(), + ) + self.log_text.info( + f"Instantiated model: {model.__class__.__name__}, with" + f" # input/output/conditional channels: {in_channels}, {out_channels}, {cond_channels}" + ) + if self.is_diffusion_model: + model = hydra.utils.instantiate(self.diffusion_config, model=model, _recursive_=False, **kwargs) + self.log_text.info( + f"Instantiated diffusion model: {model.__class__.__name__}, with" + f" #diffusion steps={model.num_timesteps}" + ) + + return model + + def forward(self, *args, **kwargs) -> Any: + y = self.model(*args, **kwargs) + return y + + # --------------------------------- Names + @property + def short_description(self) -> str: + return self.name if self.name else self.__class__.__name__ + + @property + def WANDB_LAST_SEP(self) -> str: + """Used to separate metrics. Base classes may use an additional prefix, e.g. '/ipol/'""" + return "/" + + @property + def validation_set_names(self) -> List[str]: + if hasattr(self.datamodule, "validation_set_names") and self.datamodule.validation_set_names is not None: + return self.datamodule.validation_set_names + elif hasattr(self, "aggregators_val") and self.aggregators_val is not None: + n_aggs = len(self.aggregators_val) + if n_aggs > 1: + self.log_text.warning( + "Datamodule has no attribute ``validation_set_names``. Using default names ``val_{i}``!" + ) + return [f"val_{i}" for i in range(n_aggs)] + return ["val"] + + @property + def test_set_names(self) -> List[str]: + if self._trainer is None: + return ["???"] + if hasattr(self.datamodule, "test_set_names"): + return self.datamodule.test_set_names + return ["test"] + + @property + def prediction_set_name(self) -> str: + return self.datamodule.prediction_set_name if hasattr(self.datamodule, "prediction_set_name") else "predict" + + # --------------------------------- Metrics + def get_epoch_aggregators(self, split: str, dataloader_idx: int = None) -> dict: + """Return a dictionary of epoch aggregators, i.e. functions that aggregate the metrics over the epoch. + The keys are the names of the metrics, the values are the aggregator functions. + """ + assert split in ["val", "test", "predict"], f"Invalid split {split}" + aggregators = self.datamodule.get_epoch_aggregators( + split=split, + dataloader_idx=dataloader_idx, + is_ensemble=self.use_ensemble_predictions(split), + experiment_type=self.__class__.__name__, + device=self.device, + verbose=self.current_epoch == 0, + ) + return aggregators + + def get_dataset_attribute(self, attribute: str, split: str = "train") -> Any: + """Return the attribute of the dataset.""" + split = "train" if split in ["fit", None] else split + if hasattr(self, f"_dataset_{split}_{attribute}"): + # Return the cached attribute + return getattr(self, f"_dataset_{split}_{attribute}") + + if self.datamodule is None: + raise ValueError("Cannot get dataset attribute if datamodule is None. Please set datamodule first.") + + dl = { + "train": self.datamodule.train_dataloader(), + "val": self.datamodule.val_dataloader(), + "test": self.datamodule.test_dataloader(), + "predict": self.datamodule.predict_dataloader(), + }[split] + if dl is None: + return None + + # Try to get the attribute from the dataset + ds = dl.dataset if isinstance(dl, torch.utils.data.DataLoader) else dl[0].dataset + attr_value = getattr(ds, attribute, getattr(ds, f"_{attribute}", None)) + if attr_value is not None: + # Cache the attribute + setattr(self, f"_dataset_{split}_{attribute}", attr_value) + return attr_value + + # --------------------------------- Check arguments for validity + def _check_args(self): + """Check if the arguments are valid.""" + pass + + @contextmanager + def ema_scope(self, context=None, force_non_ema: bool = False, condition: bool = None): + """Context manager to switch to EMA weights.""" + condition = self.use_ema if condition is None else condition + if condition and not force_non_ema: + self.model_ema.store(self.model_handle_for_ema.parameters()) + self.model_ema.copy_to(self.model_handle_for_ema) + if context is not None: + self.log_text.info(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if condition and not force_non_ema: + self.model_ema.restore(self.model_handle_for_ema.parameters()) + if context is not None: + self.log_text.info(f"{context}: Restored training weights") + + @contextmanager + def inference_dropout_scope(self, condition: bool = None, context=None): + """Context manager to switch to inference dropout mode. + Args: + condition (bool, optional): If True, switch to inference dropout mode. If False, switch to training mode. + If None, use the value of self.hparams.enable_inference_dropout. + Important: If not None, self.hparams.enable_inference_dropout is ignored! + context (str, optional): If not None, print this string when switching to inference dropout mode. + """ + condition = self.hparams.enable_inference_dropout if condition is None else condition + if condition: + self.model.enable_inference_dropout() + if context is not None: + self.log_text.info(f"{context}: Switched to enabled inference dropout") + try: + yield None + finally: + if condition: + self.model.disable_inference_dropout() + if context is not None: + self.log_text.info(f"{context}: Switched to disabled inference dropout") + + @contextmanager + def timing_scope(self, context="", no_op=True, precision=2): + """Context manager to measure the time of the code inside the context. (By default, does nothing.) + Args: + context (str, optional): If not None, print time elapsed in this context. + """ + start_time = time.time() if not no_op else None + try: + yield None + finally: + if not no_op: + context = f"``{context}``:" if context else "" + self.log_text.info(f"Elapsed time {context} {time.time() - start_time:.{precision}f}s") + + def normalize_data(self, x: Dict[str, Tensor]) -> TensorDict: + """Normalize the data.""" + # to_tensordict(x) is no-op if x is a tensor + if hasattr(self.datamodule, "normalizer"): + x = self.datamodule.normalizer.normalize(x) + return to_tensordict(x) + + def normalize_batch( + self, batch: Dict[str, Dict[str, Tensor]] | Dict[str, Tensor] | Tensor + ) -> Dict[str, TensorDict] | TensorDict: + """Normalize the batch. If the batch is a nested dictionary, normalize each nested dictionary separately.""" + if torch.is_tensor(batch) or isinstance(next(iter(batch.values())), Tensor): + return self.normalize_data(batch) + elif isinstance(batch, TensorDict): + return TensorDict({k: self.normalize_data(v) for k, v in batch.items()}, batch_size=batch.batch_size) + else: + return {k: self.normalize_data(v) for k, v in batch.items()} + + def denormalize_data(self, x: Dict[str, Tensor]) -> TensorDict: + """Denormalize the data.""" + if hasattr(self.datamodule, "normalizer"): + x = self.datamodule.normalizer.denormalize(x) + return to_tensordict(x) + + def denormalize_batch( + self, x: Dict[str, Dict[str, Tensor]] | Dict[str, Tensor] + ) -> Dict[str, TensorDict] | TensorDict: + if torch.is_tensor(x) or isinstance(next(iter(x.values())), Tensor): + return self.denormalize_data(x) + elif isinstance(x, TensorDict): + return TensorDict({k: self.denormalize_data(v) for k, v in x.items()}, batch_size=x.batch_size) + else: + return {k: self.denormalize_data(v) for k, v in x.items()} + + def predict_packed(self, *inputs: Tensor, **kwargs) -> Dict[str, Tensor]: + # check if model has sample_loop method with argument num_predictions + if ( + hasattr(self.model, "sample_loop") + and "num_predictions" in inspect.signature(self.model.sample_loop).parameters + ): + kwargs["num_predictions"] = self.num_predictions_in_mem + + results = self.model.predict_forward(*inputs, **kwargs) # by default, just call the forward method + if torch.is_tensor(results): + results = {"preds": results} + + return results + + def _predict( + self, + *inputs: Tensor, + num_predictions: Optional[int] = None, + predictions_mask: Optional[Tensor] = None, + **kwargs, + ) -> Dict[str, Tensor]: + """ + This should be the main method to use for making predictions/doing inference. + + Args: + inputs (Tensor): Input data tensor of shape :math:`(B, *, C_{in})`. + This is the same tensor one would use in :func:`forward`. + num_predictions (int, optional): Number of predictions to make. If None, use the default value. + **kwargs: Additional keyword arguments + + Returns: + Dict[str, Tensor]: The model predictions (in a post-processed format), i.e. a dictionary output_var -> output_var_prediction, + where each output_var_prediction is a Tensor of shape :math:`(B, *)` in original-scale (e.g. + in Kelvin for temperature), and non-negativity has been enforced for variables such as precipitation. + + Shapes: + - Input: :math:`(B, *, C_{in})` + - Output: Dict :math:`k_i` -> :math:`v_i`, and each :math:`v_i` has shape :math:`(B, *)` for :math:`i=1,..,C_{out}`, + + where :math:`B` is the batch size, :math:`*` is the spatial dimension(s) of the data, + and :math:`C_{out}` is the number of output features. + """ + base_num_predictions = self.num_predictions + self.num_predictions = num_predictions or base_num_predictions + + # break up inputs and kwargs into batches of size self.num_predictions_in_mem + def split_batch(x, start, end): + if isinstance(x, (Tensor, TensorDict)): + return x[start:end] + return x + + results = defaultdict(list) + # By default, we predict the entire batch at once (i.e. num_prediction_loops=1) + full_batch_size = inputs[0].shape[0] if len(inputs) > 0 else kwargs[list(kwargs.keys())[0]].shape[0] + actual_batch_size = full_batch_size // self.num_predictions # base_num_predictions + assert actual_batch_size > 0, f"{actual_batch_size=}, {full_batch_size=}, {self.num_predictions=}" + inputs_offset_factor = self.num_predictions_in_mem * actual_batch_size + for i in range(self.num_prediction_loops): + start_i, end_i = i * inputs_offset_factor, (i + 1) * inputs_offset_factor + inputs_i = [split_batch(x, start_i, end_i) for x in inputs] + kwargs_i = {k: split_batch(v, start_i, end_i) for k, v in kwargs.items()} + results_i = self.predict_packed(*inputs_i, **kwargs_i) + # log.info(f"results_i: {results_i.keys()}, {results_i['preds'].shape}, inputs_i: {inputs_i[0].shape}") + if predictions_mask is not None: + results_i = {k: v[..., predictions_mask[0, :]] for k, v in results_i.items()} + for k, v in results_i.items(): + results[k].append(v) + # log.info({k: torch.cat(v, dim=0) for k, v in results.items()}["preds2d"].shape) + results = {k: torch.cat(v, dim=0) for k, v in results.items()} + # results = TensorDict(results, batch_size=(full_batch_size,)) + # results = to_tensordict({k: torch.cat(v, dim=0) for k, v in results.items()}, find_batch_size_max=True) + # log.info(results["preds2d"].shape, "after cat") + self.num_predictions = base_num_predictions + results = self.postprocess_predictions(results) + return results + + def postprocess_predictions(self, results: Dict[str, Tensor]) -> Dict[str, Tensor]: + results = self.reshape_predictions(results) + # log.info(results["preds2d"].shape, "after reshape") + results = self.unpack_predictions(results) + for k in list(results.keys()): + if "preds" in k: # Rename the keys from to _normed + results[f"{k}_normed"] = results.pop(k) # unpacked and normalized + + # results['preds_packed'] = packed_preds # packed and normalized + if self.datamodule is not None: + # Unpack and denormalize the predictions. Keys are renamed from _normed to + results.update( + {k.replace("_normed", ""): self.denormalize_batch(v) for k, v in results.items() if "preds" in k} + ) + # for k, v in results.items(): + # print(k, v.shape if torch.is_tensor(v) else v) + return results + + def predict(self, inputs: Union[Tensor, TensorDictBase], **kwargs) -> Dict[str, Tensor]: + """Wrapper around the main predict method, to allow inputs to be a TensorDictBase or a Tensor.""" + if torch.is_tensor(inputs): + return self._predict(inputs, **kwargs) + else: + return self._predict(**inputs, **kwargs) + + def reshape_predictions(self, results: TensorDict) -> TensorDict: + """Reshape and unpack the predictions from the model. This modifies the input dictionary in-place. + Args: + results (Dict[str, Tensor]): The model outputs. Access the predictions via results['preds']. + """ + pred_keys = [k for k in results.keys() if "preds" in k] + preds_shape = results[pred_keys[0]].shape + if preds_shape[0] > 1: + if self.num_predictions > 1 and preds_shape[0] % self.num_predictions == 0: + for k in pred_keys: + results[k] = self._reshape_ensemble_preds(results[k]) + # results = self._reshape_ensemble_preds(results) + return results + + def pack_data(self, data: Dict[str, Tensor], input_or_output: str) -> Tensor: + """Pack the data into a single tensor.""" + if input_or_output == "input": + packer_name = "in_packer" + elif input_or_output == "output": + packer_name = "out_packer" + else: + raise ValueError(f"Unknown input_or_output: {input_or_output}") + if not hasattr(self.datamodule, packer_name): + return torch.tensor(data) if not torch.is_tensor(data) else data + + packer = getattr(self.datamodule, packer_name) + return packer.pack(data) + + def unpack_data( + self, results: Dict[str, Tensor], input_or_output: str, axis=None, func="unpack" + ) -> Dict[str, Tensor]: + """Unpack the predictions from the model. This modifies the input dictionary in-place. + Args: + results (Dict[str, Tensor]): The model outputs. Access the predictions via results['preds']. + input_or_output (str): Whether to unpack the input or output data. + axis (int, optional): The axis along which to unpack the data. Default: None (use the default axis). + """ + # As of now, only keys with ``preds`` in them are unpacked. + if input_or_output == "input": + packer_name = "in_packer" + elif input_or_output == "output": + packer_name = "out_packer" + else: + raise ValueError(f"Unknown input_or_output: {input_or_output}") + if not hasattr(self.datamodule, packer_name): + return results + + packer = getattr(self.datamodule, packer_name) + packer_func = getattr(packer, func) # basically packer.unpack + if torch.is_tensor(results): + results = packer_func(results, axis=axis) + elif "preds" in results.keys(): + results = {**results, "preds": packer_func(results.pop("preds"), axis=axis)} + # results["preds"] = packer.unpack(results["preds"], axis=axis) + elif hasattr(packer, "packer_names") and packer.packer_names == set( + packer.k_to_base_key(k) for k in results.keys() + ): + results = packer_func(results, axis=axis) + else: + for k, v in results.items(): + if "preds" in k: + packer_k = packer[k.replace("preds", "")] if isinstance(packer, dict) else packer + results[k] = packer_k.unpack(v, axis=axis) + else: + raise ValueError(f"Unknown key {k} in results for unpacking.") + return results + + def unpack_predictions(self, results: Dict[str, Tensor], axis=None, **kwargs) -> Dict[str, Tensor]: + return self.unpack_data(results, input_or_output="output", axis=axis, **kwargs) + + def get_target_variants(self, targets: Tensor, is_normalized: bool = False) -> Dict[str, Tensor]: + if is_normalized: + targets_normed = targets + targets_raw = self.denormalize_batch(targets_normed) + else: + targets_raw = targets + targets_normed = self.normalize_batch(targets_raw) + return { + "targets": targets_raw.contiguous(), + "targets_normed": targets_normed.contiguous(), + } + + # --------------------- training with PyTorch Lightning + def on_any_start(self, stage: str = None) -> None: + # Check if model has property ``sigma_data`` and set it to the data's std + if hasattr(self.model, "sigma_data") and getattr(self.model, "_USE_SIGMA_DATA", False): + self.model.sigma_data = self.datamodule.sigma_data + + def on_fit_start(self) -> None: + self.on_any_start(stage="fit") + + def on_validation_start(self) -> None: + self.on_any_start(stage="val") + + def on_test_start(self) -> None: + self.on_any_start(stage="test") + + def on_train_start(self) -> None: + """Log some info about the model/data at the start of training""" + assert "/" in self.WANDB_LAST_SEP, f'Please use a separator that contains a "/" in {self.WANDB_LAST_SEP}' + # Find size of the validation set(s) + ds_val = self.datamodule.val_dataloader() + val_sizes = [len(dl.dataset) for dl in (ds_val if isinstance(ds_val, list) else [ds_val])] + # Compute the effective batch size + # bs * acc * n_gpus + bs = self.datamodule.train_dataloader().batch_size + acc = self.trainer.accumulate_grad_batches + n_gpus = max(1, self.trainer.num_devices) + n_nodes = max(1, self.trainer.num_nodes) + eff_bs = bs * acc * n_gpus * n_nodes + # compute number of steps per epoch + n_steps_per_epoch = len(self.datamodule.train_dataloader()) + n_steps_per_epoch_per_gpu = n_steps_per_epoch / n_gpus + to_log = { + "Parameter count": float(self.model.num_params), + "Training set size": float(len(self.datamodule.train_dataloader().dataset)), + "Validation set size": float(sum(val_sizes)), + "Effective batch size": float(eff_bs), + "Steps per epoch": float(n_steps_per_epoch), + "Steps per epoch per GPU": float(n_steps_per_epoch_per_gpu), + "n_gpus": n_gpus, + "TESTED": False, + } + self.log_dict(to_log, on_step=False, on_epoch=True, prog_bar=False, logger=True) + # provide access to trainer to the model + self.model.trainer = self.trainer + self._n_steps_per_epoch = n_steps_per_epoch + self._n_steps_per_epoch_per_gpu = n_steps_per_epoch_per_gpu + if self.global_step <= self.hparams.log_every_step_up_to: + self._original_log_every_n_steps = self.trainer.log_every_n_steps + self.trainer.log_every_n_steps = 1 + + # Print the world size, rank, and local rank + if self.trainer.world_size > 1: + self.log_text.info( + f"World size: {self.trainer.world_size}, Rank: {self.trainer.global_rank}, Local rank: {self.trainer.local_rank}" + ) + + def on_train_epoch_start(self) -> None: + self._start_epoch_time = time.time() + + def train_step_initial_log_dict(self) -> dict: + return dict() + + @property + def main_data_keys(self) -> List[str]: + return ["dynamics"] + + @property + def main_data_keys_val(self) -> List[str]: + return self.main_data_keys + + @property + def normalize_data_keys_val(self) -> List[str]: + return self.main_data_keys_val # by default, normalize all the main data keys + + @property + def inputs_data_key(self) -> str: + return self.main_data_keys_val[0] + + def get_loss(self, batch: Any) -> Tensor: + r"""Compute the loss for the given batch""" + raise NotImplementedError(f"Please implement the get_loss method for {self.__class__.__name__}") + + def training_step(self, batch: Any, batch_idx: int): + r"""One step of training (backpropagation is done on the loss returned at the end of this function)""" + if self.global_step == self.hparams.log_every_step_up_to: + # Log on rank 0 only + if self.trainer.global_rank == 0: + self.log_text.info(f"Logging every {self._original_log_every_n_steps} steps from now on") + self.trainer.log_every_n_steps = self._original_log_every_n_steps + + time_start = time.time() + train_log_dict = self.train_step_initial_log_dict() + + for main_data_key in self.main_data_keys: + if isinstance(batch[main_data_key], dict): + batch[main_data_key] = {k: to_tensordict(v) for k, v in batch[main_data_key].items()} + batch[main_data_key] = to_tensordict(batch[main_data_key], find_batch_size_max=True) + else: + batch[main_data_key] = to_tensordict(batch[main_data_key]) + batch[main_data_key] = self.normalize_batch(batch[main_data_key]) + + # Compute main loss + loss_output = self.get_loss(batch) # either a scalar or a dict with key 'loss' + if isinstance(loss_output, dict): + self.log_dict( + {k: float(v) for k, v in loss_output.items()}, prog_bar=True, logger=True, on_step=True, on_epoch=True + ) + loss = loss_output.pop("loss") + # train_log_dict.update(loss_output) + else: + loss = loss_output + # Train logs (where on_step=True) will be logged at all steps defined by trainer.log_every_n_steps + self.log("train/loss", float(loss), on_step=True, on_epoch=True, prog_bar=True, logger=True) + + # Count number of zero gradients as diagnostic tool + train_log_dict["n_zero_gradients"] = ( + sum([int(torch.count_nonzero(p.grad == 0)) for p in self.model.get_parameters() if p.grad is not None]) + / self.model.num_params + ) + train_log_dict["time/train/step"] = time.time() - time_start + # train_log_dict["time/train/step_ratio"] = time_per_step / self.trainer.accumulate_grad_batches + + self.log_dict(train_log_dict, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return loss # {"loss": loss} + + def on_train_batch_end(self, *args, **kwargs): + if self.update_ema: + self.model_ema(self.model_handle_for_ema) # update the model EMA + + def on_train_epoch_end(self) -> None: + train_time = time.time() - self._start_epoch_time + self.log_dict({"epoch": float(self.current_epoch), "time/train": train_time}, sync_dist=True) + + # --------------------- evaluation with PyTorch Lightning + def _evaluation_step( + self, + batch: Any, + batch_idx: int, + split: str, + dataloader_idx: int = None, + aggregators: Dict[str, Callable] = None, + **kwargs, + ) -> Dict[str, Tensor]: + """ + One step of evaluation (forward pass, potentially metrics computation, logging, and return of results) + Returns: + results_dict: Dict[str, Tensor], where for each semantically different result, a separate prefix key is used + Then, for each prefix key

, results_dict must contain

_preds and

_targets. + """ + raise NotImplementedError(f"Please implement the _evaluation_step method for {self.__class__.__name__}") + + def evaluation_step(self, batch: Any, batch_idx: int, split: str, **kwargs) -> Dict[str, Tensor]: + # Handle boundary conditions + if "boundary_conditions" in inspect.signature(self._evaluation_step).parameters.keys(): + kwargs["boundary_conditions"] = self.datamodule.boundary_conditions + kwargs.update(self.datamodule.get_boundary_condition_kwargs(batch, batch_idx, split)) + + for k in self.main_data_keys_val: + if isinstance(batch[k], dict): + batch[k] = {k: to_tensordict(v) for k, v in batch[k].items()} + batch[k] = to_tensordict(batch[k], find_batch_size_max=True) + else: + batch[k] = to_tensordict(batch[k]) + + for k in self.normalize_data_keys_val: + if k == "dynamics": + # Store the raw data, if needed for post-processing/using ground truth data + batch[f"raw_{k}"] = batch[k].clone() + + # Normalize data + batch[k] = self.normalize_batch(batch[k]) + + with self.ema_scope(): # use the EMA parameters for the validation step (if using EMA) + with self.inference_dropout_scope(): # Enable dropout during inference + results = self._evaluation_step(batch, batch_idx, split, **kwargs) + + return results + + def get_batch_shape(self, batch: Any) -> Tuple[int, ...]: + """Get the shape of the batch""" + for k in self.main_data_keys + self.main_data_keys_val: + if k in batch.keys(): + if torch.is_tensor(batch[k]): + return batch[k].shape + else: + # add singleton dim for channel + return batch[k].unsqueeze(self.CHANNEL_DIM).shape + raise ValueError(f"Could not find any of the keys {self.main_data_keys=}, {self.main_data_keys_val=}") + + def use_ensemble_predictions(self, split: str) -> bool: + return self.num_predictions > 1 and split in ["val", "test", "predict"] + self.test_set_names + + def use_stacked_ensemble_inputs(self, split: str) -> bool: + return True + + def get_ensemble_inputs( + self, inputs_raw: Optional[Tensor], split: str, add_noise: bool = True, flatten_into_batch_dim: bool = True + ) -> Optional[Tensor]: + """Get the inputs for the ensemble predictions""" + if inputs_raw is None: + return None + elif not self.use_stacked_ensemble_inputs(split): + return inputs_raw # we can sample from the Gaussian distribution directly after the forward pass + elif self.use_ensemble_predictions(split): + # create a batch of inputs for the ensemble predictions + num_predictions = self.num_predictions + if isinstance(inputs_raw, (dict, TensorDictBase)): + inputs = { + k: self.get_ensemble_inputs(v, split, add_noise, flatten_into_batch_dim) + for k, v in inputs_raw.items() + } + if isinstance(inputs_raw, TensorDictBase): + # Transform back to TensorDict + original_bs = inputs_raw.batch_size + inputs = TensorDict(inputs, batch_size=[num_predictions * original_bs[0]] + list(original_bs[1:])) + else: + if isinstance(inputs_raw, Sequence): + inputs = np.array([inputs_raw] * num_predictions) + elif add_noise: + inputs = torch.stack( + [ + inputs_raw + self.inputs_noise * torch.randn_like(inputs_raw) + for _ in range(num_predictions) + ], + dim=0, + ) + else: + inputs = torch.stack([inputs_raw for _ in range(num_predictions)], dim=0) + + if flatten_into_batch_dim: + # flatten num_predictions and batch dimensions + inputs = rrearrange(inputs, "N B ... -> (N B) ...") + else: + inputs = inputs_raw + return inputs + + def _reshape_ensemble_preds(self, results: TensorDict) -> TensorDict: + r""" + Reshape the predictions of an ensemble so that the first dimension is the ensemble dimension, N. + + Args: + results: Model outputs with shape (N * B, ...), where N is the number of ensemble members and B is the batch size. + + Returns: + The reshaped predictions (i.e. each output_var_prediction has shape (N, B, *)). + """ + batch_size = results.shape[0] // self.num_predictions + results = results.reshape(self.num_predictions, batch_size, *results.shape[1:]) + return results + + def _evaluation_get_preds( + self, outputs: List[Any], split: str + ) -> Dict[str, Union[torch.distributions.Normal, np.ndarray]]: + if isinstance(outputs, list) and len(outputs) == 1 and isinstance(outputs[0], list): + outputs = outputs[0] + use_ensemble = self.use_ensemble_predictions(split) + outputs_keys, results = outputs[0].keys(), dict() + for key in outputs_keys: + # print(key, outputs[0][key].keys()) # e.g. t3_preds_normed, ['inputs3d', 'inputs2d'] + batch_axis = 1 if (use_ensemble and "targets" not in key and "true" not in key) else 0 + results[key] = concatenate_array_dicts(outputs, batch_axis, keys=[key])[key] + return results + + def on_validation_epoch_start(self) -> None: + self._start_validation_epoch_time = time.time() + val_loaders = self.datamodule.val_dataloader() + n_val_loaders = len(val_loaders) if isinstance(val_loaders, list) else 1 + self.aggregators_val = [] + for i in range(n_val_loaders): + self.aggregators_val.append(self.get_epoch_aggregators(split="val", dataloader_idx=i)) + + def validation_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None, **kwargs): + kwargs["aggregators"] = self.aggregators_val[dataloader_idx or 0] + results = self.evaluation_step(batch, batch_idx, split="val", dataloader_idx=dataloader_idx, **kwargs) + results = torch_to_numpy(results) + # self._validation_step_outputs.append(results) # uncomment to save all val predictions + return results + + def ensemble_logging_infix(self, split: str) -> str: + """No '/' in front of the infix! But '/' at the end!""" + s = "" if self.logging_infix == "" else f"{self.logging_infix}/".replace("//", "/") + # if self.inputs_noise > 0.0 and split != "val": + # s += f"{self.inputs_noise}eps/" + # s += f"{self.num_predictions}ens_mems{self.WANDB_LAST_SEP}" + s += f"{self.WANDB_LAST_SEP}" + return s + + def on_validation_epoch_end(self) -> None: + # val_outputs = self._evaluation_get_preds(self._validation_step_outputs) + self._validation_step_outputs = [] + val_stats, total_mean_metrics_all = self._on_eval_epoch_end( + "val", + time_start=self._start_validation_epoch_time, + data_split_names=self.validation_set_names, + aggregators=self.aggregators_val, + ) + + # If monitoring is enabled, check that it is one of the monitored metrics + if self.trainer.sanity_checking: + monitors = [self.monitor] + for ckpt_callback in self.trainer.checkpoint_callbacks: + if hasattr(ckpt_callback, "monitor") and ckpt_callback.monitor is not None: + monitors.append(ckpt_callback.monitor) + for monitor in monitors: + assert monitor in val_stats, ( + f"Monitor metric {monitor} not found in {val_stats.keys()}. " + f"\nTotal mean metrics: {total_mean_metrics_all}" + ) + return val_stats + + def _on_eval_epoch_end( + self, + split: str, + time_start: float, + data_split_names: List[str] = None, + aggregators: List[Dict[str, Callable]] = None, + ) -> Tuple[Dict[str, float], List[str]]: + logging_infix = self.ensemble_logging_infix(split=split).rstrip("/") + val_time = time.time() - time_start + split_name = "val" if split == "val" else split + val_stats = { + f"time/{split_name}": val_time, + "num_predictions": self.num_predictions, + "noise_level": self.inputs_noise, + "epoch": float(self.current_epoch), + "global_step": self.global_step, + } + val_media = {"epoch": self.current_epoch, "global_step": self.global_step} + data_split_names = data_split_names or [split] + + total_mean_metrics_all = [] + for prefix, aggregators in zip(data_split_names, aggregators): + label = f"{prefix}/{logging_infix}".rstrip("/") # e.g. "val/5ens_mems" + per_variable_mean_metrics = defaultdict(list) + for agg_name, agg in aggregators.items(): + # if agg.name is None: # does not work when using a listaggregator + # label = f"{label}/{agg_name}" # e.g. "val/5ens_mems/t3" + logs_metrics, logs_media = agg.get_logs(prefix=label, epoch=self.current_epoch) + val_stats.update(logs_metrics) + val_media.update(logs_media) + + if not (agg_name.startswith("t") and len(agg_name) <= 5): # up to t9999 + print(f"Skipping aggregator {agg_name} for mean metrics.") + # Don't use these aggregators for the mean metrics (not temporal) + continue + + # Compute average metrics over all aggregators I + for k, v in logs_metrics.items(): + k_base = k.replace(f"{label}/", "") + k_base = re.sub(r"t\d+/", "", k_base) # remove the /t{t} infix + per_variable_mean_metrics[k_base].append(v) + + # Compute average metrics over all aggregators II + total_mean_metrics = defaultdict(list) + for k, v in per_variable_mean_metrics.items(): + if logging_infix != "": + assert logging_infix not in k, f"Logging infix {logging_infix} found in {k}" + aggs_mean = np.mean(v) + # If there is a "/" separator, remove the variable name into "k_base" stem + # Split k such that variable is dropped e.g. k= global/rmse/z500 and k_base=global/rmse + k_base = "/".join(k.split("/")[:-1]) + val_stats[f"{label}/avg/{k}"] = aggs_mean + total_mean_metrics[f"{label}/avg/{k_base}"].append(aggs_mean) + + # Total mean metrics: ['val/avg/l1', 'val/avg/ssr', 'val/avg/rmse', 'val/avg/bias', 'val/avg/grad_mag_percent_diff', 'val/avg/crps', 'inference/avg/l1', etc...] + # Compute average metrics over all aggregators and variables III + total_mean_metrics = {k: np.mean(v) for k, v in total_mean_metrics.items()} + val_stats.update(total_mean_metrics) + total_mean_metrics_all += list(total_mean_metrics.keys()) + # print(f"Total mean metrics: {total_mean_metrics_all}, 10 values: {dict(list(val_stats.items())[:10])}") + self.log_dict(val_stats, sync_dist=True, prog_bar=False) + # log to experiment + if self.logger is not None and hasattr(self.logger.experiment, "log"): + self.logger.experiment.log(val_media) + return val_stats, total_mean_metrics_all + + def on_test_epoch_start(self) -> None: + self._start_test_epoch_time = time.time() + test_loaders = self.datamodule.test_dataloader() + n_test_loaders = len(test_loaders) if isinstance(test_loaders, list) else 1 + self.aggregators_test = [ + self.get_epoch_aggregators(split="test", dataloader_idx=i) for i in range(n_test_loaders) + ] + test_name = self.test_set_names[0] if len(self.test_set_names) == 1 else "test" + example_metric = f"{test_name}/{self.ensemble_logging_infix(test_name)}avg/crps" + if example_metric in wandb.run.summary.keys(): + raise AlreadyLoggedError(f"Testing for ``{test_name}`` data already done.") + self.log_text.info(f"Starting testing for ``{test_name}`` data.") + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None, **kwargs): + split = self.test_set_names[0 if dataloader_idx is None else dataloader_idx] + agg = self.aggregators_test[0] if dataloader_idx is None else self.aggregators_test[dataloader_idx] + results = self.evaluation_step( + batch, batch_idx, dataloader_idx=dataloader_idx, split=split, aggregators=agg, **kwargs + ) + results = torch_to_numpy(results) + self._test_step_outputs[split].append(results) + return results + + def on_test_epoch_end(self) -> None: + # for test_split in self._test_step_outputs.keys(): + # self._eval_ensemble_predictions(self._test_step_outputs[test_split], split=test_split) + self._test_step_outputs = defaultdict(list) + self._on_eval_epoch_end( + "test", + time_start=self._start_test_epoch_time, + data_split_names=self.test_set_names, + aggregators=self.aggregators_test, + ) + self.log_dict({"TESTED": True}, prog_bar=False, sync_dist=True) + + # ---------------------------------------------------------------------- Inference + def on_predict_start(self) -> None: + self.on_any_start(stage="predict") + pdls = self.trainer.predict_dataloaders + pdls = [pdls] if isinstance(pdls, torch.utils.data.DataLoader) else pdls + for pdl in pdls: + assert pdl.dataset.dataset_id == "predict", f"dataset_id is not 'predict', but {pdl.dataset.dataset_id}" + + n_preds = self.num_predictions + if n_preds > 1: + self.log_text.info(f"Generating {n_preds} predictions per input with noise level {self.inputs_noise}") + + def on_predict_epoch_start(self) -> None: + if self.inputs_noise > 0: + self.log_text.info(f"Adding noise to inputs with level {self.inputs_noise}") + + def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None, **kwargs): + """Anything returned here, will be returned when calling trainer.predict(model, datamodule).""" + results = dict() + if ( + self.hparams.save_prediction_batches is None + or self.hparams.save_prediction_batches == "all" + or batch_idx < self.hparams.save_prediction_batches + ): + results = self.evaluation_step(batch, batch_idx, split="predict", **kwargs) + results = torch_to_numpy(results) # self._reshape_ensemble_preds(results, split='predict') + # print(f"batch_idx={batch_idx}", results.keys(), type(results), type(results[list(results.keys())[0]])) # where + self._predict_step_outputs.append(results) + + return results + + def on_predict_epoch_end(self): + numpy_results = self._evaluation_get_preds(self._predict_step_outputs, split="predict") + # for k, v in numpy_results.items(): print(k, v.shape) + self._predict_step_outputs = [] + return numpy_results + + # ---------------------------------------------------------------------- Optimizers and scheduler(s) + def _get_optim(self, optim_name: str, model_handle=None, **kwargs): + """ + Method that returns the torch.optim optimizer object. + May be overridden in subclasses to provide custom optimizers. + """ + if optim_name.lower() == "fusedadam": + try: + from apex import optimizers + except ImportError as e: + raise ImportError( + "To use FusedAdam, please install apex. Alternatively, use normal AdamW with ``module.optimizer.name=adamw``" + ) from e + + optimizer = optimizers.FusedAdam # set adam_w_mode=False for Adam (by default: True => AdamW) + elif optim_name.lower() == "adamw": + optimizer = torch.optim.AdamW + elif optim_name.lower() == "adam": + optimizer = torch.optim.Adam + else: + raise ValueError(f"Unknown optimizer type: {optim_name}") + self.log_text.info(f"{optim_name} optim with kwargs: " + str(kwargs)) + model_handle = self if model_handle is None else model_handle + # return optimizer(filter(lambda p: p.requires_grad, model_handle.parameters()), **kwargs) + # Separate parameters that shouldn't be optimized with weight decay + decay = [] + no_decay = [] + kwargs_no_decay = kwargs.copy() + kwargs_no_decay["weight_decay"] = 0 + no_decay_params = {"channel_embed", "pos_embed"} + if hasattr(self.model, "no_weight_decay"): + no_decay_params = no_decay_params.union(set(self.model.no_weight_decay())) + if hasattr(self.model, "model") and hasattr(self.model.model, "no_weight_decay"): + no_decay_params = no_decay_params.union(set(self.model.model.no_weight_decay())) + + no_grad_params = 0 + for name, m in model_handle.named_parameters(): + if not m.requires_grad: # Only use parameters that require gradients + no_grad_params += 1 + continue + elif any(nd in name for nd in no_decay_params): + no_decay.append(m) + else: + decay.append(m) + + total_params_count = len(list(model_handle.parameters())) + if no_grad_params == 0: + self.log_text.info(f"Found {total_params_count} parameters.") + else: + self.log_text.info( + f"Found {total_params_count} parameters, of which {no_grad_params} do not require gradients." + ) + + optim = optimizer( + [ + {"params": decay, **kwargs}, + {"params": no_decay, **kwargs_no_decay}, + ] + ) + return optim + + def configure_optimizers(self): + """Configure optimizers and schedulers""" + if "name" not in to_DictConfig(self.hparams.optimizer).keys(): + self.log_text.info("No optimizer was specified, defaulting to AdamW.") + self.hparams.optimizer.name = "adamw" + + optim_kwargs = {k: v for k, v in self.hparams.optimizer.items() if k not in ["name", "_target_"]} + optimizer = self._get_optim(self.hparams.optimizer.name, **optim_kwargs) + + # Build the scheduler + if self.hparams.scheduler is None: + return optimizer # no scheduler + else: + scheduler_params = to_DictConfig(self.hparams.scheduler) + if "_target_" not in scheduler_params.keys() and "name" not in scheduler_params.keys(): + raise ValueError(f"Please provide a _target_ or ``name`` for module.scheduler={scheduler_params}!") + interval = scheduler_params.pop("interval", "step") + scheduler_target = scheduler_params.get("_target_") + if ( + scheduler_target is not None + and "torch.optim" not in scheduler_target + and ".lr_scheduler." not in scheduler_target + ): + # custom LambdaLR scheduler + scheduler = hydra.utils.instantiate(scheduler_params) + scheduler = { + "scheduler": LambdaLR(optimizer, lr_lambda=scheduler.schedule), + "interval": interval, + "frequency": 1, + } + else: + # To support interval=step, we need to multiply the number of epochs by the number of steps per epoch + if interval == "step": + n_steps_per_machine = len(self.datamodule.train_dataloader()) + + n_steps = int( + n_steps_per_machine + / (self.trainer.num_devices * self.trainer.num_nodes * self.trainer.accumulate_grad_batches) + ) + multiply_ep_keys = ["warmup_epochs", "max_epochs", "T_max"] + for key in multiply_ep_keys: + if key in scheduler_params: + scheduler_params[key] *= n_steps + + if "warmup_epochs" in scheduler_params: + scheduler_params["warmup_steps"] = scheduler_params.pop("warmup_epochs") + if "max_epochs" in scheduler_params: + scheduler_params["max_steps"] = scheduler_params.pop("max_epochs") + # Instantiate scheduler + if scheduler_target is not None: + scheduler = hydra.utils.instantiate(scheduler_params, optimizer=optimizer) + else: + assert scheduler_params.get("name") is not None, "Please provide a name for the scheduler." + scheduler = get_scheduler(optimizer, **scheduler_params) + scheduler = {"scheduler": scheduler, "interval": interval, "frequency": 1} + + if self.hparams.monitor is None: + self.log_text.info(f"No ``monitor`` was specified, defaulting to {self.default_monitor_metric}.") + if not hasattr(self.hparams, "mode") or self.hparams.mode is None: + self.hparams.mode = "min" + + if isinstance(scheduler, dict): + lr_dict = {**scheduler, "monitor": self.monitor} # , 'mode': self.hparams.mode} + else: + lr_dict = {"scheduler": scheduler, "monitor": self.monitor} # , 'mode': self.hparams.mode} + return {"optimizer": optimizer, "lr_scheduler": lr_dict} + + @property + def monitor(self): + return self.hparams.monitor + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if not self.use_ema: + # Remove the model EMA parameters from the state_dict (since unwanted here) + state_dict = {k: v for k, v in state_dict.items() if "model_ema" not in k} + if self.hparams.reset_optimizer: + strict = False # Allow loading of partial state_dicts (e.g. fine-tune new layers) + return super().load_state_dict(state_dict, strict=strict) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Save a model checkpoint with extra info""" + # Save wandb run info, if available + if self.logger is not None and hasattr(self.logger, "experiment") and hasattr(self.logger.experiment, "id"): + checkpoint["wandb"] = { + k: getattr(self.logger.experiment, k) for k in ["id", "name", "group", "project", "entity"] + } + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + """Log the epoch and global step of the loaded checkpoint.""" + if "epoch" in checkpoint.keys(): + self.log_text.info(f"Checkpoint epoch={checkpoint['epoch']}; global_step={checkpoint['global_step']}.") + if self.hparams.reset_optimizer: + self.log_text.warning("Resetting optimizer states.") + checkpoint["optimizer_states"] = [] + checkpoint["lr_schedulers"] = [] + + # Monitor GPU Usage + def print_gpu_memory_usage( + self, + prefix: str = "", + tqdm_bar=None, + add_description: bool = True, + keep_old: bool = False, + empty_cache: bool = False, + ): + """Use this function to print the GPU memory usage (logged or in a tqdm bar). + Use this to narrow down memory leaks, by printing the GPU memory usage before and after a function call + and checking if the available memory is the same or not. + Recommended to use with 'empty_cache=True' to get the most accurate results during debugging. + """ + print_gpu_memory_usage(prefix, tqdm_bar, add_description, keep_old, empty_cache, log_func=self.log_text.info) diff --git a/src/experiment_types/forecasting_multi_horizon.py b/src/experiment_types/forecasting_multi_horizon.py new file mode 100644 index 0000000..5a2b432 --- /dev/null +++ b/src/experiment_types/forecasting_multi_horizon.py @@ -0,0 +1,680 @@ +from __future__ import annotations + +import inspect +import math +from abc import ABC +from functools import partial +from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Tuple, Type + +import numpy as np +import torch +from tensordict import TensorDict +from torch import Tensor +from tqdm.auto import tqdm + +from src.diffusion.dyffusion import BaseDYffusion +from src.experiment_types._base_experiment import BaseExperiment +from src.utilities.utils import ( + multiply_by_scalar, + rrearrange, + split3d_and_merge_variables, + torch_select, + torch_to_numpy, +) + + +class AbstractMultiHorizonForecastingExperiment(BaseExperiment, ABC): + PASS_METADATA_TO_MODEL = True + + def __init__( + self, + autoregressive_steps: int = 0, + prediction_timesteps: Optional[Sequence[float]] = None, + empty_cache_at_autoregressive_step: bool = False, + inference_val_every_n_epochs: int = 1, + return_outputs_at_evaluation: str | bool = "auto", + stack_window_to_channel_dim=True, + **kwargs, + ): + assert autoregressive_steps >= 0, f"Autoregressive steps must be >= 0, but is {autoregressive_steps}" + assert autoregressive_steps == 0, "Autoregressive steps are not yet supported for this experiment type." + self.stack_window_to_channel_dim = stack_window_to_channel_dim + super().__init__(**kwargs) + # The following saves all the args that are passed to the constructor to self.hparams + # e.g. access them with self.hparams.autoregressive_steps + self.save_hyperparameters(ignore=["model"]) + self.USE_TIME_AS_EXTRA_INPUT = False + self._prediction_timesteps = prediction_timesteps + self.hparams.pop("prediction_timesteps", None) + if prediction_timesteps is not None: + self.log_text.info(f"Using prediction timesteps {prediction_timesteps}") + + val_time_range = self.valid_time_range_for_backbone_model + if hasattr(self.model, "set_min_max_time"): + self.model.set_min_max_time(min_time=val_time_range[0], max_time=val_time_range[-1]) + elif hasattr(self.model, "model") and hasattr(self.model.model, "set_min_max_time"): + # For diffusion models + self.model.model.set_min_max_time(min_time=val_time_range[0], max_time=val_time_range[-1]) + + @property + def horizon_range(self) -> List[int]: + return list(np.arange(1, self.horizon + 1)) + + @property + def valid_time_range_for_backbone_model(self) -> List[int]: + return self.horizon_range + + @property + def true_horizon(self) -> int: + return self.horizon + + @property + def horizon_name(self) -> str: + s = f"{self.true_horizon}h" + return s + + @property + def prediction_timesteps(self) -> List[float]: + """By default, we predict the timesteps in the horizon range (i.e. at data resolution)""" + return self._prediction_timesteps or self.horizon_range + + @prediction_timesteps.setter + def prediction_timesteps(self, value: List[float]): + assert max(value) <= self.horizon_range[-1], f"Prediction range {value} exceeds {self.horizon_range=}" + self._prediction_timesteps = value + + def num_autoregressive_steps_for_horizon(self, horizon: int) -> int: + return max(1, math.ceil(horizon / self.true_horizon)) - 1 + + @property + def short_description(self) -> str: + name = super().short_description + name += f" (h={self.horizon_name})" + return name + + def actual_num_input_channels(self, num_input_channels: int) -> int: + # if we use the inputs as conditioning, and use an output-shaped input (e.g. for DDPM), + # we need to use the output channels here! + is_standard_diffusion = self.is_diffusion_model and "dyffusion" not in self.diffusion_config._target_.lower() + is_dyffusion = self.is_diffusion_model and "dyffusion" in self.diffusion_config._target_.lower() + if is_standard_diffusion: + return self.actual_num_output_channels(self.dims["output"]) + elif is_dyffusion: + return num_input_channels # window is used as conditioning + if self.stack_window_to_channel_dim: + return multiply_by_scalar(num_input_channels, self.window) + return num_input_channels + + def get_horizon(self, split: str, dataloader_idx: int = 0) -> int: + if self.datamodule is not None and hasattr(self.datamodule, "get_horizon"): + return self.datamodule.get_horizon(split, dataloader_idx=dataloader_idx) + self.log_text.warning(f"Using default horizon {self.horizon} for split ``{split}``.") + return self.horizon + + @property + def prediction_horizon(self) -> int: + if hasattr(self.datamodule_config, "prediction_horizon") and self.datamodule_config.prediction_horizon: + return self.datamodule_config.prediction_horizon + return self.horizon * (self.hparams.autoregressive_steps + 1) + + # def on_train_start(self) -> None: + # def on_fit_start(self) -> None: + def on_any_start(self, stage: str = None) -> None: + super().on_any_start(stage) + horizon = self.get_horizon(stage) + ar_steps = self.num_autoregressive_steps_for_horizon(horizon) + # max_horizon = horizon * (ar_steps + 1) + self.log_text.info(f"Using {ar_steps} autoregressive steps for stage ``{stage}`` with horizon={horizon}.") + + # --------------------------------- Metrics + def get_epoch_aggregators(self, split: str, dataloader_idx: int = None) -> dict: + assert split in ["val", "test", "predict"], f"Invalid split {split}" + is_inference_val = split == "val" and dataloader_idx == 1 + if is_inference_val and self.current_epoch % self.hparams.inference_val_every_n_epochs != 0: + # Skip inference on validation set for this epoch (for efficiency) + return {} + + return super().get_epoch_aggregators(split, dataloader_idx) + + @torch.inference_mode() # torch.no_grad() + def _evaluation_step( + self, + batch: Any, + batch_idx: int, + split: str, + dataloader_idx: int = None, + return_outputs: bool | str = None, + # "auto", # True = -> "preds" + "targets". False: None "all": all outputs + boundary_conditions: Callable = None, + t0: float = 0.0, + dt: float = 1.0, + aggregators: Dict[str, Callable] = None, + verbose: bool = True, + prediction_horizon: int = None, + ): + # todo: for huge horizons: load full dynamics + dynamics condition on CPU and send to GPU piece by piece + return_dict = dict() + if prediction_horizon is not None: + assert split == "predict", "Prediction horizon only to be used for split='predict'" + else: + prediction_horizon = self.get_horizon(split, dataloader_idx=dataloader_idx) + + return_outputs = return_outputs or self.hparams.return_outputs_at_evaluation + if return_outputs == "auto": + return_outputs = "all" if split == "predict" and prediction_horizon < 1500 else False + no_aggregators = aggregators is None or len(aggregators.keys()) == 0 + if not no_aggregators: + split3d_and_merge_variables_p = ( + partial(split3d_and_merge_variables, level_names=self.datamodule.hparams.pressure_levels) + if hasattr(self.datamodule.hparams, "pressure_levels") + else lambda x: x + ) + + # Get predictions mask if available (applied to preds and targets, e.g. for spatially masked predictions) + predictions_mask = batch.pop("predictions_mask", None) # pop to ensure that it's not used in model + if predictions_mask is not None: + predictions_mask = predictions_mask[0, ...] # e.g. (2, 40, 80) -> (40, 80) + + main_data_raw = batch.pop("raw_dynamics", None) # Unnormalized (raw scale) data, used to compute targets + dynamic_conds = batch.pop("dynamical_condition", None) # will be added back to batch later, piece by piece + # main_batch = batch.copy() + # Compute how many autoregressive steps to complete + if dataloader_idx is not None and dataloader_idx > 0 and no_aggregators: + self.log_text.info(f"No aggregators for {split=} {dataloader_idx=} {self.current_epoch=}") + return {} + else: + assert split in ["val", "test", "predict"] + self.test_set_names, f"Invalid split {split}" + n_outer_loops = self.num_autoregressive_steps_for_horizon(prediction_horizon) + 1 + dyn_any = main_data_raw if main_data_raw is not None else batch["dynamics"] + if dyn_any.shape[1] < prediction_horizon: + raise ValueError(f"Prediction horizon {prediction_horizon} is larger than {dyn_any.shape}[1]") + + # Remove the last part of the dynamics that is not needed for prediction inside the module/model + # dynamics = batch["dynamics"].clone() + batch["dynamics"] = batch["dynamics"][:, : self.window + self.true_horizon, ...] + + if self.is_diffusion_model and split == "val" and dataloader_idx in [0, None]: + # log validation loss + if dynamic_conds is not None: + # first window of dyn. condition + batch["dynamical_condition"] = dynamic_conds[:, : self.window + self.true_horizon] + loss = self.get_loss(batch) + if isinstance(loss, dict): + # add split/ prefix if not already there + log_dict = {f"{split}/{k}" if not k.startswith(split) else k: float(v) for k, v in loss.items()} + elif torch.is_tensor(loss): + log_dict = {f"{split}/loss": float(loss)} + self.log_dict(log_dict, on_step=False, on_epoch=True) + + # Initialize autoregressive loop + autoregressive_inputs = None + total_t = t0 + predicted_range_last = [0.0] + self.prediction_timesteps[:-1] + ar_window_steps_t = self.horizon_range[-self.window :] # autoregressive window steps (all after input window) + pbar = tqdm( + range(n_outer_loops), + desc="Autoregressive Step", + position=0, + leave=True, + disable=not self.verbose or n_outer_loops <= 1, + ) + # Loop over autoregressive steps (to cover timesteps beyond training horizon) + for ar_step in pbar: + self.print_gpu_memory_usage(tqdm_bar=pbar, empty_cache=self.hparams.empty_cache_at_autoregressive_step) + ar_window_steps = [] + # Loop over training horizon + for t_step_last, t_step in zip(predicted_range_last, self.prediction_timesteps): + total_horizon = ar_step * self.true_horizon + t_step + if total_horizon > prediction_horizon: + # May happen if we have a prediction horizon that is not a multiple of the true horizon + break + PREDS_NORMED_K = f"t{t_step}_preds_normed" + PREDS_RAW_K = f"t{t_step}_preds" + pr_kwargs = {} if autoregressive_inputs is None else {"num_predictions": 1} + if dynamic_conds is not None: # self.true_horizon=1 + # ar_step = 0 --> slice(0, H+1), ar_step = 1 --> slice(H, 2H+1), etc. + current_slice = slice(ar_step * self.true_horizon, (ar_step + 1) * self.true_horizon + 1) + batch["dynamical_condition"] = dynamic_conds[:, current_slice] + + results = self.get_preds_at_t_for_batch( + batch, t_step, split, is_autoregressive=ar_step > 0, ensemble=True, **pr_kwargs + ) + total_t += dt * (t_step - t_step_last) # update time, by default this is == dt + + if float(total_horizon).is_integer() and main_data_raw is not None: + target_time = self.window + int(total_horizon) - 1 + targets_tensor_t = main_data_raw[:, target_time, ...] + targets = self.get_target_variants(targets_tensor_t, is_normalized=False) + else: + targets = None + + targets_normed = targets["targets_normed"] if targets is not None else None + targets_raw = targets["targets"] if targets is not None else None + # Apply boundary conditions to predictions, if any + if boundary_conditions is not None: + data_t = main_data_raw[:, target_time, ...] + for k in [PREDS_NORMED_K, "preds_autoregressive_init_normed"]: + if k in results: + results[k] = boundary_conditions( + preds=results[k], + targets=targets_normed, + metadata=batch.get("metadata", None), + data=data_t, + time=total_t, + ) + preds_normed = results.pop(PREDS_NORMED_K) + if return_outputs in [True, "all"]: + return_dict[f"t{total_horizon}_targets_normed"] = torch_to_numpy(targets_normed) + return_dict[f"t{total_horizon}_preds_normed"] = torch_to_numpy(preds_normed) + elif return_outputs == "preds_only": + return_dict[f"t{total_horizon}_preds_normed"] = torch_to_numpy(preds_normed) + + if return_outputs == "all": + return_dict[f"t{total_horizon}_targets"] = torch_to_numpy(targets_raw) + return_dict.update( + {k.replace(f"t{t_step}", f"t{total_horizon}"): torch_to_numpy(v) for k, v in results.items()} + ) # update keys to total horizon (instead of relative horizon of autoregressive step) + + if t_step in ar_window_steps_t: + # if predicted_range == self.horizon_range and window == 1, then this is just the last step :) + # Need to keep the last window steps that are INTEGER steps! + ar_init = results.pop("preds_autoregressive_init_normed", preds_normed) + if self.use_ensemble_predictions(split): + ar_init = rrearrange(ar_init, "N B ... -> (N B) ...") # flatten ensemble dimension + ar_window_steps += [ar_init] # keep t,c,z,h,w + + if not float(total_horizon).is_integer(): + self.log_text.info(f"Skipping non-integer total horizon {total_horizon}") + continue + + if no_aggregators: + continue + + with self.timing_scope(context=f"aggregators_{split}", no_op=True): + assert predictions_mask is None, "Predictions mask not yet supported for aggregators" + pred_data = split3d_and_merge_variables_p(results[PREDS_RAW_K]) + target_data = split3d_and_merge_variables_p(targets_raw) + aggregators[f"t{total_horizon}"].record_batch( + target_data=target_data, + gen_data=pred_data, + target_data_norm=split3d_and_merge_variables_p(targets_normed), + gen_data_norm=split3d_and_merge_variables_p(preds_normed), + predictions_mask=predictions_mask, + ) + if "time_mean" in aggregators: + aggregators["time_mean"].record_batch( + target_data=target_data, gen_data=pred_data, predictions_mask=predictions_mask + ) + del results, targets + + if ar_step < n_outer_loops - 1: # if not last step, then update dynamics + autoregressive_inputs = torch.stack(ar_window_steps, dim=1) # shape (b, window, c, h, w) + if not torch.is_tensor(autoregressive_inputs): + # Rename keys to make clear that these are treated as inputs now + for k in list(autoregressive_inputs.keys()): + autoregressive_inputs[k.replace("preds", "inputs")] = autoregressive_inputs.pop(k) + batch["dynamics"] = autoregressive_inputs + del ar_window_steps + + self.on_autoregressive_loop_end(split, dataloader_idx=dataloader_idx) + return return_dict + + def on_autoregressive_loop_end(self, split: str, dataloader_idx: int = None, **kwargs): + pass + + def test_step(self, batch: Any, batch_idx: int, dataloader_idx: int = None, **kwargs): + return super().test_step(batch, batch_idx, dataloader_idx, **kwargs) + + def on_test_epoch_end(self, **kwargs) -> None: + return super().on_test_epoch_end(**kwargs) + + def get_preds_at_t_for_batch( + self, + batch: Dict[str, Tensor], + horizon: int | float, + split: str, + ensemble: bool = False, + is_autoregressive: bool = False, + prepare_inputs: bool = True, + **kwargs, + ) -> Dict[str, Tensor]: + b, t = batch["dynamics"].shape[0:2] # batch size, time steps + assert 0 < horizon <= self.true_horizon, f"horizon={horizon} must be in [1, {self.true_horizon}]" + + isi1 = isinstance(self, MHDYffusionAbstract) + isi2 = isinstance(self, SimultaneousMultiHorizonForecasting) + cache_preds = isi1 or isi2 + if not cache_preds or horizon == self.prediction_timesteps[0]: + if self.prediction_timesteps != self.horizon_range: + if isi1: + self.model.hparams.prediction_timesteps = [p_h for p_h in self.prediction_timesteps] + # create time tensor full of t_step, with batch size shape + if prepare_inputs: + inputs, extra_kwargs = self.get_inputs_and_extra_kwargs( + batch, time=None, split=split, is_autoregressive=is_autoregressive, ensemble=ensemble + ) + else: + inputs = batch.pop(self.inputs_data_key) + extra_kwargs = batch + + # inputs may be a repeated version of batch["dynamics"] for ensemble predictions + with torch.inference_mode(): + self._current_preds = self.predict(inputs, **extra_kwargs, **kwargs) + # for k, v, in {**self._current_preds, "dynamics": batch["dynamics"]}.items(): + # log.info(f"key={k}, shape={v.shape}, min={v.min()}, max={v.max()}, mean={v.mean()}, std={v.std()}") + + if cache_preds: + # for this model, we can cache the multi-horizon predictions + preds_key = f"t{horizon}_preds" # key for this horizon's predictions + results = {k: self._current_preds.pop(k) for k in list(self._current_preds.keys()) if preds_key in k} + if horizon == self.horizon_range[-1]: + assert all( + ["preds" not in k or "preds_autoregressive_init" in k for k in self._current_preds.keys()] + ), ( + f'{preds_key=} must be the only key containing "preds" in last prediction. ' + f"Got: {list(self._current_preds.keys())}" + ) + results = {**results, **self._current_preds} # add the rest of the results, if any + del self._current_preds + else: + results = {f"t{horizon}_{k}": v for k, v in self._current_preds.items()} + return results + + def get_inputs_from_dynamics(self, dynamics: Tensor | Dict[str, Tensor]) -> Tensor | Dict[str, Tensor]: + return dynamics[:, : self.window, ...] # (b, window, c, lat, lon) at time 0 + + def get_condition_from_dynamica_cond( + self, dynamics: Tensor | Dict[str, Tensor], **kwargs + ) -> Tensor | Dict[str, Tensor]: + dynamics_cond = self.get_inputs_from_dynamics(dynamics) + dynamics_cond = self.transform_inputs(dynamics_cond, **kwargs) + return dynamics_cond + + def transform_inputs( + self, + inputs: Tensor, + time: Tensor = None, + ensemble: bool = True, + stack_window_to_channel_dim: bool = None, + **kwargs, + ) -> Tensor: + if stack_window_to_channel_dim is None: + stack_window_to_channel_dim = self.stack_window_to_channel_dim + if stack_window_to_channel_dim: + inputs = rrearrange(inputs, "b window c ... -> b (window c) ...") + if ensemble: + inputs = self.get_ensemble_inputs(inputs, **kwargs) + return inputs + + def get_extra_model_kwargs( + self, + batch: Dict[str, Tensor], + split: str, + time: Tensor = None, + ensemble: bool = False, + is_autoregressive: bool = False, + ) -> Dict[str, Any]: + extra_kwargs = dict() + ensemble_k = ensemble and not is_autoregressive + if self.USE_TIME_AS_EXTRA_INPUT: + batch["time"] = time + for k, v in batch.items(): + if k == "dynamics": + continue + elif k == "metadata": + if self.PASS_METADATA_TO_MODEL: + extra_kwargs[k] = self.get_ensemble_inputs(v, split=split, add_noise=False) if ensemble_k else v + elif k == "predictions_mask": + extra_kwargs[k] = v[0, ...] # e.g. (2, 40, 80) -> (40, 80) + elif k in ["static_condition", "time", "lookback"]: + # Static features or time: simply add ensemble dimension and done + extra_kwargs[k] = self.get_ensemble_inputs(v, split=split, add_noise=False) if ensemble else v + elif "dynamical_condition" == k: # k in ["condition", "time_varying_condition"]: + # Time-varying features + extra_kwargs[k] = self.get_condition_from_dynamica_cond( + v, split=split, time=time, ensemble=ensemble, add_noise=False + ) + else: + raise ValueError(f"Unsupported key {k} in batch") + return extra_kwargs + + def get_inputs_and_extra_kwargs( + self, + batch: Dict[str, Tensor], + time: Tensor = None, + split: str = None, + ensemble: bool = False, + is_autoregressive: bool = False, + ) -> Tuple[Tensor, Dict[str, Any]]: + inputs = self.get_inputs_from_dynamics(batch["dynamics"]) + ensemble_inputs = ensemble and not is_autoregressive + inputs = self.pack_data(inputs, input_or_output="input") + inputs = self.transform_inputs(inputs, split=split, ensemble=ensemble_inputs) + extra_kwargs = self.get_extra_model_kwargs( + batch, split=split, time=time, ensemble=ensemble, is_autoregressive=is_autoregressive + ) + return inputs, extra_kwargs + + +class MHDYffusionAbstract(AbstractMultiHorizonForecastingExperiment): + PASS_METADATA_TO_MODEL = False + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.diffusion_config is not None, "diffusion config must be set. Use ``diffusion=``!" + assert self.diffusion_config.timesteps == self.horizon, "diffusion timesteps must be equal to horizon" + + +# This class is a subclass of MHDYffusionAbstract for multi-horizon forecasting using diffusion +# models. +class MultiHorizonForecastingDYffusion(MHDYffusionAbstract): + model: BaseDYffusion + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # Problematic when module.torch_compile="model": + # assert isinstance( + # self.model, BaseDYffusion + # ), f"Model must be an instance of BaseDYffusion, but got {type(self.model)}" + if hasattr(self.model, "interpolator"): + # self.log_text.info(f"------------------- Setting num_predictions={self.hparams.num_predictions}") + self.model.interpolator.hparams.num_predictions = self.hparams.num_predictions + self.model.interpolator.num_predictions_in_mem = self.num_predictions_in_mem + + def on_fit_start(self) -> None: + super().on_fit_start() + if hasattr(self.model, "interpolator"): + self.model.interpolator._datamodule = self.datamodule + + @property + def valid_time_range_for_backbone_model(self) -> List[int]: + return self.model.valid_time_range_for_backbone_model + + def get_condition_from_dynamica_cond( + self, dynamics: Tensor | Dict[str, Tensor], **kwargs + ) -> Tensor | Dict[str, Tensor]: + # selection of times will be handled inside src.diffusion.dyffusion + return self.transform_inputs(dynamics, stack_window_to_channel_dim=False, **kwargs) + + def get_loss(self, batch: Any) -> Tensor: + r"""Compute the loss for the given batch.""" + split = "train" if self.training else "val" + dynamics = batch["dynamics"] + x_last = dynamics[:, -1, ...] + x_last = self.pack_data(x_last, input_or_output="output") + inputs, extra_kwargs = self.get_inputs_and_extra_kwargs(batch, split=split, ensemble=False) + + loss = self.model.p_losses(input_dynamics=inputs, xt_last=x_last, **extra_kwargs) + return loss + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + # Skip loading the interpolator state_dict, as its weights are loaded in src.diffusion.dyffusion.__init__ + state_dict = {k: v for k, v in state_dict.items() if not k.startswith("model.interpolator")} + return super().load_state_dict(state_dict, strict=False) + + def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + super().on_save_checkpoint(checkpoint) + # Pop the interpolator state_dict from the checkpoint, as it is not needed + checkpoint["state_dict"] = {k: v for k, v in checkpoint["state_dict"].items() if "model.interpolator" not in k} + + +class AbstractSimultaneousMultiHorizonForecastingModule(AbstractMultiHorizonForecastingExperiment): + _horizon_at_once: int = None + + def __init__(self, horizon_at_once: int = None, autoregressive_loss_weights: Sequence[float] = None, **kwargs): + """Simultaneous multi-horizon forecasting module. + + Args: + horizon_at_once (int, optional): Number of time steps to forecast at once. Defaults to None. + If None, then the full horizon is forecasted at once. + Otherwise, only ``horizon_at_once`` time steps are forecasted at once and trained autoregressively until the full horizon is reached. + """ + super().__init__(**kwargs) + self.autoregressive_train_steps = self.horizon // self.horizon_at_once + if self.autoregressive_train_steps > 1: + self.log_text.info( + f"Training autoregressively for {self.autoregressive_train_steps} steps with horizon_at_once={self.horizon_at_once}" + ) + if autoregressive_loss_weights is None: + autoregressive_loss_weights = [ + 1.0 / self.autoregressive_train_steps for _ in range(self.autoregressive_train_steps) + ] + assert ( + len(autoregressive_loss_weights) == self.autoregressive_train_steps + ), f"Expected {self.autoregressive_train_steps} autoregressive loss weights, but got {len(autoregressive_loss_weights)}" + self.autoregressive_loss_weights = autoregressive_loss_weights + + if self.stack_window_to_channel_dim: + # Need to reshape the predictions to (b, t, c, h, w), where t = num_time_steps predicted + # if self.horizon_at_once > 1: + self.targets_pre_process = partial(rrearrange, pattern="b t c ... -> b (t c) ...", t=self.horizon_at_once) + # else: + # self.targets_pre_process = lambda x: x + self.predictions_post_process = partial( + rrearrange, pattern="b (t c) ... -> b t c ...", t=self.horizon_at_once + ) + else: + self.predictions_post_process = self.targets_pre_process = None + + @property + def horizon_at_once(self) -> int: + if self._horizon_at_once is None: + self._horizon_at_once = self.hparams.horizon_at_once or self.horizon + assert self.horizon % self.horizon_at_once == 0, "horizon must be divisible by horizon_at_once" + return self._horizon_at_once + + @property + def true_horizon(self) -> int: + return self.horizon_at_once + + @property + def horizon_range(self) -> List[int]: + return list(range(1, self.horizon_at_once + 1)) + + def actual_num_output_channels(self, num_output_channels: int) -> int: + num_output_channels = super().actual_num_output_channels(num_output_channels) + if self.stack_window_to_channel_dim: + return multiply_by_scalar(num_output_channels, self.horizon_at_once) + return num_output_channels + + def reshape_predictions(self, results: TensorDict) -> TensorDict: + """Reshape and unpack the predictions from the model. This modifies the input dictionary in-place. + Args: + results (Dict[str, Tensor]): The model outputs. Access the predictions via results['preds']. + """ + # reshape predictions to (b, t, c, h, w), where t = num_time_steps predicted + # ``b`` corresponds to the batch dimension and potentially the ensemble dimension + results["preds"] = self.predictions_post_process(results["preds"]) + # for k in list(results.keys()): + # results[k] = rrearrange(results[k], "b (t c) ... -> b t c ...", t=self.horizon) + # if isinstance(results, TensorDictBase): + # results.batch_size = [*results.batch_size, self.horizon] + return super().reshape_predictions(results) + + def unpack_predictions(self, results: Dict[str, Tensor]) -> Dict[str, Tensor]: + """Unpack the predictions from the model. This modifies the input dictionary in-place. + Args: + results (Dict[str, Tensor]): The model outputs. Access the predictions via results['preds']. + """ + horizon_dim = 1 if self.num_predictions == 1 else 2 # self.CHANNEL_DIM - 1 # == -4 + preds = results.pop("preds") + assert ( + preds.shape[horizon_dim] == self.horizon_at_once + ), f"Expected {preds.shape=} with dim {horizon_dim}={self.horizon_at_once}" + for h in self.horizon_range: + results[f"t{h}_preds"] = torch_select(preds, dim=horizon_dim, index=h - 1) + # th_pred.shape = (E, B, C, H, W); E = ensemble, B = batch, C = channels, H = height, W = width + return super().unpack_predictions(results) + + +class SimultaneousMultiHorizonForecasting(AbstractSimultaneousMultiHorizonForecastingModule): + def __init__(self, timestep_loss_weights: Sequence[float] = None, **kwargs): + super().__init__(**kwargs) + self.save_hyperparameters(ignore=["model", "timestep_loss_weights"]) + + if timestep_loss_weights is None: + timestep_loss_weights = [1.0 / self.horizon_at_once for _ in range(self.horizon_at_once)] + self.timestep_loss_weights = timestep_loss_weights + + def get_loss(self, batch: Any) -> Tensor: + r"""Compute the loss for the given batch.""" + dynamics = batch["dynamics"] + split = "train" if self.training else "val" + inputs, extra_kwargs = self.get_inputs_and_extra_kwargs(batch, split=split, ensemble=False) + + losses = dict(loss=0.0) + for ar_step in range(self.autoregressive_train_steps): + offset_left = self.window + self.horizon_at_once * ar_step + offset_right = self.window + self.horizon_at_once * (ar_step + 1) + targets = dynamics[:, offset_left:offset_right, ...] + targets = self.pack_data(targets, input_or_output="output") + # if self.stack_window_to_channel_dim: + # =========== THE BELOW GIVES TERRIBLE LOSS CURVES ========================== + # DO NOT DO THIS: targets = rrearrange(targets, "b t c ... -> b (t c) ...") | + # =========================================================================== + # targets = self.targets_pre_process(targets) # This will still do it, but only if t > 1 + loss_ar_i, preds = self.model.get_loss( + inputs=inputs, + targets=targets, + return_predictions=True, + predictions_post_process=self.predictions_post_process, + targets_pre_process=self.targets_pre_process, + **extra_kwargs, + ) + if isinstance(loss_ar_i, dict): + losses["loss"] += loss_ar_i.pop("loss") * self.autoregressive_loss_weights[ar_step] + for k, v in loss_ar_i.items(): + k_ar = f"{k}_ar{ar_step}" if ar_step > 0 else k + losses[k_ar] = float(v) + else: + losses["loss"] += loss_ar_i * self.autoregressive_loss_weights[ar_step] + + if ar_step < self.autoregressive_train_steps - 1: + if isinstance(preds, dict): + # log.info(f"inputs.shape={inputs.shape}, preds.shape={preds['preds'].shape}") + inputs = preds.pop("preds") # use the predictions as inputs for the next autoregressive step + for k, v in preds.items(): + # log.info(f"Adding {k} to loss_ar_i, shape={v.shape}, before: {extra_kwargs.get(k).shape}") + extra_kwargs[k] = v # overwrite other kwargs for the next step + else: + inputs = preds + inputs = inputs[:, -self.window :, ...].squeeze(1) # keep only the last window steps + + return losses + + +def infer_class_from_ckpt(ckpt_path: str, state=None) -> Type[AbstractMultiHorizonForecastingExperiment]: + """Infer the experiment class from the checkpoint path.""" + ckpt = torch.load(ckpt_path, map_location="cpu") if state is None else state + module_config = ckpt["hyper_parameters"] + abstract_kwargs = inspect.signature(AbstractMultiHorizonForecastingExperiment).parameters + base_kwargs = {k: v for k, v in module_config.items() if k not in abstract_kwargs} + diffusion_cfg = module_config["diffusion_config"] + if diffusion_cfg is not None: + if "dyffusion" in diffusion_cfg.get("_target_", ""): + return MultiHorizonForecastingDYffusion + return SimultaneousMultiHorizonForecasting + elif "timestep_loss_weights" in base_kwargs.keys(): + return SimultaneousMultiHorizonForecasting + else: + raise ValueError(f"Could not infer class from {ckpt_path=}") diff --git a/src/experiment_types/interpolation.py b/src/experiment_types/interpolation.py new file mode 100644 index 0000000..a377e36 --- /dev/null +++ b/src/experiment_types/interpolation.py @@ -0,0 +1,183 @@ +from typing import Any, Callable, Dict, List, Optional, Union + +import numpy as np +import torch +from torch import Tensor + +from src.experiment_types._base_experiment import BaseExperiment +from src.utilities.utils import ( + rrearrange, +) + + +class InterpolationExperiment(BaseExperiment): + r"""Base class for all interpolation experiments.""" + + def __init__(self, stack_window_to_channel_dim: bool = True, inference_val_every_n_epochs=None, **kwargs): + super().__init__(**kwargs) + if inference_val_every_n_epochs is not None: + self.log_text.warning("``inference_val_every_n_epochs`` will be ignored for interpolation experiments.") + # The following saves all the args that are passed to the constructor to self.hparams + # e.g. access them with self.hparams.hidden_dims + self.save_hyperparameters(ignore=["model"]) + assert self.horizon >= 2, "horizon must be >=2 for interpolation experiments" + if hasattr(self.model, "set_min_max_time"): + self.model.set_min_max_time(min_time=self.horizon_range[0], max_time=self.horizon_range[-1]) + + @property + def horizon_range(self) -> List[int]: + # h = horizon + # We use timesteps w-l+1, ..., w-1, w+h to predict timesteps w, ..., w+h-1 + # interpolate between step t=0 and t=horizon + return list(np.arange(1, self.horizon)) + + @property + def true_horizon(self) -> int: + return self.horizon + + @property + def horizon_name(self) -> str: + s = f"{self.true_horizon}h" + return s + + @property + def short_description(self) -> str: + name = super().short_description + name += f" (h={self.horizon_name})" + return name + + @property + def WANDB_LAST_SEP(self) -> str: + return "/" # /ipol/" + + @property + def num_conditional_channels(self) -> int: + """The number of channels that are used for conditioning as auxiliary inputs.""" + nc = super().num_conditional_channels + factor = self.window + 0 + 0 # num inputs before target + num targets + num inputs after target + return nc * factor + + def actual_num_input_channels(self, num_input_channels: int) -> int: + if self.hparams.stack_window_to_channel_dim: + return num_input_channels * self.window + num_input_channels + return 2 * num_input_channels # inputs and targets are concatenated + + def postprocess_inputs(self, inputs): + inputs = self.pack_data(inputs, input_or_output="input") + if self.hparams.stack_window_to_channel_dim: # and inputs.shape[1] == self.window: + inputs = rrearrange(inputs, "b window c lat lon -> b (window c) lat lon") + return inputs + + @torch.inference_mode() + def _evaluation_step( + self, + batch: Any, + batch_idx: int, + split: str, + dataloader_idx: int = None, + aggregators: Dict[str, Callable] = None, + return_only_preds_and_targets: bool = False, + ): + no_aggregators = aggregators is None or len(aggregators.keys()) == 0 + main_data_raw = batch.pop("raw_dynamics") + dynamics = batch["dynamics"] # dynamics is a (b, t, c, h, w) tensor + + return_dict = dict() + extra_kwargs = {} + dynamical_cond = batch.pop("dynamical_condition", None) + if dynamical_cond is not None: + assert "condition" not in batch, "condition should not be in batch if dynamical_condition is present" + inputs = self.get_evaluation_inputs(dynamics, split=split) + for k, v in batch.items(): + if k != "dynamics": + extra_kwargs[k] = self.get_ensemble_inputs(v, split=split, add_noise=False) + + for t_step in self.horizon_range: + # dynamics[, self.window] is already the first target frame (t_step=1) + target_time = self.window + t_step - 1 + time = torch.full((inputs.shape[0],), t_step, device=self.device, dtype=torch.long) + if dynamical_cond is not None: + extra_kwargs["condition"] = self.get_ensemble_inputs( + self.get_dynamical_condition(dynamical_cond, target_time), split=split, add_noise=False + ) + results = self.predict(inputs, time=time, **extra_kwargs) + preds = results["preds"] + + targets_tensor_t = main_data_raw[:, target_time, ...] + targets = self.get_target_variants(targets_tensor_t, is_normalized=False) + results["targets"] = targets + results = {f"t{t_step}_{k}": v for k, v in results.items()} + + if return_only_preds_and_targets: + return_dict[f"t{t_step}_preds"] = preds + return_dict[f"t{t_step}_targets"] = targets + else: + return_dict = {**return_dict, **results} + + if no_aggregators: + continue + + PREDS_NORMED_K = f"t{t_step}_preds_normed" + PREDS_RAW_K = f"t{t_step}_preds" + targets_normed = targets["targets_normed"] if targets is not None else None + targets_raw = targets["targets"] if targets is not None else None + aggregators[f"t{t_step}"].record_batch( + target_data=targets_raw, + gen_data=results[PREDS_RAW_K], + target_data_norm=targets_normed, + gen_data_norm=results[PREDS_NORMED_K], + ) + + return return_dict + + def get_dynamical_condition( + self, dynamical_condition: Optional[Tensor], target_time: Union[int, Tensor] + ) -> Tensor: + if dynamical_condition is not None: + if isinstance(target_time, int): + return dynamical_condition[:, target_time, ...] + else: + return dynamical_condition[torch.arange(dynamical_condition.shape[0]), target_time.long(), ...] + return None + + def get_inputs_from_dynamics(self, dynamics: Tensor, **kwargs) -> Tensor: + """Get the inputs from the dynamics tensor. + Since we are doing interpolation, this consists of the first window frames plus the last frame. + """ + past_steps = dynamics[:, : self.window, ...] # (b, window, c, lat, lon) at time 0 + last_step = dynamics[:, -1:, ...] # (b, c, lat, lon) at time t=window+horizon + past_steps = self.postprocess_inputs(past_steps) + last_step = self.postprocess_inputs(last_step) + inputs = torch.cat([past_steps, last_step], dim=1) # (b, window*c + c, lat, lon) + return inputs + + def get_evaluation_inputs(self, dynamics: Tensor, split: str, **kwargs) -> Tensor: + inputs = self.get_inputs_from_dynamics(dynamics) + inputs = self.get_ensemble_inputs(inputs, split) + return inputs + + # --------------------------------- Training + def get_loss(self, batch: Any, optimizer_idx: int = 0) -> Tensor: + r"""Compute the loss for the given batch.""" + dynamics = batch["dynamics"] # dynamics is a (b, t, c, h, w) tensor + inputs = self.get_inputs_from_dynamics(dynamics) # (b, c, h, w) at time 0 + b = dynamics.shape[0] + + possible_times = torch.tensor(self.horizon_range, device=self.device, dtype=torch.long) # (h,) + # take random choice of time + t = possible_times[torch.randint(len(possible_times), (b,), device=self.device, dtype=torch.long)] # (b,) + target_time = self.window + t - 1 + # t = torch.randint(start_t, max_t, (b,), device=self.device, dtype=torch.long) # (b,) + targets = dynamics[torch.arange(b), target_time, ...] # (b, c, h, w) + targets = self.pack_data(targets, input_or_output="output") + # We use timesteps w-l+1, ..., w-1, w+h to predict timesteps w, ..., w+h-1 + # so t=0 corresponds to interpolating w, t=1 to w+1, ..., t=h-1 to w+h-1 + + loss = self.model.get_loss( + inputs=inputs, + targets=targets, + condition=self.get_dynamical_condition(batch.pop("dynamical_condition", None), target_time=target_time), + time=t, + **{k: v for k, v in batch.items() if k != "dynamics"}, + ) # function of BaseModel or BaseDiffusion classes + return loss diff --git a/src/interface.py b/src/interface.py new file mode 100644 index 0000000..fe9dfe2 --- /dev/null +++ b/src/interface.py @@ -0,0 +1,313 @@ +from __future__ import annotations + +import os +from typing import Any, Dict, List, Optional, Union + +import hydra +import pytorch_lightning +import torch +from omegaconf import DictConfig, OmegaConf + +from src.datamodules.abstract_datamodule import BaseDataModule +from src.experiment_types._base_experiment import BaseExperiment +from src.utilities.checkpointing import local_path_to_absolute_and_download_if_needed +from src.utilities.utils import ( + get_logger, + rename_state_dict_keys_and_save, +) + + +""" +In this file you can find helper functions to avoid model/data loading and reloading boilerplate code +""" + +log = get_logger(__name__) + + +def get_lightning_module(config: DictConfig, **kwargs) -> BaseExperiment: + r"""Get the ML model, a subclass of :class:`~src.experiment_types._base_experiment.BaseExperiment`, as defined by the key value pairs in ``config.model``. + + Args: + config (DictConfig): A OmegaConf config (e.g. produced by hydra .yaml file parsing) + **kwargs: Any additional keyword arguments for the model class (overrides any key in config, if present) + + Returns: + BaseExperiment: + The lightning module that you can directly use to train with pytorch-lightning + + Examples: + + .. code-block:: python + + from src.utilities.config_utils import get_config_from_hydra_compose_overrides + + config_mlp = get_config_from_hydra_compose_overrides(overrides=['model=mlp']) + mlp_model = get_model(config_mlp) + + # Get a prediction for a (B, S, C) shaped input + random_mlp_input = torch.randn(1, 100, 5) + random_prediction = mlp_model.predict(random_mlp_input) + """ + model = hydra.utils.instantiate( + config.module, + model_config=config.model, + datamodule_config=config.datamodule, + diffusion_config=config.get("diffusion", default_value=None), + _recursive_=False, + **kwargs, + ) + + return model + + +def get_datamodule(config: DictConfig) -> BaseDataModule: + r"""Get the datamodule, as defined by the key value pairs in ``config.datamodule``. A datamodule defines the data-loading logic as well as data related (hyper-)parameters like the batch size, number of workers, etc. + + Args: + config (DictConfig): A OmegaConf config (e.g. produced by hydra .yaml file parsing) + + Returns: + Base_DataModule: + A datamodule that you can directly use to train pytorch-lightning models + + Examples: + + .. code-block:: python + + from src.utilities.config_utils import get_config_from_hydra_compose_overrides + + cfg = get_config_from_hydra_compose_overrides(overrides=['datamodule=icosahedron', 'datamodule.order=5']) + ico_dm = get_datamodule(cfg) + """ + data_module = hydra.utils.instantiate( + config.datamodule, + _recursive_=False, + model_config=config.model, + ) + return data_module + + +def get_model_and_data(config: DictConfig) -> (BaseExperiment, BaseDataModule): + r"""Get the model and datamodule. This is a convenience function that wraps around :meth:`get_model` and :meth:`get_datamodule`. + + Args: + config (DictConfig): A OmegaConf config (e.g. produced by hydra .yaml file parsing) + + Returns: + (BaseExperiment, Base_DataModule): A tuple of (module, datamodule), that you can directly use to train with pytorch-lightning + + Examples: + + .. code-block:: python + + from src.utilities.config_utils import get_config_from_hydra_compose_overrides + + cfg = get_config_from_hydra_compose_overrides(overrides=['datamodule=icosahedron', 'model=mlp']) + mlp_model, icosahedron_data = get_model_and_data(cfg) + + # Use the data from datamodule (its ``train_dataloader()``), to train the model for 10 epochs + trainer = pl.Trainer(max_epochs=10, devices=1) + trainer.fit(model=model, datamodule=icosahedron_data) + + """ + data_module = get_datamodule(config) + model = get_lightning_module(config) + if config.module.get("torch_compile") == "module": + log.info("Compiling LightningModule with torch.compile()...") + model = torch.compile(model) + return model, data_module + + +def reload_model_from_config_and_ckpt( + config: DictConfig, + model_path: str, + device: Optional[torch.device] = None, + also_datamodule: bool = True, + also_ckpt: bool = False, + **kwargs, +) -> Dict[str, Any]: + r"""Load a model as defined by ``config.model`` and reload its weights from ``model_path``. + + Args: + config (DictConfig): The config to use to reload the model + model_path (str): The path to the model checkpoint (its weights) + device (torch.device): The device to load the model on. Defaults to 'cuda' if available, else 'cpu'. + also_datamodule (bool): If True, also reload the datamodule from the config. Defaults to True. + also_ckpt (bool): If True, also returns the checkpoint from ``model_path``. Defaults to False. + + Returns: + BaseModel: The reloaded model if load_datamodule is ``False``, otherwise a tuple of (reloaded-model, datamodule) + + Examples: + + .. code-block:: python + + # If you used wandb to save the model, you can use the following to reload it + from src.utilities.wandb_api import load_hydra_config_from_wandb + + run_path = ENTITY/PROJECT/RUN_ID # wandb run id (you can find it on the wandb URL after runs/, e.g. 1f5ehvll) + config = load_hydra_config_from_wandb(run_path, override_kwargs=['datamodule.num_workers=4', 'trainer.gpus=-1']) + + model, datamodule = reload_model_from_config_and_ckpt(config, model_path, load_datamodule=True) + + # Test the reloaded model + trainer = hydra.utils.instantiate(config.trainer, _recursive_=False) + trainer.test(model=model, datamodule=datamodule) + + """ + model, data_module = get_model_and_data(config) if also_datamodule else (get_lightning_module(config), None) + # Reload model + device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu") + model_state = torch.load(model_path, map_location=device, weights_only=False) + # rename weights (sometimes needed for backwards compatibility) + state_dict = rename_state_dict_keys_and_save(model_state, model_path) + # Reload weights + # remove all keys with model.interpolator prefix + # state_dict = {k: v for k, v in state_dict.items() if not k.startswith("model.interpolator")} + model.load_state_dict(state_dict, strict=False) + + to_return = { + "model": model, + "datamodule": data_module, + "epoch": model_state["epoch"], + "global_step": model_state["global_step"], + "wandb": model_state.get("wandb", None), + } + file_size = os.path.getsize(model_path) + str_to_print = ( + f"Reloaded {model_path}." + f" Epoch={model_state['epoch']}." + f" Global_step={model_state['global_step']}." + f" File size [in MB]: {file_size / 1e6:.2f}" + ) + if model_state.get("wandb") is not None: + str_to_print += f"\nRun ID: {model_state['wandb']['id']}\t Name: {model_state['wandb']['name']}" + log.info(str_to_print) + if also_ckpt: + to_return["ckpt"] = model_state + return to_return + + +def get_checkpoint_from_path_or_wandb( + model_checkpoint: Optional[torch.nn.Module] = None, + model_checkpoint_path: Optional[str] = None, + config_path: Optional[str] = None, + wandb_run_id: Optional[str] = None, + model_name: Optional[str] = "model", + wandb_kwargs: Optional[Dict[str, Any]] = None, + model_overrides: Optional[List[str]] = None, +) -> torch.nn.Module: + if model_checkpoint is not None: + assert model_checkpoint_path is None, "must provide either model_checkpoint or model_checkpoint_path" + assert wandb_run_id is None, "must provide either model_checkpoint or wandb_run_id" + model = model_checkpoint + elif wandb_run_id is not None: + # assert model_checkpoint_path is None, 'must provide either wandb_run_path or model_checkpoint_path' + override_key_value = model_overrides or [] + override_key_value += ["module.verbose=False"] + wandb_kwargs = wandb_kwargs or {} + model = reload_checkpoint_from_wandb( + run_id=wandb_run_id, + also_datamodule=False, + override_key_value=override_key_value, + local_checkpoint_path=model_checkpoint_path, + config_path=config_path, + **wandb_kwargs, + )["model"] + else: + raise ValueError("Provide either model_checkpoint, model_checkpoint_path or wandb_run_id") + return model + + +def reload_checkpoint_from_wandb( + run_id: str, + entity: str = None, + project: str = None, + config_path: Optional[str] = None, + ckpt_filename: Optional[str] = None, + epoch: Union[str, int] = "best", + override_key_value: List[str] = None, + local_checkpoint_path: str = None, + **reload_kwargs, +) -> dict: + """ + Reload model checkpoint based on only the Wandb run ID + + Args: + run_id (str): the wandb run ID (e.g. 2r0l33yc) corresponding to the model to-be-reloaded + entity (str): the wandb entity corresponding to the model to-be-reloaded + project (str): the project entity corresponding to the model to-be-reloaded + config_path (str): the path to the config file to be used to reload the model. + If None, the config is loaded from Wandb + ckpt_filename (str): the filename of the checkpoint to be reloaded (e.g. 'last.ckpt') + epoch (str or int): If 'best', the reloaded model will be the best one stored, if 'last' the latest one stored), + if an int, the reloaded model will be the one save at that epoch (if it was saved, otherwise an error is thrown) + override_key_value: each element is expected to have a "=" in it, like datamodule.num_workers=8 + local_checkpoint_path (str): If not None, the path to the local checkpoint to be reloaded. + """ + import src.utilities.wandb_api as wandb_api + + entity, project = wandb_api.get_entity(entity), project or wandb_api.get_project_train() + run_id = str(run_id).strip() + run_path = f"{entity}/{project}/{run_id}" + + # Reload config + if config_path is not None: + config_path = local_path_to_absolute_and_download_if_needed(config_path) + config = OmegaConf.load(config_path) + config = OmegaConf.unsafe_merge(config, OmegaConf.from_dotlist(override_key_value)) + else: + config = wandb_api.load_hydra_config_from_wandb(run_path, override_key_value=override_key_value) + + # Find or download the checkpoint + ckpt_path = wandb_api.restore_model_from_wandb_cloud( + run_path, + local_checkpoint_path=local_checkpoint_path, + epoch=epoch, + ckpt_filename=ckpt_filename, + throw_error_if_local_not_found=False, + config=config, + ) + + assert os.path.isfile(ckpt_path), f"Could not find {ckpt_path=} in {os.getcwd()}" + assert str(config.logger.wandb.id) == str(run_id), f"{config.logger.wandb.id=} != {run_id=}." + # Instantiate model and reload its weights + try: + reloaded_model_data = reload_model_from_config_and_ckpt(config, ckpt_path, **reload_kwargs) + except RuntimeError as e: + rank = os.environ.get("RANK", None) or os.environ.get("LOCAL_RANK", 0) + raise RuntimeError( + f"[rank: {rank}] You may have changed the model code, making it incompatible with older model " + f"versions. Tried to reload the model ckpt for run.id={run_id} from {ckpt_path}.\n" + f"config.model={config.model}" + ) from e + if reloaded_model_data.get("wandb") is not None: + if reloaded_model_data["wandb"].get("id") != run_id: + raise ValueError(f"run_id={run_id} != state_dict['wandb']['id']={reloaded_model_data['wandb']['id']}") + # config.trainer.resume_from_checkpoint = ckpt_path + # os.remove(ckpt_path) if os.path.exists(ckpt_path) else None # delete the downloaded ckpt + return {**reloaded_model_data, "config": config, "ckpt_path": ckpt_path} + + +def get_simple_trainer(**kwargs) -> pytorch_lightning.Trainer: + devices = kwargs.get("devices", 1 if torch.cuda.is_available() else None) + accelerator = kwargs.get("accelerator", "gpu" if torch.cuda.is_available() else None) + return pytorch_lightning.Trainer( + devices=devices, + accelerator=accelerator, + **kwargs, + ) + + +def run_inference( + module: pytorch_lightning.LightningModule, + datamodule: pytorch_lightning.LightningDataModule, + trainer: pytorch_lightning.Trainer = None, + trainer_kwargs: Dict[str, Any] = None, +): + trainer = trainer or get_simple_trainer(**(trainer_kwargs or {})) + results = trainer.predict(module, datamodule=datamodule) + results = module._evaluation_get_preds(results, split="predict") + if hasattr(datamodule, "numpy_results_to_xr_dataset"): + results = datamodule.numpy_results_to_xr_dataset(results, split="predict") + return results diff --git a/src/losses/__init__.py b/src/losses/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/losses/losses.py b/src/losses/losses.py new file mode 100644 index 0000000..22da061 --- /dev/null +++ b/src/losses/losses.py @@ -0,0 +1,79 @@ +from functools import partial +from typing import Iterable, Optional, Union + +import torch +import torch.nn as nn +from torch import Tensor + +from src.evaluation.metrics import weighted_mean +from src.utilities.utils import get_logger + + +log = get_logger(__name__) + + +class LpLoss(torch.nn.Module): + def __init__( + self, + p=2, + relative: bool = True, + weights: Optional[Tensor] = None, + weighted_dims: Union[int, Iterable[int]] = (), + ): + """ + Args: + p: Lp-norm type. For example, p=1 for L1-norm, p=2 for L2-norm. + relative: If True, compute the relative Lp-norm, i.e. ||x - y||_p / ||y||_p. + """ + super(LpLoss, self).__init__() + + if p <= 0: + raise ValueError("Lp-norm type should be positive") + + self.p = p + self.loss_func = self.rel if relative else self.abs + self.weights = weights + + @property + def weights(self): + return self._weights + + @weights.setter + def weights(self, weights): + self._weights = weights + if weights is not None: + self.mean_func = partial(weighted_mean, weights=weights) + else: + self.mean_func = torch.mean + + def rel(self, x, y): + num_examples = x.size()[0] + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + y_norms = torch.norm(y.reshape(num_examples, -1), self.p, 1) + + # print(diff_norms.shape, y_norms.shape, self.mean_func) + return self.mean_func(diff_norms / y_norms) + + def abs(self, x, y): + num_examples = x.size()[0] + diff_norms = torch.norm(x.reshape(num_examples, -1) - y.reshape(num_examples, -1), self.p, 1) + return self.mean_func(diff_norms) + + def __call__(self, x, y): + return self.loss_func(x, y) + + +def get_loss(name, reduction="mean", **kwargs): + """Returns the loss function with the given name.""" + name = name.lower().strip().replace("-", "_") + if name in ["l1", "mae", "mean_absolute_error"]: + loss = nn.L1Loss(reduction=reduction, **kwargs) + elif name in ["l2", "mse", "mean_squared_error"]: + loss = nn.MSELoss(reduction=reduction, **kwargs) + elif name in ["l2_rel"]: + loss = LpLoss(p=2, relative=True, **kwargs) + elif name in ["l1_rel"]: + loss = LpLoss(p=1, relative=True, **kwargs) + else: + raise ValueError(f"Unknown loss function {name}") + return loss diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/_base_model.py b/src/models/_base_model.py new file mode 100644 index 0000000..40c53cf --- /dev/null +++ b/src/models/_base_model.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +import logging +from contextlib import contextmanager +from typing import Any, Callable, Dict, Optional, Sequence, Tuple, Union + +import hydra +import numpy as np +import torch +import xarray as xr +from omegaconf import DictConfig +from pytorch_lightning import LightningModule +from torch import Tensor + +from src.losses.losses import get_loss +from src.utilities.utils import ( + disable_inference_dropout, + enable_inference_dropout, + get_logger, +) + + +class BaseModel(LightningModule): + r"""This is a template base class, that should be inherited by any stand-alone ML model. + Methods that need to be implemented by your concrete ML model (just as if you would define a :class:`torch.nn.Module`): + - :func:`__init__` + - :func:`forward` + + The other methods may be overridden as needed. + It is recommended to define the attribute + >>> self.example_input_array = torch.randn() # batch dimension can be anything, e.g. 7 + + + .. note:: + Please use the function :func:`predict` at inference time for a given input tensor, as it postprocesses the + raw predictions from the function :func:`raw_predict` (or model.forward or model())! + + Args: + name (str): optional string with a name for the model + verbose (bool): Whether to print/log or not + + Read the docs regarding LightningModule for more information: + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html + """ + + def __init__( + self, + num_input_channels: int = None, + num_output_channels: int = None, + num_output_channels_raw: int = None, # actual channels. output_channels may be larger when stacking dims + num_conditional_channels: int = 0, + spatial_shape_in: Union[Sequence[int], int] = None, + spatial_shape_out: Union[Sequence[int], int] = None, + loss_function: str = "mean_squared_error", + loss_function_weights: Optional[Dict[str, float]] = None, + datamodule_config: Optional[DictConfig] = None, + debug_mode: bool = False, + name: str = "", + verbose: bool = True, + ): + super().__init__() + # The following saves all the args that are passed to the constructor to self.hparams + # e.g. access them with self.hparams.monitor + self.save_hyperparameters(ignore=["verbose", "model"]) + # Get a logger + self.log_text = get_logger(name=self.__class__.__name__ if name == "" else name) + self.name = name + self.verbose = verbose + if not self.verbose: # turn off info level logging + self.log_text.setLevel(logging.WARN) + + self.num_input_channels = num_input_channels + self.num_output_channels = num_output_channels + self.num_output_channels_raw = num_output_channels_raw + self.num_conditional_channels = num_conditional_channels + self.spatial_shape_in = spatial_shape_in + self.spatial_shape_out = spatial_shape_out + self.datamodule_config = datamodule_config + + if loss_function is not None: + # Get the loss function + loss_function_name = ( + loss_function if isinstance(loss_function, str) else loss_function.get("_target_", "").split(".")[-1] + ) + self.loss_function_name = loss_function_name.lower() + self.loss_function_weights = loss_function_weights if loss_function_weights is not None else {} + for k in self.loss_function_weights.keys(): + assert k in ["preds"], f"Invalid loss function key: {k}" + + criterion = self.get_loss_callable() + print_text = ( + f"Criterion: {criterion} with weights: {self.loss_function_weights}" + if loss_function_weights + else f"Criterion: {criterion}" + ) + self.log_text.info(print_text) + # Using a dictionary for the criterion, so that we can have multiple loss functions if needed + if isinstance(criterion, torch.nn.ModuleDict): + self.criterion = criterion + elif isinstance(criterion, dict): + if any(isinstance(v, torch.nn.Module) for v in criterion.values()): + self.criterion = torch.nn.ModuleDict(criterion) + else: + self.criterion = criterion + elif isinstance(criterion, torch.nn.Module): + self.criterion = torch.nn.ModuleDict({"preds": criterion}) + else: + self.criterion = {"preds": criterion} + + self._channel_dim = None + self.ema_scope = None # EMA scope for the model. May be set by the BaseExperiment instance + # self._parent_module = None # BaseExperiment instance (only needed for edge cases) + + @property + def short_description(self) -> str: + return self.name if self.name else self.__class__.__name__ + + def get_parameters(self) -> list: + """Return the parameters for the optimizer.""" + return list(self.parameters()) + + def _get_loss_callable_from_name_or_config(self, loss_function: str, **kwargs): + """Return the loss function""" + if isinstance(loss_function, str): + loss = get_loss(loss_function, **kwargs) + elif isinstance(loss_function, dict): + loss = {k: get_loss(v, **kwargs) for k, v in loss_function.items()} + else: + loss = hydra.utils.instantiate(loss_function) + return loss + + def get_loss_callable(self, reduction: str = "mean", **kwargs): + """Return the loss function""" + loss_function = self.hparams.loss_function + loss = self._get_loss_callable_from_name_or_config(loss_function, reduction=reduction, **kwargs) + return loss + + @property + def num_params(self): + """Returns the number of parameters in the model""" + return sum(p.numel() for p in self.get_parameters() if p.requires_grad) + + @property + def channel_dim(self): + if self._channel_dim is None: + self._channel_dim = 1 + return self._channel_dim + + def evaluation_results_to_xarray(self, results: Dict[str, np.ndarray], **kwargs) -> Dict[str, xr.DataArray]: + """Convert the evaluation results to a xarray dataset""" + raise NotImplementedError(f"Please implement ``evaluation_results_to_xarray`` for {self.__class__.__name__}") + + def forward(self, X: Tensor, condition: Tensor = None, **kwargs): + r"""Standard ML model forward pass (to be implemented by the specific ML model). + + Args: + X (Tensor): Input data tensor of shape :math:`(B, *, C_{in})` + Shapes: + - Input: :math:`(B, *, C_{in})`, + + where :math:`B` is the batch size, :math:`*` is the spatial dimension(s) of the data, + and :math:`C_{in}` is the number of input features/channels. + """ + raise NotImplementedError("Base model is an abstract class!") + + def concat_condition_if_needed(self, inputs: Tensor, condition: Tensor = None, static_condition: Tensor = None): + if self.num_conditional_channels > 0: + # exactly one of condition or static_condition should be not None + if condition is None and static_condition is None: + raise ValueError( + f"condition and static_condition are both None but num_conditional_channels is {self.num_conditional_channels}" + ) + elif condition is not None and static_condition is not None: + condition = torch.cat((condition, static_condition), dim=1) + elif condition is None: + assert static_condition is not None, "condition and static_condition are both None" + condition = static_condition + else: + assert static_condition is None, "condition and static_condition are both not None" + + if hasattr(self, "upsample_condition"): + condition = self.upsample_condition(condition) + try: + # log.info(f"{inputs.shape=}, {condition.shape=}") + x = torch.cat((inputs, condition), dim=1) + except RuntimeError as e: + raise RuntimeError(f"inputs.shape: {inputs.shape}, condition.shape: {condition.shape}") from e + else: + x = inputs + assert condition is None, "condition is not None but num_conditional_channels is 0" + assert static_condition is None, "static_condition is not None but num_conditional_channels is 0" + return x + + def get_loss( + self, + inputs: Tensor, + targets: Tensor, + raw_targets: Tensor = None, + condition: Tensor = None, + metadata: Any = None, + predictions_mask: Optional[Tensor] = None, + # targets_mask: Optional[Tensor] = None, + return_predictions: bool = False, + predictions_post_process: Optional[Callable] = None, + targets_pre_process: Optional[Callable] = None, + **kwargs, + ) -> Union[Tensor, Tuple[Tensor, Tensor]]: + """Get the loss for the given inputs and targets. + + Args: + inputs (Tensor): Input data tensor of shape :math:`(B, *, C_{in})` + targets (Tensor): Target data tensor of shape :math:`(B, *, C_{out})` + raw_targets (Tensor): Raw target data tensor of shape :math:`(B, *, C_{out})` + condition (Tensor): Conditional data tensor of shape :math:`(B, *, C_{cond})` + metadata (Any): Optional metadata + predictions_mask (Tensor): Mask for the predictions, before computing the loss. Default: None (no mask) + return_predictions (bool): Whether to return the predictions or not. Default: False. + Note: this will return all the predictions, not just the masked ones (if any). + """ + + def mask_data(data): + if predictions_mask is not None: + return data[..., predictions_mask] + return data + + # Predict + if torch.is_tensor(inputs): + predictions = self(inputs, condition=condition, **kwargs) + else: + predictions = self(**inputs, condition=condition, **kwargs) + + if torch.is_tensor(predictions): + if predictions_post_process is not None: + predictions = predictions_post_process(predictions) + predictions = mask_data(predictions) + targets = mask_data(targets) + assert ( + predictions.shape == targets.shape + ), f"Be careful: Predictions shape {predictions.shape} != targets shape {targets.shape}. Missing singleton dimensions after batch dim. can be fatal." + loss = self.criterion["preds"](predictions, targets) + assert len(self.loss_function_weights) == 0, "Loss function weights are not supported for this case" + loss_dict = dict(loss=loss) + else: + if predictions_post_process is not None: + # Do post-processing of the predictions (but not other outputs of the model) + predictions["preds"] = predictions_post_process(predictions["preds"]) + loss = 0.0 + loss_dict = dict() + # For example, base_keys = ["preds"] + # With corresponding preds & targets shapes: (B, *, C_out, H, W) + for k in targets.keys(): + base_key = k.replace("inputs", "preds") + loss_weight_k = self.loss_function_weights.get(base_key, 1.0) + predictions_k = mask_data(predictions[base_key]) + targets_k = mask_data(targets[k]) + loss_k = self.criterion[base_key](predictions_k, targets_k) + loss += loss_weight_k * loss_k + loss_dict[f"loss/{base_key}"] = loss_k.item() + loss_dict["loss"] = loss # total loss, used to backpropagate + + if return_predictions: + return loss_dict, predictions + return loss_dict + + def predict_forward(self, *inputs: Tensor, metadata: Any = None, **kwargs): + """Forward pass for prediction. Usually the same as the forward pass, + but can be different for some models (e.g. sampling in probabilistic models). + """ + y = self(*inputs, **kwargs) + return y + + # Auxiliary methods + @contextmanager + def inference_dropout_scope(self, condition: bool, context=None): + assert isinstance(condition, bool), f"Condition must be a boolean, got {condition}" + if condition: + enable_inference_dropout(self) + if context is not None: + self.log_text.info(f"{context}: Switched to enabled inference dropout") + try: + yield None + finally: + if condition: + disable_inference_dropout(self) + if context is not None: + self.log_text.info(f"{context}: Switched to disabled inference dropout") + + def enable_inference_dropout(self): + """Set all dropout layers to training mode""" + enable_inference_dropout(self) + + def disable_inference_dropout(self): + """Set all dropout layers to eval mode""" + disable_inference_dropout(self) + + def register_buffer_dummy(self, name, tensor, **kwargs): + try: + self.register_buffer(name, tensor, **kwargs) + except TypeError: # old pytorch versions do not have the arg 'persistent' + self.register_buffer(name, tensor) diff --git a/src/models/modules/__init__.py b/src/models/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/modules/attention.py b/src/models/modules/attention.py new file mode 100644 index 0000000..a80b3bd --- /dev/null +++ b/src/models/modules/attention.py @@ -0,0 +1,116 @@ +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from src.utilities.utils import default, exists + + +class LinearAttention(nn.Module): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32, dropout: float = 0.0, rescale: str = "qk"): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Sequential(nn.Dropout(dropout), nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)) + assert rescale in ["qk", "qkv"] + self.rescale = getattr(self, f"rescale_{rescale}") + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + # nn.Sequential( + # nn.Conv2d(hidden_dim, dim, 1), + # nn.Dropout(dropout) + # ) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) + + q, k, v = self.rescale(q, k, v, h=h, w=w) + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + def rescale_qk(self, q, k, v, h, w): + q = q * self.scale + k = k.softmax(dim=-1) + return q, k, v + + def rescale_qkv(self, q, k, v, h, w): + q = q.softmax(dim=-2) + q = q * self.scale + k = k.softmax(dim=-1) + v = v / (h * w) + return q, k, v + + +def l2norm(t): + return F.normalize(t, dim=-1) + + +class Attention(nn.Module): + def __init__(self, dim: int, heads: int = 4, dim_head: int = 32, dropout: float = 0.0): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + self.dropout = nn.Dropout(dropout) + + def forward(self, x, pos_bias=None): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) + + q = q * self.scale + + sim = torch.einsum("b h d i, b h d j -> b h i j", q, k) + # relative positional bias + if exists(pos_bias): + sim = sim + pos_bias + + attn = sim.softmax(dim=-1) + attn = self.dropout(attn) + out = torch.einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) diff --git a/src/models/modules/convs.py b/src/models/modules/convs.py new file mode 100644 index 0000000..8d0b316 --- /dev/null +++ b/src/models/modules/convs.py @@ -0,0 +1,30 @@ +from functools import partial + +import torch +import torch.nn.functional as F +from einops import reduce + + +class WeightStandardizedConv2d(torch.nn.Conv2d): + """ + https://arxiv.org/abs/1903.10520 + weight standardization purportedly works synergistically with group normalization + """ + + def forward(self, x): + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + + weight = self.weight + mean = reduce(weight, "o ... -> o 1 1 1", "mean") + var = reduce(weight, "o ... -> o 1 1 1", partial(torch.var, unbiased=False)) + normalized_weight = (weight - mean) * (var + eps).rsqrt() + + return F.conv2d( + x, + normalized_weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + ) diff --git a/src/models/modules/drop_path.py b/src/models/modules/drop_path.py new file mode 100644 index 0000000..043a56a --- /dev/null +++ b/src/models/modules/drop_path.py @@ -0,0 +1,36 @@ +import torch + + +@torch.jit.script +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor: # pragma: no cover + """Drop paths (Stochastic Depth) per sample (when applied in main path of + residual blocks). + This is the same as the DropConnect impl for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in + a separate paper. See discussion: + https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 + We've opted for changing the layer and argument names to 'drop path' rather than + mix DropConnect as a layer name and use 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1.0 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2d ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(torch.nn.Module): + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual + blocks). + """ + + def __init__(self, drop_rate=None): # pragma: no cover + super(DropPath, self).__init__() + self.drop_prob = drop_rate + + def forward(self, x): # pragma: no cover + return drop_path(x, self.drop_prob, self.training) diff --git a/src/models/modules/ema.py b/src/models/modules/ema.py new file mode 100644 index 0000000..a1b96c7 --- /dev/null +++ b/src/models/modules/ema.py @@ -0,0 +1,91 @@ +""" Exponential Moving Average (EMA) module """ + +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", torch.tensor(0, dtype=torch.int) if use_num_upates else torch.tensor(-1, dtype=torch.int) + ) + # NOTE: Add any parameters to skip for EMA here. E.g. criterion since it's not used for inference. + self.skip_params = ["criterion"] + for name, p in model.named_parameters(): + if any(skip_param in name for skip_param in self.skip_params): + continue + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.inference_mode(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if any(skip_param in key for skip_param in self.skip_params): + continue + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) + else: + assert key not in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if any(skip_param in key for skip_param in self.skip_params): + continue + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + if key not in self.m_name2s_name: + pass + else: + # print(f"Expecting {key} not to be in \nself.m_name2s_name={self.m_name2s_name.keys()}\nself.shadow_params={shadow_params.keys()}") + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + # assert key not in self.m_name2s_name, f"Expecting {key} not to be in self.m_name2s_name={self.m_name2s_name}" + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/src/models/modules/misc.py b/src/models/modules/misc.py new file mode 100644 index 0000000..168dbab --- /dev/null +++ b/src/models/modules/misc.py @@ -0,0 +1,148 @@ +import math + +import torch +from einops import parse_shape, rearrange +from torch import nn + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + try: + return self.fn(x, *args, **kwargs) + x + except TypeError as e: + raise TypeError(f"Error in Residual forward with {self.fn} and {type(x)}") from e + + +# sinusoidal positional embeddings +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device, dtype=x.dtype) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """following @crowsonkb 's lead with learned sinusoidal pos emb""" + + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.inference_mode(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class EinopsWrapper(nn.Module): + def __init__(self, module: nn.Module, from_shape: str, to_shape: str): + super().__init__() + assert isinstance(module, nn.Module), f"module must be an instance of nn.Module but got: {type(module)}" + self.module = module + self.from_shape = from_shape + self.to_shape = to_shape + + def forward(self, x: torch.Tensor, *args, **kwargs): + axes_lengths = parse_shape(x, pattern=self.from_shape) + x = rearrange(x, f"{self.from_shape} -> {self.to_shape}") + x = self.module(x, *args, **kwargs) + x = rearrange(x, f"{self.to_shape} -> {self.from_shape}", **axes_lengths) + return x + + +def get_einops_wrapped_module(module, from_shape: str, to_shape: str): + class WrappedModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.wrapper = EinopsWrapper(from_shape, to_shape, module(*args, **kwargs)) + + def forward(self, x: torch.Tensor, *args, **kwargs): + return self.wrapper(x, *args, **kwargs) + + return WrappedModule + + +def get_time_embedder(time_dim: int, dim: int, sinusoidal_embedding: str = "true", learned_sinusoidal_dim: int = 16): + if sinusoidal_embedding == "learned": + sinu_pos_emb = LearnedSinusoidalPosEmb(learned_sinusoidal_dim) + pos_emb_dim = learned_sinusoidal_dim + 1 # fourier_dim + elif sinusoidal_embedding == "true": + sinu_pos_emb = SinusoidalPosEmb(dim) + pos_emb_dim = dim + elif sinusoidal_embedding is None: + sinu_pos_emb = nn.Identity() + pos_emb_dim = dim + else: + raise ValueError(f"Unknown sinusoidal embedding type: {sinusoidal_embedding}") + + time_emb_mlp = nn.Sequential( + sinu_pos_emb, nn.Linear(pos_emb_dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) + ) + return time_emb_mlp diff --git a/src/models/modules/net_norm.py b/src/models/modules/net_norm.py new file mode 100644 index 0000000..726ca59 --- /dev/null +++ b/src/models/modules/net_norm.py @@ -0,0 +1,37 @@ +import torch +from torch import nn + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b + + +class PreNorm(nn.Module): + def __init__(self, dim, fn, norm=LayerNorm): + super().__init__() + self.fn = fn + self.norm = norm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + + +class PostNorm(nn.Module): + def __init__(self, dim, fn, norm=LayerNorm): + super().__init__() + self.fn = fn + self.norm = norm(dim) + + def forward(self, x, **kwargs): + x = self.fn(x, **kwargs) + return self.norm(x) diff --git a/src/models/sfno/__init__.py b/src/models/sfno/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/models/sfno/activations.py b/src/models/sfno/activations.py new file mode 100644 index 0000000..ce6aaf1 --- /dev/null +++ b/src/models/sfno/activations.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch import nn + + +class ComplexReLU(nn.Module): + """ + Complex-valued variants of the ReLU activation function + """ + + def __init__(self, negative_slope=0.0, mode="real", bias_shape=None, scale=1.0): + super(ComplexReLU, self).__init__() + + # store parameters + self.mode = mode + if self.mode in ["modulus", "halfplane"]: + if bias_shape is not None: + self.bias = nn.Parameter(scale * torch.ones(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(scale * torch.ones((1), dtype=torch.float32)) + else: + self.bias = 0 + + self.negative_slope = negative_slope + self.act = nn.LeakyReLU(negative_slope=negative_slope) + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = torch.where(zabs + self.bias > 0, (zabs + self.bias) * z / zabs, 0.0) + # out = self.act(zabs - self.bias) * torch.exp(1.j * z.angle()) + + elif self.mode == "halfplane": + # bias is an angle parameter in this case + modified_angle = torch.angle(z) - self.bias + condition = torch.logical_and((0.0 <= modified_angle), (modified_angle < torch.pi / 2.0)) + out = torch.where(condition, z, self.negative_slope * z) + + elif self.mode == "real": + zr = torch.view_as_real(z) + outr = zr.clone() + outr[..., 0] = self.act(zr[..., 0]) + out = torch.view_as_complex(outr) + + else: + raise NotImplementedError + + return out + + +class ComplexActivation(nn.Module): + """ + A module implementing complex-valued activation functions. + The module supports different modes of operation, depending on how + the complex numbers are treated for the activation function: + - "cartesian": the activation function is applied separately to the + real and imaginary parts of the complex input. + - "modulus": the activation function is applied to the modulus of + the complex input, after adding a learnable bias. + - any other mode: the complex input is returned as-is (identity operation). + """ + + def __init__(self, activation, mode="cartesian", bias_shape=None): + super(ComplexActivation, self).__init__() + + # store parameters + self.mode = mode + if self.mode == "modulus": + if bias_shape is not None: + self.bias = nn.Parameter(torch.zeros(bias_shape, dtype=torch.float32)) + else: + self.bias = nn.Parameter(torch.zeros((1), dtype=torch.float32)) + else: + bias = torch.zeros((1), dtype=torch.float32) + self.register_buffer("bias", bias) + + # real valued activation + self.act = activation + + def forward(self, z: torch.Tensor) -> torch.Tensor: + if self.mode == "cartesian": + zr = torch.view_as_real(z) + za = self.act(zr) + out = torch.view_as_complex(za) + elif self.mode == "modulus": + zabs = torch.sqrt(torch.square(z.real) + torch.square(z.imag)) + out = self.act(zabs + self.bias) * torch.exp(1.0j * z.angle()) + else: + # identity + out = z + + return out diff --git a/src/models/sfno/contractions.py b/src/models/sfno/contractions.py new file mode 100644 index 0000000..d55d125 --- /dev/null +++ b/src/models/sfno/contractions.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + + +@torch.jit.script +def compl_mul1d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex-valued multiplication operation between two 1-dimensional + tensors. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bix,io->box", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_muladd1d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs complex multiplication of two 1-dimensional tensors 'a' and 'b', and then + adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_mul1d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script +def compl_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex-valued multiplication operation between two 2-dimensional + tensors. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,io->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs complex multiplication of two 2-dimensional tensors 'a' and 'b', and then + adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_mul2d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script # TODO remove +def _contract_localconv_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex local convolution operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script # TODO remove +def _contract_blockconv_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex block convolution operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bim,imn->bin", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script # TODO remove +def _contractadd_blockconv_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex block convolution operation between two tensors 'a' and 'b', and + then adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(_contract_blockconv_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +# for the experimental layer +@torch.jit.script # TODO remove +def compl_exp_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D complex multiplication operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,xio->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def compl_exp_muladd2d_fwd( # TODO remove + a: torch.Tensor, b: torch.Tensor, c: torch.Tensor +) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D complex multiplication operation between two tensors 'a' and 'b', + and then adds a third tensor 'c'. + """ + tmpcc = torch.view_as_complex(compl_exp_mul2d_fwd(a, b)) + cc = torch.view_as_complex(c) + return torch.view_as_real(tmpcc + cc) + + +@torch.jit.script +def real_mul2d_fwd(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D real multiplication operation between two tensors 'a' and 'b'. + """ + res = torch.einsum("bixy,io->boxy", a, b) + return res + + +@torch.jit.script +def real_muladd2d_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a 2D real multiplication operation between two tensors 'a' and 'b', and + then adds a third tensor 'c'. + """ + res = real_mul2d_fwd(a, b) + c + return res + + +# new contractions set to replace older ones. We use complex +@torch.jit.script +def _contract_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex diagonal operation between two tensors 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ioxy->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex Driscoll-Healy style convolution operation between two tensors + 'a' and 'b'. + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,iox->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_sep_diagonal(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex convolution operation between two tensors 'a' and 'b' + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ixy->boxy", ac, bc) + res = torch.view_as_real(resc) + return res + + +@torch.jit.script +def _contract_sep_dhconv(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: # pragma: no cover + """ + Performs a complex convolution operation between two tensors 'a' and 'b' + """ + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + resc = torch.einsum("bixy,ix->boxy", ac, bc) + res = torch.view_as_real(resc) + return res diff --git a/src/models/sfno/distributed/__init__.py b/src/models/sfno/distributed/__init__.py new file mode 100644 index 0000000..dbfe137 --- /dev/null +++ b/src/models/sfno/distributed/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/src/models/sfno/distributed/comm.py b/src/models/sfno/distributed/comm.py new file mode 100644 index 0000000..c58ced7 --- /dev/null +++ b/src/models/sfno/distributed/comm.py @@ -0,0 +1,314 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime as dt +import logging +import math +import os +from typing import Union + +import numpy as np +import torch +import torch.distributed as dist +from modulus.utils.sfno.logging_utils import disable_logging + + +# dummy placeholders +_COMM_LIST = [] +_COMM_NAMES = {} + + +# world comm +def get_size(comm_id: Union[str, int]) -> int: # pragma: no cover + """Returns the size of a specified communicator.""" + if isinstance(comm_id, int): + cid = comm_id + else: + cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) + + if not dist.is_initialized() or (cid >= len(_COMM_LIST)): + return 1 + else: + return dist.get_world_size(group=_COMM_LIST[cid]) + + +def get_rank(comm_id: Union[str, int]) -> int: # pragma: no cover + """Returns the rank of a specified communicator.""" + if isinstance(comm_id, int): + cid = comm_id + else: + cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) + + if not dist.is_initialized() or (cid >= len(_COMM_LIST)): + return 0 + else: + return dist.get_rank(group=_COMM_LIST[cid]) + + +def get_group(comm_id: Union[str, int]) -> int: # pragma: no cover + """Returns the group of a specified communicator.""" + if isinstance(comm_id, int): + cid = comm_id + else: + cid = _COMM_NAMES[comm_id] if (comm_id in _COMM_NAMES) else len(_COMM_LIST) + + if not dist.is_initialized() or (cid >= len(_COMM_LIST)): + raise IndexError(f"Error, comm with id {comm_id} not available.") + else: + return _COMM_LIST[cid] + + +# specialized routines for world comms +def get_world_size(): # pragma: no cover + """Returns the world size""" + if not dist.is_initialized(): + return 1 + else: + return dist.get_world_size() + + +def get_world_rank(): # pragma: no cover + """Returns the world rank""" + if not dist.is_initialized(): + return 0 + else: + return dist.get_rank() + + +def get_local_rank(): # pragma: no cover + """Returns the local rank of the current process.""" + if os.getenv("LOCAL_RANK") is not None: + # Use PyTorch env var if available + return int(os.getenv("LOCAL_RANK")) + + if not dist.is_initialized(): + return 0 + else: + return get_world_rank() % torch.cuda.device_count() + + +def get_names(): # pragma: no cover + """Returns the names of all available communicators.""" + return _COMM_NAMES + + +def is_distributed(name: str): # pragma: no cover + """check if distributed.""" + return name in _COMM_NAMES + + +# get +def init(params, verbose=False): # pragma: no cover + """Initialize distributed training.""" + # set up global and local communicator + if params.wireup_info == "env": + world_size = int(os.getenv("WORLD_SIZE", 1)) + world_rank = int(os.getenv("RANK", 0)) + if os.getenv("WORLD_RANK") is not None: + # Use WORLD_RANK if available for backwards compatibility + world_rank = int(os.getenv("WORLD_RANK")) + port = int(os.getenv("MASTER_PORT", 0)) + master_address = os.getenv("MASTER_ADDR") + if os.getenv("MASTER_ADDRESS") is not None: + # Use MASTER_ADDRESS if available for backwards compatibility + master_address = int(os.getenv("MASTER_ADDRESS")) + elif params.wireup_info == "mpi": + import socket + + from mpi4py import MPI + + mpi_comm = MPI.COMM_WORLD.Dup() + world_size = mpi_comm.Get_size() + world_rank = mpi_comm.Get_rank() + my_host = socket.gethostname() + port = 29500 + master_address = None + if world_rank == 0: + master_address_info = socket.getaddrinfo(my_host, port, family=socket.AF_INET, proto=socket.IPPROTO_TCP) + master_address = master_address_info[0][-1][0] + master_address = mpi_comm.bcast(master_address, root=0) + os.environ["MASTER_ADDRESS"] = master_address + os.environ["MASTER_PORT"] = str(port) + else: + raise ValueError(f"Error, wireup-info {params.wireup_info} not supported") + # set local rank to 0 if env var not available + local_rank = int(os.getenv("LOCAL_RANK", 0)) + + if world_size > 1: + with disable_logging(): + if params.wireup_store == "file": + wireup_file_path = os.getenv("WIREUP_FILE_PATH") + wireup_store = dist.FileStore(wireup_file_path, world_size) + elif params.wireup_store == "tcp": + # create tcp store + wireup_store = dist.TCPStore( + host_name=master_address, + port=port, + world_size=world_size, + is_master=(world_rank == 0), + timeout=dt.timedelta(seconds=900), + ) + else: + wireup_store = None + + # initialize process groups + dist.init_process_group( + backend="nccl", + rank=world_rank, + world_size=world_size, + store=wireup_store, + ) + + # get sizes + world_size = get_world_size() + world_rank = get_world_rank() + local_rank = get_local_rank() + + # barrier + dist.barrier(device_ids=[local_rank]) + + # do individual wireup for model parallel comms: + if hasattr(params, "model_parallel_sizes"): + model_parallel_sizes = params.model_parallel_sizes + else: + model_parallel_sizes = [1] + + if hasattr(params, "model_parallel_names"): + model_parallel_names = params.model_parallel_names + else: + model_parallel_names = ["model"] + assert len(model_parallel_names) == len(model_parallel_sizes), "Please specify names for your communicators" + model_parallel_size = math.prod(model_parallel_sizes) + params["model_parallel_size"] = model_parallel_size + + assert ( + world_size % model_parallel_size == 0 + ), "Error, please make sure that the product of model parallel ranks evenly divides the total number of ranks" + + # we set this to be orthogonal to the MP groups + # we can play tricks with the ddp_group later, in case if all the weights are shared + data_parallel_size = world_size // model_parallel_size + + # create orthogonal communicators first + global _COMM_LIST + global _COMM_NAMES + if params.log_to_screen: + logging.info("Starting Wireup") + + if world_size > 1: + # set up the strides: + model_parallel_sizes[::-1] + model_grid = np.reshape(np.arange(0, model_parallel_size), model_parallel_sizes[::-1]) + perm = np.roll(np.arange(0, len(model_parallel_sizes)), 1).tolist() + ranks_lookup = {} + + comm_count = 0 + for mpname in model_parallel_names: + base_group = np.reshape(model_grid, (-1, model_grid.shape[-1])) + model_groups = [] + for goffset in range(0, world_size, model_parallel_size): + model_groups += sorted((goffset + base_group).tolist()) + + if verbose and world_rank == 0: + print(f"Creating comm groups for id {mpname}: {model_groups}") + + for grp in model_groups: + if len(grp) > 1: + tmp_group = dist.new_group(ranks=grp) + if world_rank in grp: + _COMM_LIST.append(tmp_group) + _COMM_NAMES[mpname] = comm_count + comm_count += 1 + ranks_lookup[mpname] = model_groups + + # go for the next step + model_grid = np.transpose(model_grid, perm) + + # now, we create a single communicator for h and w ranks + if (get_size("h") == 1) and (get_size("w") > 1): + if verbose and world_rank == 0: + print(f'Creating comm groups for id spatial: {ranks_lookup["w"]}') + _COMM_LIST.append(get_group("w")) + _COMM_NAMES["spatial"] = comm_count + comm_count += 1 + elif (get_size("h") > 1) and (get_size("w") == 1): + if verbose and world_rank == 0: + print(f'Creating comm groups for id spatial: {ranks_lookup["h"]}') + _COMM_LIST.append(get_group("h")) + _COMM_NAMES["spatial"] = comm_count + comm_count += 1 + elif (get_size("h") > 1) and (get_size("w") > 1): + # fuse the lists: + def merge_ranks(list1, list2): + """Merge ranks""" + coll = list1 + list2 + pooled = [set(subList) for subList in coll] + merging = True + while merging: + merging = False + for i, group in enumerate(pooled): + merged = next((g for g in pooled[i + 1 :] if g.intersection(group)), None) + if not merged: + continue + group.update(merged) + pooled.remove(merged) + merging = True + return [list(x) for x in pooled] + + model_groups = merge_ranks(ranks_lookup["h"], ranks_lookup["w"]) + if verbose and world_rank == 0: + print(f"Creating comm groups for id spatial: {model_groups}") + for grp in model_groups: + tmp_group = dist.new_group(ranks=grp) + if world_rank in grp: + _COMM_LIST.append(tmp_group) + _COMM_NAMES["spatial"] = comm_count + comm_count += 1 + + # now the data and model comm: + model_groups = np.reshape(np.arange(0, world_size), (-1, model_parallel_size)).tolist() + for grp in model_groups: + if len(grp) > 1: + tmp_group = dist.new_group(ranks=grp) + if world_rank in grp: + _COMM_LIST.append(tmp_group) + _COMM_NAMES["model"] = comm_count + comm_count += 1 + + if data_parallel_size == world_size: + if verbose and world_rank == 0: + print(f"Creating comm groups for id data: {[list(range(0, world_size))]}") + + _COMM_LIST.append(None) + _COMM_NAMES["data"] = comm_count + else: + data_groups = [sorted(list(i)) for i in zip(*model_groups)] + + if verbose and world_rank == 0: + print(f"Creating comm groups for id data: {data_groups}") + + for grp in data_groups: + tmp_group = dist.new_group(ranks=grp) + if world_rank in grp: + _COMM_LIST.append(tmp_group) + _COMM_NAMES["data"] = comm_count + + # barrier + if dist.is_initialized(): + dist.barrier(device_ids=[local_rank]) + + if params.log_to_screen: + logging.info("Finished Wireup") + + return diff --git a/src/models/sfno/distributed/helpers.py b/src/models/sfno/distributed/helpers.py new file mode 100644 index 0000000..f564a62 --- /dev/null +++ b/src/models/sfno/distributed/helpers.py @@ -0,0 +1,194 @@ +# ignore_header_test + +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.distributed as dist +import torch.nn.functional as F +from modulus.utils.sfno.distributed import comm + + +def get_memory_format(tensor): # pragma: no cover + """Helper routine to get the memory format""" + if tensor.is_contiguous(memory_format=torch.channels_last): + return torch.channels_last + else: + return torch.contiguous_format + + +def sync_params(model, mode="broadcast"): # pragma: no cover + """Helper routine to ensure shared weights are the same after initialization""" + + non_singleton_group_names = [ + x for x in comm.get_names() if (comm.get_size(x) > 1) and x not in ["data", "model", "spatial"] + ] + + with torch.no_grad(): + # distributed sync step + for param in model.parameters(): + if not hasattr(param, "is_shared_mp"): + param.is_shared_mp = non_singleton_group_names.copy() + + for comm_group in param.is_shared_mp: + if comm.get_size(comm_group) > 1: + if mode == "broadcast": + tlist = [torch.empty_like(param) for x in range(comm.get_size(comm_group))] + tlist[comm.get_rank(comm_group)] = param + # gather all weights in the comm group + dist.all_gather(tlist, param, group=comm.get_group(comm_group)) + # use weight of rank 0 + # important to use copy here otherwise the handle gets detaches from the optimizer + param.copy_(tlist[0]) + elif mode == "mean": + # coalesced = _flatten_dense_tensors(param) + dist.all_reduce( + param, + op=dist.ReduceOp.AVG, + group=comm.get_group(comm_group), + async_op=False, + ) + # param.copy_(coalesced) + else: + raise ValueError(f"Unknown weight synchronization mode {mode}") + + +def pad_helper(tensor, dim, new_size, mode="zero"): # pragma: no cover + """Helper routine to pad a tensor along a given dimension""" + ndim = tensor.ndim + dim = (dim + ndim) % ndim + ndim_pad = ndim - dim + output_shape = [0 for _ in range(2 * ndim_pad)] + orig_size = tensor.shape[dim] + output_shape[1] = new_size - orig_size + tensor_pad = F.pad(tensor, output_shape, mode="constant", value=0.0) + + if mode == "conj": + lhs_slice = [slice(0, x) if idx != dim else slice(orig_size, new_size) for idx, x in enumerate(tensor.shape)] + rhs_slice = [ + slice(0, x) if idx != dim else slice(1, output_shape[1] + 1) for idx, x in enumerate(tensor.shape) + ] + tensor_pad[lhs_slice] = torch.flip(torch.conj(tensor_pad[rhs_slice]), dims=[dim]) + + return tensor_pad + + +def truncate_helper(tensor, dim, new_size): # pragma: no cover + """Helper routine to truncate a tensor along a given dimension""" + input_format = get_memory_format(tensor) + ndim = tensor.ndim + dim = (dim + ndim) % ndim + output_slice = [slice(0, x) if idx != dim else slice(0, new_size) for idx, x in enumerate(tensor.shape)] + tensor_trunc = tensor[output_slice].contiguous(memory_format=input_format) + + return tensor_trunc + + +def split_tensor_along_dim(tensor, dim, num_chunks): # pragma: no cover + """Helper routine to split a tensor along a given dimension""" + assert dim < tensor.dim(), f"Error, tensor dimension is {tensor.dim()} which cannot be split along {dim}" + assert ( + tensor.shape[dim] % num_chunks == 0 + ), f"Error, cannot split dim {dim} evenly. Dim size is \ + {tensor.shape[dim]} and requested numnber of splits is {num_chunks}" + chunk_size = tensor.shape[dim] // num_chunks + tensor_list = torch.split(tensor, chunk_size, dim=dim) + + return tensor_list + + +# distributed primitives +def _transpose(tensor, dim0, dim1, group=None, async_op=False): # pragma: no cover + """Transpose a tensor across model parallel group.""" + # get input format + input_format = get_memory_format(tensor) + + # get comm params + comm_size = dist.get_world_size(group=group) + + # split and local transposition + split_size = tensor.shape[dim0] // comm_size + x_send = [y.contiguous(memory_format=input_format) for y in torch.split(tensor, split_size, dim=dim0)] + x_recv = [torch.empty_like(x_send[0]) for _ in range(comm_size)] + + # global transposition + req = dist.all_to_all(x_recv, x_send, group=group, async_op=async_op) + + return x_recv, req + + +def _reduce(input_, use_fp32=True, group=None): # pragma: no cover + """All-reduce the input tensor across model parallel group.""" + + # Bypass the function if we are using only 1 GPU. + if dist.get_world_size(group=group) == 1: + return input_ + + # All-reduce. + if use_fp32: + dtype = input_.dtype + inputf_ = input_.float() + dist.all_reduce(inputf_, group=group) + input_ = inputf_.to(dtype) + else: + dist.all_reduce(input_, group=group) + + return input_ + + +def _split(input_, dim_, group=None): # pragma: no cover + """Split the tensor along its last dimension and keep the corresponding slice.""" + # get input format + input_format = get_memory_format(input_) + + # Bypass the function if we are using only 1 GPU. + comm_size = dist.get_world_size(group=group) + if comm_size == 1: + return input_ + + # Split along last dimension. + input_list = split_tensor_along_dim(input_, dim_, comm_size) + + # Note: torch.split does not create contiguous tensors by default. + rank = dist.get_rank(group=group) + output = input_list[rank].contiguous(memory_format=input_format) + + return output + + +def _gather(input_, dim_, group=None): # pragma: no cover + """Gather tensors and concatinate along the last dimension.""" + # get input format + input_format = get_memory_format(input_) + + comm_size = dist.get_world_size(group=group) + # Bypass the function if we are using only 1 GPU. + if comm_size == 1: + return input_ + + # sanity checks + assert dim_ < input_.dim(), f"Error, cannot gather along {dim_} for tensor with {input_.dim()} dimensions." + + # Size and dimension. + comm_rank = dist.get_rank(group=group) + + input_ = input_.contiguous(memory_format=input_format) + tensor_list = [torch.empty_like(input_) for _ in range(comm_size)] + tensor_list[comm_rank] = input_ + dist.all_gather(tensor_list, input_, group=group) + + # Note: torch.cat already creates a contiguous tensor. + output = torch.cat(tensor_list, dim=dim_).contiguous(memory_format=input_format) + + return output diff --git a/src/models/sfno/distributed/layer_norm.py b/src/models/sfno/distributed/layer_norm.py new file mode 100644 index 0000000..67c4447 --- /dev/null +++ b/src/models/sfno/distributed/layer_norm.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Tuple + +import torch +import torch.nn as nn + +# for spatial model-parallelism +from modulus.utils.sfno.distributed import comm +from modulus.utils.sfno.distributed.mappings import ( + copy_to_spatial_parallel_region, + gather_from_parallel_region, +) +from torch.cuda import amp + + +class DistributedInstanceNorm2d(nn.Module): + """ + Computes a distributed instance norm using Welford's online algorithm + """ + + def __init__(self, num_features, eps=1e-05, affine=False, device=None, dtype=None): # pragma: no cover + super(DistributedInstanceNorm2d, self).__init__() + + self.eps = eps + self.affine = affine + if self.affine: + self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1)) + self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1)) + self.weight.is_shared_mp = ["h", "w"] + self.bias.is_shared_mp = ["h", "w"] + + self.gather_mode = "welford" + + @torch.jit.ignore + def _gather_hw(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + # gather the data over the spatial communicator + xh = gather_from_parallel_region(x, -2, "h") + xw = gather_from_parallel_region(xh, -1, "w") + return xw + + @torch.jit.ignore + def _gather_spatial(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + # gather the data over the spatial communicator + xs = gather_from_parallel_region(x, -1, "spatial") + return xs + + def _stats_naive(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover + """Computes the statistics in the naive way by first gathering the tensors and then computing them""" + + x = self._gather_hw(x) + var, mean = torch.var_mean(x, dim=(-2, -1), unbiased=False, keepdim=True) + + return var, mean + + def _stats_welford(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # pragma: no cover + """Computes the statistics locally, then uses the Welford online algorithm to reduce them""" + + var, mean = torch.var_mean(x, dim=(-2, -1), unbiased=False, keepdim=False) + # workaround to not use shapes, as otherwise cuda graphs won't_i_next work + count = torch.ones_like(x[0, 0], requires_grad=False) + count = torch.sum(count, dim=(-2, -1), keepdim=False) + + vars = self._gather_spatial(var.unsqueeze(-1)) + means = self._gather_spatial(mean.unsqueeze(-1)) + counts = self._gather_spatial(count.unsqueeze(-1)) + + m2s = vars * counts + + mean = means[..., 0] + m2 = m2s[..., 0] + count = counts[..., 0] + + # use Welford's algorithm to accumulate them into a single mean and variance + for i in range(1, comm.get_size("spatial")): + delta = means[..., i] - mean + m2 = m2 + m2s[..., i] + delta**2 * count * counts[..., i] / (count + counts[..., i]) + if i == 1: + mean = (mean * count + means[..., i] * counts[..., i]) / (count + counts[..., i]) + else: + mean = mean + delta * counts[..., i] / (count + counts[..., i]) + + # update the current count + count = count + counts[..., i] + + var = m2 / count + + var = var.reshape(1, -1, 1, 1) + mean = mean.reshape(1, -1, 1, 1) + + return var, mean + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + with amp.autocast(enabled=False): + dtype = x.dtype + x = x.float() + + # start by computing std and mean + if self.gather_mode == "naive": + var, mean = self._stats_naive(x) + elif self.gather_mode == "welford": + var, mean = self._stats_welford(x) + else: + raise ValueError(f"Unknown gather mode {self.gather_mode}") + + # this is absolutely necessary to get the correct graph in the backward pass + mean = copy_to_spatial_parallel_region(mean) + var = copy_to_spatial_parallel_region(var) + + x = x.to(dtype) + mean = mean.to(dtype) + var = var.to(dtype) + + # apply the normalization + x = (x - mean) / torch.sqrt(var + self.eps) + + # affine transform if we use it + if self.affine: + x = self.weight * x + self.bias + + return x diff --git a/src/models/sfno/distributed/layers.py b/src/models/sfno/distributed/layers.py new file mode 100644 index 0000000..d972d70 --- /dev/null +++ b/src/models/sfno/distributed/layers.py @@ -0,0 +1,539 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from modulus.models.sfno.initialization import trunc_normal_ +from modulus.utils.sfno.distributed import comm +from modulus.utils.sfno.distributed.helpers import _transpose + +# matmul parallel +# spatial parallel +from modulus.utils.sfno.distributed.mappings import ( + copy_to_matmul_parallel_region, + gather_from_matmul_parallel_region, + gather_from_spatial_parallel_region, + reduce_from_matmul_parallel_region, + scatter_to_matmul_parallel_region, + scatter_to_spatial_parallel_region, +) + + +class distributed_transpose_w(torch.autograd.Function): + """Distributed transpose""" + + @staticmethod + def forward(ctx, x, dim): # pragma: no cover + xlist, _ = _transpose(x, dim[0], dim[1], group=comm.get_group("w")) + x = torch.cat(xlist, dim=dim[1]) + ctx.dim = dim + return x + + @staticmethod + def backward(ctx, go): # pragma: no cover + dim = ctx.dim + gilist, _ = _transpose(go, dim[1], dim[0], group=comm.get_group("w")) + gi = torch.cat(gilist, dim=dim[0]) + return gi, None + + +class distributed_transpose_h(torch.autograd.Function): + """Distributed transpose""" + + @staticmethod + def forward(ctx, x, dim): # pragma: no cover + xlist, _ = _transpose(x, dim[0], dim[1], group=comm.get_group("h")) + x = torch.cat(xlist, dim=dim[1]) + ctx.dim = dim + return x + + @staticmethod + def backward(ctx, go): # pragma: no cover + dim = ctx.dim + gilist, _ = _transpose(go, dim[1], dim[0], group=comm.get_group("h")) + gi = torch.cat(gilist, dim=dim[0]) + return gi, None + + +class DistributedRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(DistributedRealFFT2, self).__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + # frequency paddings + ldist = (self.lmax + self.comm_size_h - 1) // self.comm_size_h + self.lpad = ldist * self.comm_size_polar - self.lmax + mdist = (self.mmax + self.comm_size_w - 1) // self.comm_size_w + self.mpad = mdist * self.comm_size_w - self.mmax + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + # we need to ensure that we can split the channels evenly + assert x.shape[1] % self.comm_size_h == 0 + assert x.shape[1] % self.comm_size_w == 0 + + # h and w is split. First we make w local by transposing into channel dim + if self.comm_size_w > 1: + xt = distributed_transpose_w.apply(x, (1, -1)) + else: + xt = x + + # do first FFT + xtf = torch.fft.rfft(xt, n=self.nlon, dim=-1, norm="ortho") + + # truncate + xtft = xtf[..., : self.mmax] + + # pad the dim to allow for splitting + xtfp = F.pad(xtft, [0, self.mpad], mode="constant") + + # transpose: after this, m is split and c is local + if self.comm_size_w > 1: + y = distributed_transpose_w.apply(xtfp, (-1, 1)) + else: + y = xtfp + + # transpose: after this, c is split and h is local + if self.comm_size_h > 1: + yt = distributed_transpose_h.apply(y, (1, -2)) + else: + yt = y + + # the input data might be padded, make sure to truncate to nlat: + # ytt = yt[..., :self.nlat, :] + + # do second FFT: + yo = torch.fft.fft(yt, n=self.nlat, dim=-2, norm="ortho") + + # pad if required, truncation is implicit + yop = F.pad(yo, [0, 0, 0, self.lpad], mode="constant") + + # transpose: after this, l is split and c is local + if self.comm_size_h > 1: + y = distributed_transpose_h.apply(yop, (-2, 1)) + else: + y = yop + + return y + + +class DistributedInverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(DistributedInverseRealFFT2, self).__init__() + + # get the comms grid: + self.comm_size_h = comm.get_size("h") + self.comm_size_w = comm.get_size("w") + self.comm_rank_w = comm.get_rank("w") + + # dimensions + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + # spatial paddings + latdist = (self.nlat + self.comm_size_h - 1) // self.comm_size_h + self.latpad = latdist * self.comm_size_h - self.nlat + londist = (self.nlon + self.comm_size_w - 1) // self.comm_size_w + self.lonpad = londist * self.comm_size_w - self.nlon + + # frequency paddings + ldist = (self.lmax + self.comm_size_h - 1) // self.comm_size_h + self.lpad = ldist * self.comm_size_h - self.lmax + mdist = (self.mmax + self.comm_size_w - 1) // self.comm_size_w + self.mpad = mdist * self.comm_size_w - self.mmax + + def forward(self, x: torch.Tensor) -> torch.Tensor: # pragma: no cover + # we need to ensure that we can split the channels evenly + assert x.shape[1] % self.comm_size_h == 0 + assert x.shape[1] % self.comm_size_w == 0 + + # transpose: after that, channels are split, l is local: + if self.comm_size_h > 1: + xt = distributed_transpose_h.apply(x, (1, -2)) + else: + xt = x + + # truncate + xtt = xt[..., : self.lmax, :] + + # do first fft + xf = torch.fft.ifft(xtt, n=self.nlat, dim=-2, norm="ortho") + + # transpose: after this, l is split and channels are local + xfp = F.pad(xf, [0, 0, 0, self.latpad]) + + if self.comm_size_h > 1: + y = distributed_transpose_h.apply(xfp, (-2, 1)) + else: + y = xfp + + # transpose: after this, channels are split and m is local + if self.comm_size_w > 1: + yt = distributed_transpose_w.apply(y, (1, -1)) + else: + yt = y + + # truncate + ytt = yt[..., : self.mmax] + + # apply the inverse (real) FFT + x = torch.fft.irfft(ytt, n=self.nlon, dim=-1, norm="ortho") + + # pad before we transpose back + xp = F.pad(x, [0, self.lonpad]) + + # transpose: after this, m is split and channels are local + if self.comm_size_w > 1: + out = distributed_transpose_w.apply(xp, (-1, 1)) + else: + out = xp + + return out + + +# more complicated layers +class DistributedMLP(nn.Module): + """Distributed MLP layer""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + output_bias=True, + act_layer=nn.GELU, + drop_rate=0.0, + checkpointing=False, + ): # pragma: no cover + super(DistributedMLP, self).__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + # get effective embedding size: + comm_size = comm.get_size("matmul") + assert hidden_features % comm_size == 0, "Error, hidden_features needs to be divisible by matmul_parallel_size" + hidden_features_local = hidden_features // comm_size + + # first set of hp + self.w1 = nn.Parameter(torch.ones(hidden_features_local, in_features, 1, 1)) + self.b1 = nn.Parameter(torch.zeros(hidden_features_local)) + + # second set of hp + self.w2 = nn.Parameter(torch.ones(out_features, hidden_features_local, 1, 1)) + + if output_bias: + self.b2 = nn.Parameter(torch.zeros(out_features)) + + self.act = act_layer() + self.drop = nn.Dropout(drop) if drop_rate > 0.0 else nn.Identity() + + # the weights are shared spatially + self.w1.is_shared_mp = ["h", "w"] + self.b1.is_shared_mp = ["h", "w"] + self.w2.is_shared_mp = ["h", "w"] + if output_bias: + self.b2.is_shared_mp = [ + "matmul", + "h", + "w", + ] # this one is shared between all ranks + + # init weights + self._init_weights() + + def _init_weights(self): # pragma: no cover + trunc_normal_(self.w1, std=0.02) + nn.init.constant_(self.b1, 0.0) + trunc_normal_(self.w2, std=0.02) + if hasattr(self, "b2"): + nn.init.constant_(self.b2, 0.0) + + def fwd(self, x): # pragma: no cover + """Forward function.""" + # we need to prepare paralellism here + # spatial parallelism + x = scatter_to_spatial_parallel_region(x, dim=-1) + + # prepare the matmul parallel part + x = copy_to_matmul_parallel_region(x) + + # do the mlp + x = F.conv2d(x, self.w1, bias=self.b1) + x = self.act(x) + x = self.drop(x) + x = F.conv2d(x, self.w2, bias=None) + x = reduce_from_matmul_parallel_region(x) + if hasattr(self, "b2"): + x = x + torch.reshape(self.b2, (1, -1, 1, 1)) + x = self.drop(x) + + # gather from spatial parallel region + x = gather_from_spatial_parallel_region(x, dim=-1) + + return x + + @torch.jit.ignore + def _checkpoint_forward(self, x): # pragma: no cover + return checkpoint(self.fwd, x) + + def forward(self, x): # pragma: no cover + if self.checkpointing: + return self._checkpoint_forward(x) + else: + return self.fwd(x) + + +class DistributedPatchEmbed(nn.Module): + """Distributed patch embedding layer""" + + def __init__( + self, + img_size=(224, 224), + patch_size=(16, 16), + in_chans=3, + embed_dim=768, + input_is_matmul_parallel=False, + output_is_matmul_parallel=True, + ): # pragma: no cover + super(DistributedPatchEmbed, self).__init__() + + # store params + self.input_parallel = input_is_matmul_parallel + self.output_parallel = output_is_matmul_parallel + + # get comm sizes: + matmul_comm_size = comm.get_size("matmul") + spatial_comm_size = comm.get_size("spatial") + + # compute parameters + assert ( + img_size[1] // patch_size[1] + ) % spatial_comm_size == 0, "Error, make sure that the spatial comm size evenly divides patched W" + num_patches = ((img_size[1] // patch_size[1]) // spatial_comm_size) * (img_size[0] // patch_size[0]) + self.img_size = (img_size[0], img_size[1] // spatial_comm_size) + self.patch_size = patch_size + self.num_patches = num_patches + + # get effective embedding size: + if self.output_parallel: + assert ( + embed_dim % matmul_comm_size == 0 + ), "Error, the embed_dim needs to be divisible by matmul_parallel_size" + out_chans_local = embed_dim // matmul_comm_size + else: + out_chans_local = embed_dim + + # the weights of this layer is shared across spatial parallel ranks + self.proj = nn.Conv2d(in_chans, out_chans_local, kernel_size=patch_size, stride=patch_size) + + # make sure we reduce them across rank + self.proj.weight.is_shared_mp = ["h", "w"] + self.proj.bias.is_shared_mp = ["h", "w"] + + def forward(self, x): # pragma: no cover + if self.input_parallel: + x = gather_from_matmul_parallel_region(x, dim=1) + + if self.output_parallel: + x = copy_to_matmul_parallel_region(x) + + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +@torch.jit.script +def compl_mul_add_fwd(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """complex multiplication and addition""" + tmp = torch.einsum("bkixys,kiot->stbkoxy", a, b) + res = torch.stack([tmp[0, 0, ...] - tmp[1, 1, ...], tmp[1, 0, ...] + tmp[0, 1, ...]], dim=-1) + c + return res + + +@torch.jit.script +def compl_mul_add_fwd_c(a: torch.Tensor, b: torch.Tensor, c: torch.Tensor) -> torch.Tensor: # pragma: no cover + """Performs a complex multiplication and addition operation on three tensors""" + ac = torch.view_as_complex(a) + bc = torch.view_as_complex(b) + cc = torch.view_as_complex(c) + tmp = torch.einsum("bkixy,kio->bkoxy", ac, bc) + res = tmp + cc + return torch.view_as_real(res) + + +class DistributedAFNO2Dv2(nn.Module): + """Distributed AFNO""" + + def __init__( + self, + hidden_size, + num_blocks=8, + sparsity_threshold=0.01, + hard_thresholding_fraction=1, + hidden_size_factor=1, + input_is_matmul_parallel=False, + output_is_matmul_parallel=False, + use_complex_kernels=False, + ): # pragma: no cover + """Distributed AFNO2Dv2""" + super(DistributedAFNO2Dv2, self).__init__() + assert ( + hidden_size % num_blocks == 0 + ), f"hidden_size {hidden_size} should be divisble by num_blocks {num_blocks}" + + # get comm sizes: + matmul_comm_size = comm.get_size("matmul") + self.spatial_comm_size = comm.get_size("spatial") + + # select fft function handles + if self.spatial_comm_size > 1: + self.fft_handle = distributed_rfft2.apply + self.ifft_handle = distributed_irfft2.apply + else: + self.fft_handle = torch.fft.rfft2 + self.ifft_handle = torch.fft.irfft2 + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.num_blocks = num_blocks + assert ( + self.num_blocks % matmul_comm_size == 0 + ), "Error, num_blocks needs to be divisible by matmul_parallel_size" + self.num_blocks_local = self.num_blocks // matmul_comm_size + self.block_size = self.hidden_size // self.num_blocks + self.hard_thresholding_fraction = hard_thresholding_fraction + self.hidden_size_factor = hidden_size_factor + self.scale = 0.02 + self.mult_handle = compl_mul_add_fwd_c if use_complex_kernels else compl_mul_add_fwd + + # model paralellism + self.input_is_matmul_parallel = input_is_matmul_parallel + self.output_is_matmul_parallel = output_is_matmul_parallel + + # new + # these weights need to be synced across all spatial ranks! + self.w1 = nn.Parameter( + self.scale + * torch.randn( + self.num_blocks_local, + self.block_size, + self.block_size * self.hidden_size_factor, + 2, + ) + ) + self.b1 = nn.Parameter( + self.scale + * torch.randn( + self.num_blocks_local, + self.block_size * self.hidden_size_factor, + 1, + 1, + 2, + ) + ) + self.w2 = nn.Parameter( + self.scale + * torch.randn( + self.num_blocks_local, + self.block_size * self.hidden_size_factor, + self.block_size, + 2, + ) + ) + self.b2 = nn.Parameter(self.scale * torch.randn(self.num_blocks_local, self.block_size, 1, 1, 2)) + + # make sure we reduce them across rank + self.w1.is_shared_mp = ["h", "w"] + self.b1.is_shared_mp = ["h", "w"] + self.w2.is_shared_mp = ["h", "w"] + self.b2.is_shared_mp = ["h", "w"] + + def forward(self, x): # pragma: no cover + if not self.input_is_matmul_parallel: + # distribute data + x = scatter_to_matmul_parallel_region(x, dim=1) + + # bias + bias = x + + dtype = x.dtype + x = x.float() + B, C, H, W_local = x.shape + total_modes = H // 2 + 1 + kept_modes = int(total_modes * self.hard_thresholding_fraction) + + H_local = H // self.spatial_comm_size + W = W_local * self.spatial_comm_size + x = self.fft_handle(x, (H, W), (-2, -1), "ortho") + x = x.view(B, self.num_blocks_local, self.block_size, H_local, W // 2 + 1) + + # new + x = torch.view_as_real(x) + o2 = torch.zeros(x.shape, device=x.device) + + o1 = F.relu( + self.mult_handle( + x[ + :, + :, + :, + total_modes - kept_modes : total_modes + kept_modes, + :kept_modes, + :, + ], + self.w1, + self.b1, + ) + ) + o2[:, :, :, total_modes - kept_modes : total_modes + kept_modes, :kept_modes, :] = self.mult_handle( + o1, self.w2, self.b2 + ) + + # finalize + x = F.softshrink(o2, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + x = x.reshape(B, C, H_local, W // 2 + 1) + x = self.ifft_handle(x, (H, W), (-2, -1), "ortho") + x = x.type(dtype) + bias + + # gather + if not self.output_is_matmul_parallel: + x = gather_from_matmul_parallel_region(x, dim=1) + + return x diff --git a/src/models/sfno/distributed/mappings.py b/src/models/sfno/distributed/mappings.py new file mode 100644 index 0000000..3fa6a8c --- /dev/null +++ b/src/models/sfno/distributed/mappings.py @@ -0,0 +1,340 @@ +# ignore_header_test + +# coding=utf-8 +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import torch +import torch.distributed as dist +from modulus.utils.sfno.distributed import comm + +# helper functions +from modulus.utils.sfno.distributed.helpers import _gather, _reduce, _split + +# torch utils +from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors +from torch.nn.parallel import DistributedDataParallel + + +# generalized +class _CopyToParallelRegion(torch.autograd.Function): + """Pass the input to the parallel region.""" + + @staticmethod + def symbolic(graph, input_, comm_id_): # pragma: no cover + """symbolic method""" + return input_ + + @staticmethod + def forward(ctx, input_, comm_id_): # pragma: no cover + ctx.comm_id = comm_id_ + return input_ + + @staticmethod + def backward(ctx, grad_output): + if comm.is_distributed(ctx.comm_id): # pragma: no cover + return _reduce(grad_output, group=comm.get_group(ctx.comm_id)), None + else: + return grad_output, None + + +class _ReduceFromParallelRegion(torch.autograd.Function): + """All-reduce the input from the parallel region.""" + + @staticmethod + def symbolic(graph, input_, comm_id_): # pragma: no cover + """symbolic method""" + if comm.is_distributed(comm_id_): + return _reduce(input_, group=comm.get_group(comm_id_)) + else: + return input_ + + @staticmethod + def forward(ctx, input_, comm_id_): # pragma: no cover + if comm.is_distributed(comm_id_): + return _reduce(input_, group=comm.get_group(comm_id_)) + else: + return input_ + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + return grad_output, None + + +class _ScatterToParallelRegion(torch.autograd.Function): + """Split the input and keep only the corresponding chuck to the rank.""" + + @staticmethod + def symbolic(graph, input_, dim_, comm_id_): # pragma: no cover + """symbolic method""" + return _split(input_, dim_, group=comm.get_group(comm_id_)) + + @staticmethod + def forward(ctx, input_, dim_, comm_id_): # pragma: no cover + ctx.dim = dim_ + ctx.comm_id = comm_id_ + if comm.is_distributed(comm_id_): + return _split(input_, dim_, group=comm.get_group(comm_id_)) + else: + return input_ + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + if comm.is_distributed(ctx.comm_id): + return ( + _gather(grad_output, ctx.dim, group=comm.get_group(ctx.comm_id)), + None, + None, + ) + else: + return grad_output, None, None + + +class _GatherFromParallelRegion(torch.autograd.Function): + """Gather the input from parallel region and concatenate.""" + + @staticmethod + def symbolic(graph, input_, dim_, comm_id_): # pragma: no cover + """""" + if comm.is_distributed(comm_id_): + return _gather(input_, dim_, group=comm.get_group(comm_id_)) + else: + return input_ + + @staticmethod + def forward(ctx, input_, dim_, comm_id_): # pragma: no cover + ctx.dim = dim_ + ctx.comm_id = comm_id_ + if comm.is_distributed(comm_id_): + return _gather(input_, dim_, group=comm.get_group(comm_id_)) + else: + return input_ + + @staticmethod + def backward(ctx, grad_output): # pragma: no cover + if comm.is_distributed(ctx.comm_id): + return ( + _split(grad_output, ctx.dim, group=comm.get_group(ctx.comm_id)), + None, + None, + ) + else: + return grad_output, None, None + + +# ----------------- +# Helper functions. +# ----------------- +# matmul parallel +def copy_to_matmul_parallel_region(input_): # pragma: no cover + """copy helper""" + return _CopyToParallelRegion.apply(input_, "matmul") + + +def reduce_from_matmul_parallel_region(input_): # pragma: no cover + """reduce helper""" + return _ReduceFromParallelRegion.apply(input_, "matmul") + + +def scatter_to_matmul_parallel_region(input_, dim): # pragma: no cover + """scatter helper""" + return _ScatterToParallelRegion.apply(input_, dim, "matmul") + + +def gather_from_matmul_parallel_region(input_, dim): # pragma: no cover + """gather helper""" + return _GatherFromParallelRegion.apply(input_, dim, "matmul") + + +# general +def reduce_from_parallel_region(input_, comm_name): # pragma: no cover + """reduce helper""" + return _ReduceFromParallelRegion.apply(input_, comm_name) + + +def scatter_to_parallel_region(input_, dim, comm_name): # pragma: no cover + """scatter helper""" + return _ScatterToParallelRegion.apply(input_, dim, comm_name) + + +def gather_from_parallel_region(input_, dim, comm_name): # pragma: no cover + """gather helper""" + return _GatherFromParallelRegion.apply(input_, dim, comm_name) + + +# def gather_within_matmul_parallel_region(input_, dim): +# return _GatherWithinMatmulParallelRegion.apply(input_, dim, "matmul") + + +# spatial parallel +def copy_to_spatial_parallel_region(input_): # pragma: no cover + """copy helper""" + return _CopyToParallelRegion.apply(input_, "spatial") + + +def scatter_to_spatial_parallel_region(input_, dim): # pragma: no cover + """scatter helper""" + return _ScatterToParallelRegion.apply(input_, dim, "spatial") + + +def gather_from_spatial_parallel_region(input_, dim): # pragma: no cover + """gather helper""" + return _GatherFromParallelRegion.apply(input_, dim, "spatial") + + +# handler for additional gradient reductions +# helper for gradient reduction across channel parallel ranks +def init_gradient_reduction_hooks( + model, + device_ids, + output_device, + bucket_cap_mb=25, + broadcast_buffers=True, + find_unused_parameters=False, + gradient_as_bucket_view=True, + static_graph=False, +): # pragma: no cover + """ + Initialize gradient reduction hooks for a given model. + """ + + # early exit if we are not in a distributed setting: + if not dist.is_initialized(): + return model + + # set this to false in init and then find out if we can use it: + need_hooks = False + ddp_group = comm.get_group("data") + + # this is the trivial case + if comm.get_size("model") == 1: + # the simple case, we can just continue then + ddp_group = None + else: + # check if there are shared weights, otherwise we can skip + non_singleton_group_names = [ + x for x in comm.get_names() if (comm.get_size(x) > 1) and x not in ["data", "model", "spatial"] + ] + num_shared = {x: 0 for x in non_singleton_group_names} + num_parameters = 0 + + # count parameters and reduction groups + for param in model.parameters(): + # if it does not have any annotation, we assume it is shared between all groups + if not hasattr(param, "is_shared_mp"): + param.is_shared_mp = non_singleton_group_names.copy() + + # check remaining groups + for group in non_singleton_group_names: + if group in param.is_shared_mp: + num_shared[group] += 1 + num_parameters += 1 + + # group without data: + num_param_shared_model = [v for k, v in num_shared.items()] + if not num_param_shared_model: + num_shared_model = 0 + else: + num_shared_model = sum(num_param_shared_model) + + # if all parameters are just data shared and not additionally shared orthogonally to that, we can use DDP + if num_shared_model == 0: + ddp_group = None + + elif all([(x == num_parameters) for x in num_param_shared_model]): + # in this case, we just need to register a backward hook to multiply the gradients according to the multiplicity: + print("Setting up gradient hooks to account for shared parameter multiplicity") + for param in model.parameters(): + param.register_hook(lambda grad: grad * float(comm.get_size("model"))) + + ddp_group = None + else: + ddp_group = comm.get_group("data") # double check if this is correct + broadcast_buffers = False + need_hooks = True + + # we can set up DDP and exit here + print("Setting up DDP communication hooks") + model = DistributedDataParallel( + model, + device_ids=device_ids, + output_device=output_device, + bucket_cap_mb=bucket_cap_mb, + broadcast_buffers=broadcast_buffers, + find_unused_parameters=find_unused_parameters, + gradient_as_bucket_view=gradient_as_bucket_view, + static_graph=static_graph, + process_group=ddp_group, + ) + if not need_hooks: + return model + + print("Setting up custom communication hooks") + + # define comm hook: + def reduction_comm_hook( + state: object, bucket: dist.GradBucket + ) -> torch.futures.Future[torch.Tensor]: # pragma: no cover + """reduction comm hook""" + + # allreduce everything first: + buff = bucket.buffer() + + # get future for allreduce + fut = dist.all_reduce(buff, op=dist.ReduceOp.AVG, group=comm.get_group("data"), async_op=True).get_future() + + # get grads for shared weights + params = bucket.parameters() + + def grad_reduction(fut, grads, group): + """reduce remaining gradients""" + coalesced = _flatten_dense_tensors(grads) + dist.all_reduce( + coalesced, + op=dist.ReduceOp.SUM, + group=comm.get_group(group), + async_op=False, + ) + for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)): + buf.copy_(synced) + + return bucket.buffer() + + for group in non_singleton_group_names: + if group == "data": + continue + + grads = [] + for p in params: + if group in p.is_shared_mp: + grads.append(p.grad.data) + + if not grads: + continue + + # append the new reduction functions + fut = fut.then(lambda x: grad_reduction(x, grads=grads, group=group)) + # fut = fut.then(lambda x: grad_copy(x, grads=grads)) + + ## chain it together + # for redfut, copyfut in zip(redfunc, copyfunc): + # fut = fut.then(redfut).then(copyfut) + + return fut + + # register model comm hook + model.register_comm_hook(state=None, hook=reduction_comm_hook) + + return model diff --git a/src/models/sfno/factorizations.py b/src/models/sfno/factorizations.py new file mode 100644 index 0000000..b5c1893 --- /dev/null +++ b/src/models/sfno/factorizations.py @@ -0,0 +1,225 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorly as tl +import torch + + +tl.set_backend("pytorch") +# from tensorly.plugins import use_opt_einsum +# use_opt_einsum('optimal') + +from modulus.models.sfno.contractions import ( + _contract_dhconv, + _contract_diagonal, + _contract_sep_dhconv, + _contract_sep_diagonal, +) +from tltorch.factorized_tensors.core import FactorizedTensor + + +einsum_symbols = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" + + +def _contract_dense(x, weight, separable=False, operator_type="diagonal"): # pragma: no cover + order = tl.ndim(x) + # batch-size, in_channels, x, y... + x_syms = list(einsum_symbols[:order]) + + # in_channels, out_channels, x, y... + weight_syms = list(x_syms[1:]) # no batch-size + + # batch-size, out_channels, x, y... + if separable: + out_syms = [x_syms[0]] + list(weight_syms) + else: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = "".join(x_syms) + "," + "".join(weight_syms) + "->" + "".join(out_syms) + + if not torch.is_tensor(weight): + weight = weight.to_tensor() + + return tl.einsum(eq, x, weight) + + +def _contract_cp(x, cp_weight, separable=False, operator_type="diagonal"): # pragma: no cover + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + rank_sym = einsum_symbols[order] + out_sym = einsum_symbols[order + 1] + out_syms = list(x_syms) + + if separable: + factor_syms = [einsum_symbols[1] + rank_sym] # in only + else: + out_syms[1] = out_sym + factor_syms = [einsum_symbols[1] + rank_sym, out_sym + rank_sym] # in, out + + factor_syms += [xs + rank_sym for xs in x_syms[2:]] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + out_syms[-1] = einsum_symbols[order + 2] + factor_syms += [out_syms[-1] + rank_sym] + elif operator_type == "dhconv": + factor_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + rank_sym + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + return tl.einsum(eq, x, cp_weight.weights, *cp_weight.factors) + + +def _contract_tucker(x, tucker_weight, separable=False, operator_type="diagonal"): # pragma: no cover + order = tl.ndim(x) + + x_syms = str(einsum_symbols[:order]) + out_sym = einsum_symbols[order] + out_syms = list(x_syms) + if separable: + core_syms = einsum_symbols[order + 1 : 2 * order] + # factor_syms = [einsum_symbols[1]+core_syms[0]] #in only + factor_syms = [xs + rs for (xs, rs) in zip(x_syms[1:], core_syms)] # x, y, ... + + else: + core_syms = einsum_symbols[order + 1 : 2 * order + 1] + out_syms[1] = out_sym + factor_syms = [ + einsum_symbols[1] + core_syms[0], + out_sym + core_syms[1], + ] # out, in + factor_syms += [xs + rs for (xs, rs) in zip(x_syms[2:], core_syms[2:])] # x, y, ... + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + raise NotImplementedError(f"Operator type {operator_type} not implemented for Tucker") + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + eq = x_syms + "," + core_syms + "," + ",".join(factor_syms) + "->" + "".join(out_syms) + + return tl.einsum(eq, x, tucker_weight.core, *tucker_weight.factors) + + +def _contract_tt(x, tt_weight, separable=False, operator_type="diagonal"): # pragma: no cover + order = tl.ndim(x) + + x_syms = list(einsum_symbols[:order]) + weight_syms = list(x_syms[1:]) # no batch-size + + if not separable: + weight_syms.insert(1, einsum_symbols[order]) # outputs + out_syms = list(weight_syms) + out_syms[0] = x_syms[0] + else: + out_syms = list(x_syms) + + if operator_type == "diagonal": + pass + elif operator_type == "block-diagonal": + weight_syms.insert(-1, einsum_symbols[order + 1]) + out_syms[-1] = weight_syms[-2] + elif operator_type == "dhconv": + weight_syms.pop() + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + rank_syms = list(einsum_symbols[order + 2 :]) + tt_syms = [] + for i, s in enumerate(weight_syms): + tt_syms.append([rank_syms[i], s, rank_syms[i + 1]]) + eq = "".join(x_syms) + "," + ",".join("".join(f) for f in tt_syms) + "->" + "".join(out_syms) + + return tl.einsum(eq, x, *tt_weight.factors) + + +# jitted PyTorch contractions: +def _contract_dense_pytorch(x, weight, separable=False, operator_type="diagonal"): # pragma: no cover + # to cheat the fused optimizers convert to real here + x = torch.view_as_real(x) + + if separable: + if operator_type == "diagonal": + x = _contract_sep_diagonal(x, weight) + elif operator_type == "dhconv": + x = _contract_sep_dhconv(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + else: + if operator_type == "diagonal": + x = _contract_diagonal(x, weight) + elif operator_type == "dhconv": + x = _contract_dhconv(x, weight) + else: + raise ValueError(f"Unkonw operator type {operator_type}") + + # to cheat the fused optimizers convert to real here + x = torch.view_as_complex(x) + return x + + +def get_contract_fun( + weight, implementation="reconstructed", separable=False, operator_type="diagonal" +): # pragma: no cover + """Generic ND implementation of Fourier Spectral Conv contraction + + Parameters + ---------- + weight : tensorly-torch's FactorizedTensor + implementation : {'reconstructed', 'factorized'}, default is 'reconstructed' + whether to reconstruct the weight and do a forward pass (reconstructed) + or contract directly the factors of the factorized weight with the input + (factorized) + + Returns + ------- + function : (x, weight) -> x * weight in Fourier space + """ + if implementation == "reconstructed": + return _contract_dense + elif implementation == "factorized": + if torch.is_tensor(weight): + return _contract_dense_pytorch + elif isinstance(weight, FactorizedTensor): + if weight.name.lower() == "complexdense" or weight.name.lower() == "dense": + return _contract_dense + elif weight.name.lower() == "complextucker": + return _contract_tucker + elif weight.name.lower() == "complextt": + return _contract_tt + elif weight.name.lower() == "complexcp": + return _contract_cp + else: + raise ValueError(f"Got unexpected factorized weight type {weight.name}") + else: + raise ValueError(f"Got unexpected weight type of class {weight.__class__.__name__}") + else: + raise ValueError(f'Got {implementation=}, expected "reconstructed" or "factorized"') diff --git a/src/models/sfno/initialization.py b/src/models/sfno/initialization.py new file mode 100644 index 0000000..72bac13 --- /dev/null +++ b/src/models/sfno/initialization.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import warnings + +import torch + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + low = norm_cdf((a - mean) / std) + upp = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * low - 1, 2 * upp - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) diff --git a/src/models/sfno/layers.py b/src/models/sfno/layers.py new file mode 100644 index 0000000..a1bd5f8 --- /dev/null +++ b/src/models/sfno/layers.py @@ -0,0 +1,511 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import torch +import torch.fft +import torch.nn as nn +import torch.nn.functional as F +from modulus.models.sfno.activations import * +from modulus.models.sfno.contractions import * +from torch.cuda import amp +from torch.utils.checkpoint import checkpoint +from torch_harmonics import * + + +class PatchEmbed(nn.Module): + """ + Divides the input image into patches and embeds them into a specified dimension + using a convolutional layer. + """ + + def __init__(self, img_size=(224, 224), patch_size=(16, 16), in_chans=3, embed_dim=768): # pragma: no cover + super(PatchEmbed, self).__init__() + num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0]) + self.img_size = img_size + self.patch_size = patch_size + self.num_patches = num_patches + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): # pragma: no cover + # gather input + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + # new: B, C, H*W + x = self.proj(x).flatten(2) + return x + + +class MLP(nn.Module): + """ + Basic CNN with support for gradient checkpointing + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + output_bias=True, + drop_rate=0.0, + checkpointing=0, + ): # pragma: no cover + super(MLP, self).__init__() + self.checkpointing = checkpointing + out_features = out_features or in_features + hidden_features = hidden_features or in_features + + fc1 = nn.Conv2d(in_features, hidden_features, 1, bias=True) + act = act_layer() + fc2 = nn.Conv2d(hidden_features, out_features, 1, bias=output_bias) + if drop_rate > 0.0: + drop = nn.Dropout(drop_rate) + self.fwd = nn.Sequential(fc1, act, drop, fc2, drop) + else: + self.fwd = nn.Sequential(fc1, act, fc2) + + # by default, all weights are shared + + @torch.jit.ignore + def checkpoint_forward(self, x): # pragma: no cover + """Forward method with support for gradient checkpointing""" + return checkpoint(self.fwd, x) + + def forward(self, x): # pragma: no cover + if self.checkpointing >= 2: + return self.checkpoint_forward(x) + else: + return self.fwd(x) + + +class RealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(RealFFT2, self).__init__() + + # use local FFT here + self.fft_handle = torch.fft.rfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + self.truncate = True + if (self.lmax == self.nlat) and (self.mmax == (self.nlon // 2 + 1)): + self.truncate = False + + # self.num_batches = 1 + assert self.lmax % 2 == 0 + + def forward(self, x): # pragma: no cover + y = self.fft_handle(x, (self.nlat, self.nlon), (-2, -1), "ortho") + + if self.truncate: + y = torch.cat( + ( + y[..., : math.ceil(self.lmax / 2), : self.mmax], + y[..., -math.floor(self.lmax / 2) :, : self.mmax], + ), + dim=-2, + ) + + return y + + +class InverseRealFFT2(nn.Module): + """ + Helper routine to wrap FFT similarly to the SHT + """ + + def __init__(self, nlat, nlon, lmax=None, mmax=None): # pragma: no cover + super(InverseRealFFT2, self).__init__() + + # use local FFT here + self.ifft_handle = torch.fft.irfft2 + + self.nlat = nlat + self.nlon = nlon + self.lmax = lmax or self.nlat + self.mmax = mmax or self.nlon // 2 + 1 + + def forward(self, x): # pragma: no cover + out = self.ifft_handle(x, (self.nlat, self.nlon), (-2, -1), "ortho") + + return out + + +class SpectralConv2d(nn.Module): + """ + Spectral Convolution as utilized in + """ + + def __init__( + self, + forward_transform, + inverse_transform, + hidden_size, + sparsity_threshold=0.0, + hard_thresholding_fraction=1, + use_complex_kernels=False, + compression=None, + rank=0, + bias=False, + ): # pragma: no cover + super(SpectralConv2d, self).__init__() + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.hard_thresholding_fraction = hard_thresholding_fraction + self.scale = 1 / hidden_size**2 + self.contract_handle = compl_contract2d_fwd_c if use_complex_kernels else compl_contract2d_fwd + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.output_dims = (self.inverse_transform.nlat, self.inverse_transform.nlon) + modes_lat = self.inverse_transform.lmax + modes_lon = self.inverse_transform.mmax + self.modes_lat = int(modes_lat * self.hard_thresholding_fraction) + self.modes_lon = int(modes_lon * self.hard_thresholding_fraction) + + # new simple linear layer + self.w = nn.Parameter( + self.scale * torch.randn(self.hidden_size, self.hidden_size, self.modes_lat, self.modes_lon, 2) + ) + # optional bias + if bias: + self.b = nn.Parameter(self.scale * torch.randn(1, self.hidden_size, *self.output_dims)) + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.float() + B, C, H, W = x.shape + + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = self.forward_transform(x) + x = torch.view_as_real(x) + x = x.to(dtype) + + # do spectral conv + modes = torch.zeros(x.shape, device=x.device) + + # modes[:, :, :self.modes_lat, :self.modes_lon, :] = self.contract_handle(x[:, :, :self.modes_lat, :self.modes_lon, :], self.wh) + # modes[:, :, -self.modes_lat:, :self.modes_lon, :] = self.contract_handle(x[:, :, -self.modes_lat:, :self.modes_lon, :], self.wl) + modes = self.contract_handle(x, self.w) + + # finalize + x = F.softshrink(modes, lambd=self.sparsity_threshold) + x = torch.view_as_complex(x) + + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = torch.view_as_complex(x) + x = x.contiguous() + x = self.inverse_transform(x) + x = x.to(dtype) + + if hasattr(self, "b"): + x = x + self.b + + return x + + +class SpectralConvS2(nn.Module): + """ + Spectral Convolution as utilized in + """ + + def __init__( + self, + forward_transform, + inverse_transform, + hidden_size, + sparsity_threshold=0.0, + use_complex_kernels=False, + compression=None, + rank=128, + bias=False, + ): # pragma: no cover + super(SpectralConvS2, self).__init__() + + self.hidden_size = hidden_size + self.sparsity_threshold = sparsity_threshold + self.scale = 0.02 + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.modes_lat = self.forward_transform.lmax + self.modes_lon = self.forward_transform.mmax + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + # remember the lower triangular indices + ii, jj = torch.tril_indices(self.modes_lat, self.modes_lon) + self.register_buffer("ii", ii) + self.register_buffer("jj", jj) + + if compression == "tt": + self.rank = rank + # tensortrain coefficients + g1 = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.rank, 2)) + g2 = nn.Parameter(self.scale * torch.randn(self.rank, self.hidden_size, self.rank, 2)) + g3 = nn.Parameter(self.scale * torch.randn(self.rank, len(ii), 2)) + self.w = nn.ParameterList([g1, g2, g3]) + + self.contract_handle = contract_tt # if use_complex_kernels else raise(NotImplementedError) + else: + self.w = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.hidden_size, len(ii), 2)) + self.contract_handle = compl_contract_fwd_c if use_complex_kernels else compl_contract_fwd + + if bias: + self.b = nn.Parameter(self.scale * torch.randn(1, self.hidden_size, *self.output_dims)) + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.float() + B, C, H, W = x.shape + + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.contiguous() + x = self.forward_transform(x) + x = torch.view_as_real(x) + x = x.to(dtype) + + # do spectral conv + modes = torch.zeros(x.shape, device=x.device) + modes[:, :, self.ii, self.jj, :] = self.contract_handle(x[:, :, self.ii, self.jj, :], self.w) + + # finalize + x = F.softshrink(modes, lambd=self.sparsity_threshold) + + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = torch.view_as_complex(x) + x = self.inverse_transform(x) + x = x.to(dtype) + + if hasattr(self, "b"): + x = x + self.b + + return x + + +class SpectralAttention2d(nn.Module): + """ + 2d Spectral Attention layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + sparsity_threshold=0.0, + hidden_size_factor=2, + use_complex_network=True, + use_complex_kernels=False, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttention2d, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.hidden_size = int(hidden_size_factor * self.embed_dim) + self.scale = 0.02 + self.spectral_layers = spectral_layers + self.mul_add_handle = compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform.forward + self.inverse_transform = inverse_transform.forward + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + # weights + w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 2*self.embed_freqs, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 4*self.embed_freqs, self.hidden_size, 2)] + for l in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [self.scale * torch.randn(self.hidden_size, 1, 2) for _ in range(self.spectral_layers)] + ) + + self.wout = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.embed_dim, 2)) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + self.activation = ComplexReLU(mode=complex_activation, bias_shape=(self.hidden_size, 1, 1)) + + def forward_mlp(self, xr): # pragma: no cover + """forward method for the MLP part of the network""" + for layer in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[layer].to(xr.dtype), self.b[layer].to(xr.dtype)) + else: + xr = self.mul_handle(xr, self.w[layer].to(xr.dtype)) + xr = torch.view_as_complex(xr) + xr = self.activation(xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + xr = self.mul_handle(xr, self.wout) + + return xr + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.to(torch.float32) + + # FWD transform + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.contiguous() + x = self.forward_transform(x) + x = torch.view_as_real(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with amp.autocast(enabled=False): + x = torch.view_as_complex(x) + x = x.contiguous() + x = self.inverse_transform(x) + x = x.to(dtype) + + return x + + +class SpectralAttentionS2(nn.Module): + """ + geometrical Spectral Attention layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + sparsity_threshold=0.0, + hidden_size_factor=2, + use_complex_network=True, + use_complex_kernels=False, + complex_activation="real", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.hidden_size = int(hidden_size_factor * self.embed_dim) + self.scale = 0.02 + # self.mul_add_handle = compl_muladd1d_fwd_c if use_complex_kernels else compl_muladd1d_fwd + self.mul_add_handle = compl_muladd2d_fwd_c if use_complex_kernels else compl_muladd2d_fwd + # self.mul_handle = compl_mul1d_fwd_c if use_complex_kernels else compl_mul1d_fwd + self.mul_handle = compl_mul2d_fwd_c if use_complex_kernels else compl_mul2d_fwd + self.spectral_layers = spectral_layers + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform.forward + self.inverse_transform = inverse_transform.forward + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + # weights + w = [self.scale * torch.randn(self.embed_dim, self.hidden_size, 2)] + # w = [self.scale * torch.randn(self.embed_dim + 4*self.embed_freqs, self.hidden_size, 2)] + for layer in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(self.hidden_size, self.hidden_size, 2)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [self.scale * torch.randn(2 * self.hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)] + ) + + self.wout = nn.Parameter(self.scale * torch.randn(self.hidden_size, self.embed_dim, 2)) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + self.activation = ComplexReLU(mode=complex_activation, bias_shape=(self.hidden_size, 1, 1)) + + def forward_mlp(self, xr): # pragma: no cover + """forward method for the MLP part of the network""" + for layer in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[layer].to(xr.dtype), self.b[layer].to(xr.dtype)) + else: + xr = self.mul_handle(xr, self.w[layer].to(xr.dtype)) + xr = torch.view_as_complex(xr) + xr = self.activation(xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + xr = self.mul_handle(xr, self.wout) + + return xr + + def forward(self, x): # pragma: no cover + dtype = x.dtype + # x = x.to(torch.float32) + + # FWD transform + with amp.autocast(enabled=False): + x = x.to(torch.float32) + x = x.contiguous() + x = self.forward_transform(x) + x = torch.view_as_real(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with amp.autocast(enabled=False): + x = torch.view_as_complex(x) + x = x.contiguous() + x = self.inverse_transform(x) + x = x.to(dtype) + + return x diff --git a/src/models/sfno/preprocessor.py b/src/models/sfno/preprocessor.py new file mode 100644 index 0000000..1494abe --- /dev/null +++ b/src/models/sfno/preprocessor.py @@ -0,0 +1,252 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn + + +# from torch_harmonics import * + + +class Preprocessor2D(nn.Module): + """ + Preprocessing methods to flatten image history, add static features, and + convert the data format from NCHW to NHWC. + """ + + def __init__(self, params, img_size=(720, 1440)): # pragma: no cover + super(Preprocessor2D, self).__init__() + + self.n_history = params.n_history + self.transform_to_nhwc = params.enable_nhwc + + # self.poltor_decomp = params.poltor_decomp + # self.img_size = (params.img_shape_x, params.img_shape_y) if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y") else img_size + # self.input_grid = "equiangular" + # self.output_grid = "equiangular" + + # process static features + static_features = None + # needed for sharding + start_x = params.img_local_offset_x + end_x = start_x + params.img_local_shape_x + start_y = params.img_local_offset_y + end_y = start_y + params.img_local_shape_y + if params.add_grid: + tx = torch.linspace(0, 1, params.img_shape_x + 1, dtype=torch.float32)[0:-1] + ty = torch.linspace(0, 1, params.img_shape_y + 1, dtype=torch.float32)[0:-1] + + x_grid, y_grid = torch.meshgrid(tx, ty, indexing="ij") + x_grid, y_grid = x_grid.unsqueeze(0).unsqueeze(0), y_grid.unsqueeze(0).unsqueeze(0) + grid = torch.cat([x_grid, y_grid], dim=1) + + # now shard: + grid = grid[:, :, start_x:end_x, start_y:end_y] + + static_features = grid + # self.register_buffer("grid", grid) + + if params.add_orography: + from utils.conditioning_inputs import get_orography + + oro = torch.tensor(get_orography(params.orography_path), dtype=torch.float32) + oro = torch.reshape(oro, (1, 1, oro.shape[0], oro.shape[1])) + + # shard + oro = oro[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = oro + else: + static_features = torch.cat([static_features, oro], dim=1) + + if params.add_landmask: + from utils.conditioning_inputs import get_land_mask + + lsm = torch.tensor(get_land_mask(params.landmask_path), dtype=torch.long) + # one hot encode and move channels to front: + lsm = torch.permute(torch.nn.functional.one_hot(lsm), (2, 0, 1)).to(torch.float32) + lsm = torch.reshape(lsm, (1, lsm.shape[0], lsm.shape[1], lsm.shape[2])) + + # shard + lsm = lsm[:, :, start_x:end_x, start_y:end_y] + + if static_features is None: + static_features = lsm + else: + static_features = torch.cat([static_features, lsm], dim=1) + + self.add_static_features = False + if static_features is not None: + self.add_static_features = True + self.register_buffer("static_features", static_features) + + # if self.poltor_decomp: + # assert(hasattr(params, 'wind_channels')) + # wind_channels = torch.as_tensor(params.wind_channels) + # self.register_buffer("wind_channels", wind_channels) + + # self.forward_transform = RealVectorSHT(*self.img_size, grid=self.input_grid).float() + # self.inverse_transform = InverseRealSHT(*self.img_size, grid=self.output_grid).float() + + def _flatten_history(self, x, y): # pragma: no cover + # flatten input + if x.dim() == 5: + b_, t_, c_, h_, w_ = x.shape + x = torch.reshape(x, (b_, t_ * c_, h_, w_)) + + # flatten target + if (y is not None) and (y.dim() == 5): + b_, t_, c_, h_, w_ = y.shape + y = torch.reshape(y, (b_, t_ * c_, h_, w_)) + + return x, y + + def _add_static_features(self, x, y): # pragma: no cover + # we need to replicate the grid for each batch: + static = torch.tile(self.static_features, dims=(x.shape[0], 1, 1, 1)) + x = torch.cat([x, static], dim=1) + return x, y + + def _nchw_to_nhwc(self, x, y): # pragma: no cover + x = x.to(memory_format=torch.channels_last) + if y is not None: + y = y.to(memory_format=torch.channels_last) + + return x, y + + def append_history(self, x1, x2): # pragma: no cover + """ + Appends history to the main input. + Without history, just returns the second tensor (x2). + """ + # + # with grid if requested + if self.n_history == 0: + return x2 + + # if grid is added, strip it off first + if self.add_static_features: + nfeat = self.static_features.shape[1] + x1 = x1[:, :-nfeat, :, :] + + # this is more complicated + if x1.dim() == 4: + b_, c_, h_, w_ = x1.shape + x1 = torch.reshape(x1, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + + if x2.dim() == 4: + b_, c_, h_, w_ = x2.shape + x2 = torch.reshape(x2, (b_, 1, c_, h_, w_)) + + # append + res = torch.cat([x1[:, 1:, :, :, :], x2], dim=1) + + # flatten again + b_, t_, c_, h_, w_ = res.shape + res = torch.reshape(res, (b_, t_ * c_, h_, w_)) + + return res + + # def _poltor_decompose(self, x, y): + # b_, c_, h_, w_ = x.shape + # xu = x[:, self.wind_channels, :, :] + # xu = xu.reshape(b_, -1, 2, h_, w_) + # xu = self.inverse_transform(self.forward_transform(xu)) + # xu = xu.reshape(b_, -1, h_, w_) + # x[:, self.wind_channels, :, :] = xu + # return x, y + + # forward method for additional variable fiels in x and y, + # for example zenith angle: + # def forward(self, x, y, xz, yz): + # x = torch.cat([x, xz], dim=2) + # + # return x, y + + def append_channels(self, x, xc): # pragma: no cover + """Appends channels""" + if x.dim() == 4: + b_, c_, h_, w_ = x.shape + x = torch.reshape(x, (b_, (self.n_history + 1), c_ // (self.n_history + 1), h_, w_)) + + xo = torch.cat([x, xc], dim=2) + + if x.dim() == 4: + xo, _ = self._flatten_history(xo, None) + + return xo + + def forward(self, x, y=None, xz=None, yz=None): # pragma: no cover + if xz is not None: + x = self.append_channels(x, xz) + + return self._forward(x, y) + + def _forward(self, x, y): # pragma: no cover + # we always want to flatten the history, even if its a singleton + x, y = self._flatten_history(x, y) + + if self.add_static_features: + x, y = self._add_static_features(x, y) + + # if self.poltor_decomp: + # x, y = self._poltor_decompose(x, y) + + if self.transform_to_nhwc: + x, y = self._nchw_to_nhwc(x, y) + + return x, y + + +def get_preprocessor(params): # pragma: no cover + """Returns the preprocessor module.""" + return Preprocessor2D(params) + + +# class Postprocessor2D(nn.Module): +# def __init__(self, params): +# super(Postprocessor2D, self).__init__() + +# self.poltor_decomp = params.poltor_decomp +# self.img_size = (params.img_shape_x, params.img_shape_y) if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y") else img_size +# self.input_grid = "equiangular" +# self.output_grid = "equiangular" + +# if self.poltor_decomp: +# assert(hasattr(params, 'wind_channels')) +# wind_channels = torch.as_tensor(params.wind_channels) +# self.register_buffer("wind_channels", wind_channels) + +# self.forward_transform = RealSHT(*self.img_size, grid=self.input_grid).float() +# self.inverse_transform = InverseRealVectorSHT(*self.img_size, grid=self.output_grid).float() + +# def _poltor_recompose(self, x): +# b_, c_, h_, w_ = x.shape +# xu = x[:, self.wind_channels, :, :] +# xu = xu.reshape(b_, -1, 2, h_, w_) +# xu = self.inverse_transform(self.forward_transform(xu)) +# xu = xu.reshape(b_, -1, h_, w_) +# x[:, self.wind_channels, :, :] = xu +# return x + +# def forward(self, x): + +# if self.poltor_decomp: +# x = self._poltor_recompose(x) + +# return x + +# def get_postprocessor(params): +# return Postprocessor2D(params) diff --git a/src/models/sfno/s2convolutions.py b/src/models/sfno/s2convolutions.py new file mode 100644 index 0000000..a23635c --- /dev/null +++ b/src/models/sfno/s2convolutions.py @@ -0,0 +1,548 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# import FactorizedTensor from tensorly for tensorized operations +import tensorly as tl +import torch +import torch.nn as nn +from torch import amp + + +tl.set_backend("pytorch") +# from tensorly.plugins import use_opt_einsum +# use_opt_einsum('optimal') +import torch_harmonics as th +import torch_harmonics.distributed as thd + +# import convenience functions for factorized tensors +from modulus.models.sfno.activations import ComplexReLU + +# for the experimental module +from modulus.models.sfno.contractions import ( + _contract_localconv_fwd, + compl_exp_mul2d_fwd, + compl_exp_muladd2d_fwd, + compl_mul2d_fwd, + compl_muladd2d_fwd, + real_mul2d_fwd, + real_muladd2d_fwd, +) +from modulus.models.sfno.factorizations import get_contract_fun +from tltorch.factorized_tensors.core import FactorizedTensor + + +class SpectralConvS2(nn.Module): + """ + Spectral Convolution according to Driscoll & Healy. Designed for convolutions on + the two-sphere S2 using the Spherical Harmonic Transforms in torch-harmonics, but + supports convolutions on the periodic domain via the RealFFT2 and InverseRealFFT2 + wrappers. + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + scale="auto", + operator_type="diagonal", + rank=0.2, + factorization=None, + separable=False, + decomposition_kwargs=dict(), + bias=False, + use_tensorly=True, + ): # pragma: no cover + super(SpectralConvS2, self).__init__() + + if scale == "auto": + scale = 1 / (in_channels * out_channels) + + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + + self.scale_residual = ( + (self.forward_transform.nlat != self.inverse_transform.nlat) + or (self.forward_transform.nlon != self.inverse_transform.nlon) + or (self.forward_transform.grid != self.inverse_transform.grid) + ) + + # Make sure we are using a Complex Factorized Tensor + if factorization is None: + factorization = "Dense" # No factorization + + if not factorization.lower().startswith("complex"): + factorization = f"Complex{factorization}" + + # remember factorization details + self.operator_type = operator_type + self.rank = rank + self.factorization = factorization + self.separable = separable + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + weight_shape = [in_channels] + + if not self.separable: + weight_shape += [out_channels] + + if isinstance(self.inverse_transform, thd.DistributedInverseRealSHT): + self.modes_lat_local = self.inverse_transform.lmax_local + self.modes_lon_local = self.inverse_transform.mmax_local + self.lpad_local = self.inverse_transform.lpad_local + self.mpad_local = self.inverse_transform.mpad_local + else: + self.modes_lat_local = self.modes_lat + self.modes_lon_local = self.modes_lon + self.lpad = 0 + self.mpad = 0 + + # padded weights + # if self.operator_type == 'diagonal': + # weight_shape += [self.modes_lat_local+self.lpad_local, self.modes_lon_local+self.mpad_local] + # elif self.operator_type == 'dhconv': + # weight_shape += [self.modes_lat_local+self.lpad_local] + # else: + # raise ValueError(f"Unsupported operator type f{self.operator_type}") + + # unpadded weights + if self.operator_type == "diagonal": + weight_shape += [self.modes_lat_local, self.modes_lon_local] + elif self.operator_type == "dhconv": + weight_shape += [self.modes_lat_local] + else: + raise ValueError(f"Unsupported operator type f{self.operator_type}") + + if use_tensorly: + # form weight tensors + self.weight = FactorizedTensor.new( + weight_shape, + rank=self.rank, + factorization=factorization, + fixed_rank_modes=False, + **decomposition_kwargs, + ) + # initialization of weights + self.weight.normal_(0, scale) + else: + assert factorization == "ComplexDense" + self.weight = nn.Parameter(scale * torch.randn(*weight_shape, 2)) + if self.operator_type == "dhconv": + self.weight.is_shared_mp = ["matmul", "w"] + else: + self.weight.is_shared_mp = ["matmul"] + + # get the contraction handle + self._contract = get_contract_fun(self.weight, implementation="factorized", separable=separable) + + if bias: + self.bias = nn.Parameter(scale * torch.zeros(1, out_channels, 1, 1)) + + def forward(self, x): # pragma: no cover + dtype = x.dtype + residual = x + x = x.float() + B, C, H, W = x.shape + + with amp.autocast("cuda", enabled=False): + x = self.forward_transform(x) + if self.scale_residual: + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # approach with unpadded weights + xp = torch.zeros_like(x) + xp[..., : self.modes_lat_local, : self.modes_lon_local] = self._contract( + x[..., : self.modes_lat_local, : self.modes_lon_local], + self.weight, + separable=self.separable, + operator_type=self.operator_type, + ) + x = xp.contiguous() + + # # approach with padded weights + # x = self._contract(x, self.weight, separable=self.separable, operator_type=self.operator_type) + # x = x.contiguous() + + with amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.type(dtype) + + return x, residual + + +class LocalConvS2(nn.Module): + """ + S2 Convolution according to Driscoll & Healy + """ + + def __init__( + self, + forward_transform, + inverse_transform, + in_channels, + out_channels, + nradius=120, + scale="auto", + bias=False, + ): # pragma: no cover + super(LocalConvS2, self).__init__() + + if scale == "auto": + scale = 1 / (in_channels * out_channels) + + self.in_channels = in_channels + self.out_channels = out_channels + self.nradius = nradius + + self.forward_transform = forward_transform + self.zonal_transform = th.RealSHT( + forward_transform.nlat, + 1, + lmax=forward_transform.lmax, + mmax=1, + grid=forward_transform.grid, + ).float() + self.inverse_transform = inverse_transform + + self.modes_lat = self.inverse_transform.lmax + self.modes_lon = self.inverse_transform.mmax + self.output_dims = (self.inverse_transform.nlat, self.inverse_transform.nlon) + + assert self.inverse_transform.lmax == self.modes_lat + assert self.inverse_transform.mmax == self.modes_lon + + self.weight = nn.Parameter(scale * torch.randn(in_channels, out_channels, nradius, 1)) + + self._contract = _contract_localconv_fwd + + if bias: + self.bias = nn.Parameter(scale * torch.randn(1, out_channels, *self.output_dims)) + + def forward(self, x): # pragma: no cover + dtype = x.dtype + x = x.float() + B, C, H, W = x.shape + + with amp.autocast("cuda", enabled=False): + f = torch.zeros( + (self.in_channels, self.out_channels, H, 1), + dtype=x.dtype, + device=x.device, + ) + f[..., : self.nradius, :] = self.weight + + x = self.forward_transform(x) + f = self.zonal_transform(f)[..., :, 0] + + x = torch.view_as_real(x) + f = torch.view_as_real(f) + + x = self._contract(x, f) + x = x.contiguous() + + x = torch.view_as_complex(x) + + with amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + if hasattr(self, "bias"): + x = x + self.bias + + x = x.type(dtype) + + return x + + +class SpectralAttentionS2(nn.Module): + """ + Spherical non-linear FNO layer + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + operator_type="diagonal", + sparsity_threshold=0.0, + hidden_size_factor=2, + complex_activation="real", + scale="auto", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(SpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + if scale == "auto": + self.scale = 1 / (embed_dim * embed_dim) + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or ( + self.forward_transform.nlon != self.inverse_transform.nlon + ) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.embed_dim) + + if operator_type == "diagonal": + self.mul_add_handle = compl_muladd2d_fwd + self.mul_handle = compl_mul2d_fwd + + # weights + w = [self.scale * torch.randn(self.embed_dim, hidden_size, 2)] + for layer in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(hidden_size, hidden_size, 2)) + self.w = nn.ParameterList(w) + + self.wout = nn.Parameter(self.scale * torch.randn(hidden_size, self.embed_dim, 2)) + + if bias: + self.b = nn.ParameterList( + [self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)] + ) + + self.activations = nn.ModuleList([]) + for layer in range(0, self.spectral_layers): + self.activations.append( + ComplexReLU( + mode=complex_activation, + bias_shape=(hidden_size, 1, 1), + scale=self.scale, + ) + ) + + elif operator_type == "l-dependant": + self.mul_add_handle = compl_exp_muladd2d_fwd + self.mul_handle = compl_exp_mul2d_fwd + + # weights + w = [self.scale * torch.randn(self.modes_lat, self.embed_dim, hidden_size, 2)] + for layer in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(self.modes_lat, hidden_size, hidden_size, 2)) + self.w = nn.ParameterList(w) + + if bias: + self.b = nn.ParameterList( + [self.scale * torch.randn(hidden_size, 1, 1, 2) for _ in range(self.spectral_layers)] + ) + + self.wout = nn.Parameter(self.scale * torch.randn(self.modes_lat, hidden_size, self.embed_dim, 2)) + + self.activations = nn.ModuleList([]) + for layer in range(0, self.spectral_layers): + self.activations.append( + ComplexReLU( + mode=complex_activation, + bias_shape=(hidden_size, 1, 1), + scale=self.scale, + ) + ) + + else: + raise ValueError("Unknown operator type") + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): # pragma: no cover + """forward pass of the MLP""" + B, C, H, W = x.shape + + if self.operator_type == "block-separable": + x = x.permute(0, 3, 1, 2) + + xr = torch.view_as_real(x) + + for layer in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[layer], self.b[layer]) + else: + xr = self.mul_handle(xr, self.w[layer]) + xr = torch.view_as_complex(xr) + xr = self.activations[layer](xr) + xr = self.drop(xr) + xr = torch.view_as_real(xr) + + # final MLP + x = self.mul_handle(xr, self.wout) + + x = torch.view_as_complex(x) + + if self.operator_type == "block-separable": + x = x.permute(0, 2, 3, 1) + + return x + + def forward(self, x): # pragma: no cover + dtype = x.dtype + residual = x + x = x.to(torch.float32) + + # FWD transform + with amp.autocast("cuda", enabled=False): + x = self.forward_transform(x) + if self.scale_residual: + x = x.contiguous() + residual = self.inverse_transform(x) + residual = residual.to(dtype) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + x = x.contiguous() + with amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x, residual + + +class RealSpectralAttentionS2(nn.Module): + """ + Non-linear SFNO layer using a real-valued NN instead of a complex one + """ + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + operator_type="diagonal", + sparsity_threshold=0.0, + hidden_size_factor=2, + complex_activation="real", + scale="auto", + bias=False, + spectral_layers=1, + drop_rate=0.0, + ): # pragma: no cover + super(RealSpectralAttentionS2, self).__init__() + + self.embed_dim = embed_dim + self.sparsity_threshold = sparsity_threshold + self.operator_type = operator_type + self.spectral_layers = spectral_layers + + if scale == "auto": + self.scale = 1 / (embed_dim * embed_dim) + + self.modes_lat = forward_transform.lmax + self.modes_lon = forward_transform.mmax + + # only storing the forward handle to be able to call it + self.forward_transform = forward_transform + self.inverse_transform = inverse_transform + + self.scale_residual = (self.forward_transform.nlat != self.inverse_transform.nlat) or ( + self.forward_transform.nlon != self.inverse_transform.nlon + ) + + assert inverse_transform.lmax == self.modes_lat + assert inverse_transform.mmax == self.modes_lon + + hidden_size = int(hidden_size_factor * self.embed_dim * 2) + + self.mul_add_handle = real_muladd2d_fwd + self.mul_handle = real_mul2d_fwd + + # weights + w = [self.scale * torch.randn(2 * self.embed_dim, hidden_size)] + for layer in range(1, self.spectral_layers): + w.append(self.scale * torch.randn(hidden_size, hidden_size)) + self.w = nn.ParameterList(w) + + self.wout = nn.Parameter(self.scale * torch.randn(hidden_size, 2 * self.embed_dim)) + + if bias: + self.b = nn.ParameterList( + [self.scale * torch.randn(hidden_size, 1, 1) for _ in range(self.spectral_layers)] + ) + + self.activations = nn.ModuleList([]) + for layer in range(0, self.spectral_layers): + self.activations.append(nn.ReLU()) + + self.drop = nn.Dropout(drop_rate) if drop_rate > 0.0 else nn.Identity() + + def forward_mlp(self, x): # pragma: no cover + """forward pass of the MLP""" + B, C, H, W = x.shape + + xr = torch.view_as_real(x) + xr = xr.permute(0, 1, 4, 2, 3).reshape(B, C * 2, H, W) + + for layer in range(self.spectral_layers): + if hasattr(self, "b"): + xr = self.mul_add_handle(xr, self.w[layer], self.b[layer]) + else: + xr = self.mul_handle(xr, self.w[layer]) + xr = self.activations[layer](xr) + xr = self.drop(xr) + + # final MLP + xr = self.mul_handle(xr, self.wout) + + xr = xr.reshape(B, C, 2, H, W).permute(0, 1, 3, 4, 2) + + x = torch.view_as_complex(xr) + + return x + + def forward(self, x): # pragma: no cover + dtype = x.dtype + x = x.to(torch.float32) + + # FWD transform + with amp.autocast("cuda", enabled=False): + x = self.forward_transform(x) + + # MLP + x = self.forward_mlp(x) + + # BWD transform + with amp.autocast("cuda", enabled=False): + x = self.inverse_transform(x) + + # cast back to initial precision + x = x.to(dtype) + + return x diff --git a/src/models/sfno/sfnonet.py b/src/models/sfno/sfnonet.py new file mode 100644 index 0000000..189c4d7 --- /dev/null +++ b/src/models/sfno/sfnonet.py @@ -0,0 +1,841 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial +from typing import Any, Literal + +import torch +import torch.nn as nn + +# get spectral transforms from torch_harmonics +import torch_harmonics as th +import torch_harmonics.distributed as thd + + +# layer normalization +try: + from apex.normalization import FusedLayerNorm +except ImportError: + from torch.nn import LayerNorm as FusedLayerNorm # type: ignore +from einops import rearrange + +# helpers +from modulus.models.sfno.initialization import trunc_normal_ +from torch.utils.checkpoint import checkpoint + +from src.models._base_model import BaseModel +from src.models.modules.drop_path import DropPath +from src.models.modules.misc import get_time_embedder + +# more distributed stuff +from src.models.sfno.distributed import comm +from src.models.sfno.distributed.layer_norm import DistributedInstanceNorm2d + +# wrap fft, to unify interface to spectral transforms +from src.models.sfno.distributed.layers import ( + DistributedInverseRealFFT2, + DistributedMLP, + DistributedRealFFT2, +) + +# import global convolution and non-linear spectral layers +from src.models.sfno.layers import MLP, RealFFT2, SpectralAttention2d +from src.models.sfno.s2convolutions import SpectralAttentionS2, SpectralConvS2 +from src.utilities.utils import raise_error_if_invalid_value + + +# from src.models.module import Module +# from src.models.meta import ModelMetaData + + +# @dataclass +# class MetaData(ModelMetaData): +# name: str = "SFNO" +# # Optimization +# jit: bool = False +# cuda_graphs: bool = True +# amp_cpu: bool = True +# amp_gpu: bool = True +# torch_fx: bool = False +# # Inference +# onnx: bool = False +# # Physics informed +# func_torch: bool = False +# auto_grad: bool = False + + +class SpectralFilterLayer(nn.Module): + """Spectral filter layer""" + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="block-diagonal", + sparsity_threshold=0.0, + use_complex_kernels=True, + hidden_size_factor=1, + rank=1.0, + factorization=None, + separable=False, + complex_network=True, + complex_activation="real", + spectral_layers=1, + drop_rate=0.0, + ): + super(SpectralFilterLayer, self).__init__() + + if filter_type == "non-linear" and ( + isinstance(forward_transform, th.RealSHT) or isinstance(forward_transform, thd.DistributedRealSHT) + ): + self.filter = SpectralAttentionS2( + forward_transform, + inverse_transform, + embed_dim, + operator_type=operator_type, + sparsity_threshold=sparsity_threshold, + hidden_size_factor=hidden_size_factor, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate, + bias=False, + ) + + elif filter_type == "non-linear" and ( + isinstance(forward_transform, RealFFT2) or isinstance(forward_transform, DistributedRealFFT2) + ): + self.filter = SpectralAttention2d( + forward_transform, + inverse_transform, + embed_dim, + sparsity_threshold=sparsity_threshold, + hidden_size_factor=hidden_size_factor, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate, + bias=False, + ) + + # spectral transform is passed to the module + elif filter_type == "linear" and ( + isinstance(forward_transform, th.RealSHT) or isinstance(forward_transform, thd.DistributedRealSHT) + ): + if drop_rate > 0.0: + print("Dropout is not used for linear filters!") + self.filter = SpectralConvS2( + forward_transform, + inverse_transform, + embed_dim, + embed_dim, + operator_type=operator_type, + rank=rank, + factorization=factorization, + separable=separable, + bias=True, + use_tensorly=False if factorization is None else True, + ) + + else: + raise (NotImplementedError) + + def forward(self, x): + return self.filter(x) + + +class FourierNeuralOperatorBlock(nn.Module): + """Fourier Neural Operator Block""" + + def __init__( + self, + forward_transform, + inverse_transform, + embed_dim, + filter_type="linear", + operator_type="diagonal", + mlp_ratio=2.0, + drop_rate_filter=0.0, + drop_rate_mlp=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=(nn.LayerNorm, nn.LayerNorm), + sparsity_threshold=0.0, + use_complex_kernels=True, + rank=1.0, + factorization=None, + separable=False, + inner_skip="linear", + outer_skip=None, # None, nn.linear or nn.Identity + concat_skip=False, + use_mlp=False, + complex_network=True, + complex_activation="real", + spectral_layers=1, + checkpointing=0, + time_emb_dim: int = None, + time_scale_shift_before_filter: bool = True, + ): + super(FourierNeuralOperatorBlock, self).__init__() + + if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): + self.input_shape_loc = ( + forward_transform.nlat_local, + forward_transform.nlon_local, + ) + self.output_shape_loc = ( + inverse_transform.nlat_local, + inverse_transform.nlon_local, + ) + else: + self.input_shape_loc = (forward_transform.nlat, forward_transform.nlon) + self.output_shape_loc = (inverse_transform.nlat, inverse_transform.nlon) + + # norm layer + self.norm0 = norm_layer[0]() + + # time embedding + if time_emb_dim is not None: + self.time_mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, embed_dim * 2), # 2 for scale and shift + ) + self.time_scale_shift_before_filter = time_scale_shift_before_filter + else: + self.time_mlp = None + self.time_scale_shift_before_filter = False + + # convolution layer + self.filter = SpectralFilterLayer( + forward_transform, + inverse_transform, + embed_dim, + filter_type, + operator_type, + sparsity_threshold, + use_complex_kernels=use_complex_kernels, + hidden_size_factor=mlp_ratio, + rank=rank, + factorization=factorization, + separable=separable, + complex_network=complex_network, + complex_activation=complex_activation, + spectral_layers=spectral_layers, + drop_rate=drop_rate_filter, + ) + + if inner_skip == "linear": + self.inner_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) + elif inner_skip == "identity": + self.inner_skip = nn.Identity() + + self.concat_skip = concat_skip + + if concat_skip and inner_skip is not None: + self.inner_skip_conv = nn.Conv2d(2 * embed_dim, embed_dim, 1, bias=False) + + if filter_type == "linear" or filter_type == "real linear": + self.act_layer = act_layer() + + # dropout + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + # norm layer + self.norm1 = norm_layer[1]() + + if use_mlp: + MLPH = DistributedMLP if (comm.get_size("matmul") > 1) else MLP + mlp_hidden_dim = int(embed_dim * mlp_ratio) + self.mlp = MLPH( + in_features=embed_dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop_rate=drop_rate_mlp, + checkpointing=checkpointing, + ) + + if outer_skip == "linear": + self.outer_skip = nn.Conv2d(embed_dim, embed_dim, 1, 1) + elif outer_skip == "identity": + self.outer_skip = nn.Identity() + elif outer_skip is None: + self.outer_skip = None + else: + raise NotImplementedError(f"outer_skip={outer_skip} is not implemented") + + if concat_skip and outer_skip is not None: + self.outer_skip_conv = nn.Conv2d(2 * embed_dim, embed_dim, 1, bias=False) + + def time_scale_shift(self, x, time_emb): + assert time_emb is not None, "time_emb is None but time_scale_shift is called" + time_emb = self.time_mlp(time_emb) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale, shift = time_emb.chunk(2, dim=1) # split into scale and shift (channel dim) + # shapesL scale/shift: (b, 1, 1, hidden/emb/channel-dim), x: (b, h, w, hidden/emb/channel-dim) + x = x * (scale + 1) + shift + return x + + def forward(self, x, time_emb=None): + x_norm = torch.zeros_like(x) + if x_norm.shape == x.shape: + x_norm = self.norm0(x) + else: + x_norm[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] = self.norm0( + x[..., : self.input_shape_loc[0], : self.input_shape_loc[1]] + ) + + if self.time_scale_shift_before_filter and self.time_mlp is not None: + x_norm = self.time_scale_shift(x_norm, time_emb) + + x, residual = self.filter(x_norm) + + if hasattr(self, "inner_skip"): + if self.concat_skip: + x = torch.cat((x, self.inner_skip(residual)), dim=1) + x = self.inner_skip_conv(x) + else: + x = x + self.inner_skip(residual) + + if hasattr(self, "act_layer"): + x = self.act_layer(x) + + x_norm = torch.zeros_like(x) + if x_norm.shape == x.shape: + x_norm = self.norm1(x) + else: + x_norm[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] = self.norm1( + x[..., : self.output_shape_loc[0], : self.output_shape_loc[1]] + ) + x = x_norm + + if not self.time_scale_shift_before_filter and self.time_mlp is not None: + x = self.time_scale_shift(x, time_emb) + + if hasattr(self, "mlp"): + x = self.mlp(x) + + x = self.drop_path(x) + + if self.outer_skip is not None: + if self.concat_skip: + x = torch.cat((x, self.outer_skip(residual)), dim=1) + x = self.outer_skip_conv(x) + else: + x = x + self.outer_skip(residual) + + return x + + +class SphericalFourierNeuralOperatorNet(BaseModel): + """ + Spherical Fourier Neural Operator Network + + Parameters + ---------- + params : dict + Dictionary of parameters + spectral_transform : str, optional + Type of spectral transformation to use, by default "sht" + filter_type : str, optional + Type of filter to use ('linear', 'non-linear'), by default "non-linear" + operator_type : str, optional + Type of operator to use ('diaginal', 'dhconv'), by default "diagonal" + img_shape : tuple, optional + Shape of the input channels, by default (721, 1440) + scale_factor : int, optional + Scale factor to use, by default 16 + in_chans : int, optional + Number of input channels, by default 2 + out_chans : int, optional + Number of output channels, by default 2 + embed_dim : int, optional + Dimension of the embeddings, by default 256 + num_layers : int, optional + Number of layers in the network, by default 12 + use_mlp : int, optional + Whether to use MLP, by default True + mlp_ratio : int, optional + Ratio of MLP to use, by default 2.0 + activation_function : str, optional + Activation function to use, by default "gelu" + encoder_layers : int, optional + Number of layers in the encoder, by default 1 + pos_embed : bool, optional + Whether to use positional embedding, by default True + dropout : float, optional + Dropout rate, by default 0.0 + drop_path_rate : float, optional + Dropout path rate, by default 0.0 + num_blocks : int, optional + Number of blocks in the network, by default 16 + sparsity_threshold : float, optional + Threshold for sparsity, by default 0.0 + normalization_layer : str, optional + Type of normalization layer to use ("layer_norm", "instance_norm", "none"), by default "instance_norm" + hard_thresholding_fraction : float, optional + Fraction of hard thresholding to apply, by default 1.0 + use_complex_kernels : bool, optional + Whether to use complex kernels, by default True + big_skip : bool, optional + Whether to use big skip connections, by default True + rank : float, optional + Rank of the approximation, by default 1.0 + factorization : Any, optional + Type of factorization to use, by default None + separable : bool, optional + Whether to use separable convolutions, by default False + complex_network : bool, optional + Whether to use a complex network architecture, by default True + complex_activation : str, optional + Type of complex activation function to use, by default "real" + spectral_layers : int, optional + Number of spectral layers, by default 3 + checkpointing : int, optional + Number of checkpointing segments, by default 0 + + Example: + -------- + >>> from modulus.models.sfno.sfnonet import SphericalFourierNeuralOperatorNet as SFNO + >>> model = SFNO( + ... params={}, + ... img_shape=(8, 16), + ... scale_factor=4, + ... in_chans=2, + ... out_chans=2, + ... embed_dim=16, + ... num_layers=2, + ... encoder_layers=1, + ... num_blocks=4, + ... spectral_layers=2, + ... use_mlp=True,) + >>> model(torch.randn(1, 2, 8, 16)).shape + torch.Size([1, 2, 8, 16]) + """ + + def __init__( + self, + params: dict = None, + spectral_transform: str = "sht", + filter_type: str = "linear", + operator_type: str = "diagonal", + # img_shape: Tuple[int] = (721, 1440), + scale_factor: int = 16, + # in_chans: int = 2, + # out_chans: int = 2, + embed_dim: int = 256, + num_layers: int = 12, + use_mlp: int = True, + mlp_ratio: int = 2.0, + activation_function: str = "gelu", + encoder_layers: int = 1, + pos_embed: bool = True, + dropout_filter: float = 0.0, + dropout_mlp: float = 0.0, + pos_emb_dropout: float = 0.0, + drop_path_rate: float = 0.0, + num_blocks: int = 16, + sparsity_threshold: float = 0.0, + normalization_layer: str = "instance_norm", + hard_thresholding_fraction: float = 1.0, + use_complex_kernels: bool = True, + big_skip: bool = True, + rank: float = 1.0, + factorization: Any = None, + separable: bool = False, + complex_network: bool = True, + complex_activation: str = "real", + spectral_layers: int = 3, + checkpointing: int = 0, + with_time_emb: bool = False, + time_dim_mult: int = 2, + time_rescale: bool = False, + time_scale_shift_before_filter: bool = True, + data_grid: Literal["legendre-gauss", "equiangular"] = "equiangular", + **kwargs, + ): + super().__init__(**kwargs) + if self.hparams.debug_mode: + self.log_text.info(f"Using debug mode for SFNO.") + embed_dim = self.hparams.embed_dim = 16 + num_layers = self.hparams.num_layers = 2 + # super(SphericalFourierNeuralOperatorNet, self).__init__(meta=MetaData()) + params = params or {} + self.params = params + self.spectral_transform = ( + params.spectral_transform if hasattr(params, "spectral_transform") else spectral_transform + ) + self.filter_type = params.filter_type if hasattr(params, "filter_type") else filter_type + self.operator_type = params.operator_type if hasattr(params, "operator_type") else operator_type + self.img_shape = ( + (params.img_shape_x, params.img_shape_y) + if hasattr(params, "img_shape_x") and hasattr(params, "img_shape_y") + else self.spatial_shape_in + ) + self.scale_factor = params.scale_factor if hasattr(params, "scale_factor") else scale_factor + self.in_chans = ( + params.N_in_channels + if hasattr(params, "N_in_channels") + else self.num_input_channels + self.num_conditional_channels + ) + self.out_chans = params.N_out_channels if hasattr(params, "N_out_channels") else self.num_output_channels + self.embed_dim = self.num_features = params.embed_dim if hasattr(params, "embed_dim") else embed_dim + self.num_layers = params.num_layers if hasattr(params, "num_layers") else num_layers + self.num_blocks = params.num_blocks if hasattr(params, "num_blocks") else num_blocks + self.hard_thresholding_fraction = ( + params.hard_thresholding_fraction + if hasattr(params, "hard_thresholding_fraction") + else hard_thresholding_fraction + ) + self.normalization_layer = ( + params.normalization_layer if hasattr(params, "normalization_layer") else normalization_layer + ) + self.use_mlp = params.use_mlp if hasattr(params, "use_mlp") else use_mlp + self.activation_function = ( + params.activation_function if hasattr(params, "activation_function") else activation_function + ) + self.encoder_layers = params.encoder_layers if hasattr(params, "encoder_layers") else encoder_layers + self.pos_embed = params.pos_embed if hasattr(params, "pos_embed") else pos_embed + self.big_skip = params.big_skip if hasattr(params, "big_skip") else big_skip + self.rank = params.rank if hasattr(params, "rank") else rank + self.factorization = params.factorization if hasattr(params, "factorization") else factorization + self.separable = params.separable if hasattr(params, "separable") else separable + self.complex_network = params.complex_network if hasattr(params, "complex_network") else complex_network + self.complex_activation = ( + params.complex_activation if hasattr(params, "complex_activation") else complex_activation + ) + self.spectral_layers = params.spectral_layers if hasattr(params, "spectral_layers") else spectral_layers + self.checkpointing = params.checkpointing if hasattr(params, "checkpointing") else checkpointing + # self.pretrain_encoding = params.pretrain_encoding if hasattr(params, "pretrain_encoding") else False + + # compute the downscaled image size + self.h = int(self.img_shape[0] // self.scale_factor) + self.w = int(self.img_shape[1] // self.scale_factor) + + # Compute the maximum frequencies in h and in w + modes_lat = int(self.h * self.hard_thresholding_fraction) + modes_lon = int((self.w // 2 + 1) * self.hard_thresholding_fraction) + + # determine the global padding + img_dist_h = (self.img_shape[0] + comm.get_size("h") - 1) // comm.get_size("h") + img_dist_w = (self.img_shape[1] + comm.get_size("w") - 1) // comm.get_size("w") + self.padding = ( + img_dist_h * comm.get_size("h") - self.img_shape[0], + img_dist_w * comm.get_size("w") - self.img_shape[1], + ) + + # prepare the spectral transforms + if self.spectral_transform == "sht": + sht_handle = th.RealSHT + isht_handle = th.InverseRealSHT + + # parallelism + if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): + polar_group = None if (comm.get_size("h") == 1) else comm.get_group("h") + azimuth_group = None if (comm.get_size("w") == 1) else comm.get_group("w") + thd.init(polar_group, azimuth_group) + sht_handle = thd.DistributedRealSHT + isht_handle = thd.DistributedInverseRealSHT + + # set up + self.trans_down = sht_handle(*self.img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid).float() + self.itrans_up = isht_handle(*self.img_shape, lmax=modes_lat, mmax=modes_lon, grid=data_grid).float() + self.trans = sht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() + self.itrans = isht_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon, grid="legendre-gauss").float() + + elif self.spectral_transform == "fft": + fft_handle = th.RealFFT2 + ifft_handle = th.InverseRealFFT2 + + # effective image size: + self.img_shape_eff = [ + self.img_shape[0] + self.padding[0], + self.img_shape[1] + self.padding[1], + ] + self.img_shape_loc = [ + self.img_shape_eff[0] // comm.get_size("h"), + self.img_shape_eff[1] // comm.get_size("w"), + ] + + if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): + fft_handle = DistributedRealFFT2 + ifft_handle = DistributedInverseRealFFT2 + + self.trans_down = fft_handle(*self.img_shape_eff, lmax=modes_lat, mmax=modes_lon).float() + self.itrans_up = ifft_handle(*self.img_shape_eff, lmax=modes_lat, mmax=modes_lon).float() + self.trans = fft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + self.itrans = ifft_handle(self.h, self.w, lmax=modes_lat, mmax=modes_lon).float() + else: + raise (ValueError("Unknown spectral transform")) + + # use the SHT/FFT to compute the local, downscaled grid dimensions + if (comm.get_size("h") > 1) or (comm.get_size("w") > 1): + self.img_shape_loc = ( + self.trans_down.nlat_local, + self.trans_down.nlon_local, + ) + self.img_shape_eff = [ + self.trans_down.nlat_local + self.trans_down.nlatpad_local, + self.trans_down.nlon_local + self.trans_down.nlonpad_local, + ] + self.h_loc = self.itrans.nlat_local + self.w_loc = self.itrans.nlon_local + else: + self.img_shape_loc = (self.trans_down.nlat, self.trans_down.nlon) + self.img_shape_eff = (self.trans_down.nlat, self.trans_down.nlon) + self.h_loc = self.itrans.nlat + self.w_loc = self.itrans.nlon + + # determine activation function + if self.activation_function == "relu": + self.activation_function = nn.ReLU + elif self.activation_function == "gelu": + self.activation_function = nn.GELU + elif self.activation_function == "silu": + self.activation_function = nn.SiLU + else: + raise ValueError(f"Unknown activation function {self.activation_function}") + + # encoder + encoder_hidden_dim = self.embed_dim + current_dim = self.in_chans + encoder_modules = [] + for i in range(self.encoder_layers): + encoder_modules.append(nn.Conv2d(current_dim, encoder_hidden_dim, 1, bias=True)) + encoder_modules.append(self.activation_function()) + current_dim = encoder_hidden_dim + encoder_modules.append(nn.Conv2d(current_dim, self.embed_dim, 1, bias=False)) + self.encoder = nn.Sequential(*encoder_modules) + + # dropout + self.pos_drop = nn.Dropout(p=pos_emb_dropout) if pos_emb_dropout > 0.0 else nn.Identity() + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, self.num_layers)] + + # pick norm layer + if self.normalization_layer == "layer_norm": + norm_layer0 = partial( + nn.LayerNorm, + normalized_shape=(self.img_shape_loc[0], self.img_shape_loc[1]), + eps=1e-6, + ) + norm_layer1 = partial(nn.LayerNorm, normalized_shape=(self.h_loc, self.w_loc), eps=1e-6) + elif self.normalization_layer == "instance_norm": + if comm.get_size("spatial") > 1: + norm_layer0 = partial( + DistributedInstanceNorm2d, + num_features=self.embed_dim, + eps=1e-6, + affine=True, + ) + else: + norm_layer0 = partial( + nn.InstanceNorm2d, + num_features=self.embed_dim, + eps=1e-6, + affine=True, + track_running_stats=False, + ) + norm_layer1 = norm_layer0 + elif self.normalization_layer == "none": + norm_layer0 = nn.Identity + norm_layer1 = norm_layer0 + else: + raise NotImplementedError(f"Error, normalization {self.normalization_layer} not implemented.") + # time embedding + self.time_dim = None + self.with_time_emb = with_time_emb + if with_time_emb: + pos_emb_dim = self.embed_dim + sinusoidal_embedding = "true" + + self.time_dim = self.embed_dim * time_dim_mult + self.time_rescale = time_rescale + self.min_time, self.max_time = None, None + self.time_scaler = 1.0 + self.time_shift = 0.0 + self.time_emb_mlp = get_time_embedder(self.time_dim, pos_emb_dim, sinusoidal_embedding) + else: + self.time_rescale = False + + # FNO blocks + self.blocks = nn.ModuleList([]) + for i in range(self.num_layers): + first_layer = i == 0 + last_layer = i == self.num_layers - 1 + + forward_transform = self.trans_down if first_layer else self.trans + inverse_transform = self.itrans_up if last_layer else self.itrans + + inner_skip = "linear" + outer_skip = "identity" + + if first_layer: + norm_layer = (norm_layer0, norm_layer1) + elif last_layer: + norm_layer = (norm_layer1, norm_layer0) + else: + norm_layer = (norm_layer1, norm_layer1) + + filter_type = self.filter_type + + operator_type = self.operator_type + + block = FourierNeuralOperatorBlock( + forward_transform, + inverse_transform, + self.embed_dim, + filter_type=filter_type, + operator_type=operator_type, + mlp_ratio=mlp_ratio, + drop_rate_filter=dropout_filter, + drop_rate_mlp=dropout_mlp, + drop_path=dpr[i], + act_layer=self.activation_function, + norm_layer=norm_layer, + sparsity_threshold=sparsity_threshold, + use_complex_kernels=use_complex_kernels, + inner_skip=inner_skip, + outer_skip=outer_skip, + use_mlp=self.use_mlp, + rank=self.rank, + factorization=self.factorization, + separable=self.separable, + complex_network=self.complex_network, + complex_activation=self.complex_activation, + spectral_layers=self.spectral_layers, + checkpointing=self.checkpointing, + time_emb_dim=self.time_dim, + time_scale_shift_before_filter=time_scale_shift_before_filter, + ) + + self.blocks.append(block) + + self.decoder = self.get_head() + # learned position embedding + if self.pos_embed: + # currently using deliberately a differently shape position embedding + self.pos_embed = nn.Parameter(torch.zeros(1, self.embed_dim, self.img_shape_loc[0], self.img_shape_loc[1])) + # self.pos_embed = nn.Parameter( torch.zeros(1, self.embed_dim, self.img_shape_eff[0], self.img_shape_eff[1]) ) + self.pos_embed.is_shared_mp = ["matmul"] + trunc_normal_(self.pos_embed, std=0.02) + + self.apply(self._init_weights) + + def get_head(self): + decoder_hidden_dim = self.embed_dim + current_dim = self.embed_dim + self.big_skip * self.in_chans + decoder_modules = [] + for i in range(self.encoder_layers): + decoder_modules.append(nn.Conv2d(current_dim, decoder_hidden_dim, 1, bias=True)) + decoder_modules.append(self.activation_function()) + current_dim = decoder_hidden_dim + decoder_modules.append(nn.Conv2d(current_dim, self.out_chans, 1, bias=False)) + decoder = nn.Sequential(*decoder_modules) + return decoder + + def _init_weights(self, m): + """Helper routine for weight initialization""" + if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm) or isinstance(m, FusedLayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): # pragma: no cover + """Helper""" + return {"pos_embed", "cls_token"} + + def set_min_max_time(self, min_time: float, max_time: float): + """Use time stats to rescale time input to [0, 1000]. + For example, if min_time = 0 and max_time = 100, then time_scaler = 10.0. + """ + self.min_time, self.max_time = min_time, max_time + if self.time_rescale: + self.time_scaler = 1000.0 / (max_time - min_time) + self.time_shift = -min_time + self.log_text.info( + f"Time rescaling: min_time: {min_time}, max_time: {max_time}, time_scaler: {self.time_scaler}, time_shift: {self.time_shift}" + ) + else: + self.log_text.info(f"Time stats will be checked: min_time: {min_time}, max_time: {max_time}") + + def forward_features(self, x, time=None): + if self.with_time_emb: + assert ( + self.min_time is not None and self.max_time is not None + ), "min_time and max_time must be set before using time embedding" + assert (self.min_time <= time).all() and ( + time <= self.max_time + ).all(), f"time must be in [{self.min_time}, {self.max_time}], but time is {time}" + if self.time_rescale: + time = time * self.time_scaler + self.time_shift + t_repr = self.time_emb_mlp(time) + else: + t_repr = None + + for i, blk in enumerate(self.blocks): + # if x.shape[0] == 0: raise ValueError(f'x.shape[0] == 0. x.shape: {x.shape}, block i: {i}') + if self.checkpointing >= 3: + x = checkpoint(blk, x, time_emb=t_repr) + else: + x = blk(x, time_emb=t_repr) + return x, t_repr + + def forward( + self, inputs, time=None, condition=None, static_condition=None, return_time_emb: bool = False, **kwargs + ): + # print(f"{(inputs.shape if inputs is not None else None)}, {(condition.shape if condition is not None else None)}, {(static_condition.shape if static_condition is not None else None)}") + x = self.concat_condition_if_needed(inputs, condition, static_condition) + # if x.shape[0] == 0: raise ValueError(f'x.shape[0] == 0. x.shape: {x.shape}, inputs.shape: {inputs.shape}, condition.shape: {condition.shape}') + # save big skip + if self.big_skip: + residual = x + + if self.checkpointing >= 1: + x = checkpoint(self.encoder, x) + else: + x = self.encoder(x) + + if hasattr(self, "pos_embed"): + # old way of treating unequally shaped weights + if self.img_shape_loc != tuple(self.img_shape_eff): + print( + f"Warning: using differently shaped position embedding {self.img_shape_loc} vs {self.img_shape_eff}, shape: {self.img_shape}, x.shape: {x.shape}, pos_embed.shape: {self.pos_embed.shape}" + ) + xp = torch.zeros_like(x) + xp[..., : self.img_shape_loc[0], : self.img_shape_loc[1]] = ( + x[..., : self.img_shape_loc[0], : self.img_shape_loc[1]] + self.pos_embed + ) + x = xp + else: + x = x + self.pos_embed + + # maybe clean the padding just in case + x = self.pos_drop(x) + + x, t_repr = self.forward_features(x, time) + + if self.big_skip: + x = torch.cat((x, residual), dim=1) + + if self.checkpointing >= 1: + x = checkpoint(self.decoder, x) + else: + x = self.decoder(x) + + if return_time_emb: + return x, t_repr + return x diff --git a/src/models/unet.py b/src/models/unet.py new file mode 100644 index 0000000..6974902 --- /dev/null +++ b/src/models/unet.py @@ -0,0 +1,376 @@ +from functools import partial +from typing import Optional + +import torch +import torch.nn.functional as F +from einops import rearrange +from torch import nn + +from src.models._base_model import BaseModel +from src.models.modules.attention import Attention, LinearAttention +from src.models.modules.convs import WeightStandardizedConv2d +from src.models.modules.misc import Residual, get_time_embedder +from src.models.modules.net_norm import PreNorm +from src.utilities.utils import default, exists + + +def Upsample(dim, dim_out=None, scale_factor=2): + return nn.Sequential( + nn.Upsample(scale_factor=scale_factor, mode="nearest"), nn.Conv2d(dim, default(dim_out, dim), 3, padding=1) + ) + + +def Downsample(dim, dim_out=None): + return nn.Conv2d(dim, default(dim_out, dim), 4, 2, 1) + + +class LayerNorm(nn.Module): + def __init__(self, dim): + super().__init__() + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + + def forward(self, x): + # x is of shape (batch, channels, height, width) + # to use it with (batch, tokens, dim) we need to reshape it to (batch, dim, tokens) + eps = 1e-5 if x.dtype == torch.float32 else 1e-3 + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) * (var + eps).rsqrt() * self.g + + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8, dropout: float = 0.0): + super().__init__() + self.proj = WeightStandardizedConv2d(dim, dim_out, 3, padding=1) + try: + self.norm = nn.GroupNorm(groups, dim_out) + except ValueError as e: + raise ValueError(f"You misspecified the parameter groups={groups} and dim_out={dim_out}") from e + self.act = nn.SiLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + x = self.dropout(x) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + dim, + dim_out, + *, + time_emb_dim=None, + groups=8, + double_conv_layer: bool = True, + dropout1: float = 0.0, + dropout2: float = 0.0, + ): + super().__init__() + self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups, dropout=dropout1) + self.block2 = Block(dim_out, dim_out, groups=groups, dropout=dropout2) if double_conv_layer else nn.Identity() + self.residual_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + + return h + self.residual_conv(x) + + +# model +class Unet(BaseModel): + def __init__( + self, + dim, + init_dim=None, + dim_mults=(1, 2, 4, 8), + resnet_block_groups=8, + with_time_emb: bool = False, + time_dim_mult: int = 2, + block_dropout: float = 0.0, # for second block in resnet block + block_dropout1: float = 0.0, # for first block in resnet block + attn_dropout: float = 0.0, + input_dropout: float = 0.0, + double_conv_layer: bool = True, + learned_variance=False, + learned_sinusoidal_cond=False, + learned_sinusoidal_dim=16, + outer_sample_mode: str = None, # bilinear or nearest + upsample_dims: tuple = None, # (256, 256) or (128, 120) etc. + keep_spatial_dims: bool = False, + init_kernel_size: int = 7, + init_padding: int = 3, + init_stride: int = 1, + num_conditions: int = 0, + dim_head: int = 32, + **kwargs, + ): + super().__init__(**kwargs) + # determine dimensions + assert self.num_input_channels is not None, "Please specify ``num_input_channels`` in the model config." + assert self.num_output_channels is not None, "Please specify ``num_output_channels`` in the model config." + assert ( + self.num_conditional_channels is not None + ), "Please specify ``num_conditional_channels`` in the model config." + # raise_error_if_invalid_value(conditioning_mode, ["concat", "cross_attn"], "conditioning_mode") + if self.hparams.debug_mode: + self.hparams.dim_mults = dim_mults = (1, 1, 1) + self.hparams.dim = dim = 8 + input_channels = self.num_input_channels + self.num_conditional_channels + output_channels = self.num_output_channels or input_channels + self.save_hyperparameters() + + if num_conditions >= 1: + assert ( + self.num_conditional_channels > 0 + ), f"num_conditions is {num_conditions} but num_conditional_channels is {self.num_conditional_channels}" + + init_dim = default(init_dim, dim) + assert (upsample_dims is None and outer_sample_mode is None) or ( + upsample_dims is not None and outer_sample_mode is not None + ), "upsample_dims and outer_sample_mode must be both None or both not None" + # To keep spatial dimensions for uneven spatial sizes, we need to use nearest upsampling + # and then crop the output to the desired size + if outer_sample_mode is not None: + # Upsample (45, 90) to be easier to divide by 2 multiple times + # upsample_dims = (48, 96) + self.upsampler = torch.nn.Upsample(size=tuple(upsample_dims), mode=outer_sample_mode) + else: + self.upsampler = None + + self.init_conv = nn.Conv2d( + input_channels, + init_dim, + init_kernel_size, + padding=init_padding, + stride=init_stride, + ) + self.dropout_input = nn.Dropout(input_dropout) + self.dropout_input_for_residual = nn.Dropout(input_dropout) + + if with_time_emb: + pos_emb_dim = dim + sinusoidal_embedding = "learned" if learned_sinusoidal_cond else "true" + self.time_dim = dim * time_dim_mult + self.time_emb_mlp = get_time_embedder( + self.time_dim, pos_emb_dim, sinusoidal_embedding, learned_sinusoidal_dim + ) + else: + self.time_dim = None + self.time_emb_mlp = None + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial( + ResnetBlock, + groups=resnet_block_groups, + dropout2=block_dropout, + dropout1=block_dropout1, + double_conv_layer=double_conv_layer, + time_emb_dim=self.time_dim, + ) + # layers + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + linear_attn_kwargs = dict(rescale="qkv", dropout=attn_dropout) + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + do_downsample = not is_last and not keep_spatial_dims + # num_heads = dim // dim_head + num_heads, dim_head = 4, 32 + + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_in), + block_klass(dim_in, dim_in), + ( + Residual( + PreNorm( + dim_in, + fn=LinearAttention( + dim_in, **linear_attn_kwargs, heads=num_heads, dim_head=dim_head + ), + norm=LayerNorm, + ) + ) + ), + Downsample(dim_in, dim_out) if do_downsample else nn.Conv2d(dim_in, dim_out, 3, padding=1), + ] + ) + ) + + mid_dim = dims[-1] + # num_heads = mid_dim // dim_head + num_heads, dim_head = 4, 32 + self.mid_block1 = block_klass(mid_dim, mid_dim) + self.mid_attn = Residual( + PreNorm( + mid_dim, + fn=Attention(mid_dim, dropout=attn_dropout, heads=num_heads, dim_head=dim_head), + norm=LayerNorm, + ) + ) + self.mid_block2 = block_klass(mid_dim, mid_dim) + + if hasattr(self, "spatial_shape_in") and self.spatial_shape_in is not None: + b, s1, s2 = 1, *self.spatial_shape_in + self.example_input_array = [ + torch.rand(b, self.num_input_channels, s1, s2), + torch.rand(b) if with_time_emb else None, + torch.rand(b, self.num_conditional_channels, s1, s2) if self.num_conditional_channels > 0 else None, + ] + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind == (len(in_out) - 1) + do_upsample = not is_last and not keep_spatial_dims + # num_heads = dim_out // dim_head + num_heads, dim_head = 4, 32 + + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out + dim_in, dim_out), + block_klass(dim_out + dim_in, dim_out), + ( + Residual( + PreNorm( + dim_out, + fn=LinearAttention( + dim_out, heads=num_heads, dim_head=dim_head, **linear_attn_kwargs + ), + norm=LayerNorm, + ) + ) + ), + Upsample(dim_out, dim_in) if do_upsample else nn.Conv2d(dim_out, dim_in, 3, padding=1), + ] + ) + ) + + default_out_dim = input_channels * (1 if not learned_variance else 2) + self.out_dim = default(output_channels, default_out_dim) + self.final_res_block = block_klass(dim * 2, dim) + self.final_conv = self.get_head() + + def get_head(self): + return nn.Conv2d(self.hparams.dim, self.out_dim, 1) + + def get_block(self, dim_in, dim_out, dropout: Optional[float] = None): + return ResnetBlock( + dim_in, + dim_out, + groups=self.hparams.resnet_block_groups, + dropout1=dropout or self.hparams.block_dropout1, + dropout2=dropout or self.hparams.block_dropout, + time_emb_dim=self.time_dim, + ) + + def forward( + self, + inputs, + time=None, + condition=None, + static_condition=None, + return_time_emb: bool = False, + get_intermediate_shapes: bool = False, + **kwargs, + ): + x = self.concat_condition_if_needed(inputs, condition, static_condition) + + orig_x_shape = x.shape[-2:] + x = self.upsampler(x) if exists(self.upsampler) else x + try: + x = self.init_conv(x) + except RuntimeError as e: + raise RuntimeError( + f"x.shape: {x.shape}, x.dtype: {x.dtype}, init_conv.weight.shape/dtype: {self.init_conv.weight.shape}/{self.init_conv.weight.dtype}" + ) from e + r = self.dropout_input_for_residual(x) if self.hparams.input_dropout > 0 else x.clone() + x = self.dropout_input(x) + + if exists(self.time_emb_mlp): + try: + t = self.time_emb_mlp(time) + except RuntimeError as e: + raise RuntimeError( + f"Error when embedding AdaLN input. {time.shape=}, {time.dtype=}, time_emb_mlp.weight.shape/dtype: {self.time_emb_mlp[1].weight.shape}/{self.time_emb_mlp[1].weight.dtype}" + ) from e + else: + t = None + + h = [] + for i, (block1, block2, attn, downsample) in enumerate(self.downs): + x = block1(x, t) + h.append(x) + + x = block2(x, t) + x = attn(x) + h.append(x) + + x = downsample(x) + x = self.mid_block1(x, t) + x = self.mid_attn(x) + # print(f'mid_attn: {x.shape}') # e.g. [10, 256, 45, 90]) + x = self.mid_block2(x, t) + if get_intermediate_shapes: + return x + + for i, (block1, block2, attn, upsample) in enumerate(self.ups): + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + + x = torch.cat((x, h.pop()), dim=1) + x = block2(x, t) + x = attn(x) + + x = upsample(x) # each i, except for last, halves channels, doubles spatial dims + # print(f"Upsample {i} shape: {x.shape}.") + + x = torch.cat((x, r), dim=1) + if exists(self.upsampler): + # x = F.interpolate(x, orig_x_shape, mode='bilinear', align_corners=False) + x = F.interpolate(x, size=orig_x_shape, mode=self.hparams.outer_sample_mode) + + x = self.final_res_block(x, t) + x = self.final_conv(x) + return_dict = x + if return_time_emb: + return return_dict, t + return return_dict + + +if __name__ == "__main__": + unet = Unet( + dim=64, + num_input_channels=3, + num_output_channels=3, + spatial_shape_in=(45, 90), + upsample_dims=(48, 96), + outer_sample_mode="bilinear", + ) + x = torch.rand(10, 3, 45, 90) + print(unet.print_intermediate_shapes(x)) diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..90f20c3 --- /dev/null +++ b/src/train.py @@ -0,0 +1,196 @@ +import os.path +import signal +import time + +import hydra +import pytorch_lightning as pl +import torch +import wandb +from omegaconf import DictConfig + +import src.utilities.config_utils as cfg_utils +from src.interface import get_model_and_data +from src.utilities.utils import AlreadyLoggedError, divein, get_logger, melk +from src.utilities.wandb_api import get_run_api +from src.utilities.wandb_callbacks import MyWandbLogger + + +log = get_logger(__name__) + + +def run_model(config: DictConfig) -> float: + r""" + This function runs/trains/tests the model. + + .. note:: + It is recommended to call this function by running its underlying script, ``src.train.py``, + as this will enable you to make the best use of the command line integration with Hydra. + For example, you can easily train a UNet for 10 epochs on the CPU with: + + >>> python train.py trainer.max_epochs=10 trainer.accelerator="cpu" model=unet_resnet callbacks=default + + Args: + config: A DictConfig object generated by hydra containing the model, data, callbacks & trainer configuration. + + Returns: + float: the best model score reached while training the model. + E.g. "val/mse", the mean squared error on the validation set. + """ + # Seed for reproducibility + pl.seed_everything(config.seed) + ckpt_path_orig = config.get("ckpt_path") + + # If not resuming training, check if run already exists (with same hyperparameters and seed) + config = cfg_utils.extras(config, if_wandb_run_already_exists="resume") + + # Init Lightning callbacks and loggers (e.g. model checkpointing and Wandb logger) + callbacks = cfg_utils.get_all_instantiable_hydra_modules(config, "callbacks") + loggers = cfg_utils.get_all_instantiable_hydra_modules(config, "logger") + + wandb_id = config.logger.wandb.get("id") if config.get("logger") and hasattr(config.logger, "wandb") else None + uses_wandb = wandb_id is not None + # Get wandb.loggers.WandbLogger instance + if uses_wandb: + wandb_logger = [logger for logger in loggers if isinstance(logger, MyWandbLogger)] + assert len(wandb_logger) == 1, f"Expected exactly one MyWandbLogger, but got {len(wandb_logger)}!" + wandb_logger = wandb_logger[0] + cfg_utils.save_hydra_config_to_wandb(config) + + # Print config. For pretty print, rich package needs to be installed (optional) + if config.get("print_config"): + cfg_utils.print_config(config, fields="all") + + if uses_wandb and config.wandb_status == "resume": + # Reload model checkpoint if needed to resume training + # Set the checkpoint to reload + ckpt_filename = config.get("ckpt_path") + if config.get("eval_mode"): + # Use best model checkpoint for testing + ckpt_filename = ( + ckpt_filename + or os.path.join(config.ckpt_dir.replace("-test", ""), config.logger.wandb.training_id, ckpt_path_orig) + or "best.ckpt" + ) + else: + # Use last model checkpoint for resuming training + if ckpt_filename is None: + ckpt_filename = "last.ckpt" + elif not ckpt_filename.endswith("last.ckpt"): + log.warning(f'Checkpoint used to resume training is not "last.ckpt" but "{ckpt_filename}"!') + + if os.path.exists(ckpt_filename) and str(config.logger.wandb.training_id) in str(ckpt_filename): + # Load model checkpoint from local file. For safety, only do this if the wandb run id is in the path. + ckpt_path = ckpt_filename + log.info(f"Loading model weights from local checkpoint: ``{ckpt_path}``") + else: + # if config.get("eval_mode") and os.path.exists(ckpt_filename): + # print(f"ckpt_filename: {ckpt_filename} exists. ckpt_path=", config.get("ckpt_path")) + # Note: train_run_path can be different from wandb.run.path + train_run_path = config.logger.wandb.get("train_run_path", config.logger.wandb.run_path) + training_run_id = config.logger.wandb.training_id + assert ( + str(training_run_id) in train_run_path + ), f"Training run id {training_run_id} not in run path {train_run_path}" + ckpt_path = f"{training_run_id}-{ckpt_filename}" + if os.path.exists(ckpt_path): + try: + os.remove(ckpt_path) # Re-download the model checkpoint in case it is obsolete version + except FileNotFoundError: + pass # weird but ok + assert not os.path.exists(ckpt_path), f"{ckpt_path=} already exists!" + if True: # not os.path.exists(ckpt_path): + # Download model checkpoint from wandb (using the training run id) + wandb_logger.restore_checkpoint( + ckpt_filename, ckpt_path, run_path=train_run_path, root=os.getcwd(), restore_from="any" + ) + else: + assert config.get("eval_mode") is None, "eval_mode without wandb is not supported" + ckpt_path = None + + # Obtain the instantiated model and data classes from the config + model, datamodule = get_model_and_data(config) + + # Init Lightning trainer + trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=loggers) + + # Send some parameters from config to be saved by the lightning loggers + cfg_utils.log_hyperparameters( + config=config, model=model, data_module=datamodule, trainer=trainer, callbacks=callbacks + ) + if config.get("eval_mode"): + if config.get("model") is not None: + assert ckpt_path is not None, "ckpt_path must be provided in eval_mode" + # Log epoch and step of the model checkpoint to be tested + ckpt = torch.load(ckpt_path, map_location="cpu") + epoch, global_step = ckpt.get("epoch", -1), ckpt.get("global_step", -1) + step_of_ckpt = {"epoch": epoch, "global_step": global_step, "ckpt_path_orig": ckpt_path_orig} + trainer.logger.log_hyperparams(step_of_ckpt) + model._default_global_step = global_step + model._default_epoch = epoch + + else: + if hasattr(signal, "SIGUSR1"): # Windows does not support signals + signal.signal(signal.SIGUSR1, melk(trainer, config.ckpt_dir)) + signal.signal(signal.SIGUSR2, divein(trainer)) + + def fit(ckpt_filepath=None): + trainer.fit(model, datamodule=datamodule, ckpt_path=ckpt_filepath) + + try: + # Train the model + fit(ckpt_filepath=ckpt_path) + log.info(" ---------------- Training finished successfully ----------------") + except Exception as e: + melk(trainer, config.ckpt_dir)() + raise e + + # Testing: + if config.get("test_after_training") or config.get("eval_mode") == "test": + if config.get("eval_mode") == "test": + test_what = {"ckpt_path": ckpt_path, "model": model} + else: + # Testing after training --> use the best model checkpoint + test_what = {"ckpt_path": "best"} if hasattr(callbacks, "model_checkpoint") else {"model": model} + + try: + trainer.test(datamodule=datamodule, **test_what) + except AlreadyLoggedError as e: + log.warning(f"Test already logged: {e}") + if uses_wandb and config.logger.wandb.get("train_run_path") is not None: + # Set flag to indicate that the model has been tested + train_run_api = get_run_api(run_path=config.logger.wandb.train_run_path) + # train_run_api.summary.get("epoch", -1) + train_run_api.summary["tested"] = True + train_run_api.update() + + if config.get("eval_mode") == "validate": + # Validate using the model + if hasattr(model.hparams, "inference_val_every_n_epochs"): + model.hparams.inference_val_every_n_epochs = 1 + trainer.validate(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + if config.get("eval_mode") == "predict": + # Predict using the model + trainer.predict(model=model, datamodule=datamodule, ckpt_path=ckpt_path) + + if uses_wandb: + try: + time.sleep(5) # Sleep for 5 seconds to not over-print the wandb finish message + wandb.finish() + log.info(" ---------------- Sleeping for 5 seconds to make sure wandb finishes... ----------------") + time.sleep(5) + except (FileNotFoundError, PermissionError) as e: + log.info(f"Wandb finish error:\n{e}") + + # return best score (e.g. validation mse). This is useful when using Hydra+Optuna HP tuning. + return trainer.checkpoint_callback.best_model_score + + +@hydra.main(config_path="configs/", config_name="main_config.yaml", version_base=None) +def main(config: DictConfig) -> float: + """Run/train model based on the config file configs/main_config.yaml (and any command-line overrides).""" + return run_model(config) + + +if __name__ == "__main__": + main() diff --git a/src/utilities/__init__.py b/src/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utilities/checkpointing.py b/src/utilities/checkpointing.py new file mode 100644 index 0000000..03392f7 --- /dev/null +++ b/src/utilities/checkpointing.py @@ -0,0 +1,154 @@ +import os +import re +from typing import Optional + +import numpy as np +import torch +from huggingface_hub import hf_hub_download +from omegaconf import DictConfig + +from src.utilities.utils import get_logger + + +log = get_logger(__name__) + + +# try: +# torch.serialization.add_safe_globals([ListConfig]) +# except AttributeError: +# log.warning("torch.serialization.add_safe_globals([ListConfig]) not supported in this version of PyTorch") + + +def get_local_ckpt_path( + config: DictConfig, + wandb_run, #: wandb.apis.public.Run, + ckpt_filename: str = "last.ckpt", + throw_error_if_local_not_found: bool = False, +) -> Optional[str]: + potential_dirs = [ + config.ckpt_dir, + os.path.join(config.work_dir.replace("-test", ""), "checkpoints"), + os.path.join(os.getcwd(), "results", "checkpoints"), + ] + for callback_k in config.get("callbacks", {}).keys(): + if "checkpoint" in callback_k and config.callbacks[callback_k] is not None: + potential_dirs.append(config.callbacks[callback_k].dirpath) + + for local_dir in potential_dirs: + log.info(f"Checking {local_dir}. {os.path.exists(local_dir)=}") + if not os.path.exists(local_dir): + continue + if wandb_run.id not in local_dir: + local_dir = os.path.join(local_dir, wandb_run.id) + if not os.path.exists(local_dir): + continue + ckpt_files = [f for f in os.listdir(local_dir) if f.endswith(".ckpt")] + if ckpt_filename == "last.ckpt": + ckpt_files = [f for f in ckpt_files if "last" in f] + if len(ckpt_files) == 0: + continue + elif len(ckpt_files) == 1: + latest_ckpt_file = ckpt_files[0] + else: + # Get their epoch numbers from inside the file + # epochs = [torch.load(os.path.join(local_dir, f), weights_only=True)["epoch"] for f in ckpt_files] + epochs = [torch.load(os.path.join(local_dir, f))["epoch"] for f in ckpt_files] + # Find the ckpt file with the latest epoch + latest_ckpt_file = ckpt_files[np.argmax(epochs)] + log.info( + f"Found multiple last-v.ckpt files. Using the one with the highest epoch: {latest_ckpt_file}. ckpt_to_epoch: {dict(zip(ckpt_files, epochs))}" + ) + return os.path.join(local_dir, latest_ckpt_file) + + elif ckpt_filename in ["earliest_epoch", "latest_epoch", "earliest_epoch_any", "latest_epoch_any"]: + # Find the earliest epoch ckpt file + if ckpt_filename in ["earliest_epoch_any", "latest_epoch_any"]: + ckpt_files = [f for f in ckpt_files if "epoch" in f] + else: + ckpt_files = [f for f in ckpt_files if "epoch" in f and "epochepoch=" not in f] + if len(ckpt_files) == 0: + continue + + # Function to extract the epoch number from the filename + def get_epoch_number(filename): + if "_any" in ckpt_filename: + filename = filename.replace("epochepoch=", "epoch") # Fix for a bug in the filename + match = re.search(r"_epoch(\d+)_", filename) + return int(match.group(1)) + + # Find the ckpt file with the earliest epoch + min_or_max = min if ckpt_filename == "earliest_epoch" else max + earliest_ckpt_file = min_or_max(ckpt_files, key=lambda f: get_epoch_number(f)) + log.info(f"For ckpt_filename={ckpt_filename}, found ckpt file: {earliest_ckpt_file} in {local_dir}") + return os.path.join(local_dir, earliest_ckpt_file) + + ckpt_path = os.path.join(local_dir, ckpt_filename) + if os.path.exists(ckpt_path): + return ckpt_path + else: + log.warning(f"{local_dir} exists but could not find {ckpt_filename=}. Files in dir: {ckpt_files}.") + if ckpt_filename in ["earliest_epoch", "latest_epoch", "earliest_epoch_any", "latest_epoch_any"]: + raise NotImplementedError("Not implemented") + if throw_error_if_local_not_found: + raise FileNotFoundError( + f"Could not find ckpt file {ckpt_filename} in any of the potential dirs: {potential_dirs}" + ) + return None + + +def download_model_from_hf( + repo_id: Optional[str] = None, + filename: Optional[str] = None, + hf_path: Optional[str] = None, + cache_dir: str = "auto", +): + """ + Downloads a model file from Hugging Face Hub + + Args: + repo_id (str): Hugging Face repository ID, e.g. "username/repo_name" + filename (str): Name of the file to download, e.g. "model.pt" + hf_path (str): Path to the model on Hugging Face Hub. "/" + cache_dir (str): Local directory to save the model + + Returns: + str: Path to the downloaded file + """ + if hf_path is not None: + assert repo_id is None and filename is None, "hf_path should be used alone" + # After last / is the filename + repo_id, filename = hf_path.rsplit("/", 1) + else: + assert repo_id is not None and filename is not None, "repo_id and filename should be used together" + + if filename.endswith(".ckpt") or filename.endswith(".pt"): + dtype = "model" + cache_dir = ".cache/models/" if cache_dir == "auto" else cache_dir + elif filename.endswith(".yaml"): + dtype = "config" + cache_dir = ".cache/configs/" if cache_dir == "auto" else cache_dir + else: + dtype = "data" + cache_dir = ".cache/data/" if cache_dir == "auto" else cache_dir + + os.makedirs(cache_dir, exist_ok=True) + + log.info(f"Downloading {dtype} from Hugging Face Hub: {repo_id}/{filename}. Saving to {cache_dir=}") + model_path = hf_hub_download(repo_id=repo_id, filename=filename, cache_dir=cache_dir) + + return model_path + + +def local_path_to_absolute_and_download_if_needed(path: str) -> Optional[str]: + """ + Convert a local path to an absolute path and download the file if it is a Hugging Face Hub path (starts with "hf:") + """ + if path is None: + return None + if path.startswith("hf:"): + # Download ckpt from huggingface hub + # e.g. "hf:salv47/spherical-dyffusion/interpolator-sfno-best-val_avg_crps.ckpt" + hf_path = path.replace("hf:", "") + path = download_model_from_hf(hf_path=hf_path) + path = os.path.abspath(path) + return path diff --git a/src/utilities/config_utils.py b/src/utilities/config_utils.py new file mode 100644 index 0000000..3826ba3 --- /dev/null +++ b/src/utilities/config_utils.py @@ -0,0 +1,916 @@ +import os +import sys +import warnings +from datetime import datetime +from typing import List, Sequence, Union + +import hydra +import omegaconf +import pytorch_lightning as pl +import requests +import torch +import wandb +from hydra.core.global_hydra import GlobalHydra +from omegaconf import DictConfig, OmegaConf, open_dict +from pytorch_lightning.utilities import rank_zero_only + +from src.utilities import wandb_api +from src.utilities.checkpointing import get_local_ckpt_path +from src.utilities.naming import clean_name, get_detailed_name, get_group_name +from src.utilities.utils import get_logger +from src.utilities.wandb_api import get_existing_wandb_group_runs, get_run_api + + +log = get_logger(__name__) + + +@rank_zero_only +def print_config( + config, + fields: Union[str, Sequence[str]] = ( + "datamodule", + "model", + "trainer", + # "callbacks", + # "logger", + "seed", + ), + resolve: bool = True, + rich_style: str = "magenta", + max_width: int = 128, +) -> None: + """Prints content of DictConfig using Rich library and its tree structure (if installed: ``pip install rich``). + + Credits go to: https://github.com/ashleve/lightning-hydra-template + + Args: + config (ConfigDict): Configuration + fields (Sequence[str], optional): Determines which main fields from config will + be printed and in what order. + resolve (bool, optional): Whether to resolve reference fields of DictConfig. + rich_style (str, optional): Style of Rich library to use for printing. E.g "magenta", "bold", "italic", etc. + """ + import importlib + + if not importlib.util.find_spec("rich") or not importlib.util.find_spec("omegaconf"): + # no pretty printing + log.info(OmegaConf.to_yaml(config, resolve=resolve)) + return + import rich.syntax # IMPORTANT to have, otherwise errors are thrown + import rich.tree + + tree = rich.tree.Tree(":gear: CONFIG", style=rich_style, guide_style=rich_style) + if isinstance(fields, str): + if fields.lower() == "all": + fields = config.keys() + else: + fields = [fields] + + for field in fields: + branch = tree.add(field, style=rich_style, guide_style=rich_style) + + config_section = config.get(field) + branch_content = str(config_section) + if isinstance(config_section, DictConfig): + branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) + + branch.add(rich.syntax.Syntax(branch_content, "yaml")) + + console = rich.console.Console(width=max_width) + console.print(tree) + + +def extras( + config: DictConfig, + if_wandb_run_already_exists: str = "resume", + allow_permission_error: bool = False, +) -> DictConfig: + """A couple of optional utilities, controlled by main config file: + - disabling warnings + - easier access to debug mode + - forcing debug friendly configuration + - forcing multi-gpu friendly configuration + - checking if config values are valid + - init wandb if wandb logging is being used + - Merge config with wandb config if resuming a run + + Credits go to: https://github.com/ashleve/lightning-hydra-template + + While this method modifies DictConfig mostly in place, + please make sure to use the returned config as the new config, especially when resuming a run. + + Args: + if_wandb_run_already_exists (str): What to do if wandb run already exists. Wandb logger must be enabled! + Options are: + - 'resume': resume the run + - 'new': create a new run + - 'abort': raise an error if run already exists and abort + allow_permission_error (bool): Whether to allow PermissionError when creating working dir. + """ + USE_WANDB = "logger" in config.keys() and config.logger.get("wandb") and hasattr(config.logger.wandb, "_target_") + if USE_WANDB: + run_api = None + os.environ["WANDB_HTTP_TIMEOUT"] = "200" # Increase timeout for slow connections + wandb_cfg = config.logger.wandb + wandb_api.PROJECT = wandb_cfg.get("project", wandb_api.PROJECT) + wandb_api._ENTITY = config.logger.wandb.entity = wandb_api.get_entity(wandb_cfg.get("entity")) + + if wandb_cfg.get("id") or wandb_cfg.get("resume_run_id"): + wandb_status = "resume" + if wandb_cfg.get("id"): + assert not wandb_cfg.get( + "resume_run_id" + ), "Both wandb.id and wandb.resume_run_id are set. Only one should be set." + resume_run_id = str(wandb_cfg.id) + config.logger.wandb.id = resume_run_id + log.info(f"Resuming experiment with wandb run ID = {resume_run_id}") + else: + assert not wandb_cfg.get( + "id" + ), "Both wandb.id and wandb.resume_run_id are set. Only one should be set." + resume_run_id = str(wandb_cfg.resume_run_id) + config.logger.wandb.id = wandb_api.get_wandb_id_for_run() + log.info( + f"Resuming experiment with wandb run ID = {resume_run_id} on NEW run: ``{config.logger.wandb.id}``" + ) + + run_api = get_run_api( + run_id=resume_run_id, entity=config.logger.wandb.entity, project=config.logger.wandb.project + ) + # Set config wandb keys in case they were none, to the wandb defaults + keys_to_set = [k for k in wandb_cfg.keys() if k != "id"] + for k in keys_to_set: + config.logger.wandb[k] = getattr(run_api, k) if hasattr(run_api, k) else wandb_cfg[k] + if resume_run_id != wandb_cfg.id: + # Give a new name to the run based on config values (will be updated later) + config.logger.wandb.name = None + + else: + if not wandb_cfg.get("group"): # no wandb group has been assigned yet + group_name = get_group_name(config) + # potentially truncate the group name to 128 characters (W&B limit) + if len(group_name) >= 128: + group_name = group_name.replace("-fcond", "").replace("DynTime", "DynT") + group_name = group_name.replace("UNetResNet", "UNetRN") + group_name = group_name.replace("NavierStokes", "NS") + group_name = group_name.replace("DYffusion", "DY2s") + group_name = group_name.replace("SimpleUnet", "sUNet") + group_name = group_name.replace("lRecs_", "lRs_") + if len(group_name) >= 128: + group_name = group_name.replace("_cos_LC10:400", "cosSTD") + group_name = group_name.replace("1-2-2-3-4", "12234") + if len(group_name) >= 128: + group_name = group_name.replace("_cos_LC", "_cL") + if len(group_name) >= 128: + group_name = group_name.replace("Kolmogorov-", "Kolg-") + + if len(group_name) > 128: + raise ValueError(f"Group name is too long, ({len(group_name)} > 128): {group_name}") + config.logger.wandb.group = group_name + group = config.logger.wandb.group + + if if_wandb_run_already_exists in ["abort", "resume"]: + wandb_status = "new" + runs_in_group = get_existing_wandb_group_runs(config, ckpt_must_exist=True, only_best_and_last=False) + if len(runs_in_group) > 0: + log.info(f"Found {len(runs_in_group)} runs for group {group}") + for other_run in runs_in_group: + other_seed = other_run.config.get("seed") + if other_seed is None: + # Name follows Kolmogorov-MH16_ar2_UNetR_EMA_64x1-2-3-4d_54lr_30at30b10b1Dr_14wd_cos_LC10:400_11seed_23h21mJul01_1634972 + try: + split_seed = other_run.name.split("seed_")[0].split("_")[-1] + other_seed = int(split_seed) if split_seed.isdigit() else None + except Exception: + continue + if other_seed is None: + continue + + if int(other_seed) != int(config.seed): + continue + # seeds are the same, so we treat this as a duplicate run + state = other_run.state + if if_wandb_run_already_exists == "abort": + raise RuntimeError( + f"Run with seed {config.seed} already exists in group {group}. State: {state}" + ) + elif if_wandb_run_already_exists == "resume": + wandb_status = "resume" + config.ckpt_path = ( + get_local_ckpt_path(config, other_run, ckpt_filename=config.ckpt_path or "last.ckpt") + or config.ckpt_path + ) # try local ckpt first, otherwise download from wandb or S3 + config.logger.wandb.resume = "allow" # was "allow" but "must" is more clear (?) + config.logger.wandb.id = other_run.id + config.logger.wandb.name = other_run.name + log.info( + f"Resuming run {other_run.id} from group {group}. Seed={other_seed}; State was: ``{state}``" + ) + else: + raise ValueError(f"if_wandb_run_already_exists={if_wandb_run_already_exists} is not supported") + break + elif if_wandb_run_already_exists in [None, "ignore"]: + wandb_status = "resume" + else: + wandb_status = "???" + + if config.logger.wandb.get("id") is None: + # no wandb id has been assigned yet + config.logger.wandb.id = wandb_api.get_wandb_id_for_run() + + elif if_wandb_run_already_exists in ["abort", "resume"]: + wandb_status = "not_used" + log.warning("Not checking if run already exists, since wandb logging is not being used") + + else: + wandb_status = None + + if wandb_status == "resume": + # Reload config from wandb + run_path = f"{config.logger.wandb.entity}/{config.logger.wandb.project}/{config.logger.wandb.id}" + + # NEW CODE: + if run_api is None: + run_api = get_run_api(run_path=run_path) + # original overrides + command line overrides (latter take precedence) + overrides = run_api.metadata["args"] + sys.argv[1:] + GlobalHydra.instance().clear() + with hydra.initialize(version_base=None, config_path="../configs"): + new_config = hydra.compose(config_name="main_config.yaml", overrides=overrides) + with open_dict(new_config): + new_config.logger.wandb = config.logger.wandb + new_config.ckpt_path = config.ckpt_path + config = new_config + + # OLD CODE: + # override_config = get_only_overriden_config(config) + # with open_dict(override_config): + # override_config.logger.wandb.resume = config.logger.wandb.resume + # override_config.ckpt_path = config.ckpt_path + # config = wandb_api.load_hydra_config_from_wandb(run_path, override_config=override_config) + # END OF OLD CODE + with open_dict(config): + config.logger.wandb.run_path = run_path + + if USE_WANDB and not config.logger.wandb.get("name"): # no wandb name has been assigned yet + config.logger.wandb.name = get_detailed_name(config) + # Edit some config values + # Create working dir if it does not exist yet + if config.get("work_dir"): + try: + os.makedirs(name=config.get("work_dir"), exist_ok=True) + except PermissionError as e: + if allow_permission_error: + log.warning(f"PermissionError: {e}") + else: + raise e + + # disable python warnings if + if config.get("ignore_warnings"): + log.info("Disabling python warnings! ") + warnings.filterwarnings("ignore") + + # set if + if config.get("debug_mode"): + log.info("Running in debug mode! ") + if "fast_dev_run" in config.trainer: + config.trainer.fast_dev_run = True + os.environ["HYDRA_FULL_ERROR"] = "1" + os.environ["OC_CAUSE"] = "1" + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + torch.autograd.set_detect_anomaly(True) + with open_dict(config): + config.datamodule.debug_mode = config.datamodule.get("debug_mode", True) + config.model.debug_mode = config.model.get("debug_mode", True) + + # force debugger friendly configuration if + if config.trainer.get("fast_dev_run"): + log.info("Forcing debugger friendly configuration! ") + # Debuggers don't like GPUs or multiprocessing + if config.trainer.get("devices"): + config.trainer.devices = 0 + config.trainer.accelerator = "cpu" + if config.datamodule.get("pin_memory"): + config.datamodule.pin_memory = False + if config.datamodule.get("num_workers"): + config.datamodule.num_workers = 0 + elif config.datamodule.get("num_workers") == -1: + # set num_workers to #CPU cores if + config.datamodule.num_workers = os.cpu_count() + log.info(f"Automatically setting num_workers to {config.datamodule.num_workers} (CPU cores).") + + # force multi-gpu friendly configuration if + strategy = config.trainer.get("strategy", "") + strategy_name = strategy if isinstance(strategy, str) else strategy._target_.lower().split(".")[-1] + if strategy_name.startswith("ddp") or strategy_name.startswith("dp"): + if config.datamodule.get("pin_memory"): + log.info(f"Forcing ddp friendly configuration! ") + config.datamodule.pin_memory = False + + torch_matmul_precision = config.get("torch_matmul_precision", "highest") + if torch_matmul_precision != "highest": + log.info(f"Setting torch matmul precision to ``{torch_matmul_precision}``.") + torch.set_float32_matmul_precision(torch_matmul_precision) + + try: + _ = config.datamodule.get("data_dir") + except omegaconf.errors.InterpolationResolutionError as e: + # Provide more helpful error message for e.g. Windows users where $HOME does not exist by default + raise ValueError( + "Could not resolve ``datamodule.data_dir`` in config. See error message above for details.\n" + " If this is a Windows machine, you may need to set ``data_dir`` to an absolute path, e.g. ``C:/data``.\n" + " You can do so in ``src/configs/datamodule/_base_data_config.yaml`` or with the command line." + ) from e + + if config.module.get("num_predictions", 1) > 1: + monitor = config.module.get("monitor", "") or "" + if "crps" not in monitor and "rmse" in monitor: + # is_ipol_exp = "InterpolationExperiment" in config.module.get("_target_", "") + # new_monitor = "val/" + ("ipol/avg/crps_normed" if is_ipol_exp else "avg/crps_normed") + config.module.monitor = config.module.monitor.replace("rmse", "crps") + log.info(f"Setting {config.module.monitor=} since num_predictions > 1") + + # fix monitor for model_checkpoint and early_stopping callbacks + monitor = config.module.get("monitor", "") or "" + if config.get("callbacks") is not None and monitor: + clbk_ckpt = config.callbacks.get("model_checkpoint", None) + clbk_es = config.callbacks.get("early_stopping", None) + if clbk_ckpt is not None and clbk_ckpt.get("monitor"): + config.callbacks.model_checkpoint.monitor = monitor + if clbk_es is not None and clbk_es.get("monitor"): + config.callbacks.early_stopping.monitor = monitor + + # Set a short name for the model + if config.get("model"): # Some "naive" baselines don't have a model + model_name = config.model.get("name") + if model_name is None or model_name == "": + model_class = config.model.get("_target_") + mixer = config.model.mixer.get("_target_") if config.model.get("mixer") else None + dm_type = config.datamodule.get("_target_") + with open_dict(config): + config.model.name = clean_name(model_class, mixer=mixer, dm_type=dm_type) + + # Detect if using SLURM cluster + if "SLURM_JOB_ID" in os.environ: + with open_dict(config): + config.slurm_job_id = str(os.environ["SLURM_JOB_ID"]) + log.info(f"Detected SLURM job ID: {config.slurm_job_id}") + if "WANDB__SERVICE_WAIT" not in os.environ.keys(): + os.environ["WANDB__SERVICE_WAIT"] = "300" + else: + log.info(f"WANDB__SERVICE_WAIT already set to {os.environ['WANDB__SERVICE_WAIT']}") + + if "SCRIPT_NAME" in os.environ: + script_path = os.environ["SCRIPT_NAME"] + with open_dict(config): + config.script_name = script_path.split("/")[-1] # get only the script name + config.script_path = script_path + log.info(f"Detected script name: {config.script_name}") + + check_config_values(config) + + if USE_WANDB: + with open_dict(config): + config.wandb_status = wandb_status + config.logger.wandb.training_id = config.logger.wandb.get("resume_run_id") or config.logger.wandb.id + if config.logger.wandb.get("resume_run_id"): + train_run_path = ( + f"{config.logger.wandb.entity}/{config.logger.wandb.project}/{config.logger.wandb.training_id}" + ) + config.logger.wandb.train_run_path = wandb_api._TRAINING_RUN_PATH = train_run_path + + if config.get("eval_mode"): + assert config.get("eval_mode") in [ + "test", + "predict", + "validate", + ], f"eval_mode={config.get('eval_mode')} not supported!" + tags = list(config.logger.wandb.tags or []) + # Add command line kwargs to wandb tags. (we remove + or ++ from the kwargs) + config.logger.wandb.tags = tags + [ + cli_arg.replace("+", "") for cli_arg in sys.argv[1:] if "=" in cli_arg and len(cli_arg) <= 64 + ] # wandb tag limit is 64 chars + if config.logger.wandb.get("project_test") is None and config.model is not None: + config.logger.wandb.resume = "must" + config.logger.wandb.training_id = config.logger.wandb.id + wandb_api._PROJECT_TRAIN = config.logger.wandb.project + train_run_path = ( + f"{config.logger.wandb.entity}/{wandb_api._PROJECT_TRAIN}/{config.logger.wandb.training_id}" + ) + with open_dict(config): + _ = config.logger.wandb.pop("project_test") # remove project_test + config.logger.wandb.train_run_path = wandb_api._TRAINING_RUN_PATH = train_run_path + + elif config.logger.wandb.get("project_test") is not None and config.model is None: + # E.g. for non-ML models like climatology + with open_dict(config): + config.logger.wandb.project = wandb_api.PROJECT = config.logger.wandb.pop("project_test") + elif config.logger.wandb.get("project_test") is not None: + config.logger.wandb.resume = "allow" # no resuming likely, since different project + config.logger.wandb.training_id = config.logger.wandb.id + wandb_api._PROJECT_TRAIN = config.logger.wandb.project + train_run_path = ( + f"{config.logger.wandb.entity}/{wandb_api._PROJECT_TRAIN}/{config.logger.wandb.training_id}" + ) + train_run = get_run_api(run_path=train_run_path) + with open_dict(config): + config.logger.wandb.train_run_path = wandb_api._TRAINING_RUN_PATH = train_run_path + config.logger.wandb.project = wandb_api.PROJECT = config.logger.wandb.pop("project_test") + config.effective_batch_size = train_run.config.get("effective_batch_size") + + config.logger.wandb.id = wandb_api.get_wandb_id_for_run() # Set a new run ID (!= training run ID) + # Check if a test run already exists + try: + runs_in_group = get_existing_wandb_group_runs( + config, ckpt_must_exist=False, only_best_and_last=False + ) + except requests.exceptions.HTTPError as e: + # 500 Server Error: Internal Server Error for url: https://api.wandb.ai/graphql + log.warning(f"Error when getting runs in group. Not checking for existing test runs. Error: {e}") + runs_in_group = [] + for other_run in runs_in_group: + if other_run.config.get("seed") is None or config.get("eval_mode") == "predict": + continue + elif other_run.tags != config.logger.wandb.tags: + continue + elif ( + other_run.tags == config.logger.wandb.tags + ): # or all(tag in other_run.tags for tag in config.logger.wandb.tags): + # Should do same to set id and name to other_run.id and other_run.name (?!) + raise ValueError( + f"Test run with same tags already exists: {other_run.id}. Tags: {other_run.tags} vs {config.logger.wandb.tags}" + ) + elif other_run.state == "running": + continue + elif other_run.summary.get("TESTED"): # already ran full test + continue + elif int(other_run.config.get("seed")) == int(config.seed): + # If so, resume it + log.info(f">>>>>> Resuming test run {other_run.id} from group {other_run.group}.") + config.logger.wandb.resume = True + config.logger.wandb.id = other_run.id + config.logger.wandb.name = other_run.name + else: + pass # log.info(f"Seed {other_run.config.get('seed')} != {config.seed}") + + if config.get("wandb_status") == "resume": + # try local ckpt first, otherwise we'll download from wandb or S3 + config.ckpt_path = ( + get_local_ckpt_path(config, run_api, ckpt_filename=config.ckpt_path or "last.ckpt") or config.ckpt_path + ) + return config + + +@rank_zero_only +def init_wandb(**kwargs): + """Initialize wandb with the given kwargs. Only runs on rank 0 in distributed training.""" + wandb.init(**kwargs) + log.info(f"Using wandb project {wandb_api.PROJECT} and entity {wandb_api._ENTITY}") + + +def get_only_overriden_config(config: DictConfig) -> DictConfig: + """ + Get only the config values that are different from the default values in configs/main_config.yaml + + Args: + config: Hydra config object with all the config values. + + Returns: + DictConfig: Hydra config object with only the config values that are different from the default values. + """ + from hydra.core.global_hydra import GlobalHydra + + # OLD code: + GlobalHydra.instance().clear() + with hydra.initialize(version_base=None, config_path="../configs"): + config_default = hydra.compose(config_name="main_config.yaml", overrides=[]) + # config_overriden = hydra.compose(config_name="main_config.yaml", overrides=OmegaConf.from_cli()) + diff = get_difference_between_configs(config_default, config, one_sided=True) + # Merge with explicit CLI args in case they happened to be equal to the default values. + # This is needed because the default values may differ from the ones in a reloaded run config. + args_list = [ + arg.replace("+", "") for arg in sys.argv[1:] + ] # this is plain omegaconf, so we need to remove + or ++ from the kwargs + cli_kwargs = OmegaConf.from_cli(args_list=args_list) + log.info("CLI KWARGS:2", cli_kwargs, "\nDIFF:", OmegaConf.to_yaml(diff)) + diff = OmegaConf.merge(diff, cli_kwargs) + log.info("DIFF+CLI:", OmegaConf.to_yaml(diff)) + log.info( + f"modules... Defaul=\n{OmegaConf.to_yaml(config_default.module)}\n\nDiff=\n{OmegaConf.to_yaml(diff.module)}\n\nConfig=\n{OmegaConf.to_yaml(config.module)}" + ) + return diff + + +def get_difference_between_configs(config1: DictConfig, config2: DictConfig, one_sided: bool = False) -> DictConfig: + """ + Get the difference between two OmegaConf DictConfig objects (potentially use the values of config2). + + Args: + config1: OmegaConf DictConfig object. + config2: OmegaConf DictConfig object. Use the values of this config if they are different from config1. + one_sided: If False, values of config1 are included if they don't exist in config2. If True, they are not. + + Returns: + DictConfig: OmegaConf DictConfig object with only the config values that are different between config1 and config2. + That is, values that are either contained in config1 but not config2, or vice versa, or have different values. + """ + # We can convert the DictConfig to a simple dict, and then use set operations to get the difference + # However, we need to resolve the DictConfig first, otherwise we get a TypeError + config1 = OmegaConf.to_container(config1, resolve=True) + config2 = OmegaConf.to_container(config2, resolve=True) + # Get the difference between the two configs + diff = get_difference_between_dicts_nested(config1, config2, one_sided=one_sided) + # Convert back to DictConfig + diff = OmegaConf.create(diff) + return diff + + +def get_difference_between_dicts_nested(dict1: dict, dict2: dict, one_sided: bool = False) -> dict: + """ + Get the difference between two nested dictionaries (potentially use the values of dict2). + + Args: + dict1: Nested dictionary. + dict2: Nested dictionary. Use the values of this dictionary if they are different from dict1. + one_sided: If False, values of config1 are included if they don't exist in config2. If True, they are not. + + Returns: + dict: Nested dictionary with only the values that are different between dict1 and dict2. + That is, values that are either contained in dict1 but not dict2, or vice versa, or have different values. + """ + if dict1 is None: + return dict2 + if dict2 is None: + return dict1 + if not isinstance(dict1, dict) or not isinstance(dict2, dict): + raise TypeError( + f"dict1 and dict2 must be dictionaries! \nGot {type(dict1)}:{dict1}\n and {type(dict2)}:{dict2}." + ) + # Get the difference between the two dicts + if one_sided: + diff = dict() + else: + diff = {k: dict1[k] for k in set(dict1.keys()) - set(dict2.keys())} # keys in dict1 but not dict2 + diff.update({k: dict2[k] for k in set(dict2.keys()) - set(dict1.keys())}) # keys in dict2 but not dict1 + # Keys in both dicts but with different values (use the values of dict2) + for k in set(dict1.keys()) & set(dict2.keys()): + if dict1[k] != dict2[k]: + # If the value is a dict, recursively get the difference between the nested dicts + diff[k] = dict() if isinstance(dict2[k], dict) else dict2[k] + # Recursively get the difference between the nested dicts + for k in diff: + if isinstance(diff[k], dict): + if isinstance(dict1.get(k), dict) and isinstance(dict2.get(k), dict): + diff[k] = get_difference_between_dicts_nested(dict1.get(k), dict2.get(k), one_sided=one_sided) + else: + diff[k] = dict2[k] + return diff + + +def check_config_values(config: DictConfig): + """Check if config values are valid.""" + with open_dict(config): + if config.get("model", default_value=False): + if "net_normalization" in config.model.keys(): + if config.model.net_normalization is None: + config.model.net_normalization = "none" + config.model.net_normalization = config.model.net_normalization.lower() + + if config.get("diffusion", default_value=False): + # Check that diffusion model has same hparams as the model it is based on + config.model.loss_function = None + for k, v in config.model.items(): + if k in config.diffusion.keys() and k not in ["_target_", "name", "loss_function"]: + assert v == config.diffusion[k], f"Diffusion model and model have different values for {k}!" + + # ipolator_id = config.diffusion.get("interpolator_run_id") + # if ipolator_id is not None: + # get_run_api(ipolator_id) + + scheduler_cfg = config.module.get("scheduler") + if scheduler_cfg and "LambdaWarmUpCosineScheduler" in scheduler_cfg.get("_target_", ""): + # set base LR of optim to 1.0, since we will scale it by the warmup factor + config.module.optimizer.lr = 1.0 + + USE_WANDB = ( + "logger" in config.keys() and config.logger.get("wandb") and hasattr(config.logger.wandb, "_target_") + ) + if USE_WANDB: + config.logger.wandb.id = str(config.logger.wandb.id) # convert to string + if not config.get("eval_mode"): + if config.logger.wandb.get("project_test") is not None: + raise ValueError("You are trying to override the wandb project, but you are not in test mode!") + + if "callbacks" in config and not config.get("eval_mode"): + # Add wandb run ID to model checkpoint dir as a subfolder + for name in config.callbacks.keys(): + if "model_checkpoint" not in name or config.callbacks.get(name) is None: + continue + wandb_model_run_id = config.logger.wandb.get("id") + d = config.callbacks[name].dirpath + if not os.path.exists(d) and not os.path.exists(os.path.dirname(d)): + # Run on different system, use config.work_dir/checkpoints + if os.path.exists(config.work_dir): + log.info(f"Changing dirpath={d} to {config.work_dir}/checkpoints for callback {name}.") + d = os.path.join(config.work_dir, "checkpoints") + + if wandb_model_run_id is not None and wandb_model_run_id not in d: + # Save model checkpoints to special folder // + new_dir = os.path.join(d, wandb_model_run_id) + config.callbacks[name].dirpath = new_dir + try: + os.makedirs(new_dir, exist_ok=True) + except PermissionError as e: + raise PermissionError( + f"PermissionError when creating {new_dir} for callback {name}" + ) from e + log.info(f"Model checkpoints for ``{name}`` will be saved in: {os.path.abspath(new_dir)}") + else: + if config.get("callbacks") and "wandb" in config.callbacks: + raise ValueError("You are trying to use wandb callbacks but you aren't using a wandb logger!") + # log.warning("Model checkpoints will not be saved because you are not using wandb!") + config.save_config_to_wandb = False + + if config.module.get("num_predictions", 1) > 1: + # adapt the evaluation batch size to the number of predictions + bs, ebs = config.datamodule.batch_size, config.datamodule.eval_batch_size + if ebs >= bs: + effective_ebs = ebs * config.module.num_predictions + log.info( + f"Note that the effective evaluation batch size will be multiplied by the number of " + f"predictions={config.module.num_predictions} for a total of {effective_ebs}!" + ) + + # Adjust global batch size, batch size per GPU, and accumulate_grad_batches based on the number of GPUs and nodes + n_gpus = config.trainer.get("devices", 1) + n_nodes = int(config.trainer.get("num_nodes", 1)) + if n_gpus == "auto": + n_gpus = int(torch.cuda.device_count()) + elif isinstance(n_gpus, str) and "," in n_gpus: + n_gpus = len(n_gpus.split(",")) + elif isinstance(n_gpus, Sequence): + n_gpus = len(n_gpus) + world_size = int(n_gpus * n_nodes) if n_gpus > 0 else 1 + + if config.datamodule.get("num_workers") == "auto": + if world_size >= 2: + log.info(f"Setting datamodule.num_workers to the number of GPUs (={world_size})!") + config.datamodule.num_workers = world_size # 1 worker per GPU + elif world_size == 1: + log.info("Setting datamodule.num_workers to 8. This might not be optimal for a single GPU!") + config.datamodule.num_workers = 8 + else: + log.warning("Setting datamodule.num_workers to 0. This might not be optimal for CPU training!") + config.datamodule.num_workers = 0 + + if config.get("eval_mode"): + if config.datamodule.get("batch_size_per_gpu") is not None: + log.warning("Ignoring batch_size_per_gpu in eval mode. Use ``datamodule.eval_batch_size`` instead.") + config.datamodule.batch_size_per_gpu = None + else: + # Set the batch size per GPU, and accumulate_grad_batches based on the number of GPUs and nodes + batch_size = int(config.datamodule.get("batch_size", 1)) # Global batch size + bs_per_gpu_total = batch_size // world_size # effective batch size per GPU + bs_per_gpu = config.datamodule.get("batch_size_per_gpu") + if bs_per_gpu is None or bs_per_gpu > bs_per_gpu_total: + bs_per_gpu = bs_per_gpu_total + acc = bs_per_gpu_total // bs_per_gpu + acc2 = config.trainer.get("accumulate_grad_batches") + assert ( + acc2 in [None, 1] or acc2 == acc + ), f"trainer.accumulate_grad_batches={acc2} must be equal to {acc}! (bs_per_gpu_total={bs_per_gpu_total})" + if acc != acc2: + log.warning( + f"trainer.accumulate_grad_batches={acc2} will be set to {acc} to compensate for the number of GPUs and nodes. (bs_per_gpu_total={bs_per_gpu_total})" + ) + effective_ebs = bs_per_gpu * acc * world_size + if effective_ebs != batch_size: + # Check if within 10% of the batch size + calc_str = ( + f"{effective_ebs}={bs_per_gpu} * {acc} * {world_size} (bs_per_gpu * n_acc_grads * world_size)" + ) + bs_warn_suffix = f"to global batch size {batch_size}! (n_gpus={n_gpus}, n_nodes={n_nodes})" + if abs(effective_ebs - batch_size) > 0.1 * batch_size: + raise ValueError(f"effective batch size {calc_str} must be equal {bs_warn_suffix}") + else: + log.warning(f"effective batch size {calc_str} is not equal {bs_warn_suffix}") + config.n_gpus = n_gpus + config.world_size = world_size + config.effective_batch_size = effective_ebs # * acc * n_gpus + config.datamodule.batch_size = bs_per_gpu + config.trainer.accumulate_grad_batches = acc + config.datamodule.batch_size_per_gpu = None + config.datamodule.pop("batch_size_per_gpu", None) # Remove batch_size_per_gpu from config + + # Check if CUDA is available. If not, switch to CPU. + if not torch.cuda.is_available(): + if config.trainer.get("accelerator") == "gpu": + config.trainer.accelerator = "cpu" + config.trainer.devices = 1 # devices = num_processes for CPU + log.warning( + "CUDA is not available, switching to CPU.\n" + "\tIf you want to use GPU, please re-install pytorch: https://pytorch.org/get-started/locally/." + "\n\tIf you want to use a different accelerator, specify it with ``trainer.accelerator=...``." + ) + + +def get_all_instantiable_hydra_modules(config, module_name: str): + modules = [] + if module_name in config: + for _, module_config in config[module_name].items(): + if module_config is not None and "_target_" in module_config: + if "early_stopping" in module_config.get("_target_"): + diffusion = config.get("diffusion", default_value=False) + monitor = module_config.get("monitor", "") + # If diffusion model: Add _step to the early stopping callback key + if diffusion and "step" not in monitor and "epoch" not in monitor: + module_config.monitor += "_step" + log.info("*** Early stopping monitor changed to: ", module_config.monitor) + log.info("----------------------------------------\n" * 20) + + try: + modules.append(hydra.utils.instantiate(module_config)) + except omegaconf.errors.InterpolationResolutionError as e: + log.warning(f" Hydra could not instantiate {module_config} for module_name={module_name}") + raise e + return modules + + +@rank_zero_only +def log_hyperparameters( + config, + model: pl.LightningModule, + data_module: pl.LightningDataModule, + trainer: pl.Trainer, + callbacks: List[pl.Callback], +) -> None: + """This method controls which parameters from Hydra config are saved by Lightning loggers. + Credits go to: https://github.com/ashleve/lightning-hydra-template + + Additionally saves: + - number of {total, trainable, non-trainable} model parameters + """ + + def copy_and_ignore_keys(dictionary, *keys_to_ignore): + if dictionary is None: + return None + new_dict = dict() + for k in dictionary.keys(): + if k not in keys_to_ignore: + new_dict[k] = dictionary[k] + return new_dict + + log_params = dict() + log_params["start_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S") + if "seed" in config: + log_params["seed"] = config["seed"] + + # Remove redundant keys or those that are not important to know after training -- feel free to edit this! + log_params["datamodule"] = copy_and_ignore_keys(config["datamodule"]) + log_params["model"] = copy_and_ignore_keys(config["model"]) + log_params["exp"] = copy_and_ignore_keys(config["module"], "optimizer", "scheduler") + log_params["trainer"] = copy_and_ignore_keys(config["trainer"]) + # encoder, optims, and scheduler as separate top-level key + if "n_gpus" in config.keys(): + log_params["trainer"]["gpus"] = config["n_gpus"] + log_params["optim"] = copy_and_ignore_keys(config["module"]["optimizer"]) + if "base_lr" in config.keys(): + log_params["optim"]["base_lr"] = config["base_lr"] + if "effective_batch_size" in config.keys(): + log_params["optim"]["effective_batch_size"] = config["effective_batch_size"] + if "diffusion" in config: + log_params["diffusion"] = copy_and_ignore_keys(config["diffusion"]) + log_params["scheduler"] = copy_and_ignore_keys(config["module"].get("scheduler", None)) + if config.get("model"): + # Add a clean name for the model, for easier reading (e.g. src.model.MLP.MLP -> MLP) + model_class = config.model.get("_target_") + mixer = config.model.mixer.get("_target_") if config.model.get("mixer") else None + log_params["model/name_id"] = clean_name(model_class, mixer=mixer) + if config.get("logger"): + log_params["wandb"] = copy_and_ignore_keys(config.logger.get("wandb")) + + if "callbacks" in config: + skip_callbacks = ["summarize_best_val_metric", "learning_rate_logging"] + for k, v in config["callbacks"].items(): + if k in skip_callbacks: + continue + elif k == "model_checkpoint": + log_params[k] = copy_and_ignore_keys(v, "save_top_k") + else: + log_params[k] = copy_and_ignore_keys(v) + + # save number of model parameters + log_params["model/params_total"] = sum(p.numel() for p in model.parameters()) + log_params["model/params_trainable"] = sum(p.numel() for p in model.parameters() if p.requires_grad) + log_params["model/params_not_trainable"] = sum(p.numel() for p in model.parameters() if not p.requires_grad) + log_params["dirs/work_dir"] = config.get("work_dir") + log_params["dirs/ckpt_dir"] = config.get("ckpt_dir") + log_params["dirs/wandb_save_dir"] = ( + config.logger.wandb.get("save_dir") if (config.get("logger") and config.logger.get("wandb")) else None + ) + if "BEAKER_EXPERIMENT_ID" in os.environ: + log_params["beaker"] = { + "experiment_id": os.environ["BEAKER_EXPERIMENT_ID"], + "job_id": os.environ["BEAKER_JOB_ID"], + "task_id": os.environ.get("BEAKER_TASK_ID", None), + } + # Add all values that are not dictionaries + for k, v in config.items(): + if not isinstance(v, dict) and k not in log_params.keys(): + log_params[k] = v + + # send hparams to all loggers (if any logger is used) + if trainer.logger is not None: + log.info("Logging hyperparameters to the PyTorch Lightning loggers.") + trainer.logger.log_hyperparams(log_params) + + # disable logging any more hyperparameters for all loggers + # this is just a trick to prevent trainer from logging hparams of model, + # since we already did that above + # trainer.logger.log_hyperparams = no_op + + +@rank_zero_only +def save_hydra_config_to_wandb(config: DictConfig): + # Save the config to the Wandb cloud (if wandb logging is enabled) + if config.get("save_config_to_wandb"): + filename = "hydra_config.yaml" + # Check if ``filename`` already exists in wandb cloud. If so, append a version number to it. + run_api = get_run_api(run_path=wandb.run.path) + version = 2 + run_api_files = [f.name for f in run_api.files()] + while filename in run_api_files: + filename = f"hydra_config-v{version}.yaml" + version += 1 + + log.info(f"Config will be saved to wandb as {filename} and in wandb.run.dir: {os.path.abspath(wandb.run.dir)}") + # files in wandb.run.dir folder get directly uploaded to wandb + filepath = os.path.join(wandb.run.dir, filename) + with open(filepath, "w") as fp: + OmegaConf.save(config, f=fp.name, resolve=True) + wandb.save(filename) + else: + log.info("Hydra config will NOT be saved to WandB.") + + +def get_config_from_hydra_compose_overrides( + overrides: List[str], + config_path: str = "../configs", + config_name: str = "main_config.yaml", +) -> DictConfig: + """ + Function to get a Hydra config manually based on a default config file and a list of override strings. + This is an alternative to using hydra.main(..) and the command-line for overriding the default config. + + Args: + overrides: A list of strings of the form "key=value" to override the default config with. + config_path: Relative path to the folder where the default config file is located. + config_name: Name of the default config file (.yaml ending). + + Returns: + The resulting config object based on the default config file and the overrides. + + Examples: + + .. code-block:: python + + config = get_config_from_hydra_compose_overrides(overrides=['model=mlp', 'model.optimizer.lr=0.001']) + log.info(f"Lr={config.model.optimizer.lr}, MLP hidden_dims={config.model.hidden_dims}") + """ + from hydra.core.global_hydra import GlobalHydra + + overrides = list(set(overrides)) + if "-m" in overrides: + overrides.remove("-m") # if multiruns flags are mistakenly in overrides + # Not true?!: log.info(f" Initializing Hydra from {os.path.abspath(config_path)}/{config_name}") + GlobalHydra.instance().clear() # clear any previous hydra config + hydra.initialize(config_path=config_path, version_base=None) + try: + config = hydra.compose(config_name=config_name, overrides=overrides) + finally: + GlobalHydra.instance().clear() # always clean up global hydra + return config + + +def get_model_from_hydra_compose_overrides(overrides: List[str]): + """ + Function to get a torch model manually based on a default config file and a list of override strings. + + Args: + overrides: A list of strings of the form "key=value" to override the default config with. + + Returns: + The model instantiated from the resulting config. + + Examples: + + .. code-block:: python + + mlp_model = get_model_from_hydra_compose_overrides(overrides=['model=mlp']) + random_mlp_input = torch.randn(1, 100) + random_prediction = mlp_model(random_mlp_input) + """ + from src.interface import get_lightning_module + + cfg = get_config_from_hydra_compose_overrides(overrides) + return get_lightning_module(cfg) diff --git a/src/utilities/lr_scheduler.py b/src/utilities/lr_scheduler.py new file mode 100644 index 0000000..d38fedc --- /dev/null +++ b/src/utilities/lr_scheduler.py @@ -0,0 +1,201 @@ +""" +From: + https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/ldm/lr_scheduler.py +""" + +import math +import warnings +from typing import List + +import numpy as np +import torch + + +class LambdaWarmUpCosineScheduler: + """ + note: use with a base_lr of 1.0 + """ + + def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): + self.lr_warm_up_steps = warm_up_steps + self.lr_start = lr_start + self.lr_min = lr_min + self.lr_max = lr_max + self.lr_max_decay_steps = max_decay_steps + self.last_lr = 0.0 + self.verbosity_interval = verbosity_interval + + def schedule(self, n, **kwargs): + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") + if n < self.lr_warm_up_steps: + lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start + else: + t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) + t = min(t, 1.0) + lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (1 + np.cos(t * np.pi)) + + self.last_lr = lr + return lr + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaWarmUpCosineScheduler2: + """ + supports repeated iterations, configurable via lists + note: use with a base_lr of 1.0. + """ + + def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): + assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) + self.lr_warm_up_steps = warm_up_steps + self.f_start = f_start + self.f_min = f_min + self.f_max = f_max + self.cycle_lengths = cycle_lengths + self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) + self.last_f = 0.0 + self.verbosity_interval = verbosity_interval + + def find_in_interval(self, n): + interval = 0 + for cl in self.cum_cycles[1:]: + if n <= cl: + return interval + interval += 1 + + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) + t = min(t, 1.0) + f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (1 + np.cos(t * np.pi)) + self.last_f = f + return f + + def __call__(self, n, **kwargs): + return self.schedule(n, **kwargs) + + +class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): + def schedule(self, n, **kwargs): + cycle = self.find_in_interval(n) + n = n - self.cum_cycles[cycle] + if self.verbosity_interval > 0: + if n % self.verbosity_interval == 0: + print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " f"current cycle {cycle}") + + if n < self.lr_warm_up_steps[cycle]: + f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] + self.last_f = f + return f + else: + f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / ( + self.cycle_lengths[cycle] + ) + self.last_f = f + return f + + +class LinearWarmupCosineAnnealingLR(torch.optim.lr_scheduler._LRScheduler): + """Sets the learning rate of each parameter group to follow a linear warmup schedule between + warmup_start_lr and base_lr followed by a cosine annealing schedule between base_lr and + eta_min.""" + + def __init__( + self, + optimizer: torch.optim.Optimizer, + warmup_steps: int, + max_steps: int, + warmup_start_lr: float = 0.0, + eta_min: float = 0.0, + last_epoch: int = -1, + ) -> None: + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_steps (int): Maximum number of iterations for linear warmup + max_steps (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ + self.warmup_epochs = warmup_steps + self.max_epochs = max_steps + self.warmup_start_lr = warmup_start_lr + self.eta_min = eta_min + + super().__init__(optimizer, last_epoch) + + def get_lr(self) -> List[float]: + """Compute learning rate using chainable form of the scheduler.""" + if not self._get_lr_called_within_step: + warnings.warn( + "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", + UserWarning, + ) + + if self.last_epoch == self.warmup_epochs: + return self.base_lrs + if self.last_epoch == 0: + return [self.warmup_start_lr] * len(self.base_lrs) + if self.last_epoch < self.warmup_epochs: + return [ + group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: + return [ + group["lr"] + + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 + for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) + ] + + return [ + (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + / ( + 1 + + math.cos( + math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) + ) + ) + * (group["lr"] - self.eta_min) + + self.eta_min + for group in self.optimizer.param_groups + ] + + def _get_closed_form_lr(self) -> List[float]: + """Called when epoch is passed as a param to the `step` function of the scheduler.""" + if self.last_epoch < self.warmup_epochs: + return [ + self.warmup_start_lr + + self.last_epoch * (base_lr - self.warmup_start_lr) / max(1, self.warmup_epochs - 1) + for base_lr in self.base_lrs + ] + + return [ + self.eta_min + + 0.5 + * (base_lr - self.eta_min) + * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) + for base_lr in self.base_lrs + ] + + +def get_scheduler(optimizer, name, **kwargs): + if name == "cosine": + return torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **kwargs) + if name == "linear_warmup_cosine": + return LinearWarmupCosineAnnealingLR(optimizer, **kwargs) + raise ValueError(f"Unknown scheduler {name}") diff --git a/src/utilities/naming.py b/src/utilities/naming.py new file mode 100644 index 0000000..bfc53ad --- /dev/null +++ b/src/utilities/naming.py @@ -0,0 +1,509 @@ +import time +from typing import Dict, Optional + +from omegaconf import DictConfig + + +def _shared_prefix(config: DictConfig, init_prefix: str = "") -> str: + """This is a prefix for naming the runs for a more agreeable logging.""" + s = init_prefix if isinstance(init_prefix, str) else "" + if not config.get("model"): + return s + # Find mixer type if it is a transformer model (e.g. self-attention or FNO mixing) + kwargs = dict(mixer=config.model.mixer._target_) if config.model.get("mixer") else dict() + s += clean_name(config.model._target_, **kwargs) + return s.lstrip("_") + + +def get_name_for_hydra_config_class(config: DictConfig) -> Optional[str]: + """Will return a string that can describe the class of the (sub-)config.""" + if "name" in config and config.get("name") is not None: + return config.get("name") + elif "_target_" in config: + return config._target_.split(".")[-1] + return None + + +def get_clean_float_name(lr: float) -> str: + """Stringify floats <1 into very short format (use for learning rates, weight-decay etc.)""" + # basically, map Ae-B to AB (if lr<1e-5, else map 0.0001 to 1e-4) + # convert first to scientific notation: + if lr >= 0.1: + return str(lr) + lr_e = f"{lr:.1e}" # 1e-2 -> 1.0e-02, 0.03 -> 3.0e-02 + # now, split at the e into the mantissa and the exponent + lr_a, lr_b = lr_e.split("e-") + # if the decimal point is 0 (e.g 1.0, 3.0, ...), we return a simple string + if lr_a[-1] == "0": + return f"{lr_a[0]}{int(lr_b)}" + else: + return str(lr).replace("e-", "") + + +def remove_float_prefix(string, prefix_name: str = "lr", separator="_") -> str: + # Remove the lr and/or wd substrings like: + # 0.0003lr_0.01wd -> '' + # 0.0003lr -> '' + # 0.0003lr_0.5lrecs_0.01wd -> '0.5lrecs' + # 0.0003lr_0.5lrecs -> '0.5lrecs' + # 0.0003lr_0.5lrecs_0.01wd_0.5lrecs -> '0.5lrecs_0.5lrecs' + if prefix_name not in string: + return string + part1, part2 = string.split(prefix_name) + # split at '_' and keep all but the last part + part1keep = "_".join(part1.split(separator)[:-1]) + return part1keep + part2 + + +def get_loss_name(loss): + if isinstance(loss, str): + loss_name = loss.lower() + elif loss.get("_target_", "").endswith("LpLoss"): + p, is_relative = loss.get("p", 2), loss.get("relative") + loss_name = f"l{p}r" if is_relative else f"l{p}a" + else: + assert loss.get("_target_") is not None, f"Unknown loss ``{loss}``" + loss_name = loss.get("_target_").split(".")[-1].lower().replace("loss_function", "").replace("loss", "") + return loss_name + + +def get_detailed_name(config, add_unique_suffix: bool = True) -> str: + """This is a detailed name for naming the runs for logging.""" + s = config.get("name") + "_" if config.get("name") is not None else "" + hor = config.datamodule.get("horizon", 1) + if ( + hor > 1 + and f"H{hor}" not in s + and f"horizon{hor}" not in s.lower() + and f"h{hor}" not in s.lower() + and f"{hor}h" not in s.lower() + and f"{hor}l" not in s.lower() + ): + print( + f"WARNING: horizon {hor} not in name, but should be!", + s, + config.get("name_suffix"), + ) + s = s[:-1] + f"-MH{hor}_" + + s += str(config.get("name_suffix")) + "_" if config.get("name_suffix") is not None else "" + s += _shared_prefix(config) + "_" + + w = config.datamodule.get("window", 1) + if w > 1: + s += f"{w}w_" + + if config.datamodule.get("train_start_date") is not None: + s += f"{config.datamodule.train_start_date}tst_" + + if config.get("model") is None: + return s.rstrip("_-").lstrip("_-") # for "naive" baselines, e.g. climatology + + use_ema, ema_decay = config.module.get("use_ema", False), config.module.get("ema_decay", 0.9999) + if use_ema: + s += "EMA_" + if ema_decay != 0.9999: + s = s.replace("EMA", f"EMA{config.module.ema_decay}") + + is_diffusion = config.get("diffusion") is not None + if is_diffusion: + if config.diffusion.get("interpolator_run_id"): + int_run_id = config.diffusion.interpolator_run_id + replace = { + "SOME_RUN_ID": "SOME_SIMPLER_ALIAS", + } + int_run_id = replace.get(int_run_id, int_run_id) + s += f"{int_run_id}-ipolID_" + + fcond = config.diffusion.get("forward_conditioning") + if fcond != "none": + s += f"{fcond}-fcond_" if "noise" not in fcond else f"{fcond}_" + + if config.diffusion.get("time_encoding", "dynamics") != "dynamics": + tenc = config.diffusion.get("time_encoding") + if tenc == "continuous": + s += "ContTime_" + elif tenc == "dynamics": + pass # s += "DynT_" + else: + s += f"{config.diffusion.time_encoding}-timeEnc_" + + hdims = config.model.get("hidden_dims") + if hdims is None: + num_L = config.model.get("num_layers") or config.model.get("depth") + if num_L is None: + dim_mults = config.model.get("dim_mults") or config.model.get("channel_mult") + if dim_mults is None: + pass + elif tuple(dim_mults) == (1, 2, 4): + num_L = "3" + else: + num_L = "-".join([str(d) for d in dim_mults]) + + possible_dim_names = ["dim", "hidden_dim", "embed_dim", "hidden_size", "model_channels"] + hdim = None + for name in possible_dim_names: + hdim = config.model.get(name) + if hdim is not None: + break + + if hdim is not None: + hdims = f"{hdim}x{num_L}" if num_L is not None else f"{hdim}" + elif all([h == hdims[0] for h in hdims]): + hdims = f"{hdims[0]}x{len(hdims)}" + else: + hdims = str(hdims) + + s += f"{hdims}d_" if hdims is not None else "" + if config.model.get("mlp_ratio", 4.0) != 4.0: + s += f"{config.model.mlp_ratio}dxMLP_" + + if is_diffusion and config.diffusion.get("loss_function") is not None: + loss = config.diffusion.get("loss_function") + loss = get_loss_name(loss) + if loss not in ["mse", "l2"]: + s += f"{loss.upper()}_" + else: + loss = config.model.get("loss_function") + loss = get_loss_name(loss) + if loss in ["mse", "l2"]: + pass + elif loss in ["l2_rel", "l1_rel"]: + s += f"{loss.upper().replace('_REL', 'rel')}_" + else: + s += f"{loss.upper()}_" + + time_emb = config.model.get("with_time_emb", False) + if time_emb not in [False, True, "scale_shift"]: + s += f"{time_emb}_" + if (isinstance(time_emb, str) and "scale_shift" in time_emb) and not config.model.get( + "time_scale_shift_before_filter" + ): + s += "tSSA_" # time scale shift after filter + + optim = config.module.get("optimizer") + if optim is not None: + if "adamw" not in optim.name.lower(): + s += f"{optim.name.replace('Fused', '').replace('fused', '')}_" + if "fused" in optim.name.lower() or optim.get("fused", False): + s = s[:-1] + "F_" + scheduler_cfg = config.module.get("scheduler") + lr = config.get("base_lr") or optim.get("lr") + s += f"{get_clean_float_name(lr)}lr_" + if scheduler_cfg is not None and "warmup_epochs" in scheduler_cfg: + s += f"LC{scheduler_cfg.warmup_epochs}:{scheduler_cfg.max_epochs}_" + + if is_diffusion: + lam1 = config.diffusion.get("lambda_reconstruction") + lam2 = config.diffusion.get("lambda_reconstruction2") + all_lams = [lam1, lam2] + nonzero_lams = len([1 for lam in all_lams if lam is not None and lam > 0]) + uniform_lams = [ + 1 / nonzero_lams if nonzero_lams > 0 else 0, + 0.33 if nonzero_lams == 3 else 0, + ] + if config.diffusion.get("lambda_reconstruction2", 0) > 0: + if lam1 == lam2: + s += f"{lam1}lRecs_" + else: + s += f"{lam1}-{lam2}lRecs_" + + if config.diffusion.get("reconstruction2_detach_x_last", False): + s += "detX0_" + elif lam1 is not None and lam1 not in uniform_lams: + s += f"{lam1}lRec_" + + dropout = { + "": config.model.get("dropout", 0), + "in": config.model.get("input_dropout", 0), + "pos": config.model.get("pos_emb_dropout", 0), + "at": config.model.get("attn_dropout", 0), + "b": config.model.get("block_dropout", 0), + "b1": config.model.get("block_dropout1", 0), + "ft": config.model.get("dropout_filter", 0), + "mlp": config.model.get("dropout_mlp", 0), + } + any_nonzero = any([d > 0 for d in dropout.values() if d is not None]) + for k, d in dropout.items(): + if d is not None and d > 0: + s += f"{int(d * 100)}{k}" + if any_nonzero: # remove redundant 'Dr_' + s += "Dr_" + + if any_nonzero and is_diffusion and config.diffusion.get("enable_interpolator_dropout", False): + s += "iDr_" # interpolator dropout + + if config.model.get("drop_path_rate", 0) > 0: + s += f"{int(config.model.drop_path_rate * 100)}dpr_" + + if config.module.optimizer.get("weight_decay") and config.module.optimizer.get("weight_decay") > 0: + s += f"{get_clean_float_name(config.module.optimizer.get('weight_decay'))}wd_" + + if config.get("suffix", "") != "": + s += f"{config.get('suffix')}_" + + wandb_cfg = config.get("logger", {}).get("wandb", {}) + if wandb_cfg.get("resume_run_id") and wandb_cfg.get("id", "$") != wandb_cfg.get("resume_run_id", "$"): + s += f"{wandb_cfg.get('resume_run_id')}rID_" + + if add_unique_suffix: + s += f"{config.get('seed')}seed" + s += "_" + time.strftime("%Hh%Mm%b%d") + wandb_id = wandb_cfg.get("id") + if wandb_id is not None: + s += f"_{wandb_id}" + + return s.replace("None", "").rstrip("_-").lstrip("_-") + + +def clean_name(class_name, mixer=None, dm_type=None) -> str: + """This names the model class paths with a more concise name.""" + if "SphericalFourierNeuralOperatorNet" in class_name: + return "SFNO" + elif "unet_simple" in class_name: + s = "SimpleUnet" + elif "Unet" in class_name: + s = "UNetR" + elif "SimpleConvNet" in class_name: + s = "SimpleCNN" + else: + raise ValueError(f"Unknown class name: {class_name}, did you forget to add it to the clean_name function?") + + return s + + +def get_group_name(config) -> str: + """ + This is a group name for wandb logging. + On Wandb, the runs of the same group are averaged out when selecting grouping by `group` + """ + # s = get_name_for_hydra_config_class(config.model) + # s = s or _shared_prefix(config, init_prefix=s) + return get_detailed_name(config, add_unique_suffix=False) + + +def var_names_to_clean_name() -> Dict[str, str]: + """This is a clean name for the variables (e.g. for plotting)""" + var_dict = { + "tas": "Air Temperature", + "psl": "Sea-level Pressure", + "ps": "Surface Pressure", + "pr": "Precipitation", + "sst": "Sea Surface Temperature", + } + return var_dict + + +variable_name_to_metadata = { + "DLWRFsfc": {"units": "W/m**2", "long_name": "surface downward longwave flux"}, + "DSWRFsfc": { + "units": "W/m**2", + "long_name": "averaged surface downward shortwave flux", + }, + "DSWRFtoa": { + "units": "W/m**2", + "long_name": "top of atmos downward shortwave flux", + }, + "GRAUPELsfc": { + "units": "kg/m**2/s", + "long_name": "bucket surface graupel precipitation rate", + }, + "HGTsfc": {"units": "m", "long_name": "surface height"}, + "ICEsfc": { + "units": "kg/m**2/s", + "long_name": "bucket surface ice precipitation rate", + }, + "LHTFLsfc": {"units": "w/m**2", "long_name": "surface latent heat flux"}, + "PRATEsfc": { + "units": "kg/m**2/s", + "long_name": "bucket surface precipitation rate", + }, + "PRESsfc": {"units": "Pa", "long_name": "surface pressure"}, + "SHTFLsfc": {"units": "w/m**2", "long_name": "surface sensible heat flux"}, + "SNOWsfc": { + "units": "kg/m**2/s", + "long_name": "bucket surface snow precipitation rate", + }, + "ULWRFsfc": {"units": "W/m**2", "long_name": "surface upward longwave flux"}, + "ULWRFtoa": {"units": "W/m**2", "long_name": "top of atmos upward longwave flux"}, + "USWRFsfc": { + "units": "W/m**2", + "long_name": "averaged surface upward shortwave flux", + }, + "USWRFtoa": {"units": "W/m**2", "long_name": "top of atmos upward shortwave flux"}, + "air_temperature_0": {"units": "K", "long_name": "temperature level-0"}, + "air_temperature_1": {"units": "K", "long_name": "temperature level-1"}, + "air_temperature_2": {"units": "K", "long_name": "temperature level-2"}, + "air_temperature_3": {"units": "K", "long_name": "temperature level-3"}, + "air_temperature_4": {"units": "K", "long_name": "temperature level-4"}, + "air_temperature_5": {"units": "K", "long_name": "temperature level-5"}, + "air_temperature_6": {"units": "K", "long_name": "temperature level-6"}, + "air_temperature_7": {"units": "K", "long_name": "temperature level-7"}, + "ak_0": {"units": "Pa", "long_name": "ak"}, + "ak_1": {"units": "Pa", "long_name": "ak"}, + "ak_2": {"units": "Pa", "long_name": "ak"}, + "ak_3": {"units": "Pa", "long_name": "ak"}, + "ak_4": {"units": "Pa", "long_name": "ak"}, + "ak_5": {"units": "Pa", "long_name": "ak"}, + "ak_6": {"units": "Pa", "long_name": "ak"}, + "ak_7": {"units": "Pa", "long_name": "ak"}, + "ak_8": {"units": "Pa", "long_name": "ak"}, + "bk_0": {"units": "", "long_name": "bk"}, + "bk_1": {"units": "", "long_name": "bk"}, + "bk_2": {"units": "", "long_name": "bk"}, + "bk_3": {"units": "", "long_name": "bk"}, + "bk_4": {"units": "", "long_name": "bk"}, + "bk_5": {"units": "", "long_name": "bk"}, + "bk_6": {"units": "", "long_name": "bk"}, + "bk_7": {"units": "", "long_name": "bk"}, + "bk_8": {"units": "", "long_name": "bk"}, + "eastward_wind_0": {"units": "m/sec", "long_name": "zonal wind level-0"}, + "eastward_wind_1": {"units": "m/sec", "long_name": "zonal wind level-1"}, + "eastward_wind_2": {"units": "m/sec", "long_name": "zonal wind level-2"}, + "eastward_wind_3": {"units": "m/sec", "long_name": "zonal wind level-3"}, + "eastward_wind_4": {"units": "m/sec", "long_name": "zonal wind level-4"}, + "eastward_wind_5": {"units": "m/sec", "long_name": "zonal wind level-5"}, + "eastward_wind_6": {"units": "m/sec", "long_name": "zonal wind level-6"}, + "eastward_wind_7": {"units": "m/sec", "long_name": "zonal wind level-7"}, + "land_fraction": { + "units": "dimensionless", + "long_name": "fraction of grid cell area occupied by land", + }, + "northward_wind_0": {"units": "m/sec", "long_name": "meridional wind level-0"}, + "northward_wind_1": {"units": "m/sec", "long_name": "meridional wind level-1"}, + "northward_wind_2": {"units": "m/sec", "long_name": "meridional wind level-2"}, + "northward_wind_3": {"units": "m/sec", "long_name": "meridional wind level-3"}, + "northward_wind_4": {"units": "m/sec", "long_name": "meridional wind level-4"}, + "northward_wind_5": {"units": "m/sec", "long_name": "meridional wind level-5"}, + "northward_wind_6": {"units": "m/sec", "long_name": "meridional wind level-6"}, + "northward_wind_7": {"units": "m/sec", "long_name": "meridional wind level-7"}, + "ocean_fraction": { + "units": "dimensionless", + "long_name": "fraction of grid cell area occupied by ocean", + }, + "sea_ice_fraction": { + "units": "dimensionless", + "long_name": "fraction of grid cell area occupied by sea ice", + }, + "soil_moisture": { + "units": "kg/m**2", + "long_name": "total column soil moisture content", + }, + "specific_total_water_0": { + "units": "kg/kg", + "long_name": "specific total water level-0", + }, + "specific_total_water_1": { + "units": "kg/kg", + "long_name": "specific total water level-1", + }, + "specific_total_water_2": { + "units": "kg/kg", + "long_name": "specific total water level-2", + }, + "specific_total_water_3": { + "units": "kg/kg", + "long_name": "specific total water level-3", + }, + "specific_total_water_4": { + "units": "kg/kg", + "long_name": "specific total water level-4", + }, + "specific_total_water_5": { + "units": "kg/kg", + "long_name": "specific total water level-5", + }, + "specific_total_water_6": { + "units": "kg/kg", + "long_name": "specific total water level-6", + }, + "specific_total_water_7": { + "units": "kg/kg", + "long_name": "specific total water level-7", + }, + "surface_temperature": {"units": "K", "long_name": "surface temperature"}, + "tendency_of_total_water_path": { + "units": "kg/m^2/s", + "long_name": "time derivative of total water path", + }, + "tendency_of_total_water_path_due_to_advection": { + "units": "kg/m^2/s", + "long_name": "tendency of total water path due to advection", + }, + "total_water_path": {"units": "kg/m^2", "long_name": "total water path"}, +} + + +def full_variable_name_with_units(variable: str, formatted: bool = True, capitalize: bool = True) -> str: + """This is a full name for the variable (e.g. for plotting)""" + if variable not in variable_name_to_metadata: + return variable + data = variable_name_to_metadata[variable] + long_name = data.get("long_name", variable) + if capitalize: + long_name = long_name.capitalize() + # Make long name bold in latex, and units italic + if formatted is True: + name = long_name.replace("_", " ").replace(" ", "\\ ") + if data["units"] == "": + return f"$\\bf{{{name}}}$" + else: + return f'$\\bf{{{name}}}$ [$\\it{{{data["units"]}}}$]' + elif formatted == "units": + if data["units"] == "": + return f"{long_name}" + else: + return f'{long_name} [$\\it{{{data["units"]}}}$]' + else: + if data["units"] == "": + return f"{long_name}" + else: + return f'{long_name} [{data["units"]}]' + + +def formatted_units(variable: str) -> str: + """This is a full name for the variable (e.g. for plotting)""" + if variable not in variable_name_to_metadata: + return "" + data = variable_name_to_metadata[variable] + return f"[$\\it{{{data['units']}}}$]" + + +def formatted_long_name(variable: str, capitalize: bool = True) -> str: + """This is a full name for the variable (e.g. for plotting)""" + if variable not in variable_name_to_metadata: + return variable + data = variable_name_to_metadata[variable] + long_name = data.get("long_name", variable) + if capitalize: + long_name = long_name.capitalize() + long_name = long_name.replace("_", " ").replace(" ", "\\ ") + return f"$\\bf{{{long_name}}}$" + + +def clean_metric_name(metric: str) -> str: + """This is a clean name for the metrics (e.g. for plotting)""" + metric_dict = { + "mae": "MAE", + "mse": "MSE", + "crps": "CRPS", + "rmse": "RMSE", + "bias": "Bias", + "mape": "MAPE", + "ssr": "Spread / RMSE", + "ssr_abs_dist": "abs(1 - Spread / RMSE)", + "ssr_squared_dist": "(1 - Spread / RMSE)^2", + "nll": "NLL", + "r2": "R2", + "corr": "Correlation", + "corrcoef": "Correlation", + "corr_mem_avg": "Corr. Mem. Avg.", + "corr_spearman": "Spearman Correlation", + "corr_kendall": "Kendall Correlation", + "corr_pearson": "Pearson Correlation", + "grad_mag_percent_diff": "Gradient Mag. % Diff", + } + for k in ["crps", "ssr", "rmse", "grad_mag_percent_diff", "bias"]: + metric_dict[f"weighted_{k}"] = metric_dict[k] + + return metric_dict.get(metric.lower(), metric) diff --git a/src/utilities/normalization.py b/src/utilities/normalization.py new file mode 100644 index 0000000..1f2b382 --- /dev/null +++ b/src/utilities/normalization.py @@ -0,0 +1,117 @@ +from typing import Any, Dict, List + +import torch +import xarray as xr + + +class StandardNormalizer(torch.nn.Module): + """ + Responsible for normalizing tensors. + """ + + def __init__(self, means: Dict[str, torch.Tensor], stds: Dict[str, torch.Tensor], names=None): + super().__init__() + if isinstance(means, dict): + for k in means.keys(): + if means[k].ndim == 1: + # Add singleton dimensions for broadcasting over lat/lon dimensions + means[k] = torch.reshape(means[k], (-1, 1, 1)) + stds[k] = torch.reshape(stds[k], (-1, 1, 1)) + elif means[k].ndim > 1: + raise ValueError(f"Means tensor {k} has more than one dimension!") + # Make sure that means and stds move to the same device + self.means = means + self.stds = stds + self.names = names + if torch.is_tensor(means): + self._normalize = _normalize + self._denormalize = _denormalize + else: + assert isinstance(means, dict), "Means and stds must be either both tensors or both dictionaries!" + self._normalize = _normalize_dict + self._denormalize = _denormalize_dict + + def _apply(self, fn, recurse=True): + super()._apply(fn) # , recurse=recurse) + if isinstance(self.means, dict): + self.means = {k: fn(v) for k, v in self.means.items()} + self.stds = {k: fn(v) for k, v in self.stds.items()} + else: + self.means = fn(self.means) + self.stds = fn(self.stds) + + def normalize(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + return self._normalize(tensors, means=self.means, stds=self.stds) + + def denormalize(self, tensors: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: + if self.names is not None: + assert ( + len(set(tensors.keys()) - set(self.names)) == 0 + ), f"Some keys would not be denormalized: {set(tensors.keys()) - set(self.names)}!" + return self._denormalize(tensors, means=self.means, stds=self.stds) + + +@torch.jit.script +def _normalize_dict( + tensors: Dict[str, torch.Tensor], + means: Dict[str, torch.Tensor], + stds: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + return {k: (t - means[k]) / stds[k] for k, t in tensors.items()} + + +@torch.jit.script +def _denormalize_dict( + tensors: Dict[str, torch.Tensor], + means: Dict[str, torch.Tensor], + stds: Dict[str, torch.Tensor], +) -> Dict[str, torch.Tensor]: + return {k: t * stds[k] + means[k] for k, t in tensors.items()} + + +@torch.jit.script +def _normalize(tensor: torch.Tensor, means: torch.Tensor, stds: torch.Tensor) -> torch.Tensor: + return (tensor - means) / stds + + +@torch.jit.script +def _denormalize(tensor: torch.Tensor, means: torch.Tensor, stds: torch.Tensor) -> torch.Tensor: + return tensor * stds + means + + +def get_normalizer( + global_means_path, global_stds_path, names: List[str], sel: Dict[str, Any] = None, is_2d_flattened=False +) -> StandardNormalizer: + mean_ds = xr.open_dataset(global_means_path) + std_ds = xr.open_dataset(global_stds_path) + if sel is not None: + mean_ds = mean_ds.sel(**sel) + std_ds = std_ds.sel(**sel) + if is_2d_flattened: + means, stds = dict(), dict() + for name in names: + if name in mean_ds.keys(): + means[name] = torch.as_tensor(mean_ds[name].values, dtype=torch.float) + stds[name] = torch.as_tensor(std_ds[name].values, dtype=torch.float) + else: + # Retrieve _ variables + var_name, pressure_level = "_".join(name.split("_")[:-1]), int(name.split("_")[-1]) + try: + means[name] = torch.as_tensor( + mean_ds[var_name].sel(level=pressure_level).values, dtype=torch.float + ) + stds[name] = torch.as_tensor(std_ds[var_name].sel(level=pressure_level).values, dtype=torch.float) + except KeyError as e: + print(mean_ds.coords.values) + raise KeyError( + f"Variable {name} with var_name {var_name} and level ``{pressure_level}`` not found in the dataset!" + ) from e + else: + means = {name: torch.as_tensor(mean_ds[name].values, dtype=torch.float) for name in names} + stds = {name: torch.as_tensor(std_ds[name].values, dtype=torch.float) for name in names} + return StandardNormalizer(means=means, stds=stds, names=names) + + +def load_Dict_from_netcdf(path, names): + ds = xr.open_dataset(path) + return {name: ds[name].values for name in names} diff --git a/src/utilities/packer.py b/src/utilities/packer.py new file mode 100644 index 0000000..4e28b51 --- /dev/null +++ b/src/utilities/packer.py @@ -0,0 +1,77 @@ +from typing import Dict, List + +import torch +import torch.jit +from tensordict import TensorDict + + +class NoPacker: + def pack(self, tensors: Dict[str, torch.Tensor], axis=0) -> torch.Tensor: + return tensors + + def unpack(self, tensor: torch.Tensor, axis=0) -> Dict[str, torch.Tensor]: + return tensor + + +class Packer: + """ + Responsible for packing tensors into a single tensor. + """ + + def __init__(self, names: List[str], axis=None, axis_pack=None, axis_unpack=None): + self.names = names + if axis is not None: + assert axis_pack is None, "Cannot specify both axis and axis_pack" + assert axis_unpack is None, "Cannot specify both axis and axis_unpack" + self.axis_pack = axis + self.axis_unpack = axis + else: + assert axis_pack is not None, "Must specify either axis or axis_pack" + assert axis_unpack is not None, "Must specify either axis or axis_unpack" + self.axis_pack = axis_pack + self.axis_unpack = axis_unpack + + def pack(self, tensors: Dict[str, torch.Tensor], axis=None) -> torch.Tensor: + """ + Packs tensors into a single tensor, concatenated along a new axis + + Args: + tensors: Dict from names to tensors. + axis: index for new concatenation axis. + """ + axis = axis if axis is not None else self.axis_pack + return _pack(tensors, self.names, axis=axis) + + def unpack(self, tensor: torch.Tensor, axis=None) -> TensorDict: + axis = axis if axis is not None else self.axis_unpack + # packed shape is tensor.shape with axis removed + packed_shape = list(tensor.shape) + packed_shape.pop(axis) + return TensorDict(_unpack(tensor, self.names, axis=axis), batch_size=packed_shape) + + def unpack_simple(self, tensor: torch.Tensor, axis=None) -> Dict[str, torch.Tensor]: + axis = axis if axis is not None else self.axis_unpack + return _unpack(tensor, self.names, axis=axis) + + def get_state(self): + """ + Returns state as a serializable data structure. + """ + return {"names": self.names, "axis": self.axis} + + @classmethod + def from_state(self, state) -> "Packer": + """ + Loads state from a serializable data structure. + """ + return Packer(state["names"], state["axis"]) + + +@torch.jit.script +def _pack(tensors: Dict[str, torch.Tensor], names: List[str], axis: int) -> torch.Tensor: + return torch.stack([tensors[n] for n in names], dim=axis) + + +@torch.jit.script +def _unpack(tensor: torch.Tensor, names: List[str], axis: int) -> Dict[str, torch.Tensor]: + return {n: tensor.select(axis, index=i) for i, n in enumerate(names)} diff --git a/src/utilities/s3utils.py b/src/utilities/s3utils.py new file mode 100644 index 0000000..b6df83b --- /dev/null +++ b/src/utilities/s3utils.py @@ -0,0 +1,383 @@ +""" Slightly adapted from Zihao Zhou""" + +import fnmatch +import glob +import os +import time + +import boto3 +from botocore import UNSIGNED +from botocore.client import Config +from botocore.exceptions import ClientError + +from src.utilities.utils import get_logger + + +log = get_logger(__name__) + +S3_ENDPOINT_URL = os.getenv("S3_ENDPOINT_URL") +S3_BUCKET_NAME = os.getenv("S3_BUCKET_NAME") +# if not S3_ENDPOINT_URL and not S3_BUCKET_NAME: +# S3_ENDPOINT_URL = os.environ["S3_ENDPOINT_URL"] = "https://XYZ.edu" +# S3_BUCKET_NAME = os.environ["S3_BUCKET_NAME"] = "my-data-and-results" +checks = ["S3_ENDPOINT_URL", "S3_BUCKET_NAME"] +for check in checks: + if not os.getenv(check): + raise EnvironmentError(f"Please set the {check} environment variable.") + +# Export S3 credentials from ~/.config/s3 +credentials_maybe_dir = os.path.expanduser("~/.config/s3") +if os.path.exists(credentials_maybe_dir): + if os.environ.get("AWS_ACCESS_KEY_ID") is None: + os.environ["AWS_ACCESS_KEY_ID"] = open(f"{credentials_maybe_dir}/access_key_id").read().strip() + if os.environ.get("AWS_SECRET_ACCESS_KEY") is None: + os.environ["AWS_SECRET_ACCESS_KEY"] = open(f"{credentials_maybe_dir}/secret_access_key").read().strip() +# Check if credentials are provided +if os.getenv("AWS_ACCESS_KEY_ID") and os.getenv("AWS_SECRET_ACCESS_KEY"): + config = Config(retries={"max_attempts": 5, "mode": "adaptive"}, max_pool_connections=50) + # Credentials are provided, use them to create the client + s3_client = boto3.client("s3", endpoint_url=S3_ENDPOINT_URL, config=config) +else: + # Credentials are not provided, use anonymous access + log.info("Using anonymous access to S3. Set AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY for authenticated access.") + s3_client = boto3.client("s3", endpoint_url=S3_ENDPOINT_URL, config=Config(signature_version=UNSIGNED)) + + +def get_local_files(s3_path, local_path): + """ + Recursively get local files that match the s3_path pattern in the local_path directory. + """ + wildcard_index = s3_path.find("*") + if wildcard_index == -1: + prefix = os.path.dirname(s3_path) + pattern = os.path.basename(s3_path) + else: + prefix = s3_path[:wildcard_index] + if "/" in prefix: + prefix = os.path.dirname(prefix) + pattern = s3_path[len(prefix) + 1 :] + else: + prefix = "." + pattern = s3_path + + prefix = os.path.normpath(os.path.join(local_path, prefix)) + local_files = glob.glob(prefix + "/**", recursive=True) + + filtered_local_files = [] + for file in local_files: + if os.path.isdir(file): + continue + file = os.path.normpath(file) + if pattern: + if fnmatch.fnmatch(os.path.relpath(file, prefix), pattern): + file = os.path.normpath(file) + filtered_local_files.append(file) + else: + filtered_local_files.append(file) + return filtered_local_files + + +def get_s3_objects(s3_path): + """ + Recursively get all objects in S3 bucket that match the s3_path pattern. + """ + wildcard_index = s3_path.find("*") + if wildcard_index == -1: + prefix = s3_path + pattern = "" + else: + prefix = s3_path[:wildcard_index] + if prefix.endswith("/"): + prefix = prefix[:-1] + pattern = s3_path[len(prefix) + 1 :] + else: + pattern = s3_path[len(prefix) :] + + paginator = s3_client.get_paginator("list_objects_v2") + pages = paginator.paginate(Bucket=S3_BUCKET_NAME, Prefix=prefix) + + filtered_s3_objects = [] + for page in pages: + for obj in page.get("Contents", []): + key = obj["Key"] + if pattern: # only apply fnmatch if there's a pattern to match + if fnmatch.fnmatch(key[len(prefix) :], pattern): + filtered_s3_objects.append(key) + else: + filtered_s3_objects.append(key) + return filtered_s3_objects + + +def download_s3_objects(s3_objects, local_path="./"): + """ + Download specified S3 objects to the local file system. + """ + for s3_key in s3_objects: + # Construct the full local filepath + local_file_path = os.path.join(local_path, s3_key) + + # Create directory if it doesn't exist + local_file_dir = os.path.dirname(local_file_path) + if not os.path.exists(local_file_dir): + os.makedirs(local_file_dir) + + download_s3_object(s3_key, local_file_path) + + +def download_s3_object(s3_file_path, local_file_path: str, throw_error: bool = True): + # Download the file from S3 + try: + if os.path.exists(local_file_path): + log.info(f"File {local_file_path} already exists") + return + # Make sure the directory exists if it has a directory structure + if os.path.dirname(local_file_path) != "": + os.makedirs(os.path.dirname(local_file_path), exist_ok=True) + + # Download the file from S3 + s3_client.download_file(S3_BUCKET_NAME, s3_file_path, local_file_path) + + # Verify file was downloaded successfully + if not os.path.exists(local_file_path): + raise FileNotFoundError(f"Failed to download file to {local_file_path}") + + log.info(f"Downloaded {s3_file_path} to {local_file_path}") + except ClientError as e: + log.warning(f"Failed to download {s3_file_path}: {e}") + # List all files in the directory + try: + s3_objects = get_s3_objects(os.path.dirname(s3_file_path)) + except ClientError as e: + s3_objects = "" + log.warning(f"Failed to list files in directory: {e}") + if throw_error: + raise ValueError(f"File {s3_file_path} not found in S3. Files in directory: {s3_objects}") from e + else: + log.info(f"Files in directory: {s3_objects}") + return s3_objects + + +def download_s3_path(s3_path, local_path="./"): + """ + Download all files in the S3 path to the local file system. + """ + s3_objects = get_s3_objects(s3_path) + download_s3_objects(s3_objects, local_path) + + +def list_s3_objects(s3_path): + """ + List all directories / files in S3 bucket under the given path. + """ + prefix = s3_path.lstrip("/") + + paginator = s3_client.get_paginator("list_objects_v2") + page_iterator = paginator.paginate(Bucket=S3_BUCKET_NAME, Prefix=prefix, Delimiter="/") + objects = [] + directories = [] + + for page in page_iterator: + directories.extend(page.get("CommonPrefixes", [])) + objects.extend(page.get("Contents", [])) + + for d in directories: + log.info(f"Directory: {d['Prefix']}") + + for obj in objects: + log.info(f"File: {obj['Key']}") + + +def delete_local_files(local_files): + """ + Delete local files. + """ + for file in local_files: + os.remove(file) + log.info(f"Deleted {file}") + ## Delete empty folders if any + folder = os.path.dirname(file) + if not os.listdir(folder): + os.rmdir(folder) + log.info(f"Deleted {folder}") + + +def print_folders(files): + """ + Print folders of files, assuming the files are sorted by folder. + """ + last_folder = None + for file in files: + folder = "/".join(file.split("/")[:-1]) + "/" + if folder != last_folder: + log.info(folder) + last_folder = folder + + +def remove_s3_objects(objects_to_delete): + """ + Remove objects in S3. + """ + objects_to_delete = [{"Key": obj} for obj in objects_to_delete] + if objects_to_delete: + s3_client.delete_objects(Bucket=S3_BUCKET_NAME, Delete={"Objects": objects_to_delete}) + for obj in objects_to_delete: + log.info(f"Removed {obj['Key']} from S3") + + +def remove_s3_path(s3_path): + """ + Remove all files in the S3 path. + """ + s3_objects = get_s3_objects(s3_path) + remove_s3_objects(s3_objects) + + +def upload_s3_object(local_file_path, s3_file_path, retry=3, **kwargs): + """ + Upload a single local file to S3. + Args: + local_file_path: The path to the local file. + s3_file_path: The path to the S3 file. If it ends with a "/", the local file will be uploaded with the same name to that directory. + """ + assert os.path.isfile(local_file_path), f"{local_file_path} is not a file" + if s3_file_path.endswith("/"): + s3_file_path += os.path.basename(local_file_path) + else: + # Check that both path's have same file extension + local_file_ext = os.path.splitext(local_file_path)[1] + s3_file_ext = os.path.splitext(s3_file_path)[1] + assert ( + local_file_ext == s3_file_ext + ), f"File extensions do not match: {local_file_ext} != {s3_file_ext}. If you intended s3_filepath to be a directory, append a '/' to the end of it." + + for i in range(retry): + try: + s3_client.upload_file(local_file_path, S3_BUCKET_NAME, s3_file_path, **kwargs) + log.info(f"Uploaded {local_file_path} to {s3_file_path}") + break + except Exception as e: + log.warning( + f"Failed to upload {local_file_path} with S3_BUCKET_NAME={S3_BUCKET_NAME} and s3_file_path={s3_file_path}: {e}" + ) + # sleep for 10 seconds before retrying + time.sleep(5) + if i == retry - 1: + raise e + + +def upload_s3_objects(local_files, local_path="./", s3_path=""): + """ + Upload local files to S3. + """ + for local_file in local_files: + if os.path.isfile(local_file): + s3_key = os.path.relpath(local_file, os.path.dirname(local_path)) + s3_key = os.path.normpath(s3_key) + if s3_path: + s3_key = os.path.join(s3_path, s3_key) + try: + s3_client.head_object(Bucket=S3_BUCKET_NAME, Key=s3_key) + log.info(f"File {s3_key} already exists in S3") + except ClientError: + s3_client.upload_file(local_file, S3_BUCKET_NAME, s3_key) + log.info(f"Uploaded {local_file} to {s3_key}") + else: + log.warning(f"Skipping {local_file} as it is not a file.") + + +def upload_s3_path(s3_path, local_path="./"): + """ + Upload all files in the local path to the S3 path. + """ + local_files = get_local_files(s3_path, local_path) + upload_s3_objects(local_files, local_path) + + +def interactive_list_and_action(s3_path, local_path): + """ + List local and S3 files/folders, then ask the user whether to + - delete local files + - upload local files + - remove s3 files + - download s3 files + """ + if s3_path.endswith("/"): + filetype = "folders" + s3_path += "*" + else: + filetype = "files" + + log.info(f"Local {filetype} matching pattern:") + local_files = get_local_files(s3_path, local_path) + + if filetype == "folders": + print_folders(local_files) + else: + for file in local_files: + log.info(file) + + log.info(f"\nS3 {filetype} matching pattern:") + s3_objects = get_s3_objects(s3_path) + + if filetype == "folders": + print_folders(s3_objects) + else: + for file in s3_objects: + log.info(file) + + action = input("\nChoose an action [delete (local), remove (S3), download, upload, exit]: ").strip().lower() + if action == "delete": + delete_local_files(local_files) + elif action == "upload": + upload_s3_objects(local_files, local_path) + elif action == "remove": + remove_s3_objects(s3_objects) + elif action == "download": + download_s3_objects(s3_objects, local_path) + elif action == "exit": + pass + else: + log.info("Invalid action") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="S3 utils for managing files and folders.") + parser.add_argument("--find", help="Find S3 files", action="store_true") + parser.add_argument("--list", help="List S3 files", action="store_true") + parser.add_argument("--download", help="Download S3 files", action="store_true") + parser.add_argument("--upload", help="Upload local files to S3", action="store_true") + parser.add_argument("--remove", help="Remove S3 files", action="store_true") + parser.add_argument("--delete", help="Delete local files", action="store_true") + parser.add_argument("--interactive", help="Interactive mode", action="store_true") + parser.add_argument("path", help="The S3 or local path pattern", type=str) + + args = parser.parse_args() + + s3_path = args.path + local_path = "./" + + if args.find: + file_type = "folders" if s3_path.endswith("/") else "files" + s3_objects = get_s3_objects(s3_path + "**" if file_type == "folders" else s3_path) + if file_type == "folders": + print_folders(s3_objects) + else: + for obj in s3_objects: + log.info(obj) + elif args.list: + list_s3_objects(s3_path) + elif args.download: + download_s3_path(s3_path, local_path) + elif args.upload: + upload_s3_path(s3_path, local_path) + elif args.remove: + remove_s3_path(s3_path) + elif args.delete: + local_files = get_local_files(s3_path, local_path) + delete_local_files(local_files) + elif args.interactive: + interactive_list_and_action(s3_path, local_path) + else: + parser.print_help() diff --git a/src/utilities/utils.py b/src/utilities/utils.py new file mode 100644 index 0000000..e59f90d --- /dev/null +++ b/src/utilities/utils.py @@ -0,0 +1,967 @@ +""" +Author: Salva Rühling Cachay +""" + +from __future__ import annotations + +import functools +import logging +import os +import random +import re +import subprocess +from difflib import SequenceMatcher +from inspect import isfunction +from types import SimpleNamespace +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import xarray as xr +from einops import rearrange +from omegaconf import DictConfig, OmegaConf +from tensordict import TensorDict, TensorDictBase +from torch import Tensor + +from src.models.modules.drop_path import DropPath + + +def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: + """Initializes multi-GPU-friendly python logger.""" + from pytorch_lightning.utilities import rank_zero_only + + logger = logging.getLogger(name) + logger.setLevel(level) + + # this ensures all logging levels get marked with the rank zero decorator + # otherwise logs would get multiplied for each GPU process in multi-GPU setup + for level in ("debug", "info", "warning", "error", "exception", "fatal", "critical"): + setattr(logger, level, rank_zero_only(getattr(logger, level))) + + return logger + + +log = get_logger(__name__) + + +def no_op(*args, **kwargs): + pass + + +def identity(X, *args, **kwargs): + return X + + +def get_identity_callable(*args, **kwargs) -> Callable: + return identity + + +def exists(x): + return x is not None + + +def default(val, d): + if val is not None: + return val + return d() if isfunction(d) else d + + +distribution_params_to_edit = ["loc", "scale"] + + +def torch_to_numpy(x: Union[Tensor, Dict[str, Tensor]]) -> Union[np.ndarray, Dict[str, np.ndarray]]: + if isinstance(x, Tensor): + return x.detach().cpu().numpy() + elif isinstance(x, TensorDictBase): + return {k: torch_to_numpy(v) for k, v in x.items()} + # return x.detach().cpu() # numpy() not implemented for TensorDict + elif isinstance(x, dict): + return {k: torch_to_numpy(v) for k, v in x.items()} + elif isinstance(x, torch.distributions.Distribution): + # only move the parameters to cpu + for k in distribution_params_to_edit: + if hasattr(x, k): + setattr(x, k, getattr(x, k).detach().cpu()) + return x + else: + return x + + +def numpy_to_torch(x: Union[np.ndarray, Dict[str, np.ndarray]]) -> Union[Tensor, Dict[str, Tensor]]: + if isinstance(x, np.ndarray): + return torch.from_numpy(x) + elif isinstance(x, dict): + return {k: numpy_to_torch(v) for k, v in x.items()} + # if it's a namedtuple, convert each element + elif isinstance(x, tuple) and hasattr(x, "_fields"): + return type(x)(*[numpy_to_torch(v) for v in x]) + elif torch.is_tensor(x): + return x + # if is simple int, float, etc., return as is + elif isinstance(x, (int, float, str)): + return x + else: + raise ValueError(f"Cannot convert {type(x)} to torch.") + + +def to_torch_and_device(x, device): + x = x.values if isinstance(x, (xr.Dataset, xr.DataArray)) else x + x = torch.from_numpy(x) if isinstance(x, np.ndarray) else x + return x.to(device) if x is not None else None + + +def rrearrange( + data: Union[Tensor, torch.distributions.Distribution, TensorDictBase], + pattern: str, + find_batch_size_max: bool = True, + **axes_lengths, +): + """Extend einops.rearrange to work with distributions.""" + if torch.is_tensor(data) or isinstance(data, np.ndarray): + return rearrange(data, pattern, **axes_lengths) + elif isinstance(data, torch.distributions.Distribution): + dist_params = { + k: rearrange(getattr(data, k), pattern, **axes_lengths) + for k in distribution_params_to_edit + if hasattr(data, k) + } + return type(data)(**dist_params) + elif isinstance(data, TensorDictBase): + new_data = {k: rrearrange(v, pattern, **axes_lengths) for k, v in data.items()} + return to_tensordict(new_data, find_batch_size_max=find_batch_size_max) + elif isinstance(data, dict): + return {k: rrearrange(v, pattern, **axes_lengths) for k, v in data.items()} + else: + raise ValueError(f"Cannot rearrange {type(data)}") + + +def multiply_by_scalar(x: Union[Dict[str, Any], Any], scalar: float) -> Union[Dict[str, Any], Any]: + """Multiplies the given scalar to the given scalar or dict.""" + if isinstance(x, dict): + return {k: multiply_by_scalar(v, scalar) for k, v in x.items()} + else: + return x * scalar + + +def add(a, b): + if isinstance(a, (TensorDictBase, dict)): + return {key: add(a[key], b[key]) for key in a.keys()} + else: + return a + b + + +def subtract(a, b): + if isinstance(a, (TensorDictBase, dict)): + return {key: subtract(a[key], b[key]) for key in a.keys()} + else: + return a - b + + +def multiply(a, b): + if isinstance(a, (TensorDictBase, dict)): + return {key: multiply(a[key], b[key]) for key in a.keys()} + else: + return a * b + + +def divide(a, b): + if isinstance(a, (TensorDictBase, dict)): + return {key: divide(a[key], b[key]) for key in a.keys()} + else: + return a / b + + +def torch_select(input: Tensor, dim: int, index: int): + """Extends torch.select to work with distributions.""" + if isinstance(input, torch.distributions.Distribution): + dist_params = { + k: torch.select(getattr(input, k), dim, index) for k in distribution_params_to_edit if hasattr(input, k) + } + return type(input)(**dist_params) + else: + return torch.select(input, dim, index) + + +def ellipsis_torch_dict_boolean_tensor(input_dict: TensorDictBase, mask: Tensor) -> TensorDictBase: + """Ellipsis indexing for TensorDict with boolean mask as replacement for torch_dict[..., mask]""" + if torch.is_tensor(input_dict): + return input_dict[..., mask] + # Simply doing [..., mask] will not work, we need to select with : as many times as the number of dimensions + # in the input tensor (- the length of the mask shape) + mask_len = len(mask.shape) + ellipsis_str = (", :" * (len(input_dict.shape) - mask_len)).lstrip(", ") + output_dict = dict() + for k, v in input_dict.items(): + output_dict[k] = eval(f"v[{ellipsis_str}, mask]") + log.info("shape value1=", list(output_dict.values())[0].shape) + return to_tensordict(output_dict, find_batch_size_max=True) + # TensorDict({k: eval(f"input[k][{ellipsis_str}, mask]") for k in input.keys()}, batch_size=mask.shape) + + +def extract_into_tensor(a, t, x_shape): + """Extracts the values of tensor, a, at the given indices, t. + Then, add dummy dimensions to broadcast to x_shape.""" + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + + def noise(): + return torch.randn(shape, device=device) + + return repeat_noise() if repeat else noise() + + +def get_activation_function(name: str, functional: bool = False, num: int = 1): + """Returns the activation function with the given name.""" + name = name.lower().strip() + + def get_functional(s: str) -> Optional[Callable]: + return { + "softmax": F.softmax, + "relu": F.relu, + "tanh": torch.tanh, + "sigmoid": torch.sigmoid, + "identity": nn.Identity(), + None: None, + "swish": F.silu, + "silu": F.silu, + "elu": F.elu, + "gelu": F.gelu, + "prelu": nn.PReLU(), + }[s] + + def get_nn(s: str) -> Optional[Callable]: + return { + "softmax": nn.Softmax(dim=1), + "relu": nn.ReLU(), + "tanh": nn.Tanh(), + "sigmoid": nn.Sigmoid(), + "identity": nn.Identity(), + "silu": nn.SiLU(), + "elu": nn.ELU(), + "prelu": nn.PReLU(), + "swish": nn.SiLU(), + "gelu": nn.GELU(), + }[s] + + if num == 1: + return get_functional(name) if functional else get_nn(name) + else: + return [get_nn(name) for _ in range(num)] + + +def get_normalization_layer(name, dims, num_groups=None, *args, **kwargs): + """Returns the normalization layer with the given name. + + Args: + name: name of the normalization layer. Must be one of ['batch_norm', 'layer_norm' 'group', 'instance', 'none'] + """ + if not isinstance(name, str) or name.lower() == "none": + return None + elif "batch_norm" == name: + return nn.BatchNorm2d(num_features=dims, *args, **kwargs) + elif "layer_norm" == name: + return nn.LayerNorm(dims, *args, **kwargs) + elif "rms_layer_norm" == name: + from src.utilities.normalization import RMSLayerNorm + + return RMSLayerNorm(dims, *args, **kwargs) + elif "instance" in name: + return nn.InstanceNorm1d(num_features=dims, *args, **kwargs) + elif "group" in name: + if num_groups is None: + # find an appropriate divisor (not robust against weird dims!) + pos_groups = [int(dims / N) for N in range(2, 17) if dims % N == 0] + if len(pos_groups) == 0: + raise NotImplementedError(f"Group norm could not infer the number of groups for dim={dims}") + num_groups = max(pos_groups) + return nn.GroupNorm(num_groups=num_groups, num_channels=dims) + else: + raise ValueError("Unknown normalization name", name) + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + log.info(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def to_dict(obj: Optional[Union[dict, SimpleNamespace]]): + if obj is None: + return dict() + elif isinstance(obj, dict): + return obj + else: + return vars(obj) + + +def to_DictConfig(obj: Optional[Union[List, Dict]]): + """Tries to convert the given object to a DictConfig.""" + if isinstance(obj, DictConfig): + return obj + + if isinstance(obj, list): + try: + dict_config = OmegaConf.from_dotlist(obj) + except ValueError: + dict_config = OmegaConf.create(obj) + + elif isinstance(obj, dict): + dict_config = OmegaConf.create(obj) + + else: + dict_config = OmegaConf.create() # empty + + return dict_config + + +def get_dotted_key_from_dict(d: dict, key: str): + """Returns the value from the given dictionary with the given key, which can be a dotted key.""" + keys = key.split(".") + value = d + for k in keys: + if k not in value: + return None + value = value[k] + return value + + +def keep_dict_or_tensordict(new_dict_like: dict, original: Union[Dict, TensorDictBase]) -> Union[Dict, TensorDictBase]: + """Returns the given object if it is a dict or TensorDict, otherwise returns an empty dict.""" + if isinstance(original, TensorDictBase): + # Return class of original + return type(original)(new_dict_like, batch_size=original.batch_size) + elif isinstance(original, dict): + return new_dict_like + else: + raise ValueError(f"Expected a dict or TensorDict, but got {type(original)}") + + +def replace_substrings(string: str, replacements: Dict[str, str], ignore_case: bool = False): + """ + Given a string and a replacement map, it returns the replaced string. + :param str string: string to execute replacements on + :param dict replacements: replacement dictionary {value to find: value to replace} + :param bool ignore_case: whether the match should be case-insensitive + :rtype: str + """ + if not replacements: + # Edge case that'd produce a funny regex and cause a KeyError + return string + + # If case-insensitive, we need to normalize the old string so that later a replacement + # can be found. For instance with {"HEY": "lol"} we should match and find a replacement for "hey", + # "HEY", "hEy", etc. + if ignore_case: + + def normalize_old(s): + return s.lower() + + re_mode = re.IGNORECASE + + else: + + def normalize_old(s): + return s + + re_mode = 0 + + replacements = {normalize_old(key): val for key, val in replacements.items()} + + # Place longer ones first to keep shorter substrings from matching where the longer ones should take place + # For instance given the replacements {'ab': 'AB', 'abc': 'ABC'} against the string 'hey abc', it should produce + # 'hey ABC' and not 'hey ABc' + rep_sorted = sorted(replacements, key=len, reverse=True) + rep_escaped = map(re.escape, rep_sorted) + + # Create a big OR regex that matches any of the substrings to replace + pattern = re.compile("|".join(rep_escaped), re_mode) + + # For each match, look up the new string in the replacements, being the key the normalized old string + return pattern.sub(lambda match: replacements[normalize_old(match.group(0))], string) + + +##### +# The following two functions extend setattr and getattr to support chained objects, e.g. rsetattr(cfg, optim.lr, 1e-4) +# From https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-subobjects-chained-properties +def rsetattr(obj, attr, val): + pre, _, post = attr.rpartition(".") + return setattr(rgetattr(obj, pre) if pre else obj, post, val) + + +def rgetattr(obj, attr, *args): + def _getattr(obj, attr): + return getattr(obj, attr, *args) + + return functools.reduce(_getattr, [obj] + attr.split(".")) + + +def rhasattr(obj, attr, *args): + def _hasattr(obj, attr): + return hasattr(obj, attr, *args) + + return functools.reduce(_hasattr, [obj] + attr.split(".")) + + +def to_tensordict( + x: Dict[str, torch.Tensor], find_batch_size_max: bool = False, force_same_device: bool = False, device=None +) -> TensorDict: + """Converts a dictionary of tensors to a TensorDict.""" + if torch.is_tensor(x): + return x + elif isinstance(x, np.ndarray): + return torch.from_numpy(x) + any_batch_example = x[list(x.keys())[0]] + device = any_batch_example.device if force_same_device else device + shared_batch_size = any_batch_example.shape + if find_batch_size_max: + # Find maximum number of dimensions that are the same for all tensors + for t in x.values(): + if t.shape[: len(shared_batch_size)] != shared_batch_size: + # Find the maximum number of dimensions that are the same for all tensors + for i, (a, b) in enumerate(zip(t.shape, shared_batch_size)): + if a != b: + shared_batch_size = shared_batch_size[:i] + break + return TensorDict(x, batch_size=shared_batch_size, device=device) + + +# Errors +def raise_error_if_invalid_value(value: Any, possible_values: Sequence[Any], name: str = None): + """Raises an error if the given value (optionally named by `name`) is not one of the possible values.""" + if value not in possible_values: + name = name or (value.__name__ if hasattr(value, "__name__") else "value") + raise ValueError(f"{name} must be one of {possible_values}, but was {value} (type={type(value)})") + return value + + +def raise_error_if_has_attr_with_invalid_value(obj: Any, attr: str, possible_values: Sequence[Any]): + if hasattr(obj, attr): + raise_error_if_invalid_value(getattr(obj, attr), possible_values, name=f"{obj.__class__.__name__}.{attr}") + + +def raise_error_if_invalid_type(value: Any, possible_types: Sequence[Any], name: str = None): + """Raises an error if the given value (optionally named by `name`) is not one of the possible types.""" + if all([not isinstance(value, t) for t in possible_types]): + name = name or (value.__name__ if hasattr(value, "__name__") else "value") + raise ValueError(f"{name} must be an instance of either of {possible_types}, but was {type(value)}") + return value + + +def raise_if_invalid_shape( + value: Union[np.ndarray, Tensor], + expected_shape: Sequence[int] | int, + axis: int = None, + name: str = None, +): + if isinstance(expected_shape, int): + if value.shape[axis] != expected_shape: + name = name or (value.__name__ if hasattr(value, "__name__") else "value") + raise ValueError(f"{name} must have shape {expected_shape} along axis {axis}, but shape={value.shape}") + else: + if value.shape != expected_shape: + name = name or (value.__name__ if hasattr(value, "__name__") else "value") + raise ValueError(f"{name} must have shape {expected_shape}, but was {value.shape}") + + +class AlreadyLoggedError(Exception): + pass + + +# allow checkpointing via USR1 +def melk(trainer, ckptdir: str): + def actual_melk(*args, **kwargs): + # run all checkpoint hooks + if trainer.global_rank == 0: + log.info("Summoning checkpoint.") + # log.info("Is file: last.ckpt ?", os.path.isfile(os.path.join(ckptdir, "last.ckpt"))) + ckpt_path = os.path.join(ckptdir, "last.ckpt") + trainer.save_checkpoint(ckpt_path) + + return actual_melk + + +def divein(trainer): + def actual_divein(*args, **kwargs): + if trainer.global_rank == 0: + import pudb + + pudb.set_trace() + + return actual_divein + + +# Random seed (if not using pytorch-lightning) +def set_seed(seed, device="cuda"): + """ + Sets the random seed for the given device. + If using pytorch-lightning, preferably to use pl.seed_everything(seed) instead. + """ + # setting seeds + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if device != "cpu": + torch.cuda.manual_seed(seed) + + +def auto_gpu_selection( + usage_max: float = 0.2, + mem_max: float = 0.6, + num_gpus: int = 1, + raise_error_if_insufficient_gpus: bool = True, + verbose: bool = False, +): + """Auto set CUDA_VISIBLE_DEVICES for gpu (based on utilization) + + Args: + usage_max: max percentage of GPU memory + mem_max: max percentage of GPU utility + num_gpus: number of GPUs to use + raise_error_if_insufficient_gpus: raise error if no (not enough) GPU is available + """ + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + try: + log_output = str(subprocess.check_output("nvidia-smi", shell=True)).split(r"\n")[6:-1] + except subprocess.CalledProcessError as e: + print( + f"Error with code {e.returncode}. There's likely an issue with nvidia-smi." + f" Returning without setting CUDA_VISIBLE_DEVICES" + ) + return + + # Maximum of GPUS, 8 is enough for most + gpu_to_utilization, gpu_to_mem = dict(), dict() + gpus_available = torch.cuda.device_count() + gpu_to_usage = dict() + for gpu in range(gpus_available): + idx = gpu * 4 + 3 + if idx > log_output.__len__() - 1: + break + inf = log_output[idx].split("|") + if inf.__len__() < 3: + break + + try: + usage = int(inf[3].split("%")[0].strip()) + except ValueError: + print("Error with code. Returning without setting CUDA_VISIBLE_DEVICES") + return + mem_now = int(str(inf[2].split("/")[0]).strip()[:-3]) + mem_all = int(str(inf[2].split("/")[1]).strip()[:-3]) + + gpu_to_usage[gpu] = f"Memory:[{mem_now}/{mem_all}MiB] , GPU-Util:[{usage}%]" + if usage < 100 * usage_max and mem_now < mem_max * mem_all: + gpu_to_utilization[gpu] = usage + gpu_to_mem[gpu] = mem_now + # os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu) + if verbose: + log.info(f"GPU {gpu} is vacant: Memory:[{mem_now}/{mem_all}MiB] , GPU-Util:[{usage}%]") + else: + if verbose: + log.info( + f"GPU {gpu} is busy: Memory:[{mem_now}/{mem_all}MiB] , GPU-Util:[{usage}%] (> {usage_max * 100}%)" + ) + + if len(gpu_to_utilization) >= num_gpus: + least_utilized_gpus = sorted(gpu_to_utilization, key=gpu_to_utilization.get)[:num_gpus] + sorted(gpu_to_mem, key=gpu_to_mem.get)[:num_gpus] + if len(gpu_to_utilization) == 1: + os.environ["CUDA_VISIBLE_DEVICES"] = str(least_utilized_gpus[0]) + else: + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(gpu) for gpu in least_utilized_gpus]) + log.info(f"Set os.environ['CUDA_VISIBLE_DEVICES'] = {os.environ['CUDA_VISIBLE_DEVICES']}") + for gpu in least_utilized_gpus: + log.info(f"Use GPU {gpu} with utilization {gpu_to_usage[gpu]}") + if num_gpus > 1: + log.info(f"Use GPUs {least_utilized_gpus} based on least utilization") + else: + if raise_error_if_insufficient_gpus: + raise ValueError("No vacant GPU") + log.info("\nNo vacant GPU, use CPU instead\n") + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + + +def print_gpu_memory_usage( + prefix: str = "", + tqdm_bar=None, + add_description: bool = True, + keep_old: bool = False, + empty_cache: bool = False, + log_func: Optional[Callable] = None, +): + """Use this function to print the GPU memory usage (logged or in a tqdm bar). + Use this to narrow down memory leaks, by printing the GPU memory usage before and after a function call + and checking if the available memory is the same or not. + Recommended to use with 'empty_cache=True' to get the most accurate results during debugging. + """ + if torch.cuda.is_available(): + if empty_cache: + torch.cuda.empty_cache() + used, allocated = torch.cuda.mem_get_info() + prefix = f"{prefix} GPU mem free/allocated" if add_description else prefix + info_str = f"{prefix} {used / 1e9:.2f}/{allocated / 1e9:.2f}GB" + if tqdm_bar is not None: + if keep_old: + tqdm_bar.set_postfix_str(f"{tqdm_bar.postfix} | {info_str}") + else: + tqdm_bar.set_postfix_str(info_str) + elif log_func is not None: + log_func(info_str) + else: + log.info(info_str) + + +def get_pl_trainer_kwargs_for_evaluation( + trainer_config: DictConfig = None, +) -> (Dict[str, Any], torch.device): + """Get kwargs for pytorch-lightning Trainer for evaluation and select <=1 GPU if available""" + # GPU or not: + if torch.cuda.is_available() and (trainer_config is None or trainer_config.accelerator != "cpu"): + accelerator, devices, reload_to_device = "gpu", 1, torch.device("cuda:0") + auto_gpu_selection(usage_max=0.6, mem_max=0.75, num_gpus=devices) + else: + accelerator, devices, reload_to_device = "cpu", "auto", torch.device("cpu") + return dict(accelerator=accelerator, devices=devices, strategy="auto"), reload_to_device + + +def infer_main_batch_key_from_dataset(dataset: torch.utils.data.Dataset) -> str: + ds = dataset + main_data_key = None + if hasattr(ds, "main_data_key"): + main_data_key = ds.main_data_key + else: + data_example = ds[0] + if isinstance(data_example, dict): + if "dynamics" in data_example: + main_data_key = "dynamics" + elif "data" in data_example: + main_data_key = "data" + else: + raise ValueError(f"Could not determine main_data_key from data_example: {data_example.keys()}") + return main_data_key + + +def rename_state_dict_keys(state_dict: Dict[str, torch.Tensor]) -> (Dict[str, torch.Tensor], bool): + # Missing key(s) in state_dict: "model.downs.0.2.fn.fn.to_qkv.1.weight", "model.downs.1.2.fn.fn.to_qkv.1.weight", + # Unexpected key(s) in state_dict: "model.downs.0.2.fn.fn.to_qkv.weight", "model.downs.1.2.fn.fn.to_qkv.weight", + # rename weights + renamed = False + for k in list(state_dict.keys()): + if "fn.to_qkv.weight" in k and "mid_attn" not in k: + state_dict[k.replace("fn.to_qkv.weight", "fn.to_qkv.1.weight")] = state_dict.pop(k) + renamed = True + + return state_dict, renamed + + +def rename_state_dict_keys_and_save(torch_model_state, ckpt_path: str) -> Dict[str, torch.Tensor]: + """Renames the state dict keys and saves the renamed state dict back to the checkpoint.""" + state_dict, has_been_renamed = rename_state_dict_keys(torch_model_state["state_dict"]) + if has_been_renamed: + # Save the renamed model state + torch_model_state["state_dict"] = state_dict + torch.save(torch_model_state, ckpt_path) + return state_dict + + +def freeze_model(model: nn.Module): + for param in model.parameters(): + param.requires_grad = False + model.eval() # set to eval mode + return model + + +all_dropout_layers = [nn.Dropout, nn.Dropout2d, nn.Dropout3d, nn.AlphaDropout, nn.FeatureAlphaDropout, DropPath] + + +def enable_inference_dropout(model: nn.Module): + """Set all dropout layers to training mode""" + # find all dropout layers + dropout_layers = [m for m in model.modules() if any([isinstance(m, layer) for layer in all_dropout_layers])] + for layer in dropout_layers: + layer.train() + # assert all([layer.training for layer in [m for m in model.modules() if isinstance(m, nn.Dropout)]]) + + +def disable_inference_dropout(model: nn.Module): + """Set all dropout layers to eval mode""" + # find all dropout layers + dropout_layers = [m for m in model.modules() if any([isinstance(m, layer) for layer in all_dropout_layers])] + for layer in dropout_layers: + layer.eval() + + +def find_differences_between_dicts(d1: Dict[str, Any], d2: Dict[str, Any]) -> List[str]: + """Finds any (nested) differences between the two dictionaries.""" + diff = [] + for k, v in d1.items(): + d2_v = d2.get(k) + if not isinstance(v, dict) and d2_v != v: + diff.append(f"{k}: {v} != {d2_v}") + elif isinstance(v, dict): + diff += find_differences_between_dicts(v, d2_v) + return diff + + +def update_dict_with_other(d1: Dict[str, Any], other: Dict[str, Any]): # _and_return_difference + """Updates d1 with other, other can be a dict of dicts with partial updates. + + Returns: + d1: the updated dict + diff: the difference between the original d1 and the updated d1 as a string + + Example: + d1 = {'a': {'b': 1, 'c': 2}, 'x': 99} + other = {'a': {'b': 3}, 'y': 100} + d1, diff = update_dict_with_other(d1, other) + log.info(d1) + # {'a': {'b': 3, 'c': 2}, 'x': 99, 'y': 100} + log.info(diff) + # ['a.b: 1 -> 3', 'y: None -> 100'] + """ + diff = [] + for k, v in other.items(): + if isinstance(v, dict) and d1.get(k) is not None: + d1[k], diff_sub = update_dict_with_other(d1.get(k, {}), v) + diff += [f"{k}.{x}" for x in diff_sub] + else: + if d1.get(k) != v: + diff.append(f"{k}: {d1.get(k, None)} -> {v}") + d1[k] = v + return d1, diff + + +def flatten_dict(dictionary: Dict[Any, Any], save: bool = True) -> Dict[Any, Any]: + """Flattens a nested dict.""" + # The dictionary may consist of dicts or normal values + # If it's a dict, recursively flatten it + # If it's a normal value, return it + flattened = {} + for k, v in dictionary.items(): + if isinstance(v, (dict, TensorDictBase)): + # check that no duplicate keys exist + flattened_v = flatten_dict(v) + if save and len(set(flattened_v.keys()).intersection(set(flattened.keys()))) > 0: + raise ValueError(f"Duplicate keys in flattened dict: {set(flattened_v.keys())}") + flattened.update(flattened_v) + else: + if save and k in flattened: + raise ValueError(f"Duplicate keys in flattened dict: {k}") + flattened[k] = v + return flattened + + +def find_config_differences( + configs: List[Dict[str, Any]], + keys_to_tolerated_percent_diff: Dict[str, float] = None, + sort_by_name: bool = True, +) -> List[List[str]]: + """ + Find and return the differences between multiple nested configurations. + + This function compares each configuration with all others and identifies + keys that have different values across configurations. It returns a list + of differences for each input configuration. + + Args: + configs (List[Dict[str, Any]]): A list of nested configuration dictionaries. + keys_to_tolerated_percent_diff (Dict[str, float], optional): A dictionary mapping keys to maximum tolerated differences. Any keys not included in this dictionary will be compared for exact equality. Defaults to None. + sort_by_name (bool, optional): Whether to sort the output by key names. + Defaults to True. + + Returns: + List[List[str]]: A list containing lists of strings, where each inner list + represents the differences for one configuration in the + format ["key=value", ...]. + + Example: + configs = [ + {"a": 1, "b": {"c": 2, "d": 3}}, + {"a": 1, "b": {"c": 2, "d": 4}}, + {"a": 2, "b": {"c": 2, "d": 3}, "e": 5} + ] + result = find_config_differences(configs) + # Result will be: + # [['a=1', 'b.d=3'], ['b.d=4'], ['a=2', 'e=5']] + """ + + def flatten_dict(d: Dict[str, Any], prefix: str = "") -> Dict[str, Any]: + """ + Recursively flatten a nested dictionary, using dot notation for nested keys. + + Args: + d (Dict[str, Any]): The dictionary to flatten. + prefix (str, optional): The prefix to use for the current level of nesting. + + Returns: + Dict[str, Any]: A flattened version of the input dictionary. + """ + items = [] + for k, v in d.items(): + new_key = f"{prefix}.{k}" if prefix else k + if isinstance(v, dict): + # If the value is a dictionary, recurse with the new key as prefix + items.extend(flatten_dict(v, new_key).items()) + else: + # If the value is not a dictionary, add it to the items list + items.append((new_key, v)) + return dict(items) + + # Flatten all input configurations + flat_configs = [flatten_dict(config) for config in configs] + + # Get all unique keys from all configurations + all_keys = set().union(*flat_configs) + + # Sort keys if sort_by_name is True + if sort_by_name: + all_keys = sorted(all_keys) + + keys_to_tolerated_percent_diff = keys_to_tolerated_percent_diff or dict() + assert all( + [0 <= value <= 1 for value in keys_to_tolerated_percent_diff.values()] + ), "Values in keys_to_tolerated_percent_diff must be between 0 and 1" + differences = [] + for i, config in enumerate(flat_configs): + diff = [] + for key in all_keys: + # Check if the key exists in the current config and has a different value in any other config + value = config.get(key) + if value is not None: + # Check if the key is in the keys_to_percent_diff dictionary + if key in keys_to_tolerated_percent_diff.keys(): + if any( + abs(value - other.get(key, value)) > keys_to_tolerated_percent_diff[key] * abs(value) + for other in flat_configs + ): + diff.append(f"{key}={config[key]}") + # print(f"key={key}, value={value}. diff={[abs(value - other.get(key, value)) for other in flat_configs]}. tol={keys_to_tolerated_percent_diff[key] * abs(value)}") + elif any(config[key] != other.get(key) for other in flat_configs): + diff.append(f"{key}={config[key]}") + differences.append(diff) + + return differences + + +def find_config_differences_return_as_joined_str( + configs: List[Dict[str, Any]], join_with: str = " ", **kwargs +) -> List[str]: + differences_list = find_config_differences(configs, **kwargs) + return [join_with.join(differences) for differences in differences_list] + + +def concatenate_array_dicts( + arrays_dicts: List[Dict[str, np.ndarray]], + axis: int = 0, + keys: List[str] = None, +) -> Dict[str, np.ndarray]: + """Concatenates the given dicts of arrays along the given axis. The dict may be nested. + Args: + arrays_dicts: A list of a + """ + if len(arrays_dicts) == 0: + return dict() + + if keys is None: + # Check that all dicts have the same keys + keys = arrays_dicts[0].keys() + for d in arrays_dicts: + if d.keys() != keys: + raise ValueError(f"Keys of dicts do not match: {d.keys()} != {keys}") + else: + keys = [keys] if isinstance(keys, str) else keys + + # Concatenate the arrays + concatenated = {} + for k in keys: + if isinstance(arrays_dicts[0][k], dict): + concatenated[k] = concatenate_array_dicts([d[k] for d in arrays_dicts], axis=axis) + else: + concatenated[k] = np.concatenate([d[k] for d in arrays_dicts], axis=axis) + return concatenated + + +def get_first_array_in_nested_dict(nested_dict: Dict[str, Any]): + """Returns the first array that is found when descending the hierarchy in the given nested dict.""" + for v in nested_dict.values(): + if isinstance(v, dict): + return get_first_array_in_nested_dict(v) + elif isinstance(v, (np.ndarray, Tensor)): + return v + raise ValueError(f"Could not find any array in the given nested dict: {nested_dict}") + + +def split3d_and_merge_variables(results_dict, level_names) -> Dict[str, Any]: + """""" + if level_names is None: + return results_dict + keys3d = [k for k in results_dict.keys() if "3d" in k] + # results_dicts = dict() + assert len(keys3d) == 1, f"Expected only one 3d key, but got {keys3d}" + for k in keys3d: + data3d = results_dict.pop(k) + results_dict[k] = dict() + for variable in list(data3d.keys()): + var_data = data3d.pop(variable) + n_levels = var_data.shape[-3] + for i in range(n_levels): + new_k = f"{variable}_{level_names[i]}" + results_dict[k][new_k] = torch_select(var_data, dim=-3, index=i) + results_dict = flatten_dict(results_dict) + # results_dicts[k_base] = flatten_dict(results_dict)\ + return results_dict + + +def split3d_and_merge_variables_maybe( + results_dict, result_key: str, multiple_result_keys: List[str], level_names: List[str] +) -> Dict[str, Any]: + if result_key in results_dict.keys(): + return results_dict[result_key] + elif all([k in results_dict.keys() for k in multiple_result_keys]): + return split3d_and_merge_variables({k: results_dict[k] for k in multiple_result_keys}, level_names) + else: + raise ValueError( + f"Could not find any of the given keys in the results_dict: {result_key}, {multiple_result_keys}" + ) + + +def get_common_substrings(strings, min_length): + common_substrings = set() + + # Compare each pair of strings to find common substrings + for i in range(len(strings)): + for j in range(i + 1, len(strings)): + s1, s2 = strings[i], strings[j] + seq_matcher = SequenceMatcher(None, s1, s2) + + for match in seq_matcher.get_matching_blocks(): + if match.size >= min_length: + common_substrings.add(s1[match.a : match.a + match.size]) + + return common_substrings + + +def remove_substrings(strings, substrings): + result = [] + strings_list = [strings] if isinstance(strings, str) else strings + for string in strings_list: + for substring in substrings: + string = string.replace(substring, "") + result.append(string) + if isinstance(strings, str): + return result[0] + return result + + +def remove_common_substrings(strings, min_length: int = 5): + common_substrings = get_common_substrings(strings, min_length) + return remove_substrings(strings, common_substrings) diff --git a/src/utilities/wandb_api.py b/src/utilities/wandb_api.py new file mode 100644 index 0000000..61617de --- /dev/null +++ b/src/utilities/wandb_api.py @@ -0,0 +1,1296 @@ +from __future__ import annotations + +import glob +import logging +import os +from collections import defaultdict +from datetime import datetime, timezone +from typing import Any, Callable, Dict, List, Optional, Sequence, Union + +import numpy as np +import pandas as pd +import requests +import wandb +from omegaconf import DictConfig, OmegaConf +from tqdm.auto import tqdm + +from src.utilities.checkpointing import ( + get_local_ckpt_path, + local_path_to_absolute_and_download_if_needed, +) +from src.utilities.utils import find_config_differences_return_as_joined_str, get_logger + + +# Override this in your project +# ----------------------------------------------------------------------- +_ENTITY = None # Set to your wandb (team or personal) entity +PROJECT = "Spherical-DYffusion" +_PROJECT_TRAIN = None # Set this when using a different project for logging than for reloading checkpoints +_TRAINING_RUN_PATH = None # Set this when using a different run for training than for reloading checkpoints +# ----------------------------------------------------------------------- + +log = get_logger(__name__) + +CACHE = dict() + + +def get_entity(entity: str = None) -> str: + if entity is None: + return _ENTITY or wandb.api.default_entity + return entity + + +def get_project_train(project: str = None) -> str: + if project is None: + return _PROJECT_TRAIN or PROJECT + return project + + +def get_training_run_path(): + return _TRAINING_RUN_PATH + + +def get_api(wandb_api: wandb.Api = None, timeout=100) -> wandb.Api: + if wandb_api is None: + try: + wandb_api = wandb.Api(timeout=timeout) + except wandb.errors.UsageError: + wandb.login() + wandb_api = wandb.Api(timeout=timeout) + return wandb_api + + +def get_api_and_set_entity(entity: str = None, wandb_api: wandb.Api = None) -> wandb.Api: + entity = get_entity(entity) + api = get_api(wandb_api) + api._default_entity = entity + return api + + +def get_run_api( + run_id: str = None, + entity: str = None, + project: str = None, + run_path: str = None, + wandb_api: wandb.Api = None, +) -> wandb.apis.public.Run: + entity, project = get_entity(entity), project or PROJECT + assert run_path is None or run_id is None, "Either run_path or run_id must be None" + assert run_id is None or isinstance(run_id, str), f"run_id must be a string, but is {type(run_id)}: {run_id}" + run_path = run_path or f"{entity}/{project}/{run_id}" + api = get_api_and_set_entity(entity, wandb_api) + return api.run(run_path) + + +def get_project_runs( + entity: str = None, project: str = None, wandb_api: wandb.Api = None, **kwargs +) -> List[wandb.apis.public.Run]: + """Filter with kwarg: filters: Optional[Dict[str, Any]] = None,""" + entity, project = get_entity(entity), project or PROJECT + return get_api_and_set_entity(entity, wandb_api).runs(f"{entity}/{project}", **kwargs) + + +def get_project_groups( + entity: str = None, project: str = None, wandb_api: wandb.Api = None +) -> List[wandb.apis.public.Run]: + runs = get_project_runs(entity, project, wandb_api) + return list(set([run.group for run in runs])) + + +def get_runs_for_group( + group: str, + entity: str = None, + project: str = None, + wandb_api: wandb.Api = None, + filter_dict: Dict[str, Any] = None, + filter_functions: Sequence[Callable] = None, + only_ids: bool = False, + verbose: bool = True, + **kwargs, +) -> Union[List[wandb.apis.public.Run], List[str]]: + """Get all runs for a given group""" + extra_filters = {"group": group} if group is not None else None + filter_wandb_api = get_filter_for_wandb(filter_dict, extra_filters=extra_filters) + group_runs = get_project_runs(entity, project, wandb_api, filters=filter_wandb_api, **kwargs) # {"group": group}) + if len(group_runs) == 0: + if verbose: + pass + # print(f"----> No runs for {group=}! Did you mistype the group name? Entity/project: {entity}/{project}") + elif filter_functions is not None: + n_groups_before = len(group_runs) + filter_functions = [filter_functions] if callable(filter_functions) else list(filter_functions) + group_runs = [run for run in group_runs if all([f(run) for f in filter_functions])] + if len(group_runs) == 0 and len(filter_functions) > 0 and verbose: + print(f"Filter functions filtered out all {n_groups_before} runs for group {group}") + # elif n_groups_before == 0: + # print(f"----> No runs for group {group}!! Did you mistype the group name?") + + if only_ids: + group_runs = [run.id for run in group_runs] + return group_runs + + +def get_runs_for_project(**kwargs): + return get_runs_for_group(group=None, **kwargs) + + +def get_run_apis( + run_id: str = None, + group: str = None, + **kwargs, +) -> List[wandb.apis.public.Run]: + assert run_id is None or group is None, "Either run_id or group must be None" + assert run_id is not None or group is not None, "Either run_id or group must be not None" + assert run_id is None or isinstance(run_id, str), f"run_id must be a string, but is {type(run_id)}: {run_id}" + assert group is None or isinstance(group, str), f"group must be a string, but is {type(group)}: {group}" + if run_id is not None: + return [get_run_api(run_id=run_id, **kwargs)] + else: + return get_runs_for_group(group=group, **kwargs) + + +def get_wandb_id_for_run() -> str: + """Get a unique id for the current run. If on a Slurm cluster, use the job ID, otherwise generate a random id.""" + if "SLURM_JOB_ID" in os.environ: + # we are on a Slurm cluster... using the job ID helps when requeuing jobs to resume the same run + return str(os.environ["SLURM_JOB_ID"]) + else: + # we are not on a Slurm cluster, so just generate a random id + return wandb.sdk.lib.runid.generate_id() + + +def get_runs_for_group_with_any_metric( + wandb_group: str, + options: List[str] | str, + option_to_key: Callable[[str], str] | None = None, + wandb_api=None, + metric: str = "crps", + **wandb_kwargs, +) -> (Optional[List[wandb.apis.public.Run]], str): + """Get all runs for a given group that have any of the given metrics.""" + options = [options] if isinstance(options, str) else options + option_to_key = option_to_key or (lambda x: x) + wandb_kwargs2 = wandb_kwargs.copy() + group_runs, any_metric_key = None, None + tried_options = [] + for s_i, sum_metric in enumerate(options): + any_metric_key = f"{option_to_key(sum_metric)}/{metric}".replace("//", "/") + tried_options.append(any_metric_key) + filter_func = has_summary_metric(any_metric_key) + if "filter_functions" not in wandb_kwargs: + wandb_kwargs2["filter_functions"] = filter_func + elif "filter_functions" in wandb_kwargs and len(options) > 1: + wandb_kwargs2["filter_functions"] = wandb_kwargs["filter_functions"] + [filter_func] + else: + wandb_kwargs2["filter_functions"] = wandb_kwargs["filter_functions"] + group_runs = get_runs_for_group(wandb_group, wandb_api=wandb_api, verbose=False, **wandb_kwargs2) + if len(group_runs) > 0: + break + if len(group_runs) == 0: + logging.warning( + f"No runs found for group {wandb_group}. " + f"Possible splits: {options}.\nFull keys that were tried: {tried_options}" + ) + return None, None + return group_runs, any_metric_key.replace(f"/{metric}", "") + + +def get_wandb_ckpt_name(run_path: str, epoch: Union[str, int] = "best") -> str: + """ + Get the wandb ckpt name for a given run_path and epoch. + Args: + run_path: ENTITY/PROJECT/RUN_ID + epoch: If an int, the ckpt name will be the one for that epoch. + If 'last' ('best') the latest ('best') epoch ckpt will be returned. + + Returns: + The wandb ckpt file-name, that can be used as follows to restore the checkpoint locally: + >>> run_path = "" + >>> ckpt_name = get_wandb_ckpt_name(run_path, epoch) + >>> wandb.restore(ckpt_name, run_path=run_path, replace=True, root=os.getcwd()) + """ + assert epoch in ["best", "last"] or isinstance( + epoch, int + ), f"epoch must be 'best', 'last' or an int, but is {epoch}" + run_api = get_run_api(run_path=run_path) + ckpt_files = [f.name for f in run_api.files() if f.name.endswith(".ckpt")] + if epoch == "best": + if "best.ckpt" in ckpt_files: + ckpt_filename = "best.ckpt" + else: + raise ValueError(f"Could not find best.ckpt in {ckpt_files}") + elif "last.ckpt" in ckpt_files and epoch == "last": + ckpt_filename = "last.ckpt" + else: + if len(ckpt_files) == 0: + raise ValueError(f"Wandb run {run_path} has no checkpoint files (.ckpt) saved in the cloud!") + elif len(ckpt_files) >= 2: + ckpt_epochs = [int(name.replace("epoch", "")[:3]) for name in ckpt_files] + if epoch == "last": + # Use checkpoint of latest epoch if epoch is not specified + max_epoch = max(ckpt_epochs) + ckpt_filename = [name for name in ckpt_files if str(max_epoch) in name][0] + log.info(f"Multiple ckpt files exist: {ckpt_files}. Using latest epoch: {ckpt_filename}") + else: + # Use checkpoint with specified epoch + ckpt_filename = [name for name in ckpt_files if str(epoch) in name] + if len(ckpt_filename) == 0: + raise ValueError(f"There is no ckpt file for epoch={epoch}. Try one of the ones in {ckpt_epochs}!") + ckpt_filename = ckpt_filename[0] + else: + ckpt_filename = ckpt_files[0] + log.warning(f"Only one ckpt file exists: {ckpt_filename}. Using it...") + return ckpt_filename + + +def restore_model_from_wandb_cloud( + run_path: str, + local_checkpoint_path: str = None, + ckpt_filename: str = "best", # was None + throw_error_if_local_not_found: bool = False, + config: DictConfig = None, + **kwargs, +) -> str: + """ + Restore the model from the wandb cloud to local file-system. + Args: + run_path: PROJECT/ENTITY/RUN_ID + local_checkpoint_path: If not None, the model will be restored from this path. + ckpt_filename: If not None, the model will be restored from this filename (in the cloud). + + Returns: + The ckpt filename that can be used to reload the model locally. + """ + if local_checkpoint_path is True: + assert ckpt_filename is not None, "If local_checkpoint_path is True, please specify ckpt_filename" + local_checkpoint_path = get_local_ckpt_path( + config, + wandb_run=get_run_api(run_path=run_path), + ckpt_filename=ckpt_filename, + throw_error_if_local_not_found=throw_error_if_local_not_found, + ) + + entity, project, wandb_id = run_path.split("/") + if isinstance(local_checkpoint_path, (str,)): + ckpt_path = local_path_to_absolute_and_download_if_needed(local_checkpoint_path) + log.info(f"Restoring model from local absolute path: {ckpt_path}") + else: + if ckpt_filename is None: + ckpt_filename = get_wandb_ckpt_name(run_path, **kwargs) + ckpt_filename = ckpt_filename.split("/")[-1] # in case the file contains local dir structure + + expected_ckpt_path = os.path.join(os.getcwd(), ckpt_filename) + if wandb_id in expected_ckpt_path: + ckpt_path = expected_ckpt_path + else: + # rename best_model_fname to add a unique prefix to avoid conflicts with other runs + # (e.g. if the same model is reloaded twice). Replace only filename part of the path, not the dir structure + ckpt_path = os.path.join(os.getcwd(), f"{wandb_id}-{ckpt_filename}") + + ckpt_path_tmp = ckpt_path + if not os.path.exists(ckpt_path): + try: + from src.utilities.s3utils import download_s3_object + + s3_file_path = f"{project}/checkpoints/{wandb_id}/{ckpt_filename}" + download_s3_object(s3_file_path, ckpt_path, throw_error=True) + except Exception: + # IMPORTANT ARGS replace=True: see https://github.com/wandb/client/issues/3247 + ckpt_path_tmp = wandb.restore(ckpt_filename, run_path=run_path, replace=True, root=os.getcwd()).name + assert os.path.abspath(ckpt_path_tmp) == expected_ckpt_path + + # if os.path.exists(ckpt_path): + # # if DDP and multiple processes are restoring the same model, this may happen. check if is_rank_zero + # if ckpt_path != ckpt_path_tmp: + # os.remove(ckpt_path) # remove if one exists from before + if os.path.exists(ckpt_path_tmp): + os.rename(ckpt_path_tmp, ckpt_path) + return ckpt_path + + +def load_hydra_config_from_wandb( + run_path: str | wandb.apis.public.Run, + override_config: Optional[DictConfig] = None, + override_key_value: List[str] = None, + update_config_in_cloud: bool = False, +) -> DictConfig: + """ + Args: + run_path (str): the wandb ENTITY/PROJECT/ID (e.g. ID=2r0l33yc) corresponding to the config to-be-reloaded + override_config (DictConfig): each of its keys will override the corresponding entry loaded from wandb + override_key_value: each element is expected to have a "=" in it, like datamodule.num_workers=8 + update_config_in_cloud: if True, the config in the cloud will be updated with the new overrides + """ + if override_config is not None and override_key_value is not None: + log.warning("Both override_config and override_key_value are not None! ") + if isinstance(run_path, wandb.apis.public.Run): + run = run_path + run_path = "/".join(run.path) + else: + assert isinstance( + run_path, str + ), f"run_path must be a string or wandb.apis.public.Run, but is {type(run_path)}" + run = get_run_api(run_path=run_path) + + override_key_value = override_key_value or [] + if not isinstance(override_key_value, list): + raise ValueError(f"override_key_value must be a list of strings, but has type {type(override_key_value)}") + # copy overrides to new list + overrides = list(override_key_value.copy()) + rank = os.environ.get("RANK", None) or os.environ.get("LOCAL_RANK", 0) + + # Check if hydra_config file exists locally + work_dir = run.config.get("dirs/work_dir", run.config.get("dirs.work_dir", None)) + if work_dir is not None and run.id not in work_dir: + work_dir = os.path.join(os.path.dirname(work_dir), run.id) + is_local = False + if work_dir is not None and os.path.exists(work_dir): + wandb_dir = os.path.join(work_dir, "wandb") + if os.path.exists(os.path.join(wandb_dir, "latest-run")): + wandb_dir = os.path.join(wandb_dir, "latest-run") + # Find hydra_config file in wandb directory or subdirectories using glob + hydra_config_files = glob.glob(f"{wandb_dir}/**/hydra_config*.yaml", recursive=True) + log.info(f"[rank: {rank}] Found {len(hydra_config_files)} files in {wandb_dir}: {hydra_config_files}") + # assert len(hydra_config_files) > 0, f"Could not find any hydra_config file in {wandb_dir}" + is_local = True + # torch.distributed.barrier() + if not is_local or len(hydra_config_files) == 0: + # Find latest hydra_config-v{VERSION}.yaml file in wandb cloud (Skip versions labeled as 'old') + hydra_config_files = [f.name for f in run.files() if "hydra_config" in f.name and "old" not in f.name] + is_local = False + + if len(hydra_config_files) == 0: + raise ValueError(f"Could not find any hydra_config file in wandb run {run_path}") + elif len(hydra_config_files) == 1: + if not is_local: + assert hydra_config_files[0].endswith( + "hydra_config.yaml" + ), f"Only one hydra_config file found: {hydra_config_files}" + else: + hydra_config_files = [f for f in hydra_config_files if "hydra_config-v" in f] + assert len(hydra_config_files) > 0, f"Could not find any hydra_config-v file in wandb run {run_path}" + # Sort by version number (largest is last, earliest are hydra_config.yaml and hydra_config-v1.yaml), + hydra_config_files = sorted(hydra_config_files, key=lambda x: int(x.split("-v")[-1].split(".")[0])) + + hydra_config_file = hydra_config_files[-1] + if not hydra_config_file.endswith("hydra_config.yaml"): + log.info(f" Reloading from hydra config file: {hydra_config_file}") + + # Download from wandb cloud + if is_local or (os.path.exists(hydra_config_file) and rank not in ["0", 0]): + pass + else: + wandb_restore_kwargs = dict(run_path=run_path, replace=True, root=os.getcwd()) + wandb.restore(hydra_config_file, **wandb_restore_kwargs) + + # remove overrides of the form k=v, where k has no dot in it. We don't support this. + overrides = [o for o in overrides if "=" in o and "." in o.split("=")[0]] + if len(overrides) != len(override_key_value): + diff = set(overrides) - set(override_key_value) + log.warning(f"The following overrides were removed because they are not in the form k=v: {diff}") + + overrides += [ + f"logger.wandb.id={run.id}", + f"logger.wandb.entity={run.entity}", + f"logger.wandb.project={run.project}", + f"logger.wandb.tags={run.tags}", + f"logger.wandb.group={run.group}", + ] + config = OmegaConf.load(hydra_config_file) + overrides = OmegaConf.from_dotlist(overrides) + config = OmegaConf.unsafe_merge(config, overrides) + + if override_config is not None: + for k, v in override_config.items(): + if k in ["model", "trainer"] and isinstance(v, str): + override_config.pop(k) # remove key from override_config + log.warning(f"Key {k} is a string, but it should be a DictConfig. Ignoring it.") + # override config with override_config (which needs to be the second argument of OmegaConf.merge) + config = OmegaConf.unsafe_merge(config, override_config) # unsafe_merge since override_config is not needed + + if not is_local: + os.remove(hydra_config_file) if os.path.exists(hydra_config_file) else None + os.remove(f"../../{hydra_config_file}") if os.path.exists(f"../../{hydra_config_file}") else None + + if run.id != config.logger.wandb.id and run.id in config.logger.wandb.name: + config.logger.wandb.id = run.id + assert str(config.logger.wandb.id) == str( + run.id + ), f"{config.logger.wandb.id=} != {run.id=}. \nFull Hydra config: {config}" + if update_config_in_cloud: + with open("hydra_config.yaml", "w") as fp: + OmegaConf.save(config, f=fp.name, resolve=True) + run.upload_file("hydra_config.yaml", root=".") + os.remove("hydra_config.yaml") + return config + + +def does_any_ckpt_file_exist(wandb_run: wandb.apis.public.Run, only_best_and_last: bool = True, local_dir: str = None): + """ + Check if any checkpoint file exists in the wandb run. + Args: + wandb_run: the wandb run to check + only_best_and_last: if True, only checks for 'best.ckpt' and 'last.ckpt' files, otherwise checks for all ckpt files + Setting to true may speed up the check, since it will stop as soon as it finds one of the two files. + """ + if local_dir is not None: + if wandb_run.id not in local_dir: + local_dir = os.path.join(local_dir, wandb_run.id) + if os.path.exists(local_dir): + ckpt_files = [f for f in os.listdir(local_dir) if f.endswith(".ckpt")] + if len(ckpt_files) > 0: + return True + + names = ["last.ckpt", "best.ckpt"] if only_best_and_last else None + if "checkpoint/in_s3" in wandb_run.summary.keys(): + return True # Using S3 storage + + return len([1 for f in wandb_run.files(names=names) if f.name.endswith(".ckpt")]) > 0 + + +def get_existing_wandb_group_runs( + config: DictConfig, ckpt_must_exist: bool = False, **kwargs +) -> List[wandb.apis.public.Run]: + if config.get("logger", None) is None or config.logger.get("wandb", None) is None: + return [] + wandb_cfg = config.logger.wandb + runs_in_group = get_runs_for_group(wandb_cfg.group, entity=wandb_cfg.entity, project=wandb_cfg.project) + try: + _ = len(runs_in_group) + except ValueError: # happens if project does not exist + return [] + if ckpt_must_exist: + local_dir = config.ckpt_dir + runs_in_group = [run for run in runs_in_group if does_any_ckpt_file_exist(run, **kwargs, local_dir=local_dir)] + return runs_in_group + # other_seeds = [run.config.get('seed') for run in other_runs] + # if config.seed in other_seeds: + # state = runs_in_group[other_seeds.index(config.seed)].state + # log.info(f"Found a run (state={state}) with the same seed (={this_seed}) in group {group}.") + # return True + # return False + + +def reupload_run_history(run): + """ + This function can be called when for weird reasons your logged metrics do not appear in run.summary. + All metrics for each epoch (assumes that a key epoch=i for each epoch i was logged jointly with the metrics), + will be reuploaded to the wandb run summary. + """ + summary = {} + for row in run.scan_history(): + if "epoch" not in row.keys() or any(["gradients/" in k for k in row.keys()]): + continue + summary.update(row) + run.summary.update(summary) + + +##################################################################### +# +# Pre-filtering of wandb runs +# +def has_finished(run): + return run.state == "finished" + + +def not_running(run): + return run.state != "running" + + +def has_final_metric(run) -> bool: + return "test/mse" in run.summary.keys() and "test/mse" in run.summary.keys() + + +def has_run_id(run_ids: str | List[str]) -> Callable: + if isinstance(run_ids, str): + run_ids = [run_ids] + return lambda run: any([run.id == rid for rid in run_ids]) + + +def contains_in_run_name(name: str) -> Callable: + return lambda run: name in run.name + + +def has_summary_metric(metric_name: str, check_non_nan: bool = False) -> Callable: + metric_name = metric_name.replace("//", "/") + + def has_metric(run): + return metric_name in run.summary.keys() # or metric_name in run.summary_metrics.keys() + + def has_metric_non_nan(run): + value = run.summary.get(metric_name) + try: + return value is not None and value not in {"NaN", "Infinity"} and not np.isnan(value) + except Exception as e: + raise ValueError( + f"Error when checking metric {metric_name} in run {run.id}. Summary value: {value}, type: {type(value)}" + ) from e + + return has_metric_non_nan if check_non_nan else has_metric + + +def has_summary_metric_any(metric_names: List[str], check_non_nan: bool = False) -> Callable: + metric_names = [m.replace("//", "/") for m in metric_names] + + def has_metric(run): + return any([m in run.summary.keys() for m in metric_names]) + + def has_metric_non_nan(run): + return any([m in run.summary.keys() and not np.isnan(run.summary[m]) for m in metric_names]) + + return has_metric_non_nan if check_non_nan else has_metric + + +def has_summary_metric_lower_than(metric_name: str, lower_than: float) -> Callable: + metric_name = metric_name.replace("//", "/") + return lambda run: metric_name in run.summary.keys() and run.summary[metric_name] < lower_than + + +def has_summary_metric_greater_than(metric_name: str, greater_than: float) -> Callable: + metric_name = metric_name.replace("//", "/") + return lambda run: metric_name in run.summary.keys() and run.summary[metric_name] > greater_than + + +def has_minimum_runtime(min_minutes: float = 10.0) -> Callable: + return lambda run: run.summary.get("_runtime", 0) > min_minutes * 60 + + +def has_minimum_epoch(min_epoch: int = 10) -> Callable: + def has_min_epoch(run): + hist = run.history(keys=["epoch"]) + return len(hist) > 0 and max(hist["epoch"]) > min_epoch + + return has_min_epoch + + +def has_minimum_epoch_simple(min_epoch: int = 10) -> Callable: + def has_min_epoch(run): + return run.summary.get("epoch", 0) > min_epoch + + return has_min_epoch + + +def has_maximum_epoch_simple(max_epoch: int = 10) -> Callable: + def has_min_epoch(run): + return run.summary.get("epoch", np.inf) < max_epoch + + return has_min_epoch + + +def has_keys(keys: Union[str, List[str]]) -> Callable: + keys = [keys] if isinstance(keys, str) else keys + return lambda run: all([(k in run.summary.keys() or k in run.config.keys()) for k in keys]) + + +def hasnt_keys(keys: Union[str, List[str]]) -> Callable: + keys = [keys] if isinstance(keys, str) else keys + return lambda run: all([(k not in run.summary.keys() and k not in run.config.keys()) for k in keys]) + + +def has_max_metric_value(metric: str = "test/MERRA2/mse_epoch", max_metric_value: float = 1.0) -> Callable: + return lambda run: run.summary[metric] <= max_metric_value + + +def has_tags(tags: Union[str, List[str]]) -> Callable: + if isinstance(tags, str): + tags = [tags] + return lambda run: all([tag in run.tags for tag in tags]) + + +def hasnt_tags(tags: Union[str, List[str]]) -> Callable: + if isinstance(tags, str): + tags = [tags] + return lambda run: all([tag not in run.tags for tag in tags]) + + +def hyperparams_list_api(replace_dot_and_slashes: bool = False, **hyperparams) -> Dict[str, Any]: + filter_dict_for_api = {} + for hyperparam, value in hyperparams.items(): + if replace_dot_and_slashes: + if "/" in hyperparam: + hyperparam = hyperparam.replace("/", ".") + else: + hyperparam = hyperparam.replace(".", "/") + if ( + "config." not in hyperparam + and "summary." not in hyperparam + and "summary_metrics." not in hyperparam + and hyperparam != "tags" + ): + # Automatically add config. prefix if not present + hyperparam = f"config.{hyperparam}" + filter_dict_for_api[hyperparam] = value + return filter_dict_for_api + + +def has_config_values(**hyperparams) -> Callable: + return lambda run: all( + hyperparam in run.config and run.config[hyperparam] == value for hyperparam, value in hyperparams.items() + ) + + +def larger_than(**kwargs) -> Callable: + return lambda run: all( + hasattr(run.config, hyperparam) and value > run.config[hyperparam] for hyperparam, value in kwargs.items() + ) + + +def lower_than(**kwargs) -> Callable: + return lambda run: all( + hasattr(run.config, hyperparam) and value < run.config[hyperparam] for hyperparam, value in kwargs.items() + ) + + +str_to_run_pre_filter = { + "has_finished": has_finished, + "has_final_metric": has_final_metric, +} + + +def get_filter_functions_helper(epoch: int = None, finished: bool = True, config_values=None): + filter_functions = [] + if finished: + filter_functions.append(has_finished) + if epoch is not None: + filter_functions += [contains_in_run_name(f"{epoch}epoch")] + if config_values is not None: + filter_functions.append(has_config_values(**config_values)) + return filter_functions + + +##################################################################### +# +# Post-filtering of wandb runs (usually when you need to compare runs) +# + + +def non_unique_cols_dropper(df: pd.DataFrame) -> pd.DataFrame: + nunique = df.nunique() + cols_to_drop = nunique[nunique == 1].index + df = df.drop(cols_to_drop, axis=1) + return df + + +def groupby( + df: pd.DataFrame, + group_by: Union[str, List[str]] = "seed", + metrics: List[str] = "val/mse_epoch", + keep_columns: List[str] = "model/name", +) -> pd.DataFrame: + """ + Args: + df: pandas DataFrame to be grouped + group_by: str or list of str defining the columns to group by + metrics: list of metrics to compute the group mean and std over + keep_columns: list of columns to keep in the resulting grouped DataFrame + Returns: + A dataframe grouped by `group_by` with columns + `metric`/mean and `metric`/std for each metric passed in `metrics` and all columns in `keep_columns` remain intact. + """ + if isinstance(metrics, str): + metrics = [metrics] + if isinstance(keep_columns, str): + keep_columns = [keep_columns] + if isinstance(group_by, str): + group_by = [group_by] + + grouped_df = df.groupby(group_by, as_index=False, dropna=False) + agg_metrics = {m: ["mean", "std"] for m in metrics} + agg_remain_intact = {c: "first" for c in keep_columns} + # cols = [group_by] + keep_columns + metrics + ['id'] + stats = grouped_df.agg({**agg_metrics, **agg_remain_intact}) + stats.columns = [(f"{c[0]}/{c[1]}" if c[1] in ["mean", "std"] else c[0]) for c in stats.columns] + for m in metrics: + stats[f"{m}/std"].fillna(value=0, inplace=True) + + return stats + + +str_to_run_post_filter = { + "unique_columns": non_unique_cols_dropper, +} + + +def get_wandb_filters_dict_list_from_list(filters_list) -> dict: + if filters_list is None: + filters_list = [] + elif not isinstance(filters_list, list): + filters_list: List[Union[Callable, str]] = [filters_list] + filters_wandb = [] # dict() + for f in filters_list: + if isinstance(f, str): + f = str_to_run_pre_filter[f.lower()] + filters_wandb.append(f) + # filters_wandb = {**filters_wandb, **f} + return filters_wandb + + +def get_run_ids_for_hyperparams(hyperparams: dict, **kwargs) -> List[str]: + runs = wandb_project_run_filtered(hyperparams, **kwargs) + run_ids = [run.id for run in runs] + return run_ids + + +def get_filter_for_wandb( + filter_dict: Dict[str, Any], extra_filters: Dict[str, Any] = None, robust: bool = True +) -> Dict[str, Any]: + if filter_dict is None: + return {} if extra_filters is None else extra_filters + if "$and" not in filter_dict and "$or" not in filter_dict: + filter_wandb_api = hyperparams_list_api(**filter_dict) + if isinstance(extra_filters, dict): + filter_wandb_api = {**filter_wandb_api, **extra_filters} + + if robust: + filter_wandb_api_v2 = hyperparams_list_api(replace_dot_and_slashes=True, **filter_dict) + if isinstance(extra_filters, dict): + filter_wandb_api_v2 = {**filter_wandb_api_v2, **extra_filters} + filter_wandb_api = {"$or": [filter_wandb_api, filter_wandb_api_v2]} + else: + filter_wandb_api = {"$and": [filter_wandb_api]} # MongoDB query lang + else: + filter_wandb_api = filter_dict + return filter_wandb_api + + +def wandb_project_run_filtered( + hyperparam_filter: Dict[str, Any] = None, + extra_filters: Dict[str, Any] = None, + filter_functions: Sequence[Callable] = None, + order="-created_at", + aggregate_into_groups: bool = False, + entity: str = None, + project: str = None, + wandb_api=None, + verbose: bool = True, + robust: bool = True, +) -> Union[List[wandb.apis.public.Run], Dict[str, List[wandb.apis.public.Run]]]: + """ + Args: + hyperparam_filter: a dict str -> value, e.g. {'model/name': 'mlp', 'datamodule/exp_type': 'pristine'} + filter_functions: A set of callable functions that take a wandb run and return a boolean (True/False) so that + any run with one or more return values being False is discarded/filtered out + robust: If True, the hyperparam_filter will be applied in two ways: as is and with dots replaced by slashes + (within an OR query). This is useful when wandb keys are stored in different formats. + + Note: + For more complex/logical filters, see https://www.mongodb.com/docs/manual/reference/operator/query/ + """ + entity = get_entity(entity) + project = project or PROJECT + extra_filters = extra_filters or dict() + filter_functions = filter_functions or [] + if not isinstance(filter_functions, list): + filter_functions = [filter_functions] + filter_functions = [(f if callable(f) else str_to_run_pre_filter[f.lower()]) for f in filter_functions] + + hyperparam_filter = hyperparam_filter or dict() + api = get_api(wandb_api) + + if "group" in hyperparam_filter.keys() and "group" not in extra_filters.keys(): + hyperparam_filter = hyperparam_filter.copy() + extra_filters = {**extra_filters, "group": hyperparam_filter.pop("group")} + filter_wandb_api = get_filter_for_wandb(hyperparam_filter, extra_filters=extra_filters, robust=robust) + + runs_start = api.runs(f"{entity}/{project}", filters=filter_wandb_api, per_page=100, order=order) + + if len(filter_functions) > 0: + runs = [] + for i, run in enumerate(runs_start): + if all(f(run) for f in filter_functions): + runs.append(run) + else: + runs = runs_start + + if verbose: + log.info(f"#Filtered runs = {len(runs)}, (wandb API filtered {len(runs_start)})") + if len(runs) == 0: + log.warning( + f" No runs found for given filters: {filter_wandb_api} in {entity}/{project}" + f"\n #Runs before post-filtering with {filter_functions}: {len(runs_start)}" + ) + else: + log.info(f"Found {len(runs)} runs!") + + if aggregate_into_groups: + groups = defaultdict(list) + for run in runs: + groups[run.group].append(run) + return groups + return runs + + +def get_ordered_runs_with_config_diff( + order: str = None, + metric: str = None, + lower_is_better: bool = True, + top_k: int = 5, + every_k: int = 1, + return_metric_value: bool = True, + exclude_sub_dicts=None, + replace_epoch_by_million_imgs: bool = True, + verbose=True, + **kwargs, +) -> Dict[str, wandb.apis.public.Run]: + """ + Get the top k runs with the largest configuration differences. + """ + if order is None: + assert metric is not None, "One of order or metric must be specified" + order = f"summary_metrics.{metric}" + order = f"+{order}" if lower_is_better else f"-{order}" + # For some reason, the below does not filter out runs with None values properly, so we do it manually below in post-processing + if "extra_filters" not in kwargs: + kwargs["extra_filters"] = dict() + kwargs["extra_filters"][f"summary_metrics.{metric}"] = {"$exists": True} + # kwargs["hyperparam_filter"] = {**kwargs["hyperparam_filter"], f"summary_metrics.{metric}": {"$ne": None}, f"summary_metrics.{metric}": {"$exists": True}} + # kwargs["hyperparam_filter"] = {**kwargs["hyperparam_filter"], f"summary_metrics.{metric}": {"$ne": None}} + # # # kwargs["hyperparam_filter"] = {**kwargs["hyperparam_filter"], f"summary_metrics.{metric}": {"$ne": None}, f"summary_metrics.{metric}": {"$exists": True}} + else: + assert metric is None, "Only one of order or metric must be specified" + + kwargs.pop("extra_filters", None) + log.info(f"order={order}", kwargs) + runs = wandb_project_run_filtered(order=order, verbose=verbose, **kwargs) + if metric is not None: + runs = [run for run in runs if run.summary_metrics.get(metric) is not None] + if len(runs) == 0: + return {} + # if len(runs) < 2: + # return {"runs": runs} + best_runs = runs[: top_k * every_k : every_k] + + exclude_sub_dicts = exclude_sub_dicts or ( + "logger", + "wandb", + "callbacks", + "model_checkpoint", + "model_checkpoint_t8", + "module", + "early_stopping", + "scheduler", + "diffusion_config", + "model_config", + "datamodule_config", + "++datamodule", + "++diffusion", + "++logger", + "scheduler@module", + "dirs", + "slurm_job_id", + "start_time", + "ckpt_path", + "n_gpus", + "world_size", + "pin_memory", + "num_workers", + "regression_overrides", + "regression_run_id", + "eval_batch_size", + "batch_size", + "inference_val_every_n_epochs", + "save_prediction_batches", + "ema_decay", + "downsampling_method", + "enable_inference_dropout", + "regression_inference_dropout", + ) + exclude_nested = { + "datamodule": ["eval_batch_size", "batch_size", "lookback_window", "pin_memory", "num_workers"], + "exp": ["inference_val_every_n_epochs"], + "trainer": ["num_sanity_val_steps", "gpus"], + "optim": ["effective_batch_size"], + } + # Define the keys for which we tolerate a certain percentage difference instead of an exact match + keys_to_tolerated_percent_diff = { + "model/params_not_trainable": 0.07, # 7% difference + "model/params_trainable": 0.07, # 7% difference + "model/params_total": 0.07, # 7% difference + } + if replace_epoch_by_million_imgs: + exclude_sub_dicts = list(exclude_sub_dicts) + ["epoch", "global_step"] + + exclude_sub_dicts = set(exclude_sub_dicts) + # Get all unique config differences + best_configs = [] + for run in best_runs: + config = run.config + config = {k: v for k, v in config.items() if k not in exclude_sub_dicts} + if exclude_sub_dicts: + try: + config["Mimg."] = run.config.get("global_step", 0) * run.config.get("effective_batch_size", 0) / 1e6 + except TypeError: + log.warning(f"Could not compute Mimg for run {run.id} (name={run.name})") + for key, sub_keys in exclude_nested.items(): + if key in config: + for sub_key in sub_keys: + if sub_key in config[key]: + del config[key][sub_key] + + best_configs.append(config) + + diffs = find_config_differences_return_as_joined_str( + best_configs, sort_by_name=True, keys_to_tolerated_percent_diff=keys_to_tolerated_percent_diff + ) + diff_to_run = dict() + for run, diff in zip(best_runs, diffs): + v1 = run.summary.get(metric) + if v1 is None: + if verbose: + log.warning(f"Skipping run_id={run.id} because metric {metric} is not available") + continue + if diff in diff_to_run: + other_run = diff_to_run[diff][0] if return_metric_value else diff_to_run[diff] + if metric is not None: + v2 = other_run.summary.get(metric) + # close if first 3 digits are the same + max_diff = 1 if "rel" in metric else 1e-3 # if relative in %, then allow up to 1% + assert ( + abs(v1 - v2) < max_diff + ), f"Metric values are not the same even though the diff is the same: r1.id={run.id}, r2.id={other_run.id}; v1={v1:.5f}, v2={v2:.5f}" + if verbose: + log.warning(f"Duplicate diff found: {diff}, skipping run_id={run.id}. Keeping run_id={other_run.id}") + continue + if return_metric_value: + diff_to_run[diff] = (run, v1) + else: + diff_to_run[diff] = run + return diff_to_run + + +def runs_to_df(runs, metrics, skip_hps=None, baseline_run=None, aggregate_by="crps"): + skip_hps = skip_hps or {} + # List to store our data + data = [] + + for run in tqdm(runs, desc="Processing runs"): + diffusion_cfg = run.config.get("diffusion") + heun = run.config.get("diffusion.heun") or diffusion_cfg.get("heun") + step = run.config.get("diffusion.step") or diffusion_cfg.get("step") + s_churn = run.config.get("diffusion.S_churn") or diffusion_cfg.get("S_churn") + if heun is None or step is None or s_churn is None: + log.info( + f"Skipping run {run.id} due to missing hyperparameters. heun={heun}, step={step}, s_churn={s_churn}. config={run.config}" + ) + continue + + hps = {"heun": heun, "step": step, "churn": s_churn} + hps.update(diffusion_cfg) + hps.update(run.config.get("exp", {})) + hps["with_time_emb"] = run.config.get("model", {}).get("with_time_emb") + hps["horizon"] = run.config.get("datamodule", {}).get("horizon") + hps["lookback_window"] = run.config.get("datamodule", {}).get("lookback_window") + hps["ema_decay"] = run.config.get("exp", {}).get("ema_decay") + hps["max_epochs"] = run.config.get("trainer", {}).get("max_epochs") + hps["run_id"] = run.id + run_metrics = {k: run.summary[k] for k in metrics} + data.append({**hps, **run_metrics}) + + # Create DataFrame + df = pd.DataFrame(data) + for key, values in skip_hps.items(): + # drop rows with specific values + df = df[~df[key].isin(values)] + for metric in metrics: + df[metric] = pd.to_numeric(df[metric], errors="coerce") + df[metric] = df[metric].replace([np.inf, -np.inf], np.nan) + if baseline_run is not None: + # Generate aggregated relative metrics + aggregate_rel_metric = f"aggregated_relative_{aggregate_by}" + df[aggregate_rel_metric] = 0.0 + for metric in metrics: + if aggregate_by in metric: + base_value = baseline_run.summary.get(metric) + assert base_value is not None and isinstance( + base_value, float + ), f"Could not find baseline value for metric {metric}" + # to avoid TypeError: unsupported operand type(s) for -: 'str' and 'float' + # we need to convert the metric to float (coerce will convert non-numeric values to NaN) + df[metric] = pd.to_numeric(df[metric], errors="coerce") + df[aggregate_rel_metric] += 100 * (df[metric] - base_value) / base_value + df[aggregate_rel_metric] /= len(metrics) + return df + + +def get_runs_df( + get_metrics: bool = True, + hyperparam_filter: dict = None, + run_pre_filters: Union[str, List[Union[Callable, str]]] = "has_finished", + run_post_filters: Union[str, List[str]] = None, + verbose: int = 1, + make_hashable_df: bool = False, + **kwargs, +) -> pd.DataFrame: + """ + + get_metrics: + run_pre_filters: + run_post_filters: + verbose: 0, 1, or 2, where 0 = no output at all, 1 is a bit verbose + """ + if run_post_filters is None: + run_post_filters = [] + elif not isinstance(run_post_filters, list): + run_post_filters: List[Union[Callable, str]] = [run_post_filters] + run_post_filters = [(f if callable(f) else str_to_run_post_filter[f.lower()]) for f in run_post_filters] + + # Project is specified by + runs = wandb_project_run_filtered(hyperparam_filter, run_pre_filters, **kwargs) + summary_list = [] + config_list = [] + group_list = [] + name_list = [] + tag_list = [] + id_list = [] + for i, run in enumerate(runs): + if i % 50 == 0: + log.info(f"Going after run {i}") + # if i > 100: break + # run.summary are the output key/values like accuracy. + # We call ._json_dict to omit large files + if "model/_target_" not in run.config.keys(): + if verbose >= 1: + print(f"Run {run.name} filtered out, as model/_target_ not in run.config.") + continue + + id_list.append(str(run.id)) + tag_list.append(str(run.tags)) + if get_metrics: + summary_list.append(run.summary._json_dict) + # run.config is the hyperparameters + config_list.append({k: v for k, v in run.config.items() if k not in run.summary.keys()}) + else: + config_list.append(run.config) + + # run.name is the name of the run. + name_list.append(run.name) + group_list.append(run.group) + + summary_df = pd.DataFrame.from_records(summary_list) + config_df = pd.DataFrame.from_records(config_list) + name_df = pd.DataFrame({"name": name_list, "id": id_list, "tags": tag_list}) + group_df = pd.DataFrame({"group": group_list}) + all_df = pd.concat([name_df, config_df, summary_df, group_df], axis=1) + + cols = [c for c in all_df.columns if not c.startswith("gradients/") and c != "graph_0"] + all_df = all_df[cols] + if all_df.empty: + raise ValueError("Empty DF!") + for post_filter in run_post_filters: + all_df = post_filter(all_df) + all_df = clean_hparams(all_df) + if make_hashable_df: + all_df = all_df.applymap(lambda x: tuple(x) if isinstance(x, list) else x) + + return all_df + + +def clean_hparams(df: pd.DataFrame): + # Replace string representation of nan with real nan + df.replace("NaN", np.nan, inplace=True) + # df = df.where(pd.notnull(df), None).fillna(value=np.nan) + + # Combine/unify columns of optim/scheduler which might be present in stored params more than once + combine_cols = [col for col in df.columns if col.startswith("model/optim") or col.startswith("model/scheduler")] + for col in combine_cols: + new_col = col.replace("model/", "").replace("optimizer", "optim") + if not hasattr(df, new_col): + continue + getattr(df, new_col).fillna(getattr(df, col), inplace=True) + del df[col] + + return df + + +def get_datetime_of_run(run: wandb.apis.public.Run, to_local_timezone: bool = True) -> datetime: + """Get datetime of a run""" + dt_str = run.createdAt # a str like '2023-03-09T08:20:25' + dt_utc = datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S").replace(tzinfo=timezone.utc) + if to_local_timezone: + return dt_utc.astimezone(tz=None) + else: + return dt_utc + return datetime.fromtimestamp(run.summary["_timestamp"]) + + +def get_unique_groups_for_run_ids(run_ids: Sequence[str], wandb_api: wandb.Api = None, **kwargs) -> List[str]: + """Get unique groups for a list of run ids""" + api = get_api(wandb_api) + groups = [] + for run_id in run_ids: + run = get_run_api(run_id, wandb_api=api, **kwargs) + groups.append(run.group) + return list(set(groups)) + + +def get_unique_groups_for_hyperparam_filter( + hyperparam_filter: dict, + filter_functions: str | List[Union[Callable, str]] = None, + **kwargs, # 'has_finished' +) -> List[str]: + """Get unique groups for a hyperparam filter + + Args: + hyperparam_filter: dict of hyperparam filters. + filter_functions: list of filter functions to apply to runs before getting groups. + + Examples: + Use hyperparam_filter={'datamodule/horizon': 1, 'model/dim': 128} to get all runs with horizon=1 and dim=128 + or {'datamodule/horizon': 1, 'diffusion/timesteps': {'$gte': 10}} for horizon=1 and timesteps >= 10 + """ + runs = wandb_project_run_filtered(hyperparam_filter, filter_functions=filter_functions, **kwargs) + groups = [run.group for run in runs] + return list(set(groups)) + + +def add_summary_metrics( + run_id: str, + metric_keys: Union[str, List[str]], + metric_values: Union[float, List[float]], + wandb_api: wandb.apis.public.Api = None, + override: bool = False, + **kwargs, +): + """ + Add a metric to the summary of a run. + """ + wandb_api = get_api(wandb_api) + run = get_run_api(run_id, wandb_api=wandb_api, **kwargs) + metric_keys = [metric_keys] if isinstance(metric_keys, str) else metric_keys + metric_values = [metric_values] if isinstance(metric_values, float) else metric_values + assert len(metric_keys) == len( + metric_values + ), f"metric_keys and metric_values must have same length, but got {len(metric_keys)} and {len(metric_values)}" + + for key, value in zip(metric_keys, metric_values): + if key in run.summary.keys() and not override: + print(f"Metric {key} already present in run {run_id}, skipping.") + return + run.summary[key] = value + run.summary.update() + + +def metrics_of_runs_to_arrays( + runs: Sequence[wandb.apis.public.Run], + metrics: Sequence[str], + columns: Sequence[Any], + column_to_wandb_key: Callable[[Any], str] | Callable[[Any], List[str]], + dropna_rows: bool = True, +) -> Dict[str, np.ndarray]: + """Convert metrics of runs to arrays + + Args: + runs (list): list of wandb runs (will be the rows of the arrays) + metrics (list): list of metrics (one array will be created for each metric) + columns (list): list of columns (will be the columns of the arrays) + column_to_wandb_key (Callable): function to convert a given column to a wandb key (without metric suffix) + If it returns a list of keys, the first one will be used to get the metric (if present). + """ + + def column_to_wandb_key_with_metric(wandb_key_stem, metric: str): + if metric not in wandb_key_stem: + wandb_key_stem = f"{wandb_key_stem}/{metric}" + return wandb_key_stem.replace("//", "/") + + def get_summary_metric(run: wandb.apis.public.Run, metric: str, column: Any): + wandb_keys = column_to_wandb_key(column) + wandb_keys = [wandb_keys] if isinstance(wandb_keys, str) else wandb_keys + for wandb_key_stem in wandb_keys: + wandb_key = column_to_wandb_key_with_metric(wandb_key_stem, metric) + if wandb_key in run.summary.keys(): + return run.summary[wandb_key] + return np.nan + + nrows, ncols = len(runs), len(columns) + arrays = {m: np.zeros((nrows, ncols)) for m in metrics} + for r_i, run in enumerate(runs): + if ( + run.project != "DYffusion" + and np.isnan(get_summary_metric(run, metrics[0], columns[0])) + and "None" not in column_to_wandb_key(None) + ): + full_metric_names = [column_to_wandb_key_with_metric(column_to_wandb_key(None), m) for m in metrics] + run_metrics = get_summary_metrics_from_history(run, full_metric_names, robust=False) + for m, fm in zip(metrics, full_metric_names): + assert len(run_metrics[fm]) >= ncols, f"Expected {ncols} columns, got {len(run_metrics[fm])}" + if len(run_metrics[fm]) > ncols: + run_metrics[fm] = run_metrics[fm][ncols] + else: + arrays[m][r_i, :] = run_metrics[fm] + else: + for m in metrics: + arrays[m][r_i, :] = [get_summary_metric(run, m, c) for c in columns] + if dropna_rows: + for m in metrics: + arrays[m] = arrays[m][~np.isnan(arrays[m]).any(axis=1)] + return arrays + + +def get_summary_metrics_from_history(run, metrics: Sequence[str], robust: bool = False): + """Get summary metrics from history""" + history = run.history(keys=metrics, pandas=True) if not robust else run.scan_history(keys=metrics) + # history has one column per metric, one row per step, we want to return one numpy array per metric + if robust: + return {m: history[m].to_numpy() for m in metrics} + else: + return {m: history[m].to_numpy() for m in metrics} + + +def add_time_average_new( + run, relative_run=None, times=range(1, 61), target_metric_prefix="test-wx/25ens_mems", metric="crps", cache=True +): + """Add time average of a metric to a run.""" + times = list(times) + + time_avg_metric_name = f"time_avg_{times[0]}_{times[-1]}" + if relative_run is not None: + time_avg_metric_name = f"{time_avg_metric_name}_rel_{relative_run.id}" + + time_avg_metric_name = f"{target_metric_prefix}/{time_avg_metric_name}/{metric}" + if time_avg_metric_name in run.summary and run.summary[time_avg_metric_name] is not None: + return False + if relative_run is not None and f"{time_avg_metric_name}_rel_{relative_run.id}" in run.summary: + # Wrongly keyed, fix it + run.summary[time_avg_metric_name] = run.summary[f"{time_avg_metric_name}_rel_{relative_run.id}"] + del run.summary[f"{time_avg_metric_name}_rel_{relative_run.id}"] + return True + + metrics = np.array([run.summary.get(f"{target_metric_prefix}/t{t}/{metric}") for t in times], dtype=np.float32) + assert all([m is not None for m in metrics]), metrics + if any(np.isnan(metrics)): + print(f"NaN found in {run.name} (id={run.id})") + run.summary[time_avg_metric_name] = np.nan + return True + + if relative_run is not None: + cached_key = f"{relative_run.id}_{times[0]}_{times[-1]}" + if cache and cached_key not in CACHE: + base_metrics = np.array([relative_run.summary.get(f"{target_metric_prefix}/t{t}/{metric}") for t in times]) + assert all([m is not None for m in base_metrics]), base_metrics + CACHE[cached_key] = base_metrics + + base_metrics = CACHE[cached_key] + # Compute relative metric (in %) + metrics = (metrics - base_metrics) / base_metrics * 100 + + try: + time_avg = np.mean(metrics) + except Exception as e: + print(f"Error: {e}. Run: {run.name}. ID: {run.id}") + print(f"Metrics: {metrics}") + return False + run.summary[time_avg_metric_name] = time_avg + return time_avg + + +def wandb_run_summary_update(wandb_run: wandb.apis.public.Run): + try: + wandb_run.summary.update() + except wandb.errors.CommError: + logging.warning("Could not update wandb summary") + # except requests.exceptions.HTTPError or requests.exceptions.ConnectionError: + except (requests.exceptions.HTTPError, requests.exceptions.ConnectionError): + # try again + wandb_run.summary.update() + except TypeError: + pass # wandb_run.update() diff --git a/src/utilities/wandb_callbacks.py b/src/utilities/wandb_callbacks.py new file mode 100644 index 0000000..e9c8926 --- /dev/null +++ b/src/utilities/wandb_callbacks.py @@ -0,0 +1,355 @@ +from __future__ import annotations + +import os +import shutil +import time +import traceback +from typing import Sequence + +import pytorch_lightning as pl +import wandb +from pytorch_lightning import Callback, Trainer +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.loggers import WandbLogger +from pytorch_lightning.utilities import rank_zero_only + +from src.utilities.utils import get_logger + + +log = get_logger(__name__) + + +class WatchModel(Callback): + """ + Make wandb watch model at the beginning of the run. + This will log the gradients of the model (as a histogram for each or some weights updates). + """ + + def __init__(self, log: str = "gradients", log_freq: int = 100): + self.log_type = log + self.log_freq = log_freq + self.has_logged = False + + @rank_zero_only + def on_train_start(self, trainer, pl_module): + logger: WandbLogger = get_wandb_logger(trainer=trainer) + if not self.has_logged: + try: + logger.watch(model=pl_module, log=self.log_type, log_freq=self.log_freq, log_graph=True) + except TypeError: + log.info( + f" Pytorch-lightning/Wandb version seems to be too old to support 'log_graph' arg in wandb.watch(.)" + f" Wandb version={wandb.__version__}" + ) + wandb.watch(models=pl_module, log=self.log_type, log_freq=self.log_freq) # , log_graph=True) + self.has_logged = True + + @rank_zero_only + def on_any_non_train_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: + if not self.has_logged: + log.info("WatchModel callback has not been called yet. Calling it now & setting log_freq=0") + self.log_freq = 0 + # self.log_type = "parameters" + self.on_train_start(trainer, pl_module) + # log.info(f"wandb.run.hook_handles: {wandb.run._torch_history._hook_handles.keys()}") + param_hook_handle = wandb.run._torch_history._hook_handles.get("parameters/") # a RemovableHandle + # log.info(param_hook_handle.hooks_dict_ref()) + param_hook = param_hook_handle.hooks_dict_ref()[param_hook_handle.id] + # param_hook(pl_module, None, None) + param_hook(pl_module, None, None) + self.has_logged = True + + @rank_zero_only + def on_validation_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: + return self.on_any_non_train_start(trainer, pl_module) + + @rank_zero_only + def on_test_start(self, trainer: Trainer, pl_module: pl.LightningModule) -> None: + return self.on_any_non_train_start(trainer, pl_module) + + +class MyWandbLogger(pl.loggers.WandbLogger): + """Same as pl.WandbLogger, but also saves the last checkpoint as 'last.ckpt' and uploads it to wandb.""" + + def __init__( + self, + save_last_ckpt: bool = True, + save_best_ckpt: bool = True, + save_to_wandb: bool = True, + save_to_s3_bucket: bool = False, + s3_endpoint_url: str = None, + s3_bucket_name: str = None, + log_code: bool = True, + run_path: str = None, + train_run_path: str = None, + training_id: str = None, + resume_run_id: str = None, + **kwargs, + ): + """If using S3, set save_to_s3_bucket=True and provide your AWS_ACCESS_KEY_ID and AWS_SECRET_ACCESS_KEY as + environment variables.""" + try: + super().__init__(**kwargs) + except Exception as e: + log.info("-" * 100) + log.warning(f"Failed to initialize WandbLogger. Error: {e}") + service_wait = os.getenv("WANDB__SERVICE_WAIT", 300) + new_service_wait = int(service_wait) + 600 + os.environ["WANDB__SERVICE_WAIT"] = str(new_service_wait) + wandb_version = wandb.__version__ + log.info(f"Increasing WANDB__SERVICE_WAIT to {new_service_wait} seconds (wandb version={wandb_version}).") + log.info("-" * 100) + # Sleep for 30sec and try again + time.sleep(30) + super().__init__(**kwargs) + _ = self.experiment # Force initialize wandb run (same as wandb.init) + if hasattr(self.experiment.project, "lower") and self.experiment.project.lower() == "debug": + save_best_ckpt = False + save_last_ckpt = False + save_to_s3_bucket = False + save_to_wandb = False + log.info("Wandb: Running in debug mode. Disabling saving checkpoints.") + if log_code: + + def exclude_codes(path, root): + if path.endswith("mcvd") or path.endswith("sfno") or path.endswith("schedulers"): + return True # exclude these directories + exclude_subdirs = [ + "plotting", + ] + if any([subdir in path for subdir in exclude_subdirs]): + return True + return False + + code_dir = os.path.join(os.getcwd(), "src") # "../../src" + self.experiment.log_code(code_dir, exclude_fn=exclude_codes) # # saves python files in src/ to wandb + + self.save_last_ckpt = save_last_ckpt + self.save_best_ckpt = save_best_ckpt + self._hash_of_best_ckpts = dict() + if save_best_ckpt or save_last_ckpt: + assert save_to_wandb or save_to_s3_bucket, "You must save to either wandb or S3 bucket." + self.save_to_wandb = save_to_wandb + self.save_to_s3_bucket = save_to_s3_bucket + if save_to_s3_bucket: + if s3_endpoint_url is not None: + if os.getenv("S3_ENDPOINT_URL") is not None: + assert ( + os.getenv("S3_ENDPOINT_URL") == s3_endpoint_url + ), "S3_ENDPOINT_URL environment variable mismatch." + os.environ["S3_ENDPOINT_URL"] = s3_endpoint_url + if s3_bucket_name is not None: + if os.getenv("S3_BUCKET_NAME") is not None: + assert ( + os.getenv("S3_BUCKET_NAME") == s3_bucket_name + ), "S3_BUCKET_NAME environment variable mismatch." + + os.environ["S3_BUCKET_NAME"] = s3_bucket_name + + from src.utilities.s3utils import download_s3_object, upload_s3_object + + # Save S3 checkpoints to //checkpoints/ + self.s3_checkpoint_dir = f"{self._project}/checkpoints/{self._id}/" + self.download_s3_object = download_s3_object + self.upload_s3_object = upload_s3_object + self.summary_update( + { + "checkpoint/s3_dir": self.s3_checkpoint_dir, + "checkpoint/s3_endpoint_url": s3_endpoint_url, + "checkpoint/s3_bucket_name": s3_bucket_name, + } + ) + self.upload_wandb_files_to_s3() + + @rank_zero_only + def summary_update(self, summary_dict: dict): + self.experiment.summary.update(summary_dict) + + @rank_zero_only + def upload_wandb_files_to_s3(self): + if not self.save_to_s3_bucket: + return + import boto3 + + # Upload all files in wandb.run.dir to S3 bucket (e.g. hydra config files) + dir_to_upload = wandb.run.dir + if os.path.exists(dir_to_upload): + for file in os.listdir(dir_to_upload): + s3_file_path = os.path.join(f"{self._project}/configs/{self._id}", file) + try: + self.upload_s3_object(os.path.join(dir_to_upload, file), s3_file_path, retry=3) + except boto3.exceptions.S3UploadFailedError as e: + if "hydra_config" not in file: + log.error(f"Failed to upload {file} to S3 bucket. Skipping.") + else: + raise e + # log.info(f"Uploaded {file} to S3 bucket as {s3_file_path}.") + else: + log.warning(f"Directory {dir_to_upload} does not exist. Skipping uploading to S3 bucket.") + + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + super().after_save_checkpoint(checkpoint_callback) + self.save_last(checkpoint_callback) + self.save_best(checkpoint_callback) + + @rank_zero_only + def save_last(self, ckpt_cbk): + if not self.save_last_ckpt: + return + if isinstance(ckpt_cbk, Sequence): + ckpt_cbk = [c for c in ckpt_cbk if c.last_model_path] + if len(ckpt_cbk) == 0: + raise Exception("No checkpoint callback has a last_model_path attribute. Ckpt callback is: {ckpt_cbk}") + ckpt_cbk = ckpt_cbk[0] + + last_ckpt = ckpt_cbk.last_model_path + if self.save_last and last_ckpt: + hash_last_ckpt = hash(open(last_ckpt, "rb").read()) + if hash_last_ckpt == self._hash_of_best_ckpts.get("LAST_CKPT", None): + return + self._hash_of_best_ckpts["LAST_CKPT"] = hash_last_ckpt + if self.save_to_wandb: + self.experiment.save(last_ckpt) + if self.save_to_s3_bucket: + # Upload to S3 bucket + self.upload_s3_object(local_file_path=last_ckpt, s3_file_path=self.s3_checkpoint_dir) + try: + self.experiment.summary.update({"checkpoint/in_s3": True}) + except Exception as e: + log.error(f"Failed to update wandb summary. Error: {e}") + + self.experiment.summary.update({"checkpoint/last_filepath": last_ckpt}) + + @rank_zero_only + def save_best(self, ckpt_cbk): + if not self.save_best_ckpt: + return + # Save best model + if not isinstance(ckpt_cbk, Sequence): + ckpt_cbk = [ckpt_cbk] + + for ckpt_cbk in ckpt_cbk: + best_ckpt = ckpt_cbk.best_model_path + if not best_ckpt or not os.path.isfile(best_ckpt): + continue + # Check if the best ckpt content has changed since last time it was uploaded + hash_best_ckpt = hash(open(best_ckpt, "rb").read()) + unique_key_for_callback = f"{ckpt_cbk.monitor}" + if hash_best_ckpt == self._hash_of_best_ckpts.get(unique_key_for_callback, None): + continue + self._hash_of_best_ckpts[unique_key_for_callback] = hash_best_ckpt + # copy best ckpt to a file called 'best.ckpt' and upload it to wandb + monitor = ckpt_cbk.monitor.replace("/", "_") if ckpt_cbk.monitor is not None else "MONITOR_NOT_SET" + fname_cloud = f"best-{monitor}.ckpt" + shutil.copyfile(best_ckpt, fname_cloud) + if self.save_to_wandb: + self.experiment.save(fname_cloud) + # log.info(f"Wandb: Saved best ckpt '{best_ckpt}' as '{fname_wandb}'.") + # log.info(f"Saved best ckpt to the wandb cloud as '{fname_wandb}'.") + if self.save_to_s3_bucket: + # Upload to S3 bucket + self.upload_s3_object(local_file_path=fname_cloud, s3_file_path=self.s3_checkpoint_dir) + try: + self.experiment.summary.update({"checkpoint/in_s3": True}) + except Exception as e: + log.error(f"Failed to update wandb summary. Error: {e}") + + self.experiment.summary.update({f"checkpoint/best_filepath_{monitor}": best_ckpt}) + + def restore_checkpoint( + self, + ckpt_filename: str, + local_file_path: str, + run_path: str = None, + root: str = None, + restore_from: str = None, + ): + """Restore a checkpoint from cloud to local file path. + + Args: + ckpt_filename: The name of the checkpoint file. + local_file_path: The local file path to save the checkpoint to. + run_path: The path to the wandb run where the checkpoint is stored (or in corresponding S3 bucket). + root: The root directory to save the checkpoint, if using wandb restore. + restore_from: The source to restore the checkpoint from. Can be 's3', 'wandb' or 'any'. + + Note: + If save_to_s3_bucket is True, the checkpoint will be downloaded from the S3 bucket. + Otherwise, if save_to_wandb is True, the checkpoint will be downloaded from wandb. + + """ + if run_path is None: + run_path = f"{self.experiment.entity}/{self._project}/{self._id}" + entity, project, run_id = run_path.split("/") + + if restore_from is None: + if self.save_to_s3_bucket: + restore_from = "s3" + elif self.save_to_wandb: + restore_from = "wandb" + else: + raise RuntimeError( + "Cannot restore checkpoint since neither save_to_wandb nor save_to_s3_bucket is True." + "Alternatively, set ``restore_from`` to 's3' or 'wandb' or 'any' to restore from either source." + ) + assert restore_from in ["s3", "wandb", "any"], f"Invalid value for restore_from: {restore_from}" + + ckpt_filename = os.path.basename(ckpt_filename) # remove any path prefix (e.g. local dir) + if restore_from in ["s3", "any"]: + s3_file_path = f"{project}/checkpoints/{run_id}/{ckpt_filename}" + retries = 3 + for i in range(retries): + try: + self.download_s3_object(s3_file_path, local_file_path) + return local_file_path + except Exception as e1: + if i == retries - 1: + log.error( + f"Attempt {i}: Failed to download checkpoint from S3 bucket path ``{s3_file_path}``." + f"Error: {e1}.\n{traceback.format_exc()}" + ) + e1_str = traceback.format_exc() + else: + log.warning( + f"Attempt {i}: Failed to download checkpoint from S3 bucket path ``{s3_file_path}``. Retrying..." + ) + + if restore_from in ["wandb", "any"]: + root = root or os.getcwd() + # Download model checkpoint from wandb + try: + ckpt_path_tmp = wandb.restore(ckpt_filename, run_path=run_path, replace=True, root=root).name + os.rename(ckpt_path_tmp, local_file_path) + return local_file_path + except Exception as e2: + log.error(f"Failed to download checkpoint from wandb run {run_path}. Error: {e2}") + e2_str = traceback.format_exc() + + s3_path_str = f"S3 bucket path: {s3_file_path}" + wb_path_str = f"wandb run path: {run_path}" + if restore_from == "any": + suffix = f"S3 bucket or wandb. {s3_path_str} and {wb_path_str}.\nError S3: {e1_str}\n:Error WB: {e2_str}" + elif restore_from == "s3": + suffix = f"S3 bucket. {s3_path_str}. Error: {e1_str}" + elif restore_from == "wandb": + suffix = f"Wandb run. {wb_path_str}. Error: {e2_str}" + raise RuntimeError(f"Failed to restore checkpoint from {suffix}.") + + +def get_wandb_logger(trainer: Trainer) -> WandbLogger | MyWandbLogger: + """Safely get Weights&Biases logger from Trainer.""" + + if trainer.fast_dev_run: + raise Exception( + "Cannot use wandb callbacks since pytorch lightning disables loggers in `fast_dev_run=true` mode." + ) + + if isinstance(trainer.logger, (WandbLogger, MyWandbLogger)): + return trainer.logger + + if isinstance(trainer.loggers, list): + for logger in trainer.loggers: + if isinstance(logger, (WandbLogger, MyWandbLogger)): + return logger + + raise Exception("You are using wandb related callback, but WandbLogger was not found for some reason...") diff --git a/utils/check_copies.py b/utils/check_copies.py new file mode 100644 index 0000000..294cc27 --- /dev/null +++ b/utils/check_copies.py @@ -0,0 +1,203 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import glob +import os +import re + +import black +from doc_builder.style_doc import style_docstrings_in_code + + +# All paths are set with the intent you should run this script from the root of the repo with the command +# python utils/check_copies.py +SRC_PATH = "src" +REPO_PATH = "." + + +def _should_continue(line, indent): + return line.startswith(indent) or len(line) <= 1 or re.search(r"^\s*\)(\s*->.*:|:)\s*$", line) is not None + + +def find_code_in_dyffusion(object_name): + """Find and return the code source code of `object_name`.""" + parts = object_name.split(".") + i = 0 + + # First let's find the module where our object lives. + module = parts[i] + while i < len(parts) and not os.path.isfile(os.path.join(SRC_PATH, f"{module}.py")): + i += 1 + if i < len(parts): + module = os.path.join(module, parts[i]) + if i >= len(parts): + raise ValueError(f"`object_name` should begin with the name of a module of dyffusion but got {object_name}.") + + with open(os.path.join(SRC_PATH, f"{module}.py"), "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + + # Now let's find the class / func in the code! + indent = "" + line_index = 0 + for name in parts[i + 1 :]: + while ( + line_index < len(lines) and re.search(rf"^{indent}(class|def)\s+{name}(\(|\:)", lines[line_index]) is None + ): + line_index += 1 + indent += " " + line_index += 1 + + if line_index >= len(lines): + raise ValueError(f" {object_name} does not match any function or class in {module}.") + + # We found the beginning of the class / func, now let's find the end (when the indent diminishes). + start_index = line_index + while line_index < len(lines) and _should_continue(lines[line_index], indent): + line_index += 1 + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + code_lines = lines[start_index:line_index] + return "".join(code_lines) + + +_re_copy_warning = re.compile(r"^(\s*)#\s*Copied from\s+dyffusion\.(\S+\.\S+)\s*($|\S.*$)") +_re_replace_pattern = re.compile(r"^\s*(\S+)->(\S+)(\s+.*|$)") +_re_fill_pattern = re.compile(r"]*>") + + +def get_indent(code): + lines = code.split("\n") + idx = 0 + while idx < len(lines) and len(lines[idx]) == 0: + idx += 1 + if idx < len(lines): + return re.search(r"^(\s*)\S", lines[idx]).groups()[0] + return "" + + +def blackify(code): + """ + Applies the black part of our `make style` command to `code`. + """ + has_indent = len(get_indent(code)) > 0 + if has_indent: + code = f"class Bla:\n{code}" + mode = black.Mode(target_versions={black.TargetVersion.PY37}, line_length=119, preview=True) + result = black.format_str(code, mode=mode) + result, _ = style_docstrings_in_code(result) + return result[len("class Bla:\n") :] if has_indent else result + + +def is_copy_consistent(filename, overwrite=False): + """ + Check if the code commented as a copy in `filename` matches the original. + Return the differences or overwrites the content depending on `overwrite`. + """ + with open(filename, "r", encoding="utf-8", newline="\n") as f: + lines = f.readlines() + diffs = [] + line_index = 0 + # Not a for loop cause `lines` is going to change (if `overwrite=True`). + while line_index < len(lines): + search = _re_copy_warning.search(lines[line_index]) + if search is None: + line_index += 1 + continue + + # There is some copied code here, let's retrieve the original. + indent, object_name, replace_pattern = search.groups() + theoretical_code = find_code_in_dyffusion(object_name) + theoretical_indent = get_indent(theoretical_code) + + start_index = line_index + 1 if indent == theoretical_indent else line_index + 2 + indent = theoretical_indent + line_index = start_index + + # Loop to check the observed code, stop when indentation diminishes or if we see a End copy comment. + should_continue = True + while line_index < len(lines) and should_continue: + line_index += 1 + if line_index >= len(lines): + break + line = lines[line_index] + should_continue = _should_continue(line, indent) and re.search(f"^{indent}# End copy", line) is None + # Clean up empty lines at the end (if any). + while len(lines[line_index - 1]) <= 1: + line_index -= 1 + + observed_code_lines = lines[start_index:line_index] + observed_code = "".join(observed_code_lines) + + # Remove any nested `Copied from` comments to avoid circular copies + theoretical_code = [line for line in theoretical_code.split("\n") if _re_copy_warning.search(line) is None] + theoretical_code = "\n".join(theoretical_code) + + # Before comparing, use the `replace_pattern` on the original code. + if len(replace_pattern) > 0: + patterns = replace_pattern.replace("with", "").split(",") + patterns = [_re_replace_pattern.search(p) for p in patterns] + for pattern in patterns: + if pattern is None: + continue + obj1, obj2, option = pattern.groups() + theoretical_code = re.sub(obj1, obj2, theoretical_code) + if option.strip() == "all-casing": + theoretical_code = re.sub(obj1.lower(), obj2.lower(), theoretical_code) + theoretical_code = re.sub(obj1.upper(), obj2.upper(), theoretical_code) + + # Blackify after replacement. To be able to do that, we need the header (class or function definition) + # from the previous line + theoretical_code = blackify(lines[start_index - 1] + theoretical_code) + theoretical_code = theoretical_code[len(lines[start_index - 1]) :] + + # Test for a diff and act accordingly. + if observed_code != theoretical_code: + diffs.append([object_name, start_index]) + if overwrite: + lines = lines[:start_index] + [theoretical_code] + lines[line_index:] + line_index = start_index + 1 + + if overwrite and len(diffs) > 0: + # Warn the user a file has been modified. + print(f"Detected changes, rewriting {filename}.") + with open(filename, "w", encoding="utf-8", newline="\n") as f: + f.writelines(lines) + return diffs + + +def check_copies(overwrite: bool = False): + all_files = glob.glob(os.path.join(SRC_PATH, "**/*.py"), recursive=True) + diffs = [] + for filename in all_files: + new_diffs = is_copy_consistent(filename, overwrite) + diffs += [f"- {filename}: copy does not match {d[0]} at line {d[1]}" for d in new_diffs] + if not overwrite and len(diffs) > 0: + diff = "\n".join(diffs) + raise Exception( + "Found the following copy inconsistencies:\n" + + diff + + "\nRun `make fix-copies` or `python utils/check_copies.py --fix_and_overwrite` to fix them." + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--fix_and_overwrite", action="store_true", help="Whether to fix inconsistencies.") + args = parser.parse_args() + + check_copies(args.fix_and_overwrite) diff --git a/utils/get_modified_files.py b/utils/get_modified_files.py new file mode 100644 index 0000000..650c61c --- /dev/null +++ b/utils/get_modified_files.py @@ -0,0 +1,34 @@ +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# this script reports modified .py files under the desired list of top-level sub-dirs passed as a list of arguments, e.g.: +# python ./utils/get_modified_files.py utils src tests examples +# +# it uses git to find the forking point and which files were modified - i.e. files not under git won't be considered +# since the output of this script is fed into Makefile commands it doesn't print a newline after the results + +import re +import subprocess +import sys + + +fork_point_sha = subprocess.check_output("git merge-base main HEAD".split()).decode("utf-8") +modified_files = subprocess.check_output(f"git diff --name-only {fork_point_sha}".split()).decode("utf-8").split() + +joined_dirs = "|".join(sys.argv[1:]) +regex = re.compile(rf"^({joined_dirs}).*?\.py$") + +relevant_modified_files = [x for x in modified_files if regex.match(x)] +print(" ".join(relevant_modified_files), end="") diff --git a/utils/release.py b/utils/release.py new file mode 100644 index 0000000..71ef40b --- /dev/null +++ b/utils/release.py @@ -0,0 +1,134 @@ +# coding=utf-8 +# Copyright 2021 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import re + +import packaging.version + + +PATH_TO_EXAMPLES = "examples/" +REPLACE_PATTERNS = { + "examples": (re.compile(r'^check_min_version\("[^"]+"\)\s*$', re.MULTILINE), 'check_min_version("VERSION")\n'), + "init": (re.compile(r'^__version__\s+=\s+"([^"]+)"\s*$', re.MULTILINE), '__version__ = "VERSION"\n'), + "setup": (re.compile(r'^(\s*)version\s*=\s*"[^"]+",', re.MULTILINE), r'\1version="VERSION",'), + "doc": (re.compile(r'^(\s*)release\s*=\s*"[^"]+"$', re.MULTILINE), 'release = "VERSION"\n'), +} +REPLACE_FILES = { + "init": "src/__init__.py", + "setup": "setup.py", +} +README_FILE = "README.md" + + +def update_version_in_file(fname, version, pattern): + """Update the version in one file using a specific pattern.""" + with open(fname, "r", encoding="utf-8", newline="\n") as f: + code = f.read() + re_pattern, replace = REPLACE_PATTERNS[pattern] + replace = replace.replace("VERSION", version) + code = re_pattern.sub(replace, code) + with open(fname, "w", encoding="utf-8", newline="\n") as f: + f.write(code) + + +def update_version_in_examples(version): + """Update the version in all examples files.""" + for folder, directories, fnames in os.walk(PATH_TO_EXAMPLES): + # Removing some of the folders with non-actively maintained examples from the walk + if "research_projects" in directories: + directories.remove("research_projects") + if "legacy" in directories: + directories.remove("legacy") + for fname in fnames: + if fname.endswith(".py"): + update_version_in_file(os.path.join(folder, fname), version, pattern="examples") + + +def global_version_update(version, patch=False): + """Update the version in all needed files.""" + for pattern, fname in REPLACE_FILES.items(): + update_version_in_file(fname, version, pattern) + if not patch: + update_version_in_examples(version) + + +def get_version(): + """Reads the current version in the __init__.""" + with open(REPLACE_FILES["init"], "r") as f: + code = f.read() + default_version = REPLACE_PATTERNS["init"][0].search(code).groups()[0] + return packaging.version.parse(default_version) + + +def pre_release_work(patch=False): + """Do all the necessary pre-release steps.""" + # First let's get the default version: base version if we are in dev, bump minor otherwise. + default_version = get_version() + if patch and default_version.is_devrelease: + raise ValueError("Can't create a patch version from the dev branch, checkout a released version!") + if default_version.is_devrelease: + default_version = default_version.base_version + elif patch: + default_version = f"{default_version.major}.{default_version.minor}.{default_version.micro + 1}" + else: + default_version = f"{default_version.major}.{default_version.minor + 1}.0" + + # Now let's ask nicely if that's the right one. + version = input(f"Which version are you releasing? [{default_version}]") + if len(version) == 0: + version = default_version + + print(f"Updating version to {version}.") + global_version_update(version, patch=patch) + + +# if not patch: +# print("Cleaning main README, don't forget to run `make fix-copies`.") +# clean_main_ref_in_model_list() + + +def post_release_work(): + """Do all the necesarry post-release steps.""" + # First let's get the current version + current_version = get_version() + dev_version = f"{current_version.major}.{current_version.minor + 1}.0.dev0" + current_version = current_version.base_version + + # Check with the user we got that right. + version = input(f"Which version are we developing now? [{dev_version}]") + if len(version) == 0: + version = dev_version + + print(f"Updating version to {version}.") + global_version_update(version) + + +# print("Cleaning main README, don't forget to run `make fix-copies`.") +# clean_main_ref_in_model_list() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--post_release", action="store_true", help="Whether this is pre or post release.") + parser.add_argument("--patch", action="store_true", help="Whether or not this is a patch release.") + args = parser.parse_args() + if not args.post_release: + pre_release_work(patch=args.patch) + elif args.patch: + print("Nothing to do after a patch :-)") + else: + post_release_work()