Skip to content

Commit

Permalink
Improvements for Papers100m single gpu and single node multi gpu exam…
Browse files Browse the repository at this point in the history
…ples (Cugraph, GATConv, better default hyperparams, eval on all ranks) (#8173)

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
3 people authored Mar 27, 2024
1 parent 08eb6b9 commit 870179f
Show file tree
Hide file tree
Showing 7 changed files with 721 additions and 131 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173))
- Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075))
- Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073))
- Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029))
Expand Down
1 change: 1 addition & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see
- [`ogbn_products_sage.py`](./ogbn_products_sage.py) and [`ogbn_products_gat.py`](./ogbn_products_gat.py) show how to train [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) and [`GAT`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAT.html) models on the `ogbn-products` dataset.
- [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset.
- [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges.
- [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph).

For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile).

Expand Down
3 changes: 2 additions & 1 deletion examples/multi_gpu/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
| [`distributed_sampling.py`](./distributed_sampling.py) | single-node | Example for training GNNs on a homogeneous graph with neighbor sampling. |
| [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph with neighbor sampling on multiple nodes. |
| [`distributed_sampling_multinode.sbatch`](./distributed_sampling_multinode.sbatch) | multi-node | Example for submitting a training job to a Slurm cluster using [`distributed_sampling_multi_node.py`](./distributed_sampling_multinode.py). |
| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on a homogeneous graph. |
| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on the `ogbn-papers100M` homogeneous graph w/ ~1.6B edges. |
| [`papers100m_gcn_cugraph.py`](./papers100m_gcn_cugraph.py%60) | single-node | Example for accelerating GNN training on `ogbn-papers100M` using [CuGraph](...). |
| [`papers100m_gcn_multinode.py`](./papers100m_gcn_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph on multiple nodes. |
| [`mag240m_graphsage.py`](./mag240m_graphsage.py) | single-node | Example for training GNNs on a large heterogeneous graph. |
| [`taobao.py`](./taobao.py) | single-node | Example for training link prediction GNNs on a heterogeneous graph. |
Expand Down
239 changes: 147 additions & 92 deletions examples/multi_gpu/papers100m_gcn.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import argparse
import os
import tempfile
import time

import torch
Expand All @@ -7,136 +9,189 @@
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy

import torch_geometric
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv


def get_num_workers(world_size: int) -> int:
num_workers = None
def get_num_workers(world_size):
num_work = None
if hasattr(os, "sched_getaffinity"):
try:
num_workers = len(os.sched_getaffinity(0)) // (2 * world_size)
num_work = len(os.sched_getaffinity(0)) / (2 * world_size)
except Exception:
pass
if num_workers is None:
num_workers = os.cpu_count() // (2 * world_size)
return num_workers
if num_work is None:
num_work = os.cpu_count() / (2 * world_size)
return int(num_work)


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)
def run_train(rank, data, world_size, model, epochs, batch_size, fan_out,
split_idx, num_classes, wall_clock_start, tempdir=None,
num_layers=3):

def forward(self, x, edge_index=None):
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x


def run(rank, world_size, data, split_idx, model):
# init pytorch worker
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
dist.init_process_group('nccl', rank=rank, world_size=world_size)

split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size,
dim=0,
)[rank].clone()

model = DistributedDataParallel(model.to(rank), device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
if world_size > 1:
split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
split_idx['valid'] = split_idx['valid'].split(
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
split_idx['test'] = split_idx['test'].split(
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()
model = model.to(rank)
model = DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01,
weight_decay=0.0005)

kwargs = dict(
data=data,
batch_size=128,
num_workers=get_num_workers(world_size),
num_neighbors=[50, 50],
num_neighbors=[fan_out] * num_layers,
batch_size=batch_size,
)
train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
**kwargs,
)
if rank == 0:
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

val_steps = 1000
warmup_steps = 100
num_work = get_num_workers(world_size)
train_loader = NeighborLoader(data, input_nodes=split_idx['train'],
num_workers=num_work, shuffle=True,
drop_last=True, **kwargs)
val_loader = NeighborLoader(data, input_nodes=split_idx['valid'],
num_workers=num_work, **kwargs)
test_loader = NeighborLoader(data, input_nodes=split_idx['test'],
num_workers=num_work, **kwargs)

eval_steps = 1000
warmup_steps = 20
acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time) =", prep_time,
"seconds")
print("Beginning training...")

for epoch in range(1, 4):
model.train()
for epoch in range(epochs):
for i, batch in enumerate(train_loader):
if i == warmup_steps:
torch.cuda.synchronize()
start = time.time()
batch = batch.to(rank)
batch_size = batch.num_sampled_nodes[0]
batch.y = batch.y.to(torch.long)
optimizer.zero_grad()
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
loss = F.cross_entropy(out, y)
out = model(batch.x, batch.edge_index)
loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size])
loss.backward()
optimizer.step()

