Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StormCast training code improvements #738

Merged
merged 26 commits into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
d08ab3f
adding stormcast raw files
Nov 12, 2024
f2fdad8
major cleanup, refactor and consolidation
pzharrington Nov 14, 2024
999b9d7
More cleanup and init readme
pzharrington Nov 15, 2024
145f01d
port command line args to standard argparse
pzharrington Nov 15, 2024
2c161d8
remove unused network and loss wrappers
pzharrington Nov 15, 2024
546752a
add torchrun instructions
pzharrington Nov 15, 2024
c69c84b
drop dnnlib utils
pzharrington Nov 15, 2024
6babca4
use Modulus DistributedManager, streamline cmd args
pzharrington Nov 15, 2024
bc21616
Use standard torch checkpoints instead of pickles
pzharrington Nov 16, 2024
ca3fbe8
Standardize model configs and channel selection across training and i…
pzharrington Nov 16, 2024
cc3919a
checkpoint format standardization for train/inference
pzharrington Nov 18, 2024
9cd1c1a
finalize additional deps
pzharrington Nov 18, 2024
33cb8a3
format and linting
pzharrington Nov 18, 2024
a9d3c12
drop docker and update changelog
pzharrington Nov 19, 2024
181dfec
Address feedback
pzharrington Nov 19, 2024
473a11c
add variables to readme, rename network types
pzharrington Nov 19, 2024
bf8e416
swap stormcast to modulus nn and loss defs
pzharrington Nov 22, 2024
820b96b
Swap to modulus checkpoint save and load utils
pzharrington Nov 25, 2024
4e820b8
Swap to modulus networks/losses, use modulus checkpointing and logging
pzharrington Nov 27, 2024
59b9610
add power spectrum to modulus metrics, remove unused utils
pzharrington Dec 3, 2024
cbcd9ff
Readme update and unit tests
pzharrington Dec 5, 2024
1704738
Merge branch 'main' into stormcast-nn-consol
pzharrington Dec 6, 2024
03f5ba4
drop unused files
pzharrington Dec 9, 2024
f18f862
drop unused diffusions files
pzharrington Dec 9, 2024
68cca1d
Merge branch 'main' into stormcast-nn-consol
pzharrington Dec 13, 2024
6875b38
update changelog
pzharrington Dec 13, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Refactored StormCast training example

### Deprecated

### Removed
Expand Down
89 changes: 52 additions & 37 deletions examples/generative/stormcast/README.md
pzharrington marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
<!-- markdownlint-disable -->
## StormCast: Kilometer-Scale Convection Allowing Model Emulation using Generative Diffusion Modeling

**Note: this example is an initial release of the StormCast code and will be heavily refactored in future releases**

## Problem overview

Convection-allowing models (CAMs) are essential tools for forecasting severe thunderstorms and
Expand All @@ -18,78 +16,95 @@ accuracy, demonstrating ability to replicate storm dynamics, observed radar refl
atmospheric structure via deep learning-based CAM emulation. StormCast enables high-resolution ML-driven
regional weather forecasting and climate risk analysis.


<p align="center">
<img src="../../../docs/img/stormcast_rollout.gif"/>
</p>

The design of StormCast relies on two neural networks:
1. A regression model, which provides a deterministic estimate of the next HRRR timestep given the previous timestep's HRRR and background ERA5 states
2. A diffusion model, which is given the previous HRRR timestep as well as the estimate from the regression model, and provides a correction to the regression model estimate to produce a final high-quality prediction of the next high-resolution atmospheric state.

Much like other data-driven weather models, StormCast can make longer forecasts (more than one timestep) during inference by feeding its predictions back into the model as input for the next step (autoregressive rollout). The regression and diffusion components are trained separately (with the diffusion model training requiring a regression model as prerequisite), then coupled together in inference. Note in the above description, we specifically name HRRR and ERA5 as the regional high-resolution and global coarse-resolution data sources/targets, respectively, but the StormCast setting should generalize to any regional/global coupling of interest.



## Getting started

