Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to configure the required process group structure for a model #153

Closed

Conversation

akshaysubr
Copy link
Collaborator

Modulus Pull Request

Description

This PR adds a ProcessGroupConfig class to define the configuration of a model's parallel process group structure as a tree. Each node of the tree is of type ProcessGroupNode.

Once the process group config structure (i.e, the tree structure) is set, it is sufficient to set only the sizes for each leaf process group. Then, the size of every parent group can be automatically computed as the product reduction of the sub-tree of that parent group node.

The intent is to use this config structure as follows (which will require some more additions to DistributedManager and changes to distributed model implementations):

class MyModel(Module):
    def __init__(self):
        pass

    @staticmethod
    def get_model_parallel_info(self) -> ProcessGroupConfig:
        mp = ProcessGroupNode("model_parallel", orthogonal_group="data_parallel")
        config = ProcessGroupConfig(mp)
        config.add_node(ProcessGroupNode("spatial_parallel"), parent=mp)
        config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel")
        return config                                                                          


pg_config = MyModel.get_model_parallel_info()
group_sizes = {"channel_parallel": 4, "spatial_parallel": 2}
pg_config.set_leaf_group_sizes(group_sizes)

print(pg_config.leaf_groups())  # ['spatial_parallel', 'channel_parallel']
print(pg_config.get_node("model_parallel").size)  # 8

DistributedManager.create_process_groups(pg_config, topology_mapping)
manager = DistributedManager()

model = MyModel()
model = DDP(model, process_group=manager.group("data_parallel"))

Closes: #142

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

This PR adds treelib as a dependency to manage the process group config

@akshaysubr akshaysubr changed the title 142 fea process group config Add a way to configure the required process group structure for a model Sep 12, 2023
Copy link
Collaborator

@dallasfoster dallasfoster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only thing that is not quite clear is what the purpose and algebra of the orthogonal_group is. There doesn't appear to be anything in this MR that sets relationships or interactions with any orthogonal group. Can I have more than 1 orthogonal group?

modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
@akshaysubr
Copy link
Collaborator Author

@dallasfoster The main purpose of having an orthogonal group is to perform communication/reductions along the orthogonal dimension of a process group. For example, if you have 32 GPUs and a model parallel size of 8, then you have 4 sets of these model parallel groups that you do data parallel training over. Now when you want to reduce the gradients, you would want to only do an all_reduce along all GPUs with constant model parallel rank (8 independent all reductions across 4 GPUs each). This is basically doing that all_reduce along the orthogonal process group to the model_parallel process group, which we're calling the data_parallel group. That's what is in the last line of the code snippet in the PR description:

model = DDP(model, process_group=manager.group("data_parallel"))

Hope this helps explain why we need orthogonal process groups. There might be other instances where a model needs to perform reductions in a dimension that is orthogonal to a specific group. I don't know of any specific use case for this though, so there is less of a justification for having orthogonal groups for non-root groups maybe.

return self.__str__()


class ProcessGroupConfig:
Copy link
Collaborator

@NickGeneva NickGeneva Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this Akshay, I'm mainly curious of how the distributed manager will know what groups to make internode vs intranode?

For example in the example you gave theres three groups ("channel_parallel", "spatial_parallel", and "data_parallel"). How would we know to stick the first two on the same node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there isn't such a way. In your example, you'd have a model_parallel group with data_parallel being the orthogonal group. model_parallel would have channel_parallel and spatial_parallel child groups. When we create the process groups, we first create the model_parallel group which would use a contiguous set of ranks. So if the total model_parallel group size was 8, that would be on a single node. See #188 for the specifics of how this works.

But I want to emphasize that this is just the logical group config. It is up to the DistributedManager to actually instantiate this group config and that's where this kind of topology optimization needs to come in. Right now it is this hard coded topology mapping, but there isn't a reason we cannot add the ability for a user to specify a custom topology mapping when passing this config to the DistributedManager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the MR looks good but Nick is right: there needs to be some notion of what is close and what is not. So basically an order of initialization. So for example the leaf node list should be ordered, then it should be clear.

modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
modulus/distributed/config.py Outdated Show resolved Hide resolved
@NickGeneva NickGeneva added the 4 - In Review Currently Under Review label Sep 14, 2023
@dallasfoster
Copy link
Collaborator

The main purpose of having an orthogonal group is to perform communication/reductions along the orthogonal dimension of a process group.

I understand the intent of the so-called "orthogonal process group", I suppose my concern/confusion is around the definition, rules, and algebra of this group as set forth in this MR because it doesn't appear to have any. I would expect some behavior or rules enforced to ensure the "orthogonality" of the passed process group to ProcessGroupNode but instead it just accepts a str, which is not entirely meaningful. Will these rules be enforced by the distributed manager? If so I would like just a sketch or outline of what they would look like.

@akshaysubr
Copy link
Collaborator Author

@dallasfoster #188 adds the DistributedManager utility that creates a process group and it's orthogonal process group. Hopefully this makes things more clear.

There is one issue though that an orthogonal process group is orthogonal in the scope of the entire group.WORLD right now, not within the scope of the parent group. I'm not sure yet which one is preferable. The only real use case for using an orthogonal process group currently is DDP and both are equivalent for that.

…ig.py to process_group.py

Signed-off-by: Akshay Subramaniam <[email protected]>
Copy link
Collaborator

@dallasfoster dallasfoster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small comment, otherwise looks good to me given #188.

modulus/distributed/process_group.py Show resolved Hide resolved
@stadlmax
Copy link
Collaborator

What's the current status?

@akshaysubr
Copy link
Collaborator Author

Since this PR and #188 are so tightly coupled, consolidating these changes in #188. Closing this PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
4 - In Review Currently Under Review
Projects
None yet
Development

Successfully merging this pull request may close these issues.

🚀[FEA]: Add a way to configure model parallel process groups
5 participants