Skip to content

Commit

Permalink
Add multi GPU training benchmarks (#8288)
Browse files Browse the repository at this point in the history
Benchmarks currently cover only homogeneous cases. All tested
configurations show improvement in loss over epochs.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
DamianSzwichtenberg and pre-commit-ci[bot] authored Nov 7, 2023
1 parent a2b77e6 commit c2137ad
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added a multi GPU training benchmarks for XPU device ([#8288](https://github.com/pyg-team/pytorch_geometric/pull/8288))
- Support MRR computation in `KGEModel.test()` ([#8298](https://github.com/pyg-team/pytorch_geometric/pull/8298))
- Added an example for model parallelism (`examples/multi_gpu/model_parallel.py`) ([#8309](https://github.com/pyg-team/pytorch_geometric/pull/8309))
- Added a tutorial for multi-node multi-GPU training with pure PyTorch ([#8071](https://github.com/pyg-team/pytorch_geometric/pull/8071))
Expand Down
16 changes: 16 additions & 0 deletions benchmark/multi_gpu/training/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Training Benchmark

## Environment setup

Optional, XPU only:
```
install intel_extension_for_pytorch
install oneccl_bindings_for_pytorch
```

## Running benchmark

Run benchmark, e.g. assuming you have 2 GPUs:
```
mpirun -np 2 python training_benchmark.py --dataset ogbn-products --model edge_cnn --num-epochs 3
```
302 changes: 302 additions & 0 deletions benchmark/multi_gpu/training/training_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,302 @@
import argparse
import ast
import os
from time import perf_counter
from typing import Any, Tuple, Union

import intel_extension_for_pytorch as ipex
import oneccl_bindings_for_pytorch # noqa
import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP

from benchmark.utils import get_dataset, get_model, get_split_masks, test
from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import PNAConv

supported_sets = {
# TODO (DamianSzwichtenberg): cover heterogeneous cases
# 'ogbn-mag': ['rgat', 'rgcn'],
'ogbn-products': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],
'Reddit': ['edge_cnn', 'gat', 'gcn', 'pna', 'sage'],
}

device_conditions = {
'xpu': (lambda: torch.xpu.is_available()),
}


def train_homo(model: Any, loader: NeighborLoader, optimizer: torch.optim.Adam,
device: torch.device) -> torch.Tensor:
for batch in loader:
optimizer.zero_grad()
batch = batch.to(device)
out = model(batch.x, batch.edge_index)
batch_size = batch.batch_size
out = out[:batch_size]
target = batch.y[:batch_size]
loss = F.cross_entropy(out, target)
loss.backward()
optimizer.step()

return loss


def train_hetero(model: Any, loader: NeighborLoader,
optimizer: torch.optim.Adam,
device: torch.device) -> torch.Tensor:
for batch in loader:
optimizer.zero_grad()
batch = batch.to(device)
out = model(batch.x_dict, batch.edge_index_dict)
batch_size = batch['paper'].batch_size
out = out['paper'][:batch_size]
target = batch['paper'].y[:batch_size]
loss = F.cross_entropy(out, target)
loss.backward()
optimizer.step()

return loss


def maybe_synchronize(device: str):
if device == 'xpu' and torch.xpu.is_available():
torch.xpu.synchronize()


def create_mask_per_rank(
global_mask: Union[torch.Tensor,
Tuple[str,
torch.Tensor]], rank: int, world_size: int,
hetero: bool = False) -> Union[torch.Tensor, Tuple[str, torch.Tensor]]:
mask = global_mask[-1] if hetero else global_mask
nonzero = mask.nonzero().reshape(-1)
rank_indices = nonzero.split(nonzero.size(0) // world_size,
dim=0)[rank].clone()
mask_per_rank = torch.full_like(mask, False)
mask_per_rank[rank_indices] = True

if hetero:
return tuple((global_mask[0], mask_per_rank))
else:
return mask_per_rank


def run(rank: int, world_size: int, args: argparse.ArgumentParser):
if not device_conditions[args.device]():
raise RuntimeError(f'{args.device.upper()} is not available')

device = torch.device(f'{args.device}:{rank}')

if rank == 0:
print('BENCHMARK STARTS')
print(f'Running on {args.device.upper()}')

assert args.dataset in supported_sets.keys(
), f"Dataset {args.dataset} isn't supported."
if rank == 0:
print(f'Dataset: {args.dataset}')

data, num_classes = get_dataset(args.dataset, args.root)
hetero = True if args.dataset == 'ogbn-mag' else False
mask, val_mask, test_mask = get_split_masks(data, args.dataset)
mask = create_mask_per_rank(mask, rank, world_size, hetero)
degree = None

inputs_channels = data[
'paper'].num_features if args.dataset == 'ogbn-mag' \
else data.num_features

if args.model not in supported_sets[args.dataset]:
err_msg = (f'Configuration of {args.dataset} + {args.model}'
'not supported')
raise RuntimeError(err_msg)
if rank == 0:
print(f'Training bench for {args.model}:')

num_nodes = int(mask[-1].sum()) if hetero else int(mask.sum())
num_neighbors = args.num_neighbors

if type(num_neighbors) is list:
if len(num_neighbors) == 1:
num_neighbors = num_neighbors * args.num_layers
elif type(num_neighbors) is int:
num_neighbors = [num_neighbors] * args.num_layers

if len(num_neighbors) != args.num_layers:
err_msg = (f'num_neighbors={num_neighbors} lenght != num of'
'layers={args.num_layers}')

kwargs = {
'num_neighbors': num_neighbors,
'batch_size': args.batch_size,
'num_workers': args.num_workers,
}
subgraph_loader = NeighborLoader(
data,
input_nodes=mask,
sampler=None,
**kwargs,
)
if rank == 0 and args.evaluate:
val_loader = NeighborLoader(
data,
input_nodes=val_mask,
sampler=None,
**kwargs,
)
test_loader = NeighborLoader(
data,
input_nodes=test_mask,
sampler=None,
**kwargs,
)

if rank == 0:
print('----------------------------------------------')
print(
f'Batch size={args.batch_size}, '
f'Layers amount={args.num_layers}, '
f'Num_neighbors={num_neighbors}, '
f'Hidden features size={args.num_hidden_channels}', flush=True)

params = {
'inputs_channels': inputs_channels,
'hidden_channels': args.num_hidden_channels,
'output_channels': num_classes,
'num_heads': args.num_heads,
'num_layers': args.num_layers,
}

if args.model == 'pna' and degree is None:
degree = PNAConv.get_degree_histogram(subgraph_loader)
print(f'Rank: {rank}, calculated degree for {args.dataset}.',
flush=True)
params['degree'] = degree
dist.barrier()

torch.manual_seed(12345)
model = get_model(args.model, params,
metadata=data.metadata() if hetero else None)
model = model.to(device)
if hetero:
# TODO (DamianSzwichtenberg):
# Provide fix for:
# RuntimeError: Modules with uninitialized parameters can't be used
# with `DistributedDataParallel`. Run a dummy forward pass to correctly
# initialize the modules.
pass
model = DDP(model, device_ids=[device])
model.train()

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

if args.device == 'xpu':
model, optimizer = ipex.optimize(model, optimizer=optimizer)

train = train_hetero if hetero else train_homo

maybe_synchronize(args.device)
dist.barrier()
if rank == 0:
beg = perf_counter()

for epoch in range(args.num_epochs):
loss = train(
model,
subgraph_loader,
optimizer,
device,
)

dist.barrier()

if rank == 0:
print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}', flush=True)

if rank == 0 and args.evaluate:
# In evaluate, throughput and
# latency are not accurate.
val_acc = test(model, val_loader, device, hetero,
progress_bar=False)
print(f'Val Accuracy: {val_acc:.4f}')

dist.barrier()

maybe_synchronize(args.device)
dist.barrier()
if rank == 0:
end = perf_counter()
duration = end - beg

if rank == 0 and args.evaluate:
test_acc = test(model, test_loader, device, hetero, progress_bar=False)
print(f'Test Accuracy: {test_acc:.4f}')

dist.barrier()

if rank == 0:
num_nodes_total = num_nodes * world_size
duration_per_epoch = duration / args.num_epochs
throughput = num_nodes_total / duration_per_epoch
latency = duration_per_epoch / num_nodes_total * 1000
print(f'Time: {duration_per_epoch:.4f}s')
print(f'Throughput: {throughput:.3f} samples/s')
print(f'Latency: {latency:.3f} ms', flush=True)

dist.destroy_process_group()


def get_dist_params() -> Tuple[int, int, str]:
master_addr = "127.0.0.1"
master_port = "29500"
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port

mpi_rank = int(os.environ.get("PMI_RANK", -1))
mpi_world_size = int(os.environ.get("PMI_SIZE", -1))
rank = mpi_rank if mpi_world_size > 0 else os.environ.get("RANK", 0)
world_size = (mpi_world_size if mpi_world_size > 0 else os.environ.get(
"WORLD_SIZE", 1))

os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)

init_method = f"tcp://{master_addr}:{master_port}"

return rank, world_size, init_method


if __name__ == '__main__':
rank, world_size, init_method = get_dist_params()
dist.init_process_group(backend="ccl", init_method=init_method,
world_size=world_size, rank=rank)

argparser = argparse.ArgumentParser(
'GNN distributed (DDP) training benchmark')
add = argparser.add_argument

add('--device', choices=['xpu'], default='xpu',
help='Device to run benchmark on')
add('--dataset', choices=['ogbn-mag', 'ogbn-products', 'Reddit'],
default='Reddit', type=str)
add('--model',
choices=['edge_cnn', 'gat', 'gcn', 'pna', 'rgat', 'rgcn',
'sage'], default='sage', type=str)
add('--root', default='../../data', type=str,
help='relative path to look for the datasets')
add('--batch-size', default=4096, type=int)
add('--num-layers', default=3, type=int)
add('--num-hidden-channels', default=128, type=int)
add('--num-heads', default=2, type=int,
help='number of hidden attention heads, applies only for gat and rgat')
add('--num-neighbors', default=[10], type=ast.literal_eval,
help='number of neighbors to sample per layer')
add('--num-workers', default=0, type=int)
add('--num-epochs', default=1, type=int)
add('--evaluate', action='store_true')

args = argparser.parse_args()

run(rank, world_size, args)

0 comments on commit c2137ad

Please sign in to comment.