From 870179fd80b816e1a3c93d9e6a5bbef874f76426 Mon Sep 17 00:00:00 2001 From: Rishi Puri Date: Wed, 27 Mar 2024 08:46:58 -0700 Subject: [PATCH] Improvements for Papers100m single gpu and single node multi gpu examples (Cugraph, GATConv, better default hyperparams, eval on all ranks) (#8173) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Akihiro Nitta --- CHANGELOG.md | 1 + examples/README.md | 1 + examples/multi_gpu/README.md | 3 +- examples/multi_gpu/papers100m_gcn.py | 239 ++++++++----- examples/multi_gpu/papers100m_gcn_cugraph.py | 333 +++++++++++++++++++ examples/ogbn_papers_100m.py | 112 ++++--- examples/ogbn_papers_100m_cugraph.py | 163 +++++++++ 7 files changed, 721 insertions(+), 131 deletions(-) create mode 100644 examples/multi_gpu/papers100m_gcn_cugraph.py create mode 100644 examples/ogbn_papers_100m_cugraph.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 523118102a7b..badfc3d8a8b0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for cuGraph data loading and `GAT` in single node Papers100m examples ([#8173](https://github.com/pyg-team/pytorch_geometric/pull/8173)) - Added the `VariancePreservingAggregation` (VPA) ([#9075](https://github.com/pyg-team/pytorch_geometric/pull/9075)) - Added option to pass custom` from_smiles` functionality to `PCQM4Mv2` and `MoleculeNet` ([#9073](https://github.com/pyg-team/pytorch_geometric/pull/9073)) - Added `group_cat` functionality ([#9029](https://github.com/pyg-team/pytorch_geometric/pull/9029)) diff --git a/examples/README.md b/examples/README.md index 25fe7a7cb66d..336b3d816a82 100644 --- a/examples/README.md +++ b/examples/README.md @@ -12,6 +12,7 @@ For examples on [Open Graph Benchmark](https://ogb.stanford.edu/) datasets, see - [`ogbn_products_sage.py`](./ogbn_products_sage.py) and [`ogbn_products_gat.py`](./ogbn_products_gat.py) show how to train [`GraphSAGE`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GraphSAGE.html) and [`GAT`](https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.models.GAT.html) models on the `ogbn-products` dataset. - [`ogbn_proteins_deepgcn.py`](./ogbn_proteins_deepgcn.py) is an example to showcase how to train deep GNNs on the `ogbn-proteins` dataset. - [`ogbn_papers_100m.py`](./ogbn_papers_100m.py) is an example for training a GNN on the large-scale `ogbn-papers100m` dataset, containing approximately ~1.6B edges. +- [`ogbn_papers_100m_cugraph.py`](./ogbn_papers_100m_cugraph.py) shows how to accelerate the `ogbn-papers100m` workflow using [CuGraph](https://github.com/rapidsai/cugraph). For examples on using `torch.compile`, see the examples under [`examples/compile`](./compile). diff --git a/examples/multi_gpu/README.md b/examples/multi_gpu/README.md index ac55bdd12acd..91defd65875d 100644 --- a/examples/multi_gpu/README.md +++ b/examples/multi_gpu/README.md @@ -8,7 +8,8 @@ | [`distributed_sampling.py`](./distributed_sampling.py) | single-node | Example for training GNNs on a homogeneous graph with neighbor sampling. | | [`distributed_sampling_multinode.py`](./distributed_sampling_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph with neighbor sampling on multiple nodes. | | [`distributed_sampling_multinode.sbatch`](./distributed_sampling_multinode.sbatch) | multi-node | Example for submitting a training job to a Slurm cluster using [`distributed_sampling_multi_node.py`](./distributed_sampling_multinode.py). | -| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on a homogeneous graph. | +| [`papers100m_gcn.py`](./papers100m_gcn.py) | single-node | Example for training GNNs on the `ogbn-papers100M` homogeneous graph w/ ~1.6B edges. | +| [`papers100m_gcn_cugraph.py`](./papers100m_gcn_cugraph.py%60) | single-node | Example for accelerating GNN training on `ogbn-papers100M` using [CuGraph](...). | | [`papers100m_gcn_multinode.py`](./papers100m_gcn_multinode.py) | multi-node | Example for training GNNs on a homogeneous graph on multiple nodes. | | [`mag240m_graphsage.py`](./mag240m_graphsage.py) | single-node | Example for training GNNs on a large heterogeneous graph. | | [`taobao.py`](./taobao.py) | single-node | Example for training link prediction GNNs on a heterogeneous graph. | diff --git a/examples/multi_gpu/papers100m_gcn.py b/examples/multi_gpu/papers100m_gcn.py index 89233ae1344a..e7f6fe0dcf4b 100644 --- a/examples/multi_gpu/papers100m_gcn.py +++ b/examples/multi_gpu/papers100m_gcn.py @@ -1,4 +1,6 @@ +import argparse import os +import tempfile import time import torch @@ -7,136 +9,189 @@ import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset from torch.nn.parallel import DistributedDataParallel +from torchmetrics import Accuracy +import torch_geometric from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import GCNConv -def get_num_workers(world_size: int) -> int: - num_workers = None +def get_num_workers(world_size): + num_work = None if hasattr(os, "sched_getaffinity"): try: - num_workers = len(os.sched_getaffinity(0)) // (2 * world_size) + num_work = len(os.sched_getaffinity(0)) / (2 * world_size) except Exception: pass - if num_workers is None: - num_workers = os.cpu_count() // (2 * world_size) - return num_workers + if num_work is None: + num_work = os.cpu_count() / (2 * world_size) + return int(num_work) -class GCN(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels): - super().__init__() - self.conv1 = GCNConv(in_channels, hidden_channels) - self.conv2 = GCNConv(hidden_channels, out_channels) +def run_train(rank, data, world_size, model, epochs, batch_size, fan_out, + split_idx, num_classes, wall_clock_start, tempdir=None, + num_layers=3): - def forward(self, x, edge_index=None): - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv1(x, edge_index).relu() - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv2(x, edge_index) - return x - - -def run(rank, world_size, data, split_idx, model): + # init pytorch worker os.environ['MASTER_ADDR'] = 'localhost' os.environ['MASTER_PORT'] = '12355' dist.init_process_group('nccl', rank=rank, world_size=world_size) - split_idx['train'] = split_idx['train'].split( - split_idx['train'].size(0) // world_size, - dim=0, - )[rank].clone() - - model = DistributedDataParallel(model.to(rank), device_ids=[rank]) - optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + if world_size > 1: + split_idx['train'] = split_idx['train'].split( + split_idx['train'].size(0) // world_size, dim=0)[rank].clone() + split_idx['valid'] = split_idx['valid'].split( + split_idx['valid'].size(0) // world_size, dim=0)[rank].clone() + split_idx['test'] = split_idx['test'].split( + split_idx['test'].size(0) // world_size, dim=0)[rank].clone() + model = model.to(rank) + model = DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, + weight_decay=0.0005) kwargs = dict( - data=data, - batch_size=128, - num_workers=get_num_workers(world_size), - num_neighbors=[50, 50], + num_neighbors=[fan_out] * num_layers, + batch_size=batch_size, ) - train_loader = NeighborLoader( - input_nodes=split_idx['train'], - shuffle=True, - **kwargs, - ) - if rank == 0: - val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) - test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) - - val_steps = 1000 - warmup_steps = 100 + num_work = get_num_workers(world_size) + train_loader = NeighborLoader(data, input_nodes=split_idx['train'], + num_workers=num_work, shuffle=True, + drop_last=True, **kwargs) + val_loader = NeighborLoader(data, input_nodes=split_idx['valid'], + num_workers=num_work, **kwargs) + test_loader = NeighborLoader(data, input_nodes=split_idx['test'], + num_workers=num_work, **kwargs) + + eval_steps = 1000 + warmup_steps = 20 + acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank) + dist.barrier() + torch.cuda.synchronize() if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time) =", prep_time, + "seconds") print("Beginning training...") - - for epoch in range(1, 4): - model.train() + for epoch in range(epochs): for i, batch in enumerate(train_loader): if i == warmup_steps: + torch.cuda.synchronize() start = time.time() batch = batch.to(rank) + batch_size = batch.num_sampled_nodes[0] + batch.y = batch.y.to(torch.long) optimizer.zero_grad() - y = batch.y[:batch.batch_size].view(-1).to(torch.long) - out = model(batch.x, batch.edge_index)[:batch.batch_size] - loss = F.cross_entropy(out, y) + out = model(batch.x, batch.edge_index) + loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) loss.backward() optimizer.step() - if rank == 0 and i % 10 == 0: - print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') - + print("Epoch: " + str(epoch) + ", Iteration: " + str(i) + + ", Loss: " + str(loss)) + nb = i + 1.0 + dist.barrier() + torch.cuda.synchronize() if rank == 0: - sec_per_iter = (time.time() - start) / (i - warmup_steps) - print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter") - - model.eval() - total_correct = total_examples = 0 + print("Average Training Iteration Time:", + (time.time() - start) / (nb - warmup_steps), "s/iter") + with torch.no_grad(): for i, batch in enumerate(val_loader): - if i >= val_steps: + if i >= eval_steps: break - if i == warmup_steps: - start = time.time() batch = batch.to(rank) - with torch.no_grad(): - out = model(batch.x, batch.edge_index)[:batch.batch_size] - pred = out.argmax(dim=-1) - y = batch.y[:batch.batch_size].view(-1).to(torch.long) - - total_correct += int((pred == y).sum()) - total_examples += y.size(0) - - print(f"Val Acc: {total_correct / total_examples:.4f}") - sec_per_iter = (time.time() - start) / (i - warmup_steps) - print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter") - - if rank == 0: - model.eval() - total_correct = total_examples = 0 + batch_size = batch.num_sampled_nodes[0] + + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index) + acc_i = acc( # noqa + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + acc_sum = acc.compute() + if rank == 0: + print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", ) + dist.barrier() + acc.reset() + + with torch.no_grad(): for i, batch in enumerate(test_loader): batch = batch.to(rank) - with torch.no_grad(): - out = model(batch.x, batch.edge_index)[:batch.batch_size] - pred = out.argmax(dim=-1) - y = batch.y[:batch.batch_size].view(-1).to(torch.long) + batch_size = batch.num_sampled_nodes[0] - total_correct += int((pred == y).sum()) - total_examples += y.size(0) - print(f"Test Acc: {total_correct / total_examples:.4f}") + batch.y = batch.y.to(torch.long) + out = model(batch.x, batch.edge_index) + acc_i = acc( # noqa + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + acc_sum = acc.compute() + if rank == 0: + print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", ) + dist.barrier() + acc.reset() + if rank == 0: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") if __name__ == '__main__': - dataset = PygNodePropPredDataset(name='ogbn-papers100M') - split_idx = dataset.get_idx_split() - model = GCN(dataset.num_features, 64, dataset.num_classes) - world_size = torch.cuda.device_count() - print('Let\'s use', world_size, 'GPUs!') - mp.spawn( - run, - args=(world_size, dataset[0], split_idx, model), - nprocs=world_size, - join=True, + parser = argparse.ArgumentParser() + parser.add_argument('--hidden_channels', type=int, default=256) + parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--epochs', type=int, default=20) + parser.add_argument('--batch_size', type=int, default=1024) + parser.add_argument('--fan_out', type=int, default=30) + parser.add_argument( + "--use_gat_conv", + action='store_true', + help="Whether or not to use GATConv. (Defaults to using GCNConv)", ) + parser.add_argument( + "--n_gat_conv_heads", + type=int, + default=4, + help="If using GATConv, number of attention heads to use", + ) + parser.add_argument( + "--n_devices", type=int, default=-1, + help="1-8 to use that many GPUs. Defaults to all available GPUs") + + args = parser.parse_args() + wall_clock_start = time.perf_counter() + + dataset = PygNodePropPredDataset(name='ogbn-papers100M', + root='/datasets/ogb_datasets') + split_idx = dataset.get_idx_split() + data = dataset[0] + data.y = data.y.reshape(-1) + if args.use_gat_conv: + model = torch_geometric.nn.models.GAT(dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + heads=args.n_gat_conv_heads) + else: + model = torch_geometric.nn.models.GCN( + dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + ) + + print("Data =", data) + if args.n_devices == -1: + world_size = torch.cuda.device_count() + else: + world_size = args.n_devices + print('Let\'s use', world_size, 'GPUs!') + with tempfile.TemporaryDirectory() as tempdir: + if world_size > 1: + mp.spawn( + run_train, + args=(data, world_size, model, args.epochs, args.batch_size, + args.fan_out, split_idx, dataset.num_classes, + wall_clock_start, tempdir, args.num_layers), + nprocs=world_size, join=True) + else: + run_train(0, data, world_size, model, args.epochs, args.batch_size, + args.fan_out, split_idx, dataset.num_classes, + wall_clock_start, tempdir, args.num_layers) diff --git a/examples/multi_gpu/papers100m_gcn_cugraph.py b/examples/multi_gpu/papers100m_gcn_cugraph.py new file mode 100644 index 000000000000..0473da463683 --- /dev/null +++ b/examples/multi_gpu/papers100m_gcn_cugraph.py @@ -0,0 +1,333 @@ +import argparse +import os +import tempfile +import time + +import numpy as np +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn.functional as F +from ogb.nodeproppred import PygNodePropPredDataset +from torch.distributed.algorithms.join import Join +from torch.nn.parallel import DistributedDataParallel +from torchmetrics import Accuracy + +import torch_geometric + +# Allow computation on objects that are larger than GPU memory +# https://docs.rapids.ai/api/cudf/stable/developer_guide/library_design/#spilling-to-host-memory +os.environ['CUDF_SPILL'] = '1' + +# Ensures that a CUDA context is not created on import of rapids. +# Allows pytorch to create the context instead +os.environ['RAPIDS_NO_INITIALIZE'] = '1' + + +def start_dask_cluster(): + from cugraph.testing.mg_utils import enable_spilling + from dask_cuda import LocalCUDACluster + + cluster = LocalCUDACluster( + protocol="tcp", + rmm_pool_size=None, + memory_limit=None, + rmm_async=True, + ) + + from dask.distributed import Client + client = Client(cluster) + client.wait_for_workers(n_workers=len(cluster.workers)) + client.run(enable_spilling) + + print("Dask Cluster Setup Complete") + return client, cluster + + +def shutdown_dask_client(client): + from cugraph.dask.comms import comms as Comms + Comms.destroy() + client.close() + + +def pyg_num_work(world_size): + num_work = None + if hasattr(os, "sched_getaffinity"): + try: + num_work = len(os.sched_getaffinity(0)) / (2 * world_size) + except Exception: + pass + if num_work is None: + num_work = os.cpu_count() / (2 * world_size) + return int(num_work) + + +def init_pytorch_worker(rank, world_size): + import rmm + if rank > 0: + rmm.reinitialize(devices=rank) + + import cupy + cupy.cuda.Device(rank).use() + from rmm.allocators.cupy import rmm_cupy_allocator + cupy.cuda.set_allocator(rmm_cupy_allocator) + + from cugraph.testing.mg_utils import enable_spilling + enable_spilling() + + torch.cuda.set_device(rank) + + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + dist.init_process_group('nccl', rank=rank, world_size=world_size) + + +def run_train(rank, data, world_size, model, epochs, batch_size, fan_out, + split_idx, num_classes, wall_clock_start, tempdir=None, + num_layers=3): + + init_pytorch_worker( + rank, + world_size, + ) + + if rank == 0: + client, cluster = start_dask_cluster() + from cugraph.dask.comms import comms as Comms + Comms.initialize(p2p=True) + model = model.to(rank) + model = DistributedDataParallel(model, device_ids=[rank]) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01, + weight_decay=0.0005) + + kwargs = dict( + num_neighbors=[fan_out] * num_layers, + batch_size=batch_size, + ) + # Set Up Neighbor Loading + import cugraph + from cugraph_pyg.data import CuGraphStore + from cugraph_pyg.loader import BulkSampleLoader + + # define the edges of the Graph + G = {("N", "E", "N"): data.edge_index} + # define the number of nodes in Graph + N = {"N": data.num_nodes} + # initialize feature store + fs = cugraph.gnn.FeatureStore(backend="torch") + # store node features as x + fs.add_data(data.x, "N", "x") + # store node labels as y + fs.add_data(data.y, "N", "y") + dist.barrier() + + if rank == 0: + print("Rank 0 creating its cugraph store and \ + initializing distributed graph") + cugraph_store = CuGraphStore(fs, G, N, multi_gpu=True) + print("Distributed graph initialization complete.") + + if rank != 0: + print(f"Rank {rank} waiting for distributed graph initialization") + dist.barrier() + + if rank != 0: + print(f"Rank {rank} proceeding with store creation") + cugraph_store = CuGraphStore(fs, { + k: len(v) + for k, v in G.items() + }, N, multi_gpu=False) + print(f"Rank {rank} created store") + dist.barrier() + + if rank == 0: + # Direct cuGraph to sample offline prior to the training loop + # Sampling will occur in parallel but will be initiated on rank 0 + for epoch in range(epochs): + train_path = os.path.join(tempdir, f'samples_{epoch}') + os.mkdir(train_path) + BulkSampleLoader(cugraph_store, cugraph_store, + input_nodes=split_idx['train'], + directory=train_path, shuffle=True, + drop_last=True, **kwargs) + + print('validation', len(split_idx['valid'])) + eval_path = os.path.join(tempdir, f'samples_eval_{epoch}') + BulkSampleLoader(cugraph_store, cugraph_store, + input_nodes=split_idx['valid'], + directory=eval_path, **kwargs) + + print('test', len(split_idx['test'])) + test_path = os.path.join(tempdir, 'samples_test') + BulkSampleLoader(cugraph_store, cugraph_store, + input_nodes=split_idx['test'], directory=test_path, + **kwargs) + + dist.barrier() + + eval_steps = 1000 + warmup_steps = 20 + acc = Accuracy(task="multiclass", num_classes=num_classes).to(rank) + dist.barrier() + torch.cuda.synchronize() + if rank == 0: + prep_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total time before training begins (prep_time) =", prep_time, + "seconds") + print("Beginning training...") + for epoch in range(epochs): + train_path = os.path.join(tempdir, f'samples_{epoch}') + + input_files = np.array_split(np.array(os.listdir(train_path)), + world_size)[rank] + + train_loader = BulkSampleLoader(cugraph_store, cugraph_store, + directory=train_path, + input_files=input_files) + with Join([model], divide_by_initial_world_size=False): + for i, batch in enumerate(train_loader): + if i == warmup_steps: + torch.cuda.synchronize() + start = time.time() + batch = batch.to(rank) + + batch = batch.to_homogeneous() + batch_size = batch.num_sampled_nodes[0] + + batch.y = batch.y.to(torch.long) + optimizer.zero_grad() + out = model(batch.x, batch.edge_index) + loss = F.cross_entropy(out[:batch_size], batch.y[:batch_size]) + loss.backward() + optimizer.step() + if rank == 0 and i % 10 == 0: + print("Epoch: " + str(epoch) + ", Iteration: " + str(i) + + ", Loss: " + str(loss)) + nb = i + 1.0 + dist.barrier() + torch.cuda.synchronize() + if rank == 0: + print("Average Training Iteration Time:", + (time.time() - start) / (nb - warmup_steps), "s/iter") + eval_path = os.path.join(tempdir, f'samples_eval_{epoch}') + + input_files = np.array(os.listdir(eval_path)) + + eval_loader = BulkSampleLoader(cugraph_store, cugraph_store, + directory=eval_path, + input_files=input_files) + with Join([model], divide_by_initial_world_size=False): + with torch.no_grad(): + for i, batch in enumerate(eval_loader): + if i >= eval_steps: + break + + batch = batch.to(rank) + batch = batch.to_homogeneous() + batch_size = batch.num_sampled_nodes[0] + + batch.y = batch.y.to(torch.long) + out = model.module(batch.x, batch.edge_index) + acc_i = acc( # noqa + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + acc_sum = acc.compute() + if rank == 0: + print(f"Validation Accuracy: {acc_sum * 100.0:.4f}%", ) + dist.barrier() + + with Join([model], divide_by_initial_world_size=False): + test_path = os.path.join(tempdir, 'samples_test') + + input_files = np.array(os.listdir(test_path)) + + test_loader = BulkSampleLoader(cugraph_store, cugraph_store, + directory=test_path, + input_files=input_files) + with torch.no_grad(): + for i, batch in enumerate(test_loader): + batch = batch.to(rank) + batch = batch.to_homogeneous() + batch_size = batch.num_sampled_nodes[0] + + batch.y = batch.y.to(torch.long) + out = model.module(batch.x, batch.edge_index) + acc_i = acc( # noqa + out[:batch_size].softmax(dim=-1), batch.y[:batch_size]) + acc_sum = acc.compute() + if rank == 0: + print(f"Test Accuracy: {acc_sum * 100.0:.4f}%", ) + dist.barrier() + + import gc + del cugraph_store + gc.collect() + shutdown_dask_client(client) + dist.barrier() + if rank == 0: + total_time = round(time.perf_counter() - wall_clock_start, 2) + print("Total Program Runtime (total_time) =", total_time, "seconds") + print("total_time - prep_time =", total_time - prep_time, "seconds") + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument('--hidden_channels', type=int, default=256) + parser.add_argument('--num_layers', type=int, default=2) + parser.add_argument('--lr', type=float, default=0.001) + parser.add_argument('--epochs', type=int, default=20) + parser.add_argument('--batch_size', type=int, default=1024) + parser.add_argument('--fan_out', type=int, default=30) + parser.add_argument( + "--use_gat_conv", + action='store_true', + help="Whether or not to use GATConv. (Defaults to using GCNConv)", + ) + parser.add_argument( + "--n_gat_conv_heads", + type=int, + default=4, + help="If using GATConv, number of attention heads to use", + ) + parser.add_argument( + "--n_devices", type=int, default=-1, + help="1-8 to use that many GPUs. Defaults to all available GPUs") + + args = parser.parse_args() + wall_clock_start = time.perf_counter() + + dataset = PygNodePropPredDataset(name='ogbn-papers100M', + root='/datasets/ogb_datasets') + split_idx = dataset.get_idx_split() + data = dataset[0] + data.y = data.y.reshape(-1) + if args.use_gat_conv: + model = torch_geometric.nn.models.GAT(dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + heads=args.n_gat_conv_heads) + else: + model = torch_geometric.nn.models.GCN(dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes) + + print("Data =", data) + if args.n_devices == -1: + world_size = torch.cuda.device_count() + else: + world_size = args.n_devices + print('Let\'s use', world_size, 'GPUs!') + with tempfile.TemporaryDirectory() as tempdir: + if world_size > 1: + mp.spawn( + run_train, + args=(data, world_size, model, args.epochs, args.batch_size, + args.fan_out, split_idx, dataset.num_classes, + wall_clock_start, tempdir, args.num_layers), + nprocs=world_size, join=True) + else: + run_train(0, data, world_size, model, args.epochs, args.batch_size, + args.fan_out, split_idx, dataset.num_classes, + wall_clock_start, tempdir, args.num_layers) diff --git a/examples/ogbn_papers_100m.py b/examples/ogbn_papers_100m.py index 68812d08d99d..56e55119ad49 100644 --- a/examples/ogbn_papers_100m.py +++ b/examples/ogbn_papers_100m.py @@ -1,3 +1,4 @@ +import argparse import os import time from typing import Optional @@ -6,12 +7,33 @@ import torch.nn.functional as F from ogb.nodeproppred import PygNodePropPredDataset +import torch_geometric from torch_geometric.loader import NeighborLoader -from torch_geometric.nn import GCNConv +parser = argparse.ArgumentParser() +parser.add_argument('--hidden_channels', type=int, default=256) +parser.add_argument('--num_layers', type=int, default=2) +parser.add_argument('--lr', type=float, default=0.001) +parser.add_argument('--epochs', type=int, default=20) +parser.add_argument('--batch_size', type=int, default=1024) +parser.add_argument('--fan_out', type=int, default=30) +parser.add_argument( + "--use_gat_conv", + action='store_true', + help="Wether or not to use GATConv. (Defaults to using GCNConv)", +) +parser.add_argument( + "--n_gat_conv_heads", + type=int, + default=4, + help="If using GATConv, number of attention heads to use", +) +args = parser.parse_args() +wall_clock_start = time.perf_counter() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') -dataset = PygNodePropPredDataset(name='ogbn-papers100M') +dataset = PygNodePropPredDataset(name='ogbn-papers100M', + root='/datasets/ogb_datasets') split_idx = dataset.get_idx_split() @@ -23,51 +45,58 @@ def get_num_workers() -> int: kwargs = dict( - data=dataset[0], - num_neighbors=[50, 50], - batch_size=128, - num_workers=get_num_workers(), + num_neighbors=[args.fan_out] * args.num_layers, + batch_size=args.batch_size, ) -train_loader = NeighborLoader(input_nodes=split_idx['train'], shuffle=True, - **kwargs) -val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs) -test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs) - - -class GCN(torch.nn.Module): - def __init__(self, in_channels, hidden_channels, out_channels): - super().__init__() - self.conv1 = GCNConv(in_channels, hidden_channels) - self.conv2 = GCNConv(hidden_channels, out_channels) - - def forward(self, x, edge_index): - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv1(x, edge_index).relu() - x = F.dropout(x, p=0.5, training=self.training) - x = self.conv2(x, edge_index) - return x - - -model = GCN(dataset.num_features, 64, dataset.num_classes).to(device) -optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005) +# Set Up Neighbor Loading +data = dataset[0] +num_work = get_num_workers() +train_loader = NeighborLoader(data=data, input_nodes=split_idx['train'], + num_workers=num_work, drop_last=True, + shuffle=False, **kwargs) +val_loader = NeighborLoader(data=data, input_nodes=split_idx['valid'], + num_workers=num_work, **kwargs) +test_loader = NeighborLoader(data=data, input_nodes=split_idx['test'], + num_workers=num_work, **kwargs) + +if args.use_gat_conv: + model = torch_geometric.nn.models.GAT( + dataset.num_features, args.hidden_channels, args.num_layers, + dataset.num_classes, heads=args.n_gat_conv_heads).to(device) +else: + model = torch_geometric.nn.models.GCN( + dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + ).to(device) + +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, + weight_decay=0.0005) + +warmup_steps = 20 def train(): model.train() - for i, batch in enumerate(train_loader): - start = time.perf_counter() + if i == warmup_steps: + torch.cuda.synchronize() + start_avg_time = time.perf_counter() batch = batch.to(device) optimizer.zero_grad() - out = model(batch.x, batch.edge_index)[:batch.batch_size] - y = batch.y[:batch.batch_size].view(-1).to(torch.long) + batch_size = batch.num_sampled_nodes[0] + out = model(batch.x, batch.edge_index)[:batch_size] + y = batch.y[:batch_size].view(-1).to(torch.long) loss = F.cross_entropy(out, y) loss.backward() optimizer.step() if i % 10 == 0: - print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}, ' - f's/iter: {time.perf_counter() - start:.6f}') + print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') + torch.cuda.synchronize() + print(f'Average Training Iteration Time (s/iter): \ + {(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}') @torch.no_grad() @@ -78,11 +107,11 @@ def test(loader: NeighborLoader, val_steps: Optional[int] = None): for i, batch in enumerate(loader): if val_steps is not None and i >= val_steps: break - batch = batch.to(device) - out = model(batch.x, batch.edge_index)[:batch.batch_size] + batch_size = batch.num_sampled_nodes[0] + out = model(batch.x, batch.edge_index)[:batch_size] pred = out.argmax(dim=-1) - y = batch.y[:batch.batch_size].view(-1).to(torch.long) + y = batch.y[:batch_size].view(-1).to(torch.long) total_correct += int((pred == y).sum()) total_examples += y.size(0) @@ -90,10 +119,17 @@ def test(loader: NeighborLoader, val_steps: Optional[int] = None): return total_correct / total_examples -for epoch in range(1, 4): +torch.cuda.synchronize() +prep_time = round(time.perf_counter() - wall_clock_start, 2) +print("Total time before training begins (prep_time)=", prep_time, "seconds") +print("Beginning training...") +for epoch in range(1, 1 + args.epochs): train() val_acc = test(val_loader, val_steps=100) print(f'Val Acc: ~{val_acc:.4f}') test_acc = test(test_loader) print(f'Test Acc: {test_acc:.4f}') +total_time = round(time.perf_counter() - wall_clock_start, 2) +print("Total Program Runtime (total_time) =", total_time, "seconds") +print("total_time - prep_time =", total_time - prep_time, "seconds") diff --git a/examples/ogbn_papers_100m_cugraph.py b/examples/ogbn_papers_100m_cugraph.py new file mode 100644 index 000000000000..c9df4e07e597 --- /dev/null +++ b/examples/ogbn_papers_100m_cugraph.py @@ -0,0 +1,163 @@ +import argparse +import os +import time +from typing import Optional + +import cupy +import rmm +import torch +from rmm.allocators.cupy import rmm_cupy_allocator +from rmm.allocators.torch import rmm_torch_allocator + +# Must change allocators immediately upon import +# or else other imports will cause memory to be +# allocated and prevent changing the allocator +rmm.reinitialize(devices=[0], pool_allocator=True, managed_memory=True) +cupy.cuda.set_allocator(rmm_cupy_allocator) +torch.cuda.memory.change_current_allocator(rmm_torch_allocator) + +import cugraph # noqa +import torch.nn.functional as F # noqa +from cugraph.testing.mg_utils import enable_spilling # noqa +from cugraph_pyg.data import CuGraphStore # noqa +from cugraph_pyg.loader import CuGraphNeighborLoader # noqa + +import torch_geometric # noqa +from torch_geometric.loader import NeighborLoader # noqa + +parser = argparse.ArgumentParser() +parser.add_argument('--hidden_channels', type=int, default=256) +parser.add_argument('--num_layers', type=int, default=2) +parser.add_argument('--lr', type=float, default=0.001) +parser.add_argument('--epochs', type=int, default=20) +parser.add_argument('--batch_size', type=int, default=1024) +parser.add_argument('--fan_out', type=int, default=30) +parser.add_argument( + "--use_gat_conv", + action='store_true', + help="Wether or not to use GATConv. (Defaults to using GCNConv)", +) +parser.add_argument( + "--n_gat_conv_heads", + type=int, + default=4, + help="If using GATConv, number of attention heads to use", +) +args = parser.parse_args() +wall_clock_start = time.perf_counter() + + +def get_num_workers() -> int: + try: + return len(os.sched_getaffinity(0)) // 2 + except Exception: + return os.cpu_count() // 2 + + +kwargs = dict( + num_neighbors=[args.fan_out] * args.num_layers, + batch_size=args.batch_size, +) +# Set Up Neighbor Loading +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + +enable_spilling() + +from ogb.nodeproppred import PygNodePropPredDataset # noqa + +dataset = PygNodePropPredDataset(name='ogbn-papers100M', + root='/datasets/ogb_datasets') +split_idx = dataset.get_idx_split() +data = dataset[0] + +G = {("N", "E", "N"): data.edge_index} +N = {"N": data.num_nodes} +fs = cugraph.gnn.FeatureStore(backend="torch") +fs.add_data(data.x, "N", "x") +fs.add_data(data.y, "N", "y") +cugraph_store = CuGraphStore(fs, G, N) +train_loader = CuGraphNeighborLoader(cugraph_store, + input_nodes=split_idx['train'], + shuffle=True, drop_last=True, **kwargs) +val_loader = CuGraphNeighborLoader(cugraph_store, + input_nodes=split_idx['valid'], **kwargs) +test_loader = CuGraphNeighborLoader(cugraph_store, + input_nodes=split_idx['test'], **kwargs) + +if args.use_gat_conv: + model = torch_geometric.nn.models.GAT( + dataset.num_features, args.hidden_channels, args.num_layers, + dataset.num_classes, heads=args.n_gat_conv_heads).to(device) +else: + model = torch_geometric.nn.models.GCN( + dataset.num_features, + args.hidden_channels, + args.num_layers, + dataset.num_classes, + ).to(device) + +optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, + weight_decay=0.0005) + +warmup_steps = 20 + + +def train(): + model.train() + for i, batch in enumerate(train_loader): + batch = batch.to_homogeneous() + + if i == warmup_steps: + torch.cuda.synchronize() + start_avg_time = time.perf_counter() + batch = batch.to(device) + optimizer.zero_grad() + batch_size = batch.num_sampled_nodes[0] + out = model(batch.x, batch.edge_index)[:batch_size] + y = batch.y[:batch_size].view(-1).to(torch.long) + loss = F.cross_entropy(out, y) + loss.backward() + optimizer.step() + + if i % 10 == 0: + print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}') + torch.cuda.synchronize() + print(f'Average Training Iteration Time (s/iter): \ + {(time.perf_counter() - start_avg_time)/(i-warmup_steps):.6f}') + + +@torch.no_grad() +def test(loader: NeighborLoader, val_steps: Optional[int] = None): + model.eval() + + total_correct = total_examples = 0 + for i, batch in enumerate(loader): + if val_steps is not None and i >= val_steps: + break + batch = batch.to_homogeneous() + batch = batch.to(device) + batch_size = batch.num_sampled_nodes[0] + out = model(batch.x, batch.edge_index)[:batch_size] + pred = out.argmax(dim=-1) + y = batch.y[:batch_size].view(-1).to(torch.long) + + total_correct += int((pred == y).sum()) + total_examples += y.size(0) + + return total_correct / total_examples + + +torch.cuda.synchronize() +prep_time = round(time.perf_counter() - wall_clock_start, 2) +print("Total time before training begins (prep_time)=", prep_time, "seconds") +print("Beginning training...") +for epoch in range(1, 1 + args.epochs): + train() + val_acc = test(val_loader, val_steps=100) + print(f'Val Acc: ~{val_acc:.4f}') + +test_acc = test(test_loader) +print(f'Test Acc: {test_acc:.4f}') +total_time = round(time.perf_counter() - wall_clock_start, 2) +print("Total Program Runtime (total_time) =", total_time, "seconds") +print("total_time - prep_time =", total_time - prep_time, "seconds")