Skip to content

Commit

Permalink
Add dense computation graph for AddRandomWalkPE (#8431)
Browse files Browse the repository at this point in the history
Fixes #8427
  • Loading branch information
rusty1s authored Nov 23, 2023
1 parent e82c357 commit 0385b0d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 7 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431))
- Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015))
- Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426))
- Added support for floating-point average degree numbers in `FakeDataset` and `FakeHeteroDataset` ([#8404](https://github.com/pyg-team/pytorch_geometric/pull/8404))
Expand Down
27 changes: 21 additions & 6 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import torch
from torch import Tensor

import torch_geometric.typing
from torch_geometric.data import Data
Expand All @@ -10,6 +11,7 @@
from torch_geometric.utils import (
get_laplacian,
get_self_loop_attr,
is_torch_sparse_tensor,
scatter,
to_edge_index,
to_scipy_sparse_matrix,
Expand All @@ -18,8 +20,11 @@
)


def add_node_attr(data: Data, value: Any,
attr_name: Optional[str] = None) -> Data:
def add_node_attr(
data: Data,
value: Any,
attr_name: Optional[str] = None,
) -> Data:
# TODO Move to `BaseTransform`.
if attr_name is None:
if 'x' in data:
Expand Down Expand Up @@ -138,17 +143,27 @@ def forward(self, data: Data) -> Data:
value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row]
value = 1.0 / value

if torch_geometric.typing.WITH_WINDOWS:
if N <= 2_000: # Dense code path for faster computation:
adj = torch.zeros((N, N), device=row.device)
adj[row, col] = value
loop_index = torch.arange(N, device=row.device)
elif torch_geometric.typing.WITH_WINDOWS:
adj = to_torch_coo_tensor(data.edge_index, value, size=data.size())
else:
adj = to_torch_csr_tensor(data.edge_index, value, size=data.size())

def get_pe(out: Tensor) -> Tensor:
if is_torch_sparse_tensor(out):
return get_self_loop_attr(*to_edge_index(out), num_nodes=N)
return out[loop_index, loop_index]

out = adj
pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)]
pe_list = [get_pe(out)]
for _ in range(self.walk_length - 1):
out = out @ adj
pe_list.append(get_self_loop_attr(*to_edge_index(out), N))
pe = torch.stack(pe_list, dim=-1)
pe_list.append(get_pe(out))

pe = torch.stack(pe_list, dim=-1)
data = add_node_attr(data, pe, attr_name=self.attr_name)

return data
2 changes: 1 addition & 1 deletion torch_geometric/utils/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,7 +402,7 @@ def get_self_loop_attr(edge_index: Tensor, edge_attr: OptTensor = None,
if edge_attr is not None:
loop_attr = edge_attr[loop_mask]
else: # A vector of ones:
loop_attr = torch.ones_like(loop_index, dtype=torch.float)
loop_attr = torch.ones(loop_index.numel(), device=edge_index.device)

num_nodes = maybe_num_nodes(edge_index, num_nodes)
full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])
Expand Down

0 comments on commit 0385b0d

Please sign in to comment.