Skip to content

Commit

Permalink
[Feature] Enable activation checkpoint offloading (#722)
Browse files Browse the repository at this point in the history
* Add act offloading option

* Add changelog

* Turn off offloading when ac is not enabled

* formatting

---------

Co-authored-by: Mohammad Amin Nabian <[email protected]>
  • Loading branch information
chang-l and mnabian authored Nov 25, 2024
1 parent 7f739f7 commit a5d3b5b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 9 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- The XAeroNet model.
- Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
cross entropy loss.
- Option to offload checkpoints to further reduce memory usage
- Added StormCast model training and simple inference to examples

### Changed
Expand Down
46 changes: 37 additions & 9 deletions modulus/models/meshgraphnet/meshgraphnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext

import torch
import torch.nn as nn
from torch import Tensor
Expand Down Expand Up @@ -98,6 +100,8 @@ class MeshGraphNet(Module):
Whether to replace concat+MLP with MLP+idx+sum
num_processor_checkpoint_segments: int, optional
Number of processor segments for gradient checkpointing, by default 0 (checkpointing disabled)
checkpoint_offloading: bool, optional
Whether to offload the checkpointing to the CPU, by default False
Example
-------
Expand Down Expand Up @@ -138,6 +142,7 @@ def __init__(
aggregation: str = "sum",
do_concat_trick: bool = False,
num_processor_checkpoint_segments: int = 0,
checkpoint_offloading: bool = False,
recompute_activation: bool = False,
):
super().__init__(meta=MetaData())
Expand Down Expand Up @@ -184,6 +189,7 @@ def __init__(
activation_fn=activation_fn,
do_concat_trick=do_concat_trick,
num_processor_checkpoint_segments=num_processor_checkpoint_segments,
checkpoint_offloading=checkpoint_offloading,
)

def forward(
Expand Down Expand Up @@ -215,10 +221,14 @@ def __init__(
activation_fn: nn.Module = nn.ReLU(),
do_concat_trick: bool = False,
num_processor_checkpoint_segments: int = 0,
checkpoint_offloading: bool = False,
):
super().__init__()
self.processor_size = processor_size
self.num_processor_checkpoint_segments = num_processor_checkpoint_segments
self.checkpoint_offloading = (
checkpoint_offloading if (num_processor_checkpoint_segments > 0) else False
)

edge_block_invars = (
input_dim_node,
Expand Down Expand Up @@ -254,6 +264,23 @@ def __init__(
self.processor_layers = nn.ModuleList(layers)
self.num_processor_layers = len(self.processor_layers)
self.set_checkpoint_segments(self.num_processor_checkpoint_segments)
self.set_checkpoint_offload_ctx(self.checkpoint_offloading)

def set_checkpoint_offload_ctx(self, enabled: bool):
"""
Set the context for CPU offloading of checkpoints
Parameters
----------
checkpoint_offloading : bool
whether to offload the checkpointing to the CPU
"""
if enabled:
self.checkpoint_offload_ctx = torch.autograd.graph.save_on_cpu(
pin_memory=True
)
else:
self.checkpoint_offload_ctx = nullcontext()

def set_checkpoint_segments(self, checkpoint_segments: int):
"""
Expand Down Expand Up @@ -326,14 +353,15 @@ def forward(
edge_features: Tensor,
graph: Union[DGLGraph, List[DGLGraph], CuGraphCSC],
) -> Tensor:
for segment_start, segment_end in self.checkpoint_segments:
edge_features, node_features = self.checkpoint_fn(
self.run_function(segment_start, segment_end),
node_features,
edge_features,
graph,
use_reentrant=False,
preserve_rng_state=False,
)
with self.checkpoint_offload_ctx:
for segment_start, segment_end in self.checkpoint_segments:
edge_features, node_features = self.checkpoint_fn(
self.run_function(segment_start, segment_end),
node_features,
edge_features,
graph,
use_reentrant=False,
preserve_rng_state=False,
)

return node_features

0 comments on commit a5d3b5b

Please sign in to comment.