Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fix
GraphMaskExplainer
for deep GNNs (#8401)
The current GraphMask explainer gave me an error for a model with > 2 layers. If you take the GCN Node Classification task in the example file, and modify the GNN https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/examples/explain/graphmask_explainer.py#L19-L29 to ```python class GCN(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_features, 16) self.conv2 = GCNConv(16, 16) self.conv3 = GCNConv(16, dataset.num_classes) def forward(self, x, edge_index): x = self.conv1(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index).relu() x = F.dropout(x, training=self.training) x = self.conv3(x, edge_index) return F.log_softmax(x, dim=1) ``` And run the example, I get the following output: ``` Train explainer for node(s) tensor([5]) with layer 2: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.14s/it] Train explainer for node(s) tensor([5]) with layer 1: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.15s/it] Train explainer for node(s) tensor([5]) with layer 0: 100%|████████████████████████████████| 1/1 [00:01<00:00, 1.34s/it] Explain: 67%|███████████████████████████████████████████████████▎ | 2/3 [00:00<00:00, 12.54it/s]Traceback (most recent call last): File "/local/scratch/ga384/pyg/ex.py", line 101, in <module> explanation = explainer(data.x, data.edge_index, index=node_index) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/explainer.py", line 204, in __call__ explanation = self.algorithm( File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 133, in forward edge_mask = self._explain(model, index=index) File "/local/scratch/ga384/miniconda3/envs/pytorch-geo/lib/python3.10/site-packages/torch_geometric/explain/algorithm/graphmask_explainer.py", line 526, in _explain sampling_weights = F.pad( RuntimeError: Padding length too large ``` I don't think the padding is necessary here: it looks like `edge_weight` is just accumulating the results per layer, given this line: https://github.com/pyg-team/pytorch_geometric/blob/cf24b4bcb4e825537ba08d8fc5f31073e2cd84c7/torch_geometric/explain/algorithm/graphmask_explainer.py#L541-L542 This PR therefore removes the padding. With that change, my GCN works. --------- Co-authored-by: rusty1s <[email protected]>
- Loading branch information