Skip to content

Commit

Permalink
Fix AttentionExplainer for AttentiveFP (#8244)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 22, 2023
1 parent abf438a commit 405ef2c
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 8 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244))
- Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242))
- Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))
- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
Expand Down
23 changes: 22 additions & 1 deletion test/explain/algorithm/test_attention_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch_geometric.explain import AttentionExplainer, Explainer
from torch_geometric.explain.config import ExplanationType, MaskType
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv
from torch_geometric.nn import AttentiveFP, GATConv, GATv2Conv, TransformerConv


class AttentionGNN(torch.nn.Module):
Expand All @@ -25,6 +25,8 @@ def forward(self, x, edge_index):
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],
])
edge_attr = torch.randn(edge_index.size(1), 5)
batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])


@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])
Expand Down Expand Up @@ -61,3 +63,22 @@ def test_attention_explainer_supports(explanation_type, node_mask_type):
return_type='raw',
),
)


def test_attention_explainer_attentive_fp(check_explanation):
model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2)

explainer = Explainer(
model=model,
algorithm=AttentionExplainer(),
explanation_type='model',
edge_mask_type='object',
model_config=dict(
mode='binary_classification',
task_level='node',
return_type='raw',
),
)

explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch)
check_explanation(explanation, None, explainer.edge_mask_type)
3 changes: 2 additions & 1 deletion torch_geometric/explain/algorithm/attention_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def hook(module, msg_kwargs, out):

hook_handles = []
for module in model.modules(): # Register message forward hooks:
if isinstance(module, MessagePassing):
if (isinstance(module, MessagePassing)
and module.explain is not False):
hook_handles.append(module.register_message_forward_hook(hook))

model(x, edge_index, **kwargs)
Expand Down
17 changes: 11 additions & 6 deletions torch_geometric/nn/models/attentive_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,27 @@ def reset_parameters(self):
zeros(self.bias)

def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
# propagate_type: (x: Tensor, edge_attr: Tensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
# edge_updater_type: (x: Tensor, edge_attr: Tensor)
alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

# propagate_type: (x: Tensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha, size=None)
out = out + self.bias
return out

def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:

def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
alpha_j = (x_j @ self.att_l.t()).squeeze(-1)
alpha_i = (x_i @ self.att_r.t()).squeeze(-1)
alpha = alpha_j + alpha_i
alpha = F.leaky_relu_(alpha)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha

def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return self.lin2(x_j) * alpha.unsqueeze(-1)


Expand Down

0 comments on commit 405ef2c

Please sign in to comment.