Skip to content

Commit

Permalink
Add dim argument to torch.cross calls (#8918)
Browse files Browse the repository at this point in the history
### What does this PR do ?

Add the `dim` argument to the `torch.cross` tensor function to suppress
user warnings.

> [!NOTE]
> Using `torch.cross` without specifying the dim arg is deprecated.
Please either pass the dim explicitly or simply use
`torch.linalg.cross`.

### What are the modifications ?

I simply added the dimension where the size of the tensor is 3 (by
default), which is at `dim=1`. In some cases the argument was already
there, I added it when appropriate.

### Versions

- torch: 2.2.0

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
3 people authored Feb 16, 2024
1 parent cc04932 commit 3ddd11d
Show file tree
Hide file tree
Showing 5 changed files with 15 additions and 6 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Added the `dim` arg to `torch.cross` calls ([#8918](https://github.com/pyg-team/pytorch_geometric/pull/8918))

### Deprecated

### Fixed
Expand Down
4 changes: 2 additions & 2 deletions torch_geometric/nn/models/dimenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,11 +699,11 @@ def forward(
if isinstance(self, DimeNetPlusPlus):
pos_jk, pos_ij = pos[idx_j] - pos[idx_k], pos[idx_i] - pos[idx_j]
a = (pos_ij * pos_jk).sum(dim=-1)
b = torch.cross(pos_ij, pos_jk).norm(dim=-1)
b = torch.cross(pos_ij, pos_jk, dim=1).norm(dim=-1)
elif isinstance(self, DimeNet):
pos_ji, pos_ki = pos[idx_j] - pos[idx_i], pos[idx_k] - pos[idx_i]
a = (pos_ji * pos_ki).sum(dim=-1)
b = torch.cross(pos_ji, pos_ki).norm(dim=-1)
b = torch.cross(pos_ji, pos_ki, dim=1).norm(dim=-1)
angle = torch.atan2(b, a)

rbf = self.rbf(dist)
Expand Down
2 changes: 1 addition & 1 deletion torch_geometric/transforms/generate_mesh_normals.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def forward(self, data: Data) -> Data:

vec1 = pos[face[1]] - pos[face[0]]
vec2 = pos[face[2]] - pos[face[0]]
face_norm = F.normalize(vec1.cross(vec2), p=2, dim=-1) # [F, 3]
face_norm = F.normalize(vec1.cross(vec2, dim=1), p=2, dim=-1) # [F, 3]

face_norm = face_norm.repeat(3, 1)
idx = face.view(-1)
Expand Down
8 changes: 6 additions & 2 deletions torch_geometric/transforms/sample_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ def forward(self, data: Data) -> Data:
pos_max = pos.abs().max()
pos = pos / pos_max

area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]])
area = (pos[face[1]] - pos[face[0]]).cross(
pos[face[2]] - pos[face[0]],
dim=1,
)
area = area.norm(p=2, dim=1).abs() / 2

prob = area / area.sum()
Expand All @@ -52,7 +55,8 @@ def forward(self, data: Data) -> Data:
vec2 = pos[face[2]] - pos[face[0]]

if self.include_normals:
data.normal = torch.nn.functional.normalize(vec1.cross(vec2), p=2)
data.normal = torch.nn.functional.normalize(
vec1.cross(vec2, dim=1), p=2)

pos_sampled = pos[face[0]]
pos_sampled += frac[:, :1] * vec1
Expand Down
5 changes: 4 additions & 1 deletion torch_geometric/utils/geodesic.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def geodesic_distance( # noqa: D417
max_distance = float('inf') if max_distance is None else max_distance

if norm:
area = (pos[face[1]] - pos[face[0]]).cross(pos[face[2]] - pos[face[0]])
area = (pos[face[1]] - pos[face[0]]).cross(
pos[face[2]] - pos[face[0]],
dim=1,
)
scale = float((area.norm(p=2, dim=1) / 2).sum().sqrt())
else:
scale = 1.0
Expand Down

0 comments on commit 3ddd11d

Please sign in to comment.