Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a way to configure the required process group structure for a model #153

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 259 additions & 0 deletions modulus/distributed/config.py
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from treelib import Tree
from typing import Dict, List, Optional, Union


class ProcessGroupNode:
"""
Class to store the attributes of a distributed process group

Attributes
----------
name : str
Name of the process group
size : Optional[int]
Optional, size of the process group
orthogonal_group : Optional[str]
Optional, name of an orthogonal process group to create
"""

def __init__(
self,
name: str,
size: Optional[int] = None,
orthogonal_group: Optional[str] = None,
):
"""
Constructor for the ProcessGroupNode class

Parameters
----------
name : str
Name of the process group
size : Optional[int]
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
Optional, size of the process group
orthogonal_group : Optional[str]
Optional, name of an orthogonal process group to create
"""
self.name = name
self.size = size
self.orthogonal_group = orthogonal_group
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved

def __str__(self):
"""
String representation of the process group node

Returns
-------
str
String representation of the process group node
"""
return (
"ProcessGroupNode("
f"name={self.name}, "
f"size={self.size}, "
f"orthogonal_group={self.orthogonal_group})"
)

def __repr__(self):
"""
String representation of the process group node

Returns
-------
str
String representation of the process group node
"""
return self.__str__()


class ProcessGroupConfig:
Copy link
Collaborator

@NickGeneva NickGeneva Sep 13, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like this Akshay, I'm mainly curious of how the distributed manager will know what groups to make internode vs intranode?

For example in the example you gave theres three groups ("channel_parallel", "spatial_parallel", and "data_parallel"). How would we know to stick the first two on the same node?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently there isn't such a way. In your example, you'd have a model_parallel group with data_parallel being the orthogonal group. model_parallel would have channel_parallel and spatial_parallel child groups. When we create the process groups, we first create the model_parallel group which would use a contiguous set of ranks. So if the total model_parallel group size was 8, that would be on a single node. See #188 for the specifics of how this works.

But I want to emphasize that this is just the logical group config. It is up to the DistributedManager to actually instantiate this group config and that's where this kind of topology optimization needs to come in. Right now it is this hard coded topology mapping, but there isn't a reason we cannot add the ability for a user to specify a custom topology mapping when passing this config to the DistributedManager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the MR looks good but Nick is right: there needs to be some notion of what is close and what is not. So basically an order of initialization. So for example the leaf node list should be ordered, then it should be clear.

akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
"""
Class to define the configuration of a model's parallel process group structure as a
tree. Each node of the tree is of type `ProcessGroupNode`.

Once the process group config structure (i.e, the tree structure) is set, it is
sufficient to set only the sizes for each leaf process group. Then, the size of
every parent group can be automatically computed as the product reduction of the
sub-tree of that parent group node.

Examples
--------
>>> from modulus.distributed.config import ProcessGroupNode, ProcessGroupConfig
>>> mp = ProcessGroupNode("model_parallel", orthogonal_group="data_parallel")
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
>>>
>>> # Create model parallel group with data parallel as the orthogonal group
>>> mp = ProcessGroupNode("model_parallel", orthogonal_group="data_parallel")
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
>>>
>>> # Create the process group config with the highest level process group
>>> config = ProcessGroupConfig(mp)
>>>
>>> # Create spatial and channel parallel sub-groups
>>> config.add_node(ProcessGroupNode("spatial_parallel"), parent=mp)
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
>>> config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel")
>>>
>>> pg_config.leaf_groups()
['spatial_parallel', 'channel_parallel']
>>>
>>> # Set leaf group sizes
>>> group_sizes = {"channel_parallel": 3, "spatial_parallel": 2}
>>> pg_config.set_leaf_group_sizes(group_sizes) # Update all parent group sizes too
>>> pg_config.get_node("model_parallel").size
dallasfoster marked this conversation as resolved.
Show resolved Hide resolved
6
"""

def __init__(self, node: ProcessGroupNode):
"""
Constructor to the ProcessGroupConfig class

Parameters
----------
node : ProcessGroupNode
Root node of the tree, typically would be 'model_parallel'
Note, it is generally recommended to always set the orthogonal_group for
the 'model_parallel' group to be 'data_parallel' to aid with distributed
data parallel training
"""
self.root = node
self.root_id = node.name
self.tree = Tree()
self.tree.create_node(node.name, node.name, data=node)

