diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ecc97adb613..8ac45e0b0de5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed +- Fixed `GraphMaskExplainer` for GNNs with more than two layers ([#8401](https://github.com/pyg-team/pytorch_geometric/pull/8401)) - Breaking Change: Properly initialize modules in `GATConv` depending on whether the input is bipartite or non-bipartite ([#8397](https://github.com/pyg-team/pytorch_geometric/pull/8397)) - Fixed `input_id` computation in `NeighborLoader` in case a `mask` is given ([#8312](https://github.com/pyg-team/pytorch_geometric/pull/8312)) - Respect current device when deep-copying `Linear` layers ([#8311](https://github.com/pyg-team/pytorch_geometric/pull/8311)) diff --git a/torch_geometric/explain/algorithm/graphmask_explainer.py b/torch_geometric/explain/algorithm/graphmask_explainer.py index 912026ce6c56..089fe7659a6e 100644 --- a/torch_geometric/explain/algorithm/graphmask_explainer.py +++ b/torch_geometric/explain/algorithm/graphmask_explainer.py @@ -526,12 +526,6 @@ def _explain( if i == 0: edge_weight = sampling_weights else: - if edge_weight.size(-1) != sampling_weights.size(-1): - sampling_weights = F.pad( - input=sampling_weights, - pad=(0, edge_weight.size(-1) - - sampling_weights.size(-1), 0, 0), - mode='constant', value=0) edge_weight = torch.cat((edge_weight, sampling_weights), 0) if self.log: pbar.update(1)