From 0385b0d7a411b96c65a90d0a333adfca9e7dabf9 Mon Sep 17 00:00:00 2001 From: Matthias Fey Date: Thu, 23 Nov 2023 20:35:54 +0100 Subject: [PATCH] Add dense computation graph for `AddRandomWalkPE` (#8431) Fixes https://github.com/pyg-team/pytorch_geometric/issues/8427 --- CHANGELOG.md | 1 + .../transforms/add_positional_encoding.py | 27 ++++++++++++++----- torch_geometric/utils/loop.py | 2 +- 3 files changed, 23 insertions(+), 7 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f4ec45d86a29..e51705924953 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added dense computation for `AddRandomWalkPE` ([#8431](https://github.com/pyg-team/pytorch_geometric/pull/8431)) - Added a tutorial for point cloud processing ([#8015](https://github.com/pyg-team/pytorch_geometric/pull/8015)) - Added `fsspec` as file system backend ([#8379](https://github.com/pyg-team/pytorch_geometric/pull/8379), [#8426](https://github.com/pyg-team/pytorch_geometric/pull/8426)) - Added support for floating-point average degree numbers in `FakeDataset` and `FakeHeteroDataset` ([#8404](https://github.com/pyg-team/pytorch_geometric/pull/8404)) diff --git a/torch_geometric/transforms/add_positional_encoding.py b/torch_geometric/transforms/add_positional_encoding.py index 6739479840a4..095fa6635299 100644 --- a/torch_geometric/transforms/add_positional_encoding.py +++ b/torch_geometric/transforms/add_positional_encoding.py @@ -2,6 +2,7 @@ import numpy as np import torch +from torch import Tensor import torch_geometric.typing from torch_geometric.data import Data @@ -10,6 +11,7 @@ from torch_geometric.utils import ( get_laplacian, get_self_loop_attr, + is_torch_sparse_tensor, scatter, to_edge_index, to_scipy_sparse_matrix, @@ -18,8 +20,11 @@ ) -def add_node_attr(data: Data, value: Any, - attr_name: Optional[str] = None) -> Data: +def add_node_attr( + data: Data, + value: Any, + attr_name: Optional[str] = None, +) -> Data: # TODO Move to `BaseTransform`. if attr_name is None: if 'x' in data: @@ -138,17 +143,27 @@ def forward(self, data: Data) -> Data: value = scatter(value, row, dim_size=N, reduce='sum').clamp(min=1)[row] value = 1.0 / value - if torch_geometric.typing.WITH_WINDOWS: + if N <= 2_000: # Dense code path for faster computation: + adj = torch.zeros((N, N), device=row.device) + adj[row, col] = value + loop_index = torch.arange(N, device=row.device) + elif torch_geometric.typing.WITH_WINDOWS: adj = to_torch_coo_tensor(data.edge_index, value, size=data.size()) else: adj = to_torch_csr_tensor(data.edge_index, value, size=data.size()) + def get_pe(out: Tensor) -> Tensor: + if is_torch_sparse_tensor(out): + return get_self_loop_attr(*to_edge_index(out), num_nodes=N) + return out[loop_index, loop_index] + out = adj - pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)] + pe_list = [get_pe(out)] for _ in range(self.walk_length - 1): out = out @ adj - pe_list.append(get_self_loop_attr(*to_edge_index(out), N)) - pe = torch.stack(pe_list, dim=-1) + pe_list.append(get_pe(out)) + pe = torch.stack(pe_list, dim=-1) data = add_node_attr(data, pe, attr_name=self.attr_name) + return data diff --git a/torch_geometric/utils/loop.py b/torch_geometric/utils/loop.py index 4a9c69376a4f..6a8c0e2df78a 100644 --- a/torch_geometric/utils/loop.py +++ b/torch_geometric/utils/loop.py @@ -402,7 +402,7 @@ def get_self_loop_attr(edge_index: Tensor, edge_attr: OptTensor = None, if edge_attr is not None: loop_attr = edge_attr[loop_mask] else: # A vector of ones: - loop_attr = torch.ones_like(loop_index, dtype=torch.float) + loop_attr = torch.ones(loop_index.numel(), device=edge_index.device) num_nodes = maybe_num_nodes(edge_index, num_nodes) full_loop_attr = loop_attr.new_zeros((num_nodes, ) + loop_attr.size()[1:])