diff --git a/CHANGELOG.md b/CHANGELOG.md index b0487d98cd84..5e01393d1cab 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Disable device asserts during `torch_geometric.compile` ([#8220](https://github.com/pyg-team/pytorch_geometric/pull/8220)) + ### Deprecated ### Fixed diff --git a/test/nn/models/test_basic_gnn.py b/test/nn/models/test_basic_gnn.py index bc2c8859e940..3262f5ed261d 100644 --- a/test/nn/models/test_basic_gnn.py +++ b/test/nn/models/test_basic_gnn.py @@ -1,5 +1,6 @@ import os import os.path as osp +import random import sys import warnings @@ -388,11 +389,24 @@ def test_basic_gnn_cache(): parser = argparse.ArgumentParser() parser.add_argument('--device', type=str, default='cuda') parser.add_argument('--backward', action='store_true') + parser.add_argument('--dynamic', action='store_true') args = parser.parse_args() - num_nodes, num_edges = 10_000, 200_000 - x = torch.randn(num_nodes, 64, device=args.device) - edge_index = torch.randint(num_nodes, (2, num_edges), device=args.device) + if args.dynamic: + min_num_nodes, max_num_nodes = 10_000, 15_000 + min_num_edges, max_num_edges = 200_000, 300_000 + else: + min_num_nodes, max_num_nodes = 10_000, 10_000 + min_num_edges, max_num_edges = 200_000, 200_000 + + def gen_args(): + N = random.randint(min_num_nodes, max_num_nodes) + E = random.randint(min_num_edges, max_num_edges) + + x = torch.randn(N, 64, device=args.device) + edge_index = torch.randint(N, (2, E), device=args.device) + + return x, edge_index for Model in [GCN, GraphSAGE, GIN, EdgeCNN]: print(f'Model: {Model.__name__}') @@ -403,7 +417,7 @@ def test_basic_gnn_cache(): benchmark( funcs=[model, compiled_model], func_names=['Vanilla', 'Compiled'], - args=(x, edge_index), + args=gen_args, num_steps=50 if args.device == 'cpu' else 500, num_warmups=10 if args.device == 'cpu' else 100, backward=args.backward, diff --git a/torch_geometric/compile.py b/torch_geometric/compile.py index 6fcfc1966a79..3aad89993aa5 100644 --- a/torch_geometric/compile.py +++ b/torch_geometric/compile.py @@ -51,6 +51,9 @@ def compile(model: Optional[Callable] = None, *args, **kwargs) -> Callable: jittable instances (see :meth:`torch_geometric.nn.conv.MessagePassing.jittable`) + 3. disables generation of device asserts during fused gather/scatter calls + to avoid performance impacts + .. note:: Without these adjustments, :meth:`torch.compile` may currently fail to correctly optimize your :pyg:`PyG` model. @@ -89,6 +92,9 @@ def fn(model: Callable) -> Callable: # Replace instances of `MessagePassing` by their jittable version: model = to_jittable(model) + # Do not generate device asserts which may slow down model execution: + torch._inductor.config.triton.assert_indirect_indexing = False + # Finally, run `torch.compile` to create an optimized version: out = torch.compile(model, *args, **kwargs) diff --git a/torch_geometric/profile/benchmark.py b/torch_geometric/profile/benchmark.py index d1dfe09b904c..1c5e7a229881 100644 --- a/torch_geometric/profile/benchmark.py +++ b/torch_geometric/profile/benchmark.py @@ -38,6 +38,8 @@ def benchmark( args ((Any, ) or [(Any, )]): The arguments to pass to the functions. Can be a list of arguments for each function in :obj:`funcs` in case their headers differ. + Alternatively, you can pass in functions that generate arguments + on-the-fly (e.g., useful for benchmarking models on various sizes). num_steps (int): The number of steps to run the benchmark. func_names ([str], optional): The names of the functions. If not given, will try to infer the name from the function itself. @@ -69,7 +71,7 @@ def benchmark( f"'func_names' (got {len(func_names)}) must be equal") # Zero-copy `args` for each function (if necessary): - args_list = [args] * len(funcs) if isinstance(args, tuple) else args + args_list = [args] * len(funcs) if not isinstance(args, list) else args iterator = zip(funcs, args_list, func_names) if progress_bar: @@ -77,9 +79,10 @@ def benchmark( iterator = tqdm(iterator, total=len(funcs)) ts: List[List[str]] = [] - for func, args, name in iterator: + for func, inputs, name in iterator: t_forward = t_backward = 0 for i in range(num_warmups + num_steps): + args = inputs() if callable(inputs) else inputs args = require_grad(args, backward) if torch.cuda.is_available():