### Preliminaries
Start by installing Modulus (if not already installed) and copying this folder (`examples/generative/stormcast`) to a system with a GPU available. Also, prepare a combined HRRR/ERA5 dataset in the form specified in `utils/data_loader_hrrr_era5.py` (**Note: subsequent versions of this example will include more detailed dataset preparation instructions**).

### Configuration basics

StormCast training is handled by `train.py` and controlled by a YAML configuration file in `config/config.yaml` and command line arguments. You can choose the configuration file using the `--config_file` option, and a specific configuration within that file with the `--config-name` option. The main configuration file specifies the training dataset, the model configuration and the training options. To change a configuration option, you can either edit the existing configurations directly or make new ones by inheriting from the existing configs and overriding specific options. For example, one could create a new config for training the diffusion model in StormCast by creating a new config that inherits from the existing `diffusion` config in `config/config.yaml`:
```
diffusion_bs64:
<<: *diffusion
batch_size: 1
StormCast training is handled by `train.py`, configured using [hydra](https://hydra.cc/docs/intro/) based on the contents of the `config` directory. Hydra allows for YAML-based modular and hierarchical configuration management and supports command-line overrides for quick testing and experimentation. The `config` directory includes the following subdirectories:
- `dataset`: specifies the resolution, number of variables, and other parameters of the dataset
- `model`: specifies the model type and model-specific hyperparameters
- `sampler`: specifies hyperparameters used in the sampling process for diffusion models
- `training`: specifies training-specific hyperparameters and settings like checkpoint/log frequency and where to save training outputs
- `inference` specifies inference-specific settings like which initial condition to run, which model checkpoints to use, etc.
- `hydra`: specifies basic hydra settings, like where to store outputs (based on the training or inference outputs directories)

Also in the `config` directory are several top-level configs which show how to train a `regression` model or `diffusion` model, and run inference (`stormcast-inference`). One can select any of these by specifying it as a config name at the command line (e.g., `--config-name=regression`); optionally one can also override any specific items of interest via command line args, e.g.:
```bash
python train.py --config-name regression training.batch_size=4
```

The basic configuration file currently contains configurations for just the `regression` and `diffusion` components of StormCast. Note any diffusion model you train will need a pretrained regression model to use, due to how StormCast is designed (you can refer to the paper for more details), thus there are two config items that must be defined to train a diffusion model:
1. `regression_weights` -- The path to a checkpoint with model weights for the regression model. This file should be a pytorch checkpoint saved by your training script, with the `state_dict` for the regression network saved under the `net` key.
2. `regression_config` -- the config name used to train this regression model
More extensive configuration modifications can be made by creating a new top-level configuration file similar to `regression` or `diffusion`. See `diffusion.yaml` for an example of how to specify a top-level config that uses default configuration settings with additional custom modifications added on top.

All configuration items related to the dataset are also contained in `config/config.yaml`, most importantly the location on the filesystem of the prepared HRRR/ERA5 Dataset (see [Dataset section](#dataset) for details).
Note any diffusion model you train will need a pretrained regression model to use, so there are two config items that must be defined to train a diffusion model:
1. `model.use_regession_net = True`
2. `model.regression_weights` set to the path of a Modulus (`.mdlus`) checkpoint with model weights for the regression model. These are saved in the checkpoints directory during training.

There is also a model registry `config/registry.json` which can be used to index different model versions to be used in inference/evaluation. For simplicity, there is just a single model version specified there currently, which matches the StormCast model used to generate results in the paper.
Once again, the reference `diffusion.yaml` top-level config shows an example of how to specify these settings.

At runtime, hydra will parse the config subdirectory and command line over-rides into a runtime configuration object `cfg`, which will have all settings accessible via both attribute or dictionary-like interfaces. For example, the total training batch size can be accessed either as `cfg.training.batch_size` or `cfg['training']['batch_size']`.

### Training the regression model
To train the StormCast regression model, we use the default configuration file `config.yaml` and specify the `regression` config, along with the `--outdir` argument to choose where training logs and checkpoints should be saved.
We also can use command line options defined in `train.py` to specify other details, like a unique run ID to use for the experiment (`--run_id`). On a single GPU machine, for example, run:
To train the StormCast regression model, we simply specify the example `regression` config and an optional name for the training experiment. On a single GPU machine, for example, run:
```bash
python train.py --outdir rundir --config_file ./config/config.yaml --config_name regression --run_id 0
python train.py --config-name regression training.experiment_name=regression
```

This will initialize training experiment and launch the main training loop, which is defined in `utils/trainer.py`. Outputs (training logs, checkpoints, etc.) will be saved to a directory specified by the following `training` config items:
```yaml
training.outdir: 'rundir' # Root path under which to save training outputs
training.experiment_name: 'stormcast' # Name for the training experiment
training.run_id: '0' # Unique ID to use for this training run
training.rundir: ./${training.outdir}/${training.experiment_name}/${training.run_id} # Path where experiement outputs will be saved
```

This will initialize training experiment and launch the main training loop, which is defined in `utils/diffusions/training_loop.py`.
As you can see, the `training.run_id` setting can be used for distinguishing between different runs of the same configuration. The final training output directory is constructed by composing together the `training.outdir` root path (defaults to `rundir`), the `training.experiment_name`, and the `training.run_id`.

### Training the diffusion model

The method for launching a diffusion model training looks almost identical, and we just have to change the configuration name appropriately. However, since we need a pre-trained regression model for the diffusion model training, this config must define `regression_pickle` to point to a compatible pickle file with network weights for the regression model. Once that is taken care of, launching diffusion training looks nearly identical as previously:
The method for launching a diffusion model training looks almost identical, and we just have to change the configuration name appropriately. However, since we need a pre-trained regression model for the diffusion model training, the specified config must include the settings mentioned above in [Configuration Basics](#configuration-basics) to provide network weights for the regression model. With that, launching diffusion training looks something like:
```bash
python train.py --outdir rundir --config_file ./config/config.yaml --config_name diffusion --run_id 0
python train.py --config-name diffusion training.experiment_name=diffusion
```

Note that the full training pipeline for StormCast is fairly lengthy, requiring about 120 hours on 64 NVIDIA H100 GPUs. However, more lightweight trainings can still produce decent models if the diffusion model is not trained for as long.

Both regression and diffusion training can be distributed easily with data parallelism via `torchrun`. One just needs to ensure the configuration being run has a large enough batch size to be distributed over the number of available GPUs/processes. The example `regression` and `diffusion` configs in `config/config.yaml` just use a batch size of 1 for simplicity, but new configs can be easily added [as described above](#configuration-basics). For example, distributed training over 8 GPUs on one node would look something like:
Both regression and diffusion training can be distributed easily with data parallelism via `torchrun` or other launchers (e.g., SLURM `srun`). One just needs to ensure the configuration being run has a large enough batch size to be distributed over the number of available GPUs/processes. The example `regression` and `diffusion` configs just use a batch size of 1 for simplicity, but new configs can be easily added [as described above](#configuration-basics). For example, distributed training over 8 GPUs on one node would look something like:
```bash
torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --outdir rundir --config_file ./config/config.yaml --config_name <your_distributed_training_config> --run_id 0
torchrun --standalone --nnodes=1 --nproc_per_node=8 train.py --config-name <your_distributed_training_config>
```

Once the training is completed, you can enter a new model into `config/registry.json` that points to the checkpoints (`.pt` file in your training output directory), and you are ready to run inference.

### Inference

A simple demonstrative inference script is given in `inference.py`, which loads a pretrained model from a local directory named `stormcast_checkpoints`.
Yout should update this path to the checkpoints saved by your training runs that you want to run inference for.
The `inference.py` script will run a 12-hour forecast and save outputs as a `zarr` file along with a few plots saved as `png` files.
A simple demonstrative inference script is given in `inference.py`, which is also configured using hydra in a manner similar to training. The reference `stormcast_inference` config shows an example inference config, which looks largely the same as a training config except the output directory is now controlled by the settings from `inference` rather than `training` config:
```yaml
inference.outdir: 'rundir' # Root path under which to save inference outputs
inference.experiment_name: 'stormcast-inference' # Name for the inference experiment being run
inference.run_id: '0' # Unique identifier for the inference run
inference.rundir: ./${inference.outdir}/${inference.experiment_name}/${inference.run_id} # Path where experiment outputs will be saved
```

To run inference, simply do:

```bash
python inference.py
```
This inference script is configured by the contents of a model registry, which specifies config files and names to use for each of the diffusion and regression networks, along with other inference options which specify architecure types and a short description of the model. The `inference.py` script will automatically use the default file for the model registry (`config/registry.json`) and evaluate the `stormcast` example model, but you can configure it to run your desired inference case(s) with the following command-line options:
```bash
--outdir DIR Where to save the results
--registry_file FILE Path to model registry file
--model_name MODEL Name of model to evaluate from the registry
python inference.py --config-name <your_inference_config>
```

We also recommend bringing your checkpoints to [earth2studio](https://github.com/NVIDIA/earth2studio)
for further anaylysis and visualizations.
This will load regression and diffusion models from directories specified by `inference.regression_checkpoint` and `inference.diffusion_checkpoint` respectively; each of these should be a path to a Modulus checkpoint (`.mdlus` file) from your training runs. The `inference.py` script will use these models to run a forecast and save outputs as a `zarr` file along with a few plots saved as `png` files. We also recommend bringing your checkpoints to [earth2studio](https://github.com/NVIDIA/earth2studio)
for further analysis and visualizations.


## Dataset
Expand Down Expand Up @@ -133,7 +148,7 @@ A custom dataset object is defined in `utils/data_loader_hrrr_era5.py`, which lo

## Logging

These scripts use Weights & Biases for experiment tracking, which can be enabled by passing the `--log_to_wandb` argument to `train.py`. Academic accounts are free to create at [wandb.ai](https://wandb.ai/).
These scripts use Weights & Biases for experiment tracking, which can be enabled by setting `training.log_to_wandb=True`. Academic accounts are free to create at [wandb.ai](https://wandb.ai/).
Once you have an account set up, you can adjust `entity` and `project` in `train.py` to the appropriate names for your `wandb` workspace.


Expand Down
70 changes: 70 additions & 0 deletions examples/generative/stormcast/config/dataset/hrrr_era5.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Main dataset
location: 'data' # Path to the dataset
conus_dataset_name: 'hrrr_v3' # Version name for the dataset
hrrr_stats: 'stats_v3_2019_2021' # Summary stats name for the dataset

# Domain
hrrr_img_size: [512, 640] # Image dimensions of the HRRR region of interest
boundary_padding_pixels: 0 # set this to 0 for no padding of ERA5 beyond HRRR domain,
# 32 for 32 pixels of padding in each direction, etc.

# Temporal selection
dt: 1 # Timestep between samples (in multiples of the base HRRR 1hr timestep)
train_years: [2018, 2019, 2020, 2021] # Years to use for training
valid_years: [2022] # Years to use for validation

# Variable selection
invariants: ["lsm", "orog"] # Invariant quantitites to include
input_channels: 'all' #'all' or list of channels to condition on
diffusion_channels: "all" #'all' or list of channels to condition on
exclude_channels: # Dataset channels to exclude from inputs/predicitons
- u35
- u40
- v35
- v40
- t35
- t40
- q35
- q40
- w1
- w2
- w3
- w4
- w5
- w6
- w7
- w8
- w9
- w10
- w11
- w13
- w15
- w20
- w25
- w30
- w35
- w40
- p25
- p30
- p35
- p40
- z35
- z40
- tcwv
- vil
45 changes: 45 additions & 0 deletions examples/generative/stormcast/config/diffusion.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Defaults
defaults:

# Dataset
- dataset/hrrr_era5

# Model
- model/stormcast

# Training
- training/default

# Sampler
- sampler/edm_deterministic

# Hydra
- hydra/default

- _self_

# Diffusion model specific changes
model:
use_regression_net: True
regression_weights: "stormcast_checkpoints/regression/StormCastUNet.0.0.mdlus"
previous_step_conditioning: True
spatial_pos_embed: True

training:
loss: 'edm'
18 changes: 18 additions & 0 deletions examples/generative/stormcast/config/hydra/default.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
# SPDX-FileCopyrightText: All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

run:
dir: ${training.outdir}/${training.experiment_name}/${training.run_id}
Loading