From 4d280f5635d2b021eea415c043f479f8c78bb042 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Tue, 10 Oct 2023 00:56:51 +0200 Subject: [PATCH] Fix `HeteroConv` for layers that have a non-default argument order (#8166) Fixes #8150 --- CHANGELOG.md | 1 + .../algorithm/test_explain_algorithm_utils.py | 10 +-- test/nn/conv/test_hetero_conv.py | 79 ++++++++++++++----- torch_geometric/explain/algorithm/utils.py | 22 +++--- torch_geometric/nn/conv/hetero_conv.py | 42 +++++----- torch_geometric/nn/conv/utils/cheatsheet.py | 2 + 6 files changed, 102 insertions(+), 54 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 335fcc2bacfb..f85cca8c6eb8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -118,6 +118,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166)) - Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163)) - Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145)) - Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143)) diff --git a/test/explain/algorithm/test_explain_algorithm_utils.py b/test/explain/algorithm/test_explain_algorithm_utils.py index 71a376889234..4502eb788329 100644 --- a/test/explain/algorithm/test_explain_algorithm_utils.py +++ b/test/explain/algorithm/test_explain_algorithm_utils.py @@ -53,16 +53,14 @@ def test_set_clear_mask(hetero_data): set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict) for edge_type in hetero_data.edge_types: # Check that masks are correctly set: - str_edge_type = '__'.join(edge_type) - assert torch.allclose(model.conv1.convs[str_edge_type]._edge_mask, + assert torch.allclose(model.conv1.convs[edge_type]._edge_mask, edge_mask_dict[edge_type]) - assert model.conv1.convs[str_edge_type].explain + assert model.conv1.convs[edge_type].explain clear_masks(model) for edge_type in hetero_data.edge_types: - str_edge_type = '__'.join(edge_type) - assert model.conv1.convs[str_edge_type]._edge_mask is None - assert not model.conv1.convs[str_edge_type].explain + assert model.conv1.convs[edge_type]._edge_mask is None + assert not model.conv1.convs[edge_type].explain model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False) diff --git a/test/nn/conv/test_hetero_conv.py b/test/nn/conv/test_hetero_conv.py index ad38a767ee04..7eb43f48507d 100644 --- a/test/nn/conv/test_hetero_conv.py +++ b/test/nn/conv/test_hetero_conv.py @@ -4,6 +4,7 @@ from torch_geometric.data import HeteroData from torch_geometric.nn import ( GATConv, + GCN2Conv, GCNConv, HeteroConv, Linear, @@ -35,24 +36,55 @@ def test_hetero_conv(aggr): SAGEConv((-1, -1), 64), ('paper', 'to', 'author'): GATConv((-1, -1), 64, edge_dim=3, add_self_loops=False), - }, aggr=aggr) + }, + aggr=aggr, + ) assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' - out = conv(data.x_dict, data.edge_index_dict, data.edge_attr_dict, - edge_weight_dict=data.edge_weight_dict) + out_dict = conv( + data.x_dict, + data.edge_index_dict, + data.edge_attr_dict, + edge_weight_dict=data.edge_weight_dict, + ) - assert len(out) == 2 + assert len(out_dict) == 2 if aggr == 'cat': - assert out['paper'].size() == (50, 128) - assert out['author'].size() == (30, 64) + assert out_dict['paper'].size() == (50, 128) + assert out_dict['author'].size() == (30, 64) elif aggr is not None: - assert out['paper'].size() == (50, 64) - assert out['author'].size() == (30, 64) + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) else: - assert out['paper'].size() == (50, 2, 64) - assert out['author'].size() == (30, 1, 64) + assert out_dict['paper'].size() == (50, 2, 64) + assert out_dict['author'].size() == (30, 1, 64) + + +def test_gcn2_hetero_conv(): + data = HeteroData() + data['paper'].x = torch.randn(50, 32) + data['author'].x = torch.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100) + data['paper', 'paper'].edge_weight = torch.rand(200) + + conv = HeteroConv({ + ('paper', 'to', 'paper'): GCN2Conv(32, alpha=0.1), + ('author', 'to', 'author'): GCN2Conv(64, alpha=0.2), + }) + + out_dict = conv( + data.x_dict, + data.x_dict, + data.edge_index_dict, + edge_weight_dict=data.edge_weight_dict, + ) + + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 32) + assert out_dict['author'].size() == (30, 64) class CustomConv(MessagePassing): @@ -81,11 +113,15 @@ def test_hetero_conv_with_custom_conv(): conv = HeteroConv({key: CustomConv(64) for key in data.edge_types}) # Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`: - out = conv(data.x_dict, data.edge_index_dict, data.y_dict, - z_dict=data.z_dict) - assert len(out) == 2 - assert out['paper'].size() == (50, 64) - assert out['author'].size() == (30, 64) + out_dict = conv( + data.x_dict, + data.edge_index_dict, + data.y_dict, + z_dict=data.z_dict, + ) + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) class MessagePassingLoops(MessagePassing): @@ -122,9 +158,12 @@ def test_hetero_conv_with_dot_syntax_node_types(): assert len(list(conv.parameters())) > 0 assert str(conv) == 'HeteroConv(num_relations=3)' - out = conv(data.x_dict, data.edge_index_dict, - edge_weight_dict=data.edge_weight_dict) + out_dict = conv( + data.x_dict, + data.edge_index_dict, + edge_weight_dict=data.edge_weight_dict, + ) - assert len(out) == 2 - assert out['src.paper'].size() == (50, 64) - assert out['author'].size() == (30, 64) + assert len(out_dict) == 2 + assert out_dict['src.paper'].size() == (50, 64) + assert out_dict['author'].size() == (30, 64) diff --git a/torch_geometric/explain/algorithm/utils.py b/torch_geometric/explain/algorithm/utils.py index 0d181f35f110..fc2f1502a021 100644 --- a/torch_geometric/explain/algorithm/utils.py +++ b/torch_geometric/explain/algorithm/utils.py @@ -45,15 +45,19 @@ def set_hetero_masks( for module in model.modules(): if isinstance(module, torch.nn.ModuleDict): for edge_type in mask_dict.keys(): - # TODO (jinu) Use common function get `str_edge_type`. - str_edge_type = '__'.join(edge_type) - if str_edge_type in module: - set_masks( - module[str_edge_type], - mask_dict[edge_type], - edge_index_dict[edge_type], - apply_sigmoid=apply_sigmoid, - ) + if edge_type in module: + edge_level_module = module[edge_type] + elif '__'.join(edge_type) in module: + edge_level_module = module['__'.join(edge_type)] + else: + continue + + set_masks( + edge_level_module, + mask_dict[edge_type], + edge_index_dict[edge_type], + apply_sigmoid=apply_sigmoid, + ) def clear_masks(model: torch.nn.Module): diff --git a/torch_geometric/nn/conv/hetero_conv.py b/torch_geometric/nn/conv/hetero_conv.py index e1d4d2a9d26e..82cbe14d1965 100644 --- a/torch_geometric/nn/conv/hetero_conv.py +++ b/torch_geometric/nn/conv/hetero_conv.py @@ -7,7 +7,7 @@ from torch_geometric.nn.conv import MessagePassing from torch_geometric.nn.module_dict import ModuleDict -from torch_geometric.typing import Adj, EdgeType, NodeType +from torch_geometric.typing import EdgeType, NodeType from torch_geometric.utils.hetero import check_add_self_loops @@ -80,7 +80,7 @@ def __init__( f"passing as they do not occur as destination type in any " f"edge type. This may lead to unexpected behavior.") - self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()}) + self.convs = ModuleDict(convs) self.aggr = aggr def reset_parameters(self): @@ -90,8 +90,6 @@ def reset_parameters(self): def forward( self, - x_dict: Dict[NodeType, Tensor], - edge_index_dict: Dict[EdgeType, Adj], *args_dict, **kwargs_dict, ) -> Dict[NodeType, Tensor]: @@ -117,42 +115,48 @@ def forward( :obj:`edge_attr_dict = { edge_type: edge_attr }`. """ out_dict = defaultdict(list) - for edge_type, edge_index in edge_index_dict.items(): + + for edge_type in self.convs.keys(): src, rel, dst = edge_type - str_edge_type = '__'.join(edge_type) - if str_edge_type not in self.convs: - continue + has_edge_level_arg = False args = [] for value_dict in args_dict: if edge_type in value_dict: + has_edge_level_arg = True args.append(value_dict[edge_type]) elif src == dst and src in value_dict: args.append(value_dict[src]) elif src in value_dict or dst in value_dict: - args.append( - (value_dict.get(src, None), value_dict.get(dst, None))) + args.append(( + value_dict.get(src, None), + value_dict.get(dst, None), + )) kwargs = {} for arg, value_dict in kwargs_dict.items(): + if not arg.endswith('_dict'): + raise ValueError( + f"Keyword arguments in '{self.__class__.__name__}' " + f"need to end with '_dict' (got '{arg}')") + arg = arg[:-5] # `{*}_dict` if edge_type in value_dict: + has_edge_level_arg = True kwargs[arg] = value_dict[edge_type] elif src == dst and src in value_dict: kwargs[arg] = value_dict[src] elif src in value_dict or dst in value_dict: - kwargs[arg] = (value_dict.get(src, None), - value_dict.get(dst, None)) + kwargs[arg] = ( + value_dict.get(src, None), + value_dict.get(dst, None), + ) - conv = self.convs[str_edge_type] - - if src == dst: - out = conv(x_dict[src], edge_index, *args, **kwargs) - else: - out = conv((x_dict[src], x_dict[dst]), edge_index, *args, - **kwargs) + if not has_edge_level_arg: + continue + out = self.convs[edge_type](*args, **kwargs) out_dict[dst].append(out) for key, value in out_dict.items(): diff --git a/torch_geometric/nn/conv/utils/cheatsheet.py b/torch_geometric/nn/conv/utils/cheatsheet.py index 4586a32d6067..9ba9eea72a1b 100644 --- a/torch_geometric/nn/conv/utils/cheatsheet.py +++ b/torch_geometric/nn/conv/utils/cheatsheet.py @@ -53,6 +53,8 @@ def supports_lazy_initialization(cls: str) -> bool: def processes_heterogeneous_graphs(cls: str) -> bool: + if 'hetero' in cls.lower(): + return True cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls] signature = inspect.signature(cls.forward) return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature)