Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🚀[FEA]: Add a wrapper class for shared tensors in model parallel implementations #235

Closed
akshaysubr opened this issue Nov 16, 2023 · 3 comments · Fixed by #243
Closed
Labels
? - Needs Triage Need team to review and classify distributed Distributed and model parallel tools enhancement New feature or request

Comments

@akshaysubr
Copy link
Collaborator

Is this a new feature, an improvement, or a change to existing functionality?

New Feature

How would you describe the priority of this feature request

Critical (currently preventing usage)

Please provide a clear description of problem you would like to solve.

#153 recommends using DDP with a "data_parallel" process group that is orthogonal to the "model_parallel" process group across which the weights of a model are distributed. This will perform the gradient all-reduce only along the "data_parallel" process group which is the right behavior for distributed weights.

However, in many cases, even with model parallelism, some weights of a model are shared across the "model_parallel" group. For these cases, the gradient reduction should also happen along the "model_parallel" process group. This can be done in two ways:

  1. Registering a custom communication hook that performs the additional reduction along the "model_parallel" group for parameters that are shared. See SFNO implementation for specifics.
  2. Adding custom ops to the model implementation that are no-ops in the forward pass but do the gradient reduction across the shared process group in the backward pass.

In both cases, the model implementation somehow needs to indicate which tensors are shared and across what groups. It would be ideal if whatever this mechanism ends up being also automatically sets up gradient hooks or the backward reduction op. One proposed solution is to create a class wrapper that can also be used as a decorator. Something like

m = SharedModel(nn.Linear(20, 30))

and use as usual

output = m(input)

A rough sketch of the SharedModel wrapper class implementation:

import modulus
from modulus.distributed import DistributedManager
from modulus.distributed.utils import _reduce
import torch
from torch import nn


class SharedModelBackwardReduction(torch.autograd.Function):

    @staticmethod
    def symbolic(graph, input_, group_):
        return input_

    @staticmethod
    def forward(ctx, input_, group_):
        ctx.group = group_
        return input_

    @staticmethod
    def backward(ctx, grad_output):
        return _reduce(grad_output, group=DistributedManager().group(ctx.group))


class SharedModel(nn.Module):
    def __init__(self, model: Union[nn.Module, modulus.Module], group=None):
        self.group = group
        self.model = model
     
    def forward(self, *args, **kwargs):
        # Set up forward no-op for backward reduction
        for param in self.model.parameters():
            SharedModelBackwardReduction.apply(param, self.group)

        return self.model(*args, **kwargs)

Describe any alternatives you have considered

No response

@akshaysubr akshaysubr added enhancement New feature or request ? - Needs Triage Need team to review and classify distributed Distributed and model parallel tools labels Nov 16, 2023
@stadlmax
Copy link
Collaborator

Given that 1.) is more cumbersome, I would be in favor of the more transparent solution proposed in 2.) and as discussed offline.

With the current approach, we have gone in the direction of "distributed-first" anyways. Distributed (non-shared) weights will be explicitly treated as such. This is e.g. unlike the approach PyTorch takes in their experimental API of "TensorParallelism" (https://pytorch.org/docs/stable/distributed.tensor.parallel.html). Hence, the 2.) suggestion should also fit into the current paradigm in Modulus more easily.

@bonevbs
Copy link

bonevbs commented Dec 11, 2023

I would like to point out that either way the user will need to explicitly define not just shared weights but also the distributed ops. While this does not require custom reduction logic, it would certainly require custom forward logic.

With this being the case, I prefer a verbose approach, which allows to annotate Tensors directly, which is then handled automatically by the reduction hooks. This is very much in line with what PyTorch does (user defines forward ops, backward is handled automatically). This has the added benefit of making the code more readable as the forward pass only includes logic for the (already complicated, distributed) forward pass

@stadlmax
Copy link
Collaborator

Feel free to have a look at #243.

This for now would rely on the post_accumulate_grad_hook of shared parameters. The other option would a custom comm_hook which would then handle tensor annotations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage Need team to review and classify distributed Distributed and model parallel tools enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants