Skip to content

Commit

Permalink
Update distribute_pyg.rst
Browse files Browse the repository at this point in the history
rewrite partitioning part
  • Loading branch information
JakubPietrakIntel authored Jan 18, 2024
1 parent d31e17c commit faa458e
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions docs/source/tutorial/distribute_pyg.rst
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
Distributed Training for PyG
=================================================

In real life applications graphs often consists of billions of nodes that can't be fitted into a single system memory. This is when the distributed training comes in handy. By allocating a number of partitions of the large graph into a cluster of CPUs one can deploy a synchronized model training on the whole database at once, by making use of `PyTorch Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_ training. The architecture seamlessly distributes graph neural network training across multiple nodes by integrating Remote Procedure Call (RPC) for efficient sampling and retrieval of non-local features into standard DDP for model training. This distributed training implementation doesn't require any additonal packages to be installed on top of a default PyG stack. In the future the solution will also be available for Intel's GPUs.
In real life applications graphs often consists of billions of nodes that can't be fitted into a single system memory. This is when the distributed training comes in handy. By allocating a number of partitions of the large graph into a cluster of CPUs one can deploy a synchronized model training on the whole database at once, by making use of `PyTorch Distributed Data Parallel (DDP) <https://pytorch.org/docs/stable/notes/ddp.html>`_ training. The architecture seamlessly distributes graph neural network training across multiple nodes by integrating `Remote Procedure Call (RPC) <https://pytorch.org/docs/stable/rpc.html>`_ for efficient sampling and retrieval of non-local features into standard DDP for model training. This distributed training implementation doesn't require any additonal packages to be installed on top of a default PyG stack. In the future the solution will also be available for Intel's GPUs.

Key Advantages
--------------------------------------
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. (TODO: add links)
#. Balanced graph partitioning with METIS for large graph databases, using ``Partitoner``
#. Utilizing DDP for model training in conjunction with RPC for remote sampling and feature calls, with TCP and the 'gloo' backend specifically tailored for CPU-based sampling, enhances the efficiency and scalability of the training process.
Expand All @@ -21,17 +21,45 @@ Key Advantages
The purpose of this manual is to guide you through the most important steps of deploying your distributed training application. For the code examples, please refer to:

* `partition_graph.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/partition_graph.py>`_ for graph partitioning
* `partition_graph.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/distributed_cpu.py>`_ for end-to-end GNN (GraphSAGE) model training with homogenous or heterogenous data
* `distributed_cpu.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/distributed_cpu.py>`_ for end-to-end GNN (GraphSAGE) model training with homogenous or heterogenous data


1. Graph Partitioning
--------------------------------------

The first step for distributed training is to partition the graph into multiples which can be used by multiple nodes.
The first step for distributed training is to split the graph into multiple smaller partitions, which can then be loaded into nodes of the cluster. This is a pre-processing step that can be done once as the resulting dataset ``.pt`` files can be reused. The ``Partitoner`` build on top of ``ClusterData``, uses pyg-lib implementation of METIS `pyg_lib.partition <https://pyg-lib.readthedocs.io/en/latest/modules/partition.html>`_ algorithm to perform graph partitioning in an efficient way, even on very large graphs. By default METIS always tries to balance the number of nodes of each type in each partition and minimize the amount of edges between the partitions. This guarantees that the partition provides accessibility to all neighboring local vertices, enabling samplers to perform local computations without the need for inter-communication. Through this partitioning approach, every edge receives a distinct assignment, although certain vertices may be replicated. The vertices shared between partitions are so called "halo nodes".
Please note that METIS requires undirected, homogenous graph as input, but ``Partitioner`` performs necessary processing steps to parition heterogenous data objects with correct distribution and indexing.

.. DGL metis figure goes here
Provided example script `partition_graph.py <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/partition_graph.py>`_ demonstrates the partitioning for homogenous ``ogbn-products``,``Reddit``, and heterogenous:``ogbn-mag``, ``Movielens`` datasets.
The ``Partitioner`` can also process temporal attributes of the nodes which is presented in the `Movielens`` dataset partitioning.
** Important note: **
As result of METIS is non-deterministic, the resulting partitions differ between iterations. To perform training, make sure that each node has an access to the same data partition. Use a shared drive or remote storage, i.e. a docker volume or manually copy the dataset to each node of the cluster!

The result of partitioning, for a two-part split of homogenous ``ogbn-products`` is as follows:

#. ogbn-products-labels:
* label.pt: target node/edge labels
#. ogbn-products-partitions:
* edge_map.pt: mapping (partition book) between edge_id and partition_id
* node_map.pt: mapping (partition book) between node_id and partition_id
* META.json: {"num_parts": 2, "is_hetero": false, "node_types": null, "edge_types": null, "is_sorted": true}
* part0: partition 0
* graph.pt: graph topo
* node_feats.pt: node features
* edge_feats.pt: edge features (if present)
* part1: partition 1
* ...
#. ogbn-products-train-partitions:
* partion0.pt: training node indices for partition0
* partion1.pt: training node indices for partition1
#. ogbn-products-test-partitions:
* partion0.pt: test node indices for partition0
* partion0.pt: test node indices for partition1

There are two partition examples in latest pyg from `here <https://github.com/pyg-team/pytorch_geometric/edit/master/examples/distributed/pyg>`__ for homo/hetero partition cases. Here we will use the ``ogbn-products`` as homo dataset/``ogbn-mags`` as hetero dataset to demonstrate how to partition it into two parts for distributed training.
The complete script for partitioning ``ogbn-products`` dataset/``ogbn-mags`` for hetero dataset can be found `here <https://github.com/pyg-team/pytorch_geometric/edit/master/examples/distributed/pyg/partition_graph.py>`__ for homo partition and `here <https://github.com/pyg-team/pytorch_geometric/blob/master/examples/distributed/pyg/partition_hetero_graph.py>`__ for hetero partition.

In distributed training, each node in the cluster holds a partition of the graph. Before the training starts, we will need partition the graph dataset into multiple partitions, each of which corresponds to a specific training node.

1.1 Partitioning the graph
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down

0 comments on commit faa458e

Please sign in to comment.