diff --git a/CHANGELOG.md b/CHANGELOG.md index 30853e8434c6..afea1dfbef34 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244)) - Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242)) - Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239)) - Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216)) diff --git a/test/explain/algorithm/test_attention_explainer.py b/test/explain/algorithm/test_attention_explainer.py index f5ab16d18d73..e4479367d5b6 100644 --- a/test/explain/algorithm/test_attention_explainer.py +++ b/test/explain/algorithm/test_attention_explainer.py @@ -3,7 +3,7 @@ from torch_geometric.explain import AttentionExplainer, Explainer from torch_geometric.explain.config import ExplanationType, MaskType -from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv +from torch_geometric.nn import AttentiveFP, GATConv, GATv2Conv, TransformerConv class AttentionGNN(torch.nn.Module): @@ -25,6 +25,8 @@ def forward(self, x, edge_index): [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], ]) +edge_attr = torch.randn(edge_index.size(1), 5) +batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2]) @pytest.mark.parametrize('index', [None, 2, torch.arange(3)]) @@ -61,3 +63,22 @@ def test_attention_explainer_supports(explanation_type, node_mask_type): return_type='raw', ), ) + + +def test_attention_explainer_attentive_fp(check_explanation): + model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) + + explainer = Explainer( + model=model, + algorithm=AttentionExplainer(), + explanation_type='model', + edge_mask_type='object', + model_config=dict( + mode='binary_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) + check_explanation(explanation, None, explainer.edge_mask_type) diff --git a/torch_geometric/explain/algorithm/attention_explainer.py b/torch_geometric/explain/algorithm/attention_explainer.py index 6db4905eaae2..69f2ebeb17c2 100644 --- a/torch_geometric/explain/algorithm/attention_explainer.py +++ b/torch_geometric/explain/algorithm/attention_explainer.py @@ -58,7 +58,8 @@ def hook(module, msg_kwargs, out): hook_handles = [] for module in model.modules(): # Register message forward hooks: - if isinstance(module, MessagePassing): + if (isinstance(module, MessagePassing) + and module.explain is not False): hook_handles.append(module.register_message_forward_hook(hook)) model(x, edge_index, **kwargs) diff --git a/torch_geometric/nn/models/attentive_fp.py b/torch_geometric/nn/models/attentive_fp.py index 7af1d19274ad..43a664fc3b68 100644 --- a/torch_geometric/nn/models/attentive_fp.py +++ b/torch_geometric/nn/models/attentive_fp.py @@ -41,15 +41,17 @@ def reset_parameters(self): zeros(self.bias) def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor: - # propagate_type: (x: Tensor, edge_attr: Tensor) - out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None) + # edge_updater_type: (x: Tensor, edge_attr: Tensor) + alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr) + + # propagate_type: (x: Tensor, alpha: Tensor) + out = self.propagate(edge_index, x=x, alpha=alpha, size=None) out = out + self.bias return out - def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, - index: Tensor, ptr: OptTensor, - size_i: Optional[int]) -> Tensor: - + def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, + index: Tensor, ptr: OptTensor, + size_i: Optional[int]) -> Tensor: x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1))) alpha_j = (x_j @ self.att_l.t()).squeeze(-1) alpha_i = (x_i @ self.att_r.t()).squeeze(-1) @@ -57,6 +59,9 @@ def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor, alpha = F.leaky_relu_(alpha) alpha = softmax(alpha, index, ptr, size_i) alpha = F.dropout(alpha, p=self.dropout, training=self.training) + return alpha + + def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: return self.lin2(x_j) * alpha.unsqueeze(-1)