Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] hook-based support for distributed but shared parameter #243

Merged
merged 9 commits into from
Jan 4, 2024
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Distributed process group configuration mechanism.
- DistributedManager utility to instantiate process groups based on they
- Helper functions to faciliate distributed training with shared parameters.
process group config.
- Brain anomaly detection example.
- Updated Frechet Inception Distance to use Wasserstein 2-norm with improved
Expand Down
2 changes: 1 addition & 1 deletion modulus/distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from .autograd import all_gather_v, gather_v, indexed_all_to_all_v, scatter_v
from .config import ProcessGroupConfig, ProcessGroupNode
from .manager import DistributedManager
from .utils import gather_loss
from .utils import mark_module_as_shared, reduce_loss, unmark_module_as_shared
4 changes: 2 additions & 2 deletions modulus/distributed/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
import torch.distributed as dist

from .utils import (
all_gather_v_bwd_wrapper,
all_gather_v_wrapper,
all_reduce_v_wrapper,
gather_v_wrapper,
indexed_all_to_all_v_wrapper,
indexed_all_to_all_v_wrapper_bwd,
Expand Down Expand Up @@ -68,7 +68,7 @@ def backward(ctx, grad_output: torch.Tensor): # pragma: no cover
needs_grad = ctx.needs_input_grad[0]

if needs_grad:
grad_tensor = all_reduce_v_wrapper(
stadlmax marked this conversation as resolved.
Show resolved Hide resolved
grad_tensor = all_gather_v_bwd_wrapper(
grad_output,
ctx.sizes,
dim=ctx.dim,
Expand Down
116 changes: 99 additions & 17 deletions modulus/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F

from .manager import DistributedManager
Expand Down Expand Up @@ -89,17 +90,17 @@ def split_tensor_along_dim(tensor, dim, num_chunks):


@torch.no_grad()
def gather_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: no cover
"""Gathers loss from all processes to one for logging
def reduce_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: no cover
"""Reduces loss from all processes to destination rank for logging.

Parameters
----------
loss : float
loss value
dst_rank : int, Optional
destination rank to gather to, by default 0
destination rank to redce to, by default 0.
mean : bool, Optional
Calculate the mean of the losses gathered, by default True
Calculate the mean of the losses gathered, by default True.

Raises
------
Expand All @@ -108,29 +109,21 @@ def gather_loss(loss: float, dst_rank: int = 0, mean: bool = True): # pragma: n
"""
if not DistributedManager.is_initialized():
raise Exception(
"Distributed manager should be initialized when using gather_loss"
"Distributed manager should be initialized when using reduce_loss"
)

distmng = DistributedManager()
loss = torch.Tensor([loss])
loss = torch.Tensor([loss]).to(distmng.device)

# For serial runs, just return the current loss!
if distmng.world_size == 1:
return float(loss)

# Gather using PyTorch distributed function
gather_list = None
if distmng.rank == dst_rank:
gather_list = [
torch.zeros(1).to(distmng.device) for i in range(distmng.world_size)
]
dist.gather(loss.to(distmng.device), gather_list, dst_rank)
op = torch.distributed.ReduceOp.SUM if not mean else torch.distributed.ReduceOp.AVG
torch.distributed.reduce(loss, dst_rank, op, group=None)

# Return loss if dst_rank, None otherwise
if distmng.rank == dst_rank:
loss = torch.sum(torch.cat(gather_list))
if mean:
loss = loss / distmng.world_size
return float(loss.cpu())
else:
return None
Expand Down Expand Up @@ -286,7 +279,7 @@ def all_gather_v_wrapper(
return output


def all_reduce_v_wrapper(
def all_gather_v_bwd_wrapper(
tensor: torch.Tensor,
sizes: List[int],
dim: int = 0,
Expand Down Expand Up @@ -670,3 +663,92 @@ def indexed_all_to_all_v_wrapper_bwd(
out = out.to(tensor.dtype)

return out


def mark_module_as_shared(
module: nn.Module,
process_group: Optional[str],
stadlmax marked this conversation as resolved.
Show resolved Hide resolved
recurse: bool = True,
use_fp32_reduction: bool = True,
) -> nn.Module:
stadlmax marked this conversation as resolved.
Show resolved Hide resolved
"""
Helper function to mark parameters of a module as being shared
across ranks by attaching gradient hooks to the corresponding tensors.

Parameters
----------
module : nn.Module
PyTorch module which is to be marked as having shared parameters.
process_group : str | None
str indicating process_group which contains ranks across which
the module's parameters are shared. If passed as None, will default
to the world group.
recurse : bool, default=True
Flag indicating whether the module's parameters are traversed in
a recursive fashion, i.e. whether sub-modules are also considered
as having shared parameters.
use_fp32_reduction : bool, default=True
Flag indicating whether the reduction for accumulating gradients
will be done in FP32 or the native datatype.
"""

group = DistributedManager().group(process_group)
handle_key = "_shared_weight_dist_hook"

def hook(grad: torch.Tensor) -> torch.Tensor:
# the documentation states that
# "The hook should not modify its argument, but it can optionally return a new gradient
# which will be used in place of grad."
# as all_reduce is an in-place operation, need to copy gradient
grad = _reduce(grad.clone(), group=group, use_fp32=use_fp32_reduction)
return grad

def hook_post_accum(param: torch.Tensor) -> None:
# the documentation states that
# "Note that, unlike other autograd hooks, this hook operates on the tensor that requires grad
# and not the grad itself. The hook can in-place modify and access its Tensor argument,
# including its .grad field."
param.grad = _reduce(param.grad, group=group, use_fp32=use_fp32_reduction)

for name, param in module.named_parameters(recurse=recurse):
error_msg = f"Parameter {name} already marked as having shared weights, can't mark it again!"
if hasattr(param, handle_key):
raise RuntimeError(error_msg)
if torch.__version__ < (2, 1):
handle = param.register_hook(hook)
stadlmax marked this conversation as resolved.
Show resolved Hide resolved
else:
handle = param.register_post_accumulate_grad_hook(hook_post_accum)
setattr(param, handle_key, handle)

return module


def unmark_module_as_shared(
module: nn.Module,
recurse: bool = True,
) -> nn.Module:
"""
Helper function to unmark parameters of a module as being shared
across ranks by removing attached gradient hooks.

Parameters
----------
module : nn.Module
PyTorch module which is to be unmarked as having shared parameters.
recurse : bool, default=True
Flag indicating whether the module's parameters are traversed in
a recursive fashion, i.e. whether sub-modules are also considered
as having shared parameters.
"""
handle_key = "_shared_weight_dist_hook"
for name, param in module.named_parameters(recurse=recurse):
error_msg = (
f"Parameter {name} NOT marked as having shared weights, can't unmark it!"
)
if not hasattr(param, handle_key):
raise RuntimeError(error_msg)
handle = getattr(param, handle_key)
handle.remove()
delattr(param, handle_key)

return module
Loading