Skip to content

Commit

Permalink
Extend torch.compile benchmark to dynamic=True; Add `assert_indir…
Browse files Browse the repository at this point in the history
…ect_indexing = False` (#8220)
  • Loading branch information
rusty1s authored Oct 18, 2023
1 parent ee30973 commit 71a7fee
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220))

### Deprecated

### Fixed
Expand Down
22 changes: 18 additions & 4 deletions test/nn/models/test_basic_gnn.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import os.path as osp
import random
import sys
import warnings

Expand Down Expand Up @@ -388,11 +389,24 @@ def test_basic_gnn_cache():
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cuda')
parser.add_argument('--backward', action='store_true')
parser.add_argument('--dynamic', action='store_true')
args = parser.parse_args()

num_nodes, num_edges = 10_000, 200_000
x = torch.randn(num_nodes, 64, device=args.device)
edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device)
if args.dynamic:
min_num_nodes, max_num_nodes = 10_000, 15_000
min_num_edges, max_num_edges = 200_000, 300_000
else:
min_num_nodes, max_num_nodes = 10_000, 10_000
min_num_edges, max_num_edges = 200_000, 200_000

def gen_args():
N = random.randint(min_num_nodes, max_num_nodes)
E = random.randint(min_num_edges, max_num_edges)

x = torch.randn(N, 64, device=args.device)
edge_index = torch.randint(N, (2, E), device=args.device)

return x, edge_index

for Model in [GCN, GraphSAGE, GIN, EdgeCNN]:
print(f'Model: {Model.__name__}')
Expand All @@ -403,7 +417,7 @@ def test_basic_gnn_cache():
benchmark(
funcs=[model, compiled_model],
func_names=['Vanilla', 'Compiled'],
args=(x, edge_index),
args=gen_args,
num_steps=50 if args.device == 'cpu' else 500,
num_warmups=10 if args.device == 'cpu' else 100,
backward=args.backward,
Expand Down
6 changes: 6 additions & 0 deletions torch_geometric/compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def compile(model: Optional[Callable] = None, *args, **kwargs) -> Callable:
jittable instances
(see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`)
3. disables generation of device asserts during fused gather/scatter calls
to avoid performance impacts
.. note::
Without these adjustments, :meth:`torch.compile` may currently fail to
correctly optimize your :pyg:`PyG` model.
Expand Down Expand Up @@ -89,6 +92,9 @@ def fn(model: Callable) -> Callable:
# Replace instances of `MessagePassing` by their jittable version:
model = to_jittable(model)

# Do not generate device asserts which may slow down model execution:
torch._inductor.config.triton.assert_indirect_indexing = False

# Finally, run `torch.compile` to create an optimized version:
out = torch.compile(model, *args, **kwargs)

Expand Down
7 changes: 5 additions & 2 deletions torch_geometric/profile/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def benchmark(
args ((Any, ) or [(Any, )]): The arguments to pass to the functions.
Can be a list of arguments for each function in :obj:`funcs` in
case their headers differ.
Alternatively, you can pass in functions that generate arguments
on-the-fly (e.g., useful for benchmarking models on various sizes).
num_steps (int): The number of steps to run the benchmark.
func_names ([str], optional): The names of the functions. If not given,
will try to infer the name from the function itself.
Expand Down Expand Up @@ -69,17 +71,18 @@ def benchmark(
f"'func_names' (got {len(func_names)}) must be equal")

# Zero-copy `args` for each function (if necessary):
args_list = [args] * len(funcs) if isinstance(args, tuple) else args
args_list = [args] * len(funcs) if not isinstance(args, list) else args

iterator = zip(funcs, args_list, func_names)
if progress_bar:
from tqdm import tqdm
iterator = tqdm(iterator, total=len(funcs))

ts: List[List[str]] = []
for func, args, name in iterator:
for func, inputs, name in iterator:
t_forward = t_backward = 0
for i in range(num_warmups + num_steps):
args = inputs() if callable(inputs) else inputs
args = require_grad(args, backward)

if torch.cuda.is_available():
Expand Down

0 comments on commit 71a7fee

Please sign in to comment.