🚀[FEA]: Add a wrapper class for shared tensors in model parallel implementations #235
Labels
? - Needs Triage
Need team to review and classify
distributed
Distributed and model parallel tools
enhancement
New feature or request
Is this a new feature, an improvement, or a change to existing functionality?
New Feature
How would you describe the priority of this feature request
Critical (currently preventing usage)
Please provide a clear description of problem you would like to solve.
#153 recommends using DDP with a
"data_parallel"
process group that is orthogonal to the"model_parallel"
process group across which the weights of a model are distributed. This will perform the gradient all-reduce only along the"data_parallel"
process group which is the right behavior for distributed weights.However, in many cases, even with model parallelism, some weights of a model are shared across the
"model_parallel"
group. For these cases, the gradient reduction should also happen along the"model_parallel"
process group. This can be done in two ways:"model_parallel"
group for parameters that are shared. See SFNO implementation for specifics.In both cases, the model implementation somehow needs to indicate which tensors are shared and across what groups. It would be ideal if whatever this mechanism ends up being also automatically sets up gradient hooks or the backward reduction op. One proposed solution is to create a class wrapper that can also be used as a decorator. Something like
and use as usual
A rough sketch of the
SharedModel
wrapper class implementation:Describe any alternatives you have considered
No response
The text was updated successfully, but these errors were encountered: