Skip to content

Commit

Permalink
Fix GraphMaskExplainer for deep GNNs (#8401)
Browse files Browse the repository at this point in the history
The current GraphMask explainer gave me an error for a model with > 2
layers. If you take the GCN Node Classification task in the example
file, and modify the GNN


https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/examples/explain/graphmask_explainer.py#L19-L29

to

```python
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 16)
        self.conv2 = GCNConv(16, 16)
        self.conv3 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index).relu()
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index).relu()
        x = F.dropout(x, training=self.training)
        x = self.conv3(x, edge_index)
        return F.log_softmax(x, dim=1)
```

And run the example, I get the following output:

```
Train explainer for node(s) tensor([5]) with layer 2: 100%|████████████████████████████████| 1/1 [00:01<00:00,  1.14s/it]
Train explainer for node(s) tensor([5]) with layer 1: 100%|████████████████████████████████| 1/1 [00:01<00:00,  1.15s/it]
Train explainer for node(s) tensor([5]) with layer 0: 100%|████████████████████████████████| 1/1 [00:01<00:00,  1.34s/it]
Explain:  67%|███████████████████████████████████████████████████▎                         | 2/3 [00:00<00:00, 12.54it/s]Traceback (most recent call last):
  File "/local/scratch/ga384/pyg/ex.py", line 101, in <module>
    explanation = explainer(data.x, data.edge_index, index=node_index)
  File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/explainer.py", line 204, in __call__
    explanation = self.algorithm(
  File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 133, in forward
    edge_mask = self._explain(model, index=index)
  File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 526, in _explain
    sampling_weights = F.pad(
RuntimeError: Padding length too large
```

I don't think the padding is necessary here: it looks like `edge_weight`
is just accumulating the results per layer, given this line:


https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/torch_geometric/explain/algorithm/graphmask_explainer.py#L541-L542

This PR therefore removes the padding. With that change, my GCN works.

---------

Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
GuyAglionby and rusty1s authored Nov 19, 2023
1 parent 24a661f commit 5afd075
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 6 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `GraphMaskExplainer` for GNNs with more than two layers ([#8401](https://github.com/pyg-team/pytorch_geometric/pull/8401))
- Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397))
- Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312))
- Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311))
Expand Down
6 changes: 0 additions & 6 deletions torch_geometric/explain/algorithm/graphmask_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,12 +526,6 @@ def _explain(
if i == 0:
edge_weight = sampling_weights
else:
if edge_weight.size(-1) != sampling_weights.size(-1):
sampling_weights = F.pad(
input=sampling_weights,
pad=(0, edge_weight.size(-1) -
sampling_weights.size(-1), 0, 0),
mode='constant', value=0)
edge_weight = torch.cat((edge_weight, sampling_weights), 0)
if self.log:
pbar.update(1)
Expand Down

0 comments on commit 5afd075

Please sign in to comment.