From abf438a8b0700137bbd725ae06c79beb05a6a077 Mon Sep 17 00:00:00 2001 From: Jintang Li Date: Sun, 22 Oct 2023 15:54:49 +0800 Subject: [PATCH 1/3] [nit] Remove unnecessary codes in `AddRandomWalkPE` (#8245) A follow-up of #8225. We will do `_coalesced_` in `to_edge_index`. --- torch_geometric/transforms/add_positional_encoding.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py index db646be5a79d..6739479840a4 100644 --- a/torch_geometric/transforms/add_positional_encoding.py +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -147,8 +147,6 @@ def forward(self, data: Data) -> Data: pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)] for _ in range(self.walk_length - 1): out = out @ adj - if out.layout == torch.sparse_coo: - out = out._coalesced_(True) pe_list.append(get_self_loop_attr(*to_edge_index(out), N)) pe = torch.stack(pe_list, dim=-1) From 405ef2cd74212ee061a67cb454d99994f34a95c9 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sun, 22 Oct 2023 00:55:14 -0700 Subject: [PATCH 2/3] Fix `AttentionExplainer` for `AttentiveFP` (#8244) --- CHANGELOG.md | 1 + .../algorithm/test_attention_explainer.py | 23 ++++++++++++++++++- .../explain/algorithm/attention_explainer.py | 3 ++- torch_geometric/nn/models/attentive_fp.py | 17 +++++++++----- 4 files changed, 36 insertions(+), 8 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 30853e8434c6..afea1dfbef34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244)) - Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242)) - Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239)) - Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) diff --git a/test/explain/algorithm/test_attention_explainer.py b/test/explain/algorithm/test_attention_explainer.py index f5ab16d18d73..e4479367d5b6 100644 --- a/test/explain/algorithm/test_attention_explainer.py +++ b/test/explain/algorithm/test_attention_explainer.py @@ -3,7 +3,7 @@ from torch_geometric.explain import AttentionExplainer, Explainer from torch_geometric.explain.config import ExplanationType, MaskType -from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv +from torch_geometric.nn import AttentiveFP, GATConv, GATv2Conv, TransformerConv class AttentionGNN(torch.nn.Module): @@ -25,6 +25,8 @@ def forward(self, x, edge_index): [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) +edge_attr = torch.randn(edge_index.size(1), 5) +batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) @@ -61,3 +63,22 @@ def test_attention_explainer_supports(explanation_type, node_mask_type): return_type='raw', ), ) + + +def test_attention_explainer_attentive_fp(check_explanation): + model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) + + explainer = Explainer( + model=model, + algorithm=AttentionExplainer(), + explanation_type='model', + edge_mask_type='object', + model_config=dict( + mode='binary_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) + check_explanation(explanation, None, explainer.edge_mask_type) diff --git a/torch_geometric/explain/algorithm/attention_explainer.py b/torch_geometric/explain/algorithm/attention_explainer.py index 6db4905eaae2..69f2ebeb17c2 100644 --- a/torch_geometric/explain/algorithm/attention_explainer.py +++ b/torch_geometric/explain/algorithm/attention_explainer.py @@ -58,7 +58,8 @@ def hook(module, msg_kwargs, out): hook_handles = [] for module in model.modules(): # Register message forward hooks: - if isinstance(module, MessagePassing): + if (isinstance(module, MessagePassing) + and module.explain is not False): hook_handles.append(module.register_message_forward_hook(hook)) model(x, edge_index, **kwargs) diff --git a/torch_geometric/nn/models/attentive_fp.py b/torch_geometric/nn/models/attentive_fp.py index 7af1d19274ad..43a664fc3b68 100644 --- a/torch_geometric/nn/models/attentive_fp.py +++ b/torch_geometric/nn/models/attentive_fp.py @@ -41,15 +41,17 @@ def reset_parameters(self): zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor: - # propagate_type: (x: Tensor, edge_attr: Tensor) - out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) + # edge_updater_type: (x: Tensor, edge_attr: Tensor) + alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr) + + # propagate_type: (x: Tensor, alpha: Tensor) + out = self.propagate(edge_index, x=x, alpha=alpha, size=None) out = out + self.bias return out - def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - + def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, + index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1))) alpha_j = (x_j @ self.att_l.t()).squeeze(-1) alpha_i = (x_i @ self.att_r.t()).squeeze(-1) @@ -57,6 +59,9 @@ def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, alpha = F.leaky_relu_(alpha) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) + return alpha + + def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return self.lin2(x_j) * alpha.unsqueeze(-1) From b17555e5b91e00259a904b902d1774184f1e09f2 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Sun, 22 Oct 2023 01:13:17 -0700 Subject: [PATCH 3/3] Allow `to_hetero_with_bases` to be applied on static graphs (#8247) --- CHANGELOG.md | 1 + test/nn/test_to_hetero_transformer.py | 6 ++---- .../test_to_hetero_with_bases_transformer.py | 21 +++++++++++++++++++ .../nn/to_hetero_with_bases_transformer.py | 18 ++++++++-------- 4 files changed, 33 insertions(+), 13 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index afea1dfbef34..22598fb75595 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247)) - Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196)) - Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032)) - Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) diff --git a/test/nn/test_to_hetero_transformer.py b/test/nn/test_to_hetero_transformer.py index 6bf1c269b529..2ea69c840bd3 100644 --- a/test/nn/test_to_hetero_transformer.py +++ b/test/nn/test_to_hetero_transformer.py @@ -520,10 +520,8 @@ def test_to_hetero_on_static_graphs(): 'author': torch.randn(4, 100, 16), } edge_index_dict = { - ('paper', 'written_by', 'author'): - torch.randint(100, (2, 200), dtype=torch.long), - ('author', 'writes', 'paper'): - torch.randint(100, (2, 200), dtype=torch.long), + ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)), + ('author', 'writes', 'paper'): torch.randint(100, (2, 200)), } metadata = list(x_dict.keys()), list(edge_index_dict.keys()) diff --git a/test/nn/test_to_hetero_with_bases_transformer.py b/test/nn/test_to_hetero_with_bases_transformer.py index 6f40f0cb20a9..3736ad6c6e85 100644 --- a/test/nn/test_to_hetero_with_bases_transformer.py +++ b/test/nn/test_to_hetero_with_bases_transformer.py @@ -293,3 +293,24 @@ def test_to_hetero_with_bases_validate(): with pytest.warns(UserWarning, match="letters, numbers and underscores"): model = to_hetero_with_bases(model, metadata, num_bases=4, debug=False) + + +def test_to_hetero_with_bases_on_static_graphs(): + x_dict = { + 'paper': torch.randn(4, 100, 16), + 'author': torch.randn(4, 100, 16), + } + edge_index_dict = { + ('paper', 'written_by', 'author'): torch.randint(100, (2, 200)), + ('author', 'writes', 'paper'): torch.randint(100, (2, 200)), + } + + metadata = list(x_dict.keys()), list(edge_index_dict.keys()) + model = to_hetero_with_bases(Net4(), metadata, num_bases=4, + in_channels={'x0': 16}, debug=False) + + out_dict = model(x_dict, edge_index_dict) + + assert len(out_dict) == 2 + assert out_dict['paper'].size() == (4, 100, 32) + assert out_dict['author'].size() == (4, 100, 32) diff --git a/torch_geometric/nn/to_hetero_with_bases_transformer.py b/torch_geometric/nn/to_hetero_with_bases_transformer.py index 01410d9675a1..cb787f364515 100644 --- a/torch_geometric/nn/to_hetero_with_bases_transformer.py +++ b/torch_geometric/nn/to_hetero_with_bases_transformer.py @@ -324,7 +324,7 @@ def __init__(self, module: MessagePassing, num_relations: int, # to a materialization of messages. def hook(module, inputs, output): assert isinstance(module._edge_type, Tensor) - if module._edge_type.size(0) != output.size(0): + if module._edge_type.size(0) != output.size(-2): raise ValueError( f"Number of messages ({output.size(0)}) does not match " f"with the number of original edges " @@ -332,7 +332,7 @@ def hook(module, inputs, output): f"passing layer create additional self-loops? Try to " f"remove them via 'add_self_loops=False'") weight = module.edge_type_weight.view(-1)[module._edge_type] - weight = weight.view([-1] + [1] * (output.dim() - 1)) + weight = weight.view([1] * (output.dim() - 2) + [-1, 1]) return weight * output params = list(module.parameters()) @@ -415,7 +415,7 @@ def get_node_offset_dict( out: Dict[NodeType, int] = {} for key in type2id.keys(): out[key] = cumsum - cumsum += input_dict[key].size(0) + cumsum += input_dict[key].size(-2) return out @@ -433,7 +433,7 @@ def get_edge_offset_dict( elif value.dtype == torch.long and value.size(0) == 2: cumsum += value.size(-1) else: - cumsum += value.size(0) + cumsum += value.size(-2) return out @@ -458,7 +458,7 @@ def get_edge_type( out = torch.full((value.nnz(), ), i, dtype=torch.long, device=value.device()) else: - out = value.new_full((value.size(0), ), i, dtype=torch.long) + out = value.new_full((value.size(-2), ), i, dtype=torch.long) outs.append(out) return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0) @@ -474,7 +474,7 @@ def group_node_placeholder(input_dict: Dict[NodeType, Tensor], type2id: Dict[NodeType, int]) -> Tensor: inputs = [input_dict[key] for key in type2id.keys()] - return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=0) + return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=-2) def group_edge_placeholder( @@ -528,7 +528,7 @@ def group_edge_placeholder( return torch.stack([row, col], dim=0) else: - return torch.cat(inputs, dim=0) + return torch.cat(inputs, dim=-2) ############################################################################### @@ -542,9 +542,9 @@ def split_output( offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]], ) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]: - cumsums = list(offset_dict.values()) + [output.size(0)] + cumsums = list(offset_dict.values()) + [output.size(-2)] sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))] - outputs = output.split(sizes) + outputs = output.split(sizes, dim=-2) return {key: output for key, output in zip(offset_dict, outputs)}