Skip to content

Commit

Permalink
[Transforms] Fix sparse-sparse matrix multiplication support on Windo…
Browse files Browse the repository at this point in the history
…ws (Part 2) (#8225)

Address #8219

---------

Co-authored-by: Matthias Fey <[email protected]>
  • Loading branch information
EdisonLeeeee and rusty1s authored Oct 19, 2023
1 parent d7e338d commit 5a7d501
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 1 deletion.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
- Fixed `to_networkx(to_undirected=True)` in case the input graph is not undirected ([#8204](https://github.com/pyg-team/pytorch_geometric/pull/8204))
- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197))
- Fixed sparse-sparse matrix multiplication support on Windows in `TwoHop` and `AddRandomWalkPE` transformations ([#8197](https://github.com/pyg-team/pytorch_geometric/pull/8197), [#8225](https://github.com/pyg-team/pytorch_geometric/pull/8225))

### Removed

Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ def forward(self, data: Data) -> Data:
pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)]
for _ in range(self.walk_length - 1):
out = out @ adj
if out.layout == torch.sparse_coo:
out = out._coalesced_(True)
pe_list.append(get_self_loop_attr(*to_edge_index(out), N))
pe = torch.stack(pe_list, dim=-1)

Expand Down
1 change: 1 addition & 0 deletions torch_geometric/utils/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ def to_edge_index(adj: Union[Tensor, SparseTensor]) -> Tuple[Tensor, Tensor]:
return torch.stack([row, col], dim=0).long(), value

if adj.layout == torch.sparse_coo:
adj = adj._coalesced_(True)
return adj.indices().detach().long(), adj.values()

if adj.layout == torch.sparse_csr:
Expand Down

0 comments on commit 5a7d501

Please sign in to comment.