diff --git a/CHANGELOG.md b/CHANGELOG.md index 2bea8d940096..d7dd7402066d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Improvements to multi-node `ogbn-papers100m` default hyperparameters and adding evaluation on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823)) - Changed distributed sampler and loader tests to correctly report failures in subprocesses to `pytest` ([#8978](https://github.com/pyg-team/pytorch_geometric/pull/8978)) - Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021)) - Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009)) diff --git a/examples/multi_gpu/papers100m_gcn_multinode.py b/examples/multi_gpu/papers100m_gcn_multinode.py index f827700ac73c..af434b4d2ef7 100644 --- a/examples/multi_gpu/papers100m_gcn_multinode.py +++ b/examples/multi_gpu/papers100m_gcn_multinode.py @@ -1,6 +1,6 @@ """Multi-node multi-GPU example on ogbn-papers100m. -To run: +Example way to run using srun: srun -l -N --ntasks-per-node= \ --container-name=cont --container-image= \ --container-mounts=/ogb-papers100m/:/workspace/dataset @@ -8,15 +8,17 @@ """ import os import time +from typing import Optional import torch import torch.distributed as dist import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel +from torchmetrics import Accuracy from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import GCNConv +from torch_geometric.nn import GCN def get_num_workers() -> int: @@ -31,21 +33,7 @@ def get_num_workers() -> int: return num_workers -class GCN(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels): - super().__init__() - self.conv1 = GCNConv(in_channels, hidden_channels) - self.conv2 = GCNConv(hidden_channels, out_channels) - - def forward(self, x, edge_index): - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv1(x, edge_index).relu() - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv2(x, edge_index) - return x - - -def run(world_size, data, split_idx, model): +def run(world_size, data, split_idx, model, acc, wall_clock_start): local_id = int(os.environ['LOCAL_RANK']) rank = torch.distributed.get_rank() torch.cuda.set_device(local_id) @@ -54,38 +42,48 @@ def run(world_size, data, split_idx, model): print(f'Using {nprocs} GPUs...') split_idx['train'] = split_idx['train'].split( - split_idx['train'].size(0) // world_size, - dim=0, - )[rank].clone() + split_idx['train'].size(0) // world_size, dim=0)[rank].clone() + split_idx['valid'] = split_idx['valid'].split( + split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() + split_idx['test'] = split_idx['test'].split( + split_idx['test'].size(0) // world_size, dim=0)[rank].clone() model = DistributedDataParallel(model.to(device), device_ids=[local_id]) - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + optimizer = torch.optim.Adam(model.parameters(), lr=0.001, + weight_decay=5e-4) kwargs = dict( data=data, - batch_size=128, + batch_size=1024, num_workers=get_num_workers(), - num_neighbors=[50, 50], + num_neighbors=[30, 30], ) train_loader = NeighborLoader( input_nodes=split_idx['train'], shuffle=True, + drop_last=True, **kwargs, ) - if rank == 0: - val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) - test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) + val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) + test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) val_steps = 1000 warmup_steps = 100 + acc = acc.to(device) + dist.barrier() + torch.cuda.synchronize() if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time)=", prep_time, + "seconds") print("Beginning training...") - for epoch in range(1, 4): + for epoch in range(1, 21): model.train() for i, batch in enumerate(train_loader): if i == warmup_steps: + torch.cuda.synchronize() start = time.time() batch = batch.to(device) optimizer.zero_grad() @@ -98,53 +96,56 @@ def run(world_size, data, split_idx, model): if rank == 0 and i % 10 == 0: print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') + dist.barrier() + torch.cuda.synchronize() if rank == 0: - sec_per_iter = (time.time() - start) / (i - warmup_steps) + sec_per_iter = (time.time() - start) / (i + 1 - warmup_steps) print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter") + @torch.no_grad() + def test(loader: NeighborLoader, num_steps: Optional[int] = None): model.eval() - total_correct = total_examples = 0 - for i, batch in enumerate(val_loader): - if i >= val_steps: + for j, batch in enumerate(loader): + if num_steps is not None and j >= num_steps: break - if i == warmup_steps: - start = time.time() - batch = batch.to(device) - with torch.no_grad(): - out = model(batch.x, batch.edge_index)[:batch.batch_size] - pred = out.argmax(dim=-1) + out = model(batch.x, batch.edge_index)[:batch.batch_size] y = batch.y[:batch.batch_size].view(-1).to(torch.long) + acc(out, y) + acc_sum = acc.compute() + return acc_sum - total_correct += int((pred == y).sum()) - total_examples += y.size(0) + eval_acc = test(val_loader, num_steps=val_steps) + if rank == 0: + print(f"Val Accuracy: {eval_acc:.4f}%", ) - print(f"Val Acc: {total_correct / total_examples:.4f}") - sec_per_iter = (time.time() - start) / (i - warmup_steps) - print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter") + acc.reset() + dist.barrier() + test_acc = test(test_loader) if rank == 0: - model.eval() - total_correct = total_examples = 0 - for i, batch in enumerate(test_loader): - batch = batch.to(device) - with torch.no_grad(): - out = model(batch.x, batch.edge_index)[:batch.batch_size] - pred = out.argmax(dim=-1) - y = batch.y[:batch.batch_size].view(-1).to(torch.long) + print(f"Test Accuracy: {test_acc:.4f}%", ) - total_correct += int((pred == y).sum()) - total_examples += y.size(0) - print(f"Test Acc: {total_correct / total_examples:.4f}") + dist.barrier() + acc.reset() + torch.cuda.synchronize() + + if rank == 0: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") if __name__ == '__main__': + wall_clock_start = time.perf_counter() # Setup multi-node: torch.distributed.init_process_group("nccl") nprocs = dist.get_world_size() assert dist.is_initialized(), "Distributed cluster not initialized" dataset = PygNodePropPredDataset(name='ogbn-papers100M') split_idx = dataset.get_idx_split() - model = GCN(dataset.num_features, 64, dataset.num_classes) - - run(nprocs, dataset[0], split_idx, model) + model = GCN(dataset.num_features, 256, 2, dataset.num_classes) + acc = Accuracy(task="multiclass", num_classes=dataset.num_classes) + data = dataset[0] + data.y = data.y.reshape(-1) + run(nprocs, data, split_idx, model, acc, wall_clock_start)