Skip to content

Commit

Permalink
Add DistNeighborSampler (#7974)
Browse files Browse the repository at this point in the history
This code belongs to the part of the whole distributed training for PyG.

`DistNeighborSampler` leverages the `NeighborSampler` class from
`pytorch_geometric` and the `neighbor_sample` function from `pyg-lib`.
However, due to the fact that in case of distributed training it is
required to synchronise the results between machines after each layer,
the part of the code responsible for sampling was implemented in python.

Added suport for the following sampling methods:
- node, edge, negative, disjoint, temporal

**TODOs:**

- [x] finish hetero part
- [x] subgraph sampling

**This PR should be merged together with other distributed PRs:**
pyg-lib: [#246](pyg-team/pyg-lib#246),
[#252](pyg-team/pyg-lib#252)
GraphStore\FeatureStore:
#8083
DistLoaders:
1.  #8079
2.  #8080
3.  #8085

---------

Co-authored-by: JakubPietrakIntel <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: ZhengHongming888 <[email protected]>
Co-authored-by: Jakub Pietrak <[email protected]>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
6 people authored Oct 9, 2023
1 parent 89f4873 commit f71ead8
Show file tree
Hide file tree
Showing 6 changed files with 900 additions and 45 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a CPU-based and GPU-based `map_index` implementation ([#7493](https://github.com/pyg-team/pytorch_geometric/pull/7493), [#7764](https://github.com/pyg-team/pytorch_geometric/pull/7764) [#7765](https://github.com/pyg-team/pytorch_geometric/pull/7765))
- Added the `AmazonBook` heterogeneous dataset ([#7483](https://github.com/pyg-team/pytorch_geometric/pull/7483))
- Added hierarchical heterogeneous GraphSAGE example on OGB-MAG ([#7425](https://github.com/pyg-team/pytorch_geometric/pull/7425))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715))
- Added the `torch_geometric.distributed` package ([#7451](https://github.com/pyg-team/pytorch_geometric/pull/7451), [#7452](https://github.com/pyg-team/pytorch_geometric/pull/7452)), [#7482](https://github.com/pyg-team/pytorch_geometric/pull/7482), [#7502](https://github.com/pyg-team/pytorch_geometric/pull/7502), [#7628](https://github.com/pyg-team/pytorch_geometric/pull/7628), [#7671](https://github.com/pyg-team/pytorch_geometric/pull/7671), [#7846](https://github.com/pyg-team/pytorch_geometric/pull/7846), [#7715](https://github.com/pyg-team/pytorch_geometric/pull/7715), [#7974](https://github.com/pyg-team/pytorch_geometric/pull/7974))
- Added the `GDELTLite` dataset ([#7442](https://github.com/pyg-team/pytorch_geometric/pull/7442))
- Added the `approx_knn` function for approximated nearest neighbor search ([#7421](https://github.com/pyg-team/pytorch_geometric/pull/7421))
- Added the `IGMCDataset` ([#7441](https://github.com/pyg-team/pytorch_geometric/pull/7441))
Expand Down
35 changes: 35 additions & 0 deletions test/distributed/test_dist_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import torch

from torch_geometric.distributed.utils import remove_duplicates
from torch_geometric.sampler import SamplerOutput


def test_remove_duplicates():
node = torch.tensor([0, 1, 2, 3])
out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 7, 3, 8])

out = SamplerOutput(out_node, None, None, None)

src, node, _, _ = remove_duplicates(out, node)

assert src.tolist() == [4, 5, 6, 7, 8]
assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]


def test_remove_duplicates_disjoint():
node = torch.tensor([0, 1, 2, 3])
batch = torch.tensor([0, 1, 2, 3])

out_node = torch.tensor([0, 4, 1, 5, 1, 6, 2, 6, 7, 3, 8])
out_batch = torch.tensor([0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3])

out = SamplerOutput(out_node, None, None, None, out_batch)

src, node, src_batch, batch = remove_duplicates(out, node, batch,
disjoint=True)

assert src.tolist() == [4, 5, 6, 7, 8]
assert node.tolist() == [0, 1, 2, 3, 4, 5, 6, 7, 8]

assert src_batch.tolist() == [0, 1, 2, 3, 3]
assert batch.tolist() == [0, 1, 2, 3, 0, 1, 2, 3, 3]
Loading

0 comments on commit f71ead8

Please sign in to comment.