Skip to content

Commit

Permalink
torch.compile benchmark for HeteroConv; Allow device conversions …
Browse files Browse the repository at this point in the history
…of datasets (#8402)

```
+----------+-----------+------------+---------+
| Name     | Forward   | Backward   | Total   |
|----------+-----------+------------+---------|
| Vanilla  | 2.0331s   | 0.5175s    | 2.5505s |
| Compiled | 0.6899s   | 0.6946s    | 1.3844s |
+----------+-----------+------------+---------+
```
  • Loading branch information
rusty1s authored Nov 19, 2023
1 parent 5afd075 commit ea2ab70
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for device conversions of `InMemoryDataset` ([#8402] (https://github.com/pyg-team/pytorch_geometric/pull/8402))
- Added support for edge-level temporal sampling in `NeighborLoader` and `LinkNeighborLoader` ([#8372] (https://github.com/pyg-team/pytorch_geometric/pull/8372))
- Added support for `torch.compile` in `ModuleDict` and `ParameterDict` ([#8363](https://github.com/pyg-team/pytorch_geometric/pull/8363))
- Added `force_reload` option to `Dataset` and `InMemoryDataset` to reload datasets ([#8352](https://github.com/pyg-team/pytorch_geometric/pull/8352), [#8357](https://github.com/pyg-team/pytorch_geometric/pull/8357))
Expand Down
65 changes: 65 additions & 0 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import random

import pytest
import torch

import torch_geometric
from torch_geometric.data import HeteroData
from torch_geometric.datasets import FakeHeteroDataset
from torch_geometric.nn import (
GATConv,
GCN2Conv,
Expand All @@ -12,6 +15,7 @@
MessagePassing,
SAGEConv,
)
from torch_geometric.profile import benchmark
from torch_geometric.testing import (
disableExtensions,
get_random_edge_index,
Expand Down Expand Up @@ -205,3 +209,64 @@ def test_compile_hetero_conv_graph_breaks(device):
assert len(out) == len(expected)
for key in expected.keys():
assert torch.allclose(out[key], expected[key], atol=1e-6)


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
args = parser.parse_args()

dataset = FakeHeteroDataset(num_graphs=10).to(args.device)

def gen_args():
data = dataset[random.randrange(len(dataset))]
return data.x_dict, data.edge_index_dict

class HeteroGNN(torch.nn.Module):
def __init__(self, channels: int = 32, num_layers: int = 2):
super().__init__()
self.convs = torch.nn.ModuleList()

conv = HeteroConv({
edge_type:
SAGEConv(
in_channels=(
dataset.num_features[edge_type[0]],
dataset.num_features[edge_type[-1]],
),
out_channels=channels,
)
for edge_type in dataset[0].edge_types
})
self.convs.append(conv)

for _ in range(num_layers - 1):
conv = HeteroConv({
edge_type:
SAGEConv((channels, channels), channels)
for edge_type in dataset[0].edge_types
})
self.convs.append(conv)

self.lin = Linear(channels, 1)

def forward(self, x_dict, edge_index_dict):
for conv in self.convs:
x_dict = conv(x_dict, edge_index_dict)
x_dict = {key: x.relu() for key, x in x_dict.items()}
return self.lin(x_dict['v0'])

model = HeteroGNN().to(args.device)
compiled_model = torch_geometric.compile(model)

benchmark(
funcs=[model, compiled_model],
func_names=['Vanilla', 'Compiled'],
args=gen_args,
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
)
25 changes: 25 additions & 0 deletions torch_geometric/data/in_memory_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,31 @@ def __getattr__(self, key: str) -> Any:
raise AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{key}'")

def to(self, device: Union[int, str]) -> 'InMemoryDataset':
r"""Performs device conversion of the whole dataset."""
if self._indices is not None:
raise ValueError("The given 'InMemoryDataset' only references a "
"subset of examples of the full dataset")
if self._data_list is not None:
raise ValueError("The data of the dataset is already cached")
self._data.to(device)
return self

def cpu(self, *args: str) -> 'InMemoryDataset':
r"""Moves the dataset to CPU memory."""
return self.to(torch.device('cpu'))

def cuda(
self,
device: Optional[Union[int, str]] = None,
) -> 'InMemoryDataset':
r"""Moves the dataset toto CUDA memory."""
if isinstance(device, int):
device = f'cuda:{int}'
elif device is None:
device = 'cuda'
return self.to(device)


def nested_iter(node: Union[Mapping, Sequence]) -> Iterable:
if isinstance(node, Mapping):
Expand Down

0 comments on commit ea2ab70

Please sign in to comment.