Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to multinode papers100m default hyperparams and adding eval on all ranks #8823

Merged
merged 61 commits into from
Mar 12, 2024
Merged
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
650542a
Improvements to multinode papers100m default hyperparams and adding e…
puririshi98 Jan 25, 2024
98ba40e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
deec969
Update CHANGELOG.md
puririshi98 Jan 25, 2024
ab0fad4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2024
729a524
cleanup
puririshi98 Jan 26, 2024
a152835
fixing
puririshi98 Jan 26, 2024
281de73
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 26, 2024
acce4de
graphsage
puririshi98 Jan 26, 2024
07c5781
back to GCN
puririshi98 Jan 26, 2024
2d37cc7
specify download location
puririshi98 Jan 26, 2024
d2c851b
better hyperparams
puririshi98 Jan 29, 2024
f39c0cc
Merge branch 'master' into improve-multinode-example
puririshi98 Jan 29, 2024
f167890
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 29, 2024
7bc1ecc
adding cuda sync
puririshi98 Jan 29, 2024
87cd881
cuda sync
puririshi98 Jan 29, 2024
0ffdccd
new hyperparams
puririshi98 Jan 30, 2024
60b6db4
cuda syncs for timing
puririshi98 Jan 31, 2024
f1894a5
better timing
puririshi98 Jan 31, 2024
2e60980
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2024
b23f8a6
clean up
puririshi98 Feb 5, 2024
42dcac0
cleaning
puririshi98 Feb 5, 2024
907af56
better timing
puririshi98 Feb 5, 2024
de555d5
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 6, 2024
db9c123
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2024
9e776eb
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 7, 2024
d98352c
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 8, 2024
2868327
fix
puririshi98 Feb 15, 2024
9c90504
fix for eval
puririshi98 Feb 16, 2024
69c942a
fixing copypaste from SNMG
puririshi98 Feb 16, 2024
4650c91
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 16, 2024
e5bbd30
final cleanup, its running well now
puririshi98 Feb 16, 2024
82cc227
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 20, 2024
c1ea4d6
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 21, 2024
10939f3
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 21, 2024
8aec67c
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 23, 2024
303338d
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 26, 2024
7b329f3
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 27, 2024
469295c
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 28, 2024
054eb5c
Merge branch 'master' into improve-multinode-example
puririshi98 Feb 29, 2024
983208d
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 1, 2024
1228066
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 1, 2024
b42d29a
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 4, 2024
e9fdad1
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 5, 2024
23179c3
using acc.compute to align with mag240m and single node papers100m ex…
puririshi98 Mar 7, 2024
a36248f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
e0b2301
Update papers100m_gcn_multinode.py
puririshi98 Mar 7, 2024
b7fe928
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
7c38068
cleaning
puririshi98 Mar 7, 2024
8eaa4b9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
55b4f71
Update papers100m_gcn_multinode.py
puririshi98 Mar 7, 2024
3e1d879
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 7, 2024
a994b62
Update papers100m_gcn_multinode.py
puririshi98 Mar 7, 2024
3a91d03
Update papers100m_gcn_multinode.py
puririshi98 Mar 7, 2024
8ebf762
Update papers100m_gcn_multinode.py
puririshi98 Mar 8, 2024
e10a6ee
Update papers100m_gcn_multinode.py
puririshi98 Mar 8, 2024
3b43632
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 8, 2024
ad4497c
Merge branch 'master' into improve-multinode-example
puririshi98 Mar 11, 2024
8723521
update
rusty1s Mar 12, 2024
eff9ef6
update
rusty1s Mar 12, 2024
ede89ac
update
rusty1s Mar 12, 2024
e518abc
Merge branch 'master' into improve-multinode-example
rusty1s Mar 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Improvements to multinode papers100m default hyperparams and adding eval on all ranks ([#8823](https://github.com/pyg-team/pytorch_geometric/pull/8823))
- Added support for `EdgeIndex` in `MessagePassing` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added support for `torch.compile` in combination with `EdgeIndex` ([#9007](https://github.com/pyg-team/pytorch_geometric/pull/9007))
- Added a `ogbn-mag240m` example ([#8249](https://github.com/pyg-team/pytorch_geometric/pull/8249))
Expand Down
150 changes: 84 additions & 66 deletions examples/multi_gpu/papers100m_gcn_multinode.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Multi-node multi-GPU example on ogbn-papers100m.

To run:
Example way to run using srun:
srun -l -N<num_nodes> --ntasks-per-node=<ngpu_per_node> \
--container-name=cont --container-image=<image_url> \
--container-mounts=/ogb-papers100m/:/workspace/dataset
Expand All @@ -14,9 +14,10 @@
import torch.nn.functional as F
from ogb.nodeproppred import PygNodePropPredDataset
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import Accuracy

from torch_geometric.loader import NeighborLoader
from torch_geometric.nn import GCNConv
from torch_geometric.nn.models import GCN


def get_num_workers() -> int:
Expand All @@ -31,21 +32,7 @@ def get_num_workers() -> int:
return num_workers


class GCN(torch.nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels):
super().__init__()
self.conv1 = GCNConv(in_channels, hidden_channels)
self.conv2 = GCNConv(hidden_channels, out_channels)

def forward(self, x, edge_index):
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv1(x, edge_index).relu()
x = F.dropout(x, p=0.5, training=self.training)
x = self.conv2(x, edge_index)
return x


def run(world_size, data, split_idx, model):
def run(world_size, data, split_idx, model, acc, wall_clock_start):
local_id = int(os.environ['LOCAL_RANK'])
rank = torch.distributed.get_rank()
torch.cuda.set_device(local_id)
Expand All @@ -54,97 +41,128 @@ def run(world_size, data, split_idx, model):
print(f'Using {nprocs} GPUs...')

split_idx['train'] = split_idx['train'].split(
split_idx['train'].size(0) // world_size,
dim=0,
)[rank].clone()
split_idx['train'].size(0) // world_size, dim=0)[rank].clone()
split_idx['valid'] = split_idx['valid'].split(
split_idx['valid'].size(0) // world_size, dim=0)[rank].clone()
split_idx['test'] = split_idx['test'].split(
split_idx['test'].size(0) // world_size, dim=0)[rank].clone()

model = DistributedDataParallel(model.to(device), device_ids=[local_id])
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001,
weight_decay=5e-4)

kwargs = dict(
data=data,
batch_size=128,
batch_size=1024,
num_workers=get_num_workers(),
num_neighbors=[50, 50],
num_neighbors=[30, 30],
)

train_loader = NeighborLoader(
input_nodes=split_idx['train'],
shuffle=True,
**kwargs,
)
if rank == 0:
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)
val_loader = NeighborLoader(input_nodes=split_idx['valid'], **kwargs)
test_loader = NeighborLoader(input_nodes=split_idx['test'], **kwargs)

val_steps = 1000
warmup_steps = 100
acc = acc.to(device)
dist.barrier()
torch.cuda.synchronize()
if rank == 0:
prep_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total time before training begins (prep_time)=", prep_time,
"seconds")
print("Beginning training...")

for epoch in range(1, 4):
for epoch in range(1, 21):
model.train()
for i, batch in enumerate(train_loader):
if i == warmup_steps:
torch.cuda.synchronize()
start = time.time()
batch = batch.to(device)
batch_size = batch.batch_size
optimizer.zero_grad()
y = batch.y[:batch.batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch.batch_size]
y = batch.y[:batch_size].view(-1).to(torch.long)
out = model(batch.x, batch.edge_index)[:batch_size]
loss = F.cross_entropy(out, y)
loss.backward()
optimizer.step()

if rank == 0 and i % 10 == 0:
print(f'Epoch: {epoch:02d}, Iteration: {i}, Loss: {loss:.4f}')

dist.barrier()
torch.cuda.synchronize()
num_batches = i + 1.0
if rank == 0:
sec_per_iter = (time.time() - start) / (i - warmup_steps)
sec_per_iter = (time.time() - start) / (num_batches - warmup_steps)
print(f"Avg Training Iteration Time: {sec_per_iter:.6f} s/iter")

model.eval()
total_correct = total_examples = 0
for i, batch in enumerate(val_loader):
if i >= val_steps:
break
if i == warmup_steps:
start = time.time()

batch = batch.to(device)
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch.batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
total_examples += y.size(0)

print(f"Val Acc: {total_correct / total_examples:.4f}")
sec_per_iter = (time.time() - start) / (i - warmup_steps)
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")

if rank == 0:
model.eval()
total_correct = total_examples = 0
for i, batch in enumerate(test_loader):
acc_sum = 0.0
for i, batch in enumerate(val_loader):
if i >= val_steps:
break
if i == warmup_steps:
torch.cuda.synchronize()
start = time.time()

batch = batch.to(device)
batch_size = batch.batch_size
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch.batch_size]
pred = out.argmax(dim=-1)
y = batch.y[:batch.batch_size].view(-1).to(torch.long)

total_correct += int((pred == y).sum())
total_examples += y.size(0)
print(f"Test Acc: {total_correct / total_examples:.4f}")
out = model(batch.x, batch.edge_index)[:batch_size]
acc_sum += acc(out[:batch_size].softmax(dim=-1),
batch.y[:batch_size])
puririshi98 marked this conversation as resolved.
Show resolved Hide resolved
acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32,
device=device)
dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
num_batches = torch.tensor(float(i + 1), dtype=torch.float32,
puririshi98 marked this conversation as resolved.
Show resolved Hide resolved
device=acc_sum.device)
dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
torch.cuda.synchronize()
if rank == 0:
print(
f"Validation Accuracy: {acc_sum/(num_batches) * 100.0:.4f}%", )
sec_per_iter = (time.time() - start) / (num_batches - warmup_steps)
print(f"Avg Inference Iteration Time: {sec_per_iter:.6f} s/iter")
dist.barrier()

model.eval()
acc_sum = 0.0
for i, batch in enumerate(test_loader):
puririshi98 marked this conversation as resolved.
Show resolved Hide resolved
batch = batch.to(device)
batch_size = batch.batch_size
with torch.no_grad():
out = model(batch.x, batch.edge_index)[:batch_size]
acc_sum += acc(out[:batch_size].softmax(dim=-1), batch.y[:batch_size])
acc_sum = torch.tensor(float(acc_sum), dtype=torch.float32, device=device)
dist.all_reduce(acc_sum, op=dist.ReduceOp.SUM)
num_batches = torch.tensor(float(i + 1), dtype=torch.float32,
device=acc_sum.device)
dist.all_reduce(num_batches, op=dist.ReduceOp.SUM)
if rank == 0:
print(f"Test Accuracy: {acc_sum/(num_batches) * 100.0:.4f}%", )
dist.barrier()
if rank == 0:
total_time = round(time.perf_counter() - wall_clock_start, 2)
print("Total Program Runtime (total_time) =", total_time, "seconds")
print("total_time - prep_time =", total_time - prep_time, "seconds")


if __name__ == '__main__':
wall_clock_start = time.perf_counter()
# Setup multi-node:
torch.distributed.init_process_group("nccl")
nprocs = dist.get_world_size()
assert dist.is_initialized(), "Distributed cluster not initialized"
dataset = PygNodePropPredDataset(name='ogbn-papers100M')
dataset = PygNodePropPredDataset(name='ogbn-papers100M',
root='/datasets/ogb_datasets')
split_idx = dataset.get_idx_split()
model = GCN(dataset.num_features, 64, dataset.num_classes)

run(nprocs, dataset[0], split_idx, model)
model = GCN(dataset.num_features, 256, 2, dataset.num_classes)
acc = Accuracy(task="multiclass", num_classes=dataset.num_classes)
data = dataset[0]
data.y = data.y.reshape(-1)
run(nprocs, data, split_idx, model, acc, wall_clock_start)