if rank == 0 and i % 10 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')

print("Epoch: " + str(epoch) + ", Iteration: " + str(i) +
", Loss: " + str(loss))
nb = i + 1.0
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
sec_per_iter = (time.time() - start) / (i - warmup_steps)
print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter")

model.eval()
total_correct = total_examples = 0
print("Average Training Iteration Time:",
(time.time() - start) / (nb - warmup_steps), "s/iter")
with torch.no_grad():
for i, batch in enumerate(val_loader):
if i >= val_steps:
if i >= eval_steps:
break
if i == warmup_steps:
start = time.time()

batch = batch.to(rank)
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch.batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
total_examples += y.size(0)

print(f"Val Acc: {total_correct / total_examples:.4f}")
sec_per_iter = (time.time() - start) / (i - warmup_steps)
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")

if rank == 0:
model.eval()
total_correct = total_examples = 0
batch_size = batch.num_sampled_nodes[0]

batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)
acc_i = acc( # noqa
out[:batch_size].softmax(dim=-1), batch.y[:batch_size])
acc_sum = acc.compute()
if rank == 0:
print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", )
dist.barrier()
acc.reset()

with torch.no_grad():
for i, batch in enumerate(test_loader):
batch = batch.to(rank)
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
batch_size = batch.num_sampled_nodes[0]

total_correct += int((pred == y).sum())
total_examples += y.size(0)
print(f"Test Acc: {total_correct / total_examples:.4f}")
batch.y = batch.y.to(torch.long)
out = model(batch.x, batch.edge_index)
acc_i = acc( # noqa
out[:batch_size].softmax(dim=-1), batch.y[:batch_size])
acc_sum = acc.compute()
if rank == 0:
print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", )
dist.barrier()
acc.reset()
if rank == 0:
total_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total Program Runtime (total_time) =", total_time, "seconds")
print("total_time - prep_time =", total_time - prep_time, "seconds")


if __name__ == '__main__':
dataset = PygNodePropPredDataset(name='ogbn-papers100M')
split_idx = dataset.get_idx_split()
model = GCN(dataset.num_features, 64, dataset.num_classes)

world_size = torch.cuda.device_count()
print('Let\'s use', world_size, 'GPUs!')
mp.spawn(
run,
args=(world_size, dataset[0], split_idx, model),
nprocs=world_size,
join=True,
parser = argparse.ArgumentParser()
parser.add_argument('--hidden_channels', type=int, default=256)
parser.add_argument('--num_layers', type=int, default=2)
parser.add_argument('--lr', type=float, default=0.001)
parser.add_argument('--epochs', type=int, default=20)
parser.add_argument('--batch_size', type=int, default=1024)
parser.add_argument('--fan_out', type=int, default=30)
parser.add_argument(
"--use_gat_conv",
action='store_true',
help="Whether or not to use GATConv. (Defaults to using GCNConv)",
)
parser.add_argument(
"--n_gat_conv_heads",
type=int,
default=4,
help="If using GATConv, number of attention heads to use",
)
parser.add_argument(
"--n_devices", type=int, default=-1,
help="1-8 to use that many GPUs. Defaults to all available GPUs")

args = parser.parse_args()
wall_clock_start = time.perf_counter()

dataset = PygNodePropPredDataset(name='ogbn-papers100M',
root='/datasets/ogb_datasets')
split_idx = dataset.get_idx_split()
data = dataset[0]
data.y = data.y.reshape(-1)
if args.use_gat_conv:
model = torch_geometric.nn.models.GAT(dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
heads=args.n_gat_conv_heads)
else:
model = torch_geometric.nn.models.GCN(
dataset.num_features,
args.hidden_channels,
args.num_layers,
dataset.num_classes,
)

print("Data =", data)
if args.n_devices == -1:
world_size = torch.cuda.device_count()
else:
world_size = args.n_devices
print('Let\'s use', world_size, 'GPUs!')
with tempfile.TemporaryDirectory() as tempdir:
if world_size > 1:
mp.spawn(
run_train,
args=(data, world_size, model, args.epochs, args.batch_size,
args.fan_out, split_idx, dataset.num_classes,
wall_clock_start, tempdir, args.num_layers),
nprocs=world_size, join=True)
else:
run_train(0, data, world_size, model, args.epochs, args.batch_size,
args.fan_out, split_idx, dataset.num_classes,
wall_clock_start, tempdir, args.num_layers)
Loading

0 comments on commit 870179f

Please sign in to comment.