def add_node(self, node: ProcessGroupNode, parent=Union[str, ProcessGroupNode]):
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
"""
Add a node to the process group config

Parameters
----------
node : ProcessGroupNode
The new node to be added to the config
parent : Union[str, ProcessGroupNode]
Parent node of the node to be added. Should already be in the config.
If str, it is the name of the parent node. Otherwise, the parent
ProcessGroupNode itself.
"""
if isinstance(parent, ProcessGroupNode):
parent = parent.name
self.tree.create_node(node.name, node.name, data=node, parent=parent)

def get_node(self, name: str) -> ProcessGroupNode:
"""
Method to get the node given the name of the node

Parameters
----------
name : str
Name of the node to retrieve

Returns
-------
ProcessGroupNode
Node with the given name from the config
"""
return self.tree.get_node(name).data

def update_parent_sizes(self, verbose: bool = False) -> int:
"""
Method to update parent node sizes after setting the sizes for each leaf node

Parameters
----------
verbose : bool
If True, print a message each time a parent node size was updated

Returns
-------
int
Size of the root node
"""
return tree_product_reduction(self.tree, self.root_id, verbose=verbose)

def leaf_groups(self) -> List[str]:
"""
Get a list of all leaf group names

Returns
-------
List[str]
List of all leaf node names
"""
# return find_leaf_nodes(self.tree, self.root_id)
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
return [n.identifier for n in self.tree.leaves()]

def set_leaf_group_sizes(
self, group_sizes: Dict[str, int], update_parent_sizes: bool = True
):
"""
Set process group sizes for all leaf groups

Parameters
----------
group_sizes : Dict[str, int]
Dictionary with a mapping of each leaf group name to its size
update_parent_sizes : bool
Update all parent group sizes based on the leaf group if True
If False, only set the leaf group sizes.
"""
for id, size in group_sizes.items():
assert self.tree.contains(
id
), f"Process group {id} is not in this process group config"
node = self.tree.get_node(id)
assert node.is_leaf(), f"Process group {id} is not a leaf group"
node.data.size = size

if update_parent_sizes:
self.update_parent_sizes()


def tree_product_reduction(tree, node_id, verbose=False):
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to traverse a tree and compute the product reduction of
the sub-tree for each node starting from `node_id`
"""
children = tree.children(node_id)
node = tree.get_node(node_id)
if not children:
assert node.data.size is not None, "Leaf nodes should have a valid size set"
return node.data.size

product = 1

for child in children:
product *= tree_product_reduction(tree, child.identifier)

if node.data.size != product:
if verbose:
print(
"Updating size of node "
f"{node.data.name} from {node.data.size} to {product}"
)
node.data.size = product

return product


def find_leaf_nodes(tree, node_id):
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
"""
Function to find all leaf node identifiers of `tree` starting from `node_id`
"""
if not tree.children(node_id):
return [node_id]
else:
leaf_nodes = []
for child_id in tree.children(node_id):
leaf_nodes.extend(find_leaf_nodes(tree, child_id.identifier))
return leaf_nodes
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ dependencies = [
"torch>=1.12",
"xarray>=2023.1.0",
"s3fs>=2023.5.0",
"scikit-learn>=1.0.2"
"scikit-learn>=1.0.2",
"treelib>=1.2.5"
akshaysubr marked this conversation as resolved.
Show resolved Hide resolved
]
classifiers = [
"Programming Language :: Python :: 3",
Expand Down
38 changes: 38 additions & 0 deletions test/distributed/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from modulus.distributed.config import ProcessGroupNode, ProcessGroupConfig


def test_config():
# Create model parallel group with data parallel as the orthogonal group
mp = ProcessGroupNode("model_parallel", orthogonal_group="data_parallel")

# Create the process group config with the highest level process group
pg_config = ProcessGroupConfig(mp)

# Create spatial and channel parallel sub-groups
pg_config.add_node(ProcessGroupNode("spatial_parallel"), parent=mp)
pg_config.add_node(ProcessGroupNode("channel_parallel"), parent="model_parallel")

# Now check that the leaf nodes are correct
assert sorted(pg_config.leaf_groups()) == ["channel_parallel", "spatial_parallel"]

# Set leaf group sizes
group_sizes = {"channel_parallel": 3, "spatial_parallel": 2}
pg_config.set_leaf_group_sizes(group_sizes) # Updates all parent group sizes too

assert (
pg_config.get_node("model_parallel").size == 6
), "Incorrect size for parent node"
1 change: 1 addition & 0 deletions test/distributed/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from modulus.distributed import DistributedManager
from modulus.distributed import gather_loss


# TODO: Need to figure out how to test parallel set up
def test_gather_loss():
os.environ["MASTER_ADDR"] = "localhost"
Expand Down