Skip to content

Commit

Permalink
Fix Pre-trained DimeNet++ performance on QM9 (#8239)
Browse files Browse the repository at this point in the history
Fix #4698
```
Target: 00, MAE: 0.02975 ± 0.05869
Target: 01, MAE: 0.04322 ± 0.15740
Target: 02, MAE: 24.43286 ± 35.56775
Target: 03, MAE: 19.42164 ± 31.41084
Target: 05, MAE: 0.28941 ± 0.66366
Target: 06, MAE: 1.21997 ± 2.13234
Target: 07, MAE: 6.15220 ± 13.26018
Target: 08, MAE: 6.20371 ± 12.61268
Target: 09, MAE: 6.48553 ± 14.85413
Target: 10, MAE: 7.41481 ± 14.50145
Target: 11, MAE: 0.02268 ± 0.02604
```
  • Loading branch information
xnuohz authored Oct 21, 2023
1 parent 4553ca8 commit 9632694
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed Pre-trained `DimeNet++` performance on QM9 ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))
- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
- Fixed `to_networkx(to_undirected=True)` in case the input graph is not undirected ([#8204](https://github.com/pyg-team/pytorch_geometric/pull/8204))
- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197), [#8225](https://github.com/pyg-team/pytorch_geometric/pull/8225))
Expand Down
14 changes: 9 additions & 5 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ def __init__(
])
self.lin = Linear(hidden_channels, hidden_channels)
self.layers_after_skip = torch.nn.ModuleList([
ResidualLayer(hidden_channels, act) for _ in range(num_before_skip)
ResidualLayer(hidden_channels, act) for _ in range(num_after_skip)
])

self.reset_parameters()
Expand Down Expand Up @@ -695,10 +695,14 @@ def forward(
dist = (pos[i] - pos[j]).pow(2).sum(dim=-1).sqrt()

# Calculate angles.
pos_i = pos[idx_i]
pos_ji, pos_ki = pos[idx_j] - pos_i, pos[idx_k] - pos_i
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
if isinstance(self, DimeNetPlusPlus):
pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j]
a = (pos_ij * pos_jk).sum(dim=-1)
b = torch.cross(pos_ij, pos_jk).norm(dim=-1)
elif isinstance(self, DimeNet):
pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_i]
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
Expand Down

0 comments on commit 9632694

Please sign in to comment.