Skip to content

Commit

Permalink
Remove filtering of node/edge types in trim_to_layer (#9021)
Browse files Browse the repository at this point in the history
This is not safe in most cases, since a filtering of an empty edge type
may lead to the unexpected drop of node features.

Fixes #9015
  • Loading branch information
rusty1s authored Mar 5, 2024
1 parent 0d30e89 commit af0f5f4
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 53 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Remove filtering of node/edge types in `trim_to_layer` functionality ([#9021](https://github.com/pyg-team/pytorch_geometric/pull/9021))
- Default to `scatter` operations in `MessagePassing` in case `torch.use_deterministic_algorithms` is not set ([#9009](https://github.com/pyg-team/pytorch_geometric/pull/9009))
- Made `MessagePassing` interface thread-safe ([#9001](https://github.com/pyg-team/pytorch_geometric/pull/9001))
- Breaking Change: Added support for `EdgeIndex` in `cugraph` GNN layers ([#8938](https://github.com/pyg-team/pytorch_geometric/pull/8937))
Expand Down
36 changes: 0 additions & 36 deletions test/utils/test_trim_to_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,39 +197,3 @@ def test_trim_to_layer_with_neighbor_loader():
assert out2.size() == (2, 16)

assert torch.allclose(out1, out2, atol=1e-6)


def test_trim_to_layer_filtering():
x_dict = {
'paper': torch.rand((13, 128)),
'author': torch.rand((5, 128)),
'field_of_study': torch.rand((6, 128))
}
edge_index_dict = {
('author', 'writes', 'paper'):
torch.tensor([[0, 1, 2, 3, 4], [0, 0, 1, 2, 2]]),
('paper', 'has_topic', 'field_of_study'):
torch.tensor([[6, 7, 8, 9], [0, 0, 1, 1]])
}
num_sampled_nodes_dict = {
'paper': [1, 2, 10],
'author': [0, 2, 3],
'field_of_study': [0, 2, 4]
}
num_sampled_edges_dict = {
('author', 'writes', 'paper'): [2, 3],
('paper', 'has_topic', 'field_of_study'): [0, 4]
}
x_dict, edge_index_dict, _ = trim_to_layer(
layer=1,
num_sampled_nodes_per_hop=num_sampled_nodes_dict,
num_sampled_edges_per_hop=num_sampled_edges_dict,
x=x_dict,
edge_index=edge_index_dict,
)
assert list(edge_index_dict.keys()) == [('author', 'writes', 'paper')]
assert torch.equal(edge_index_dict[('author', 'writes', 'paper')],
torch.tensor([[0, 1], [0, 0]]))
assert x_dict['paper'].size() == (3, 128)
assert x_dict['author'].size() == (2, 128)
assert x_dict['field_of_study'].size() == (2, 128)
18 changes: 1 addition & 17 deletions torch_geometric/utils/_trim_to_layer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import copy
from typing import Any, Dict, List, Optional, Tuple, Union, overload
from typing import Dict, List, Optional, Tuple, Union, overload

import torch
from torch import Tensor
Expand All @@ -17,18 +16,6 @@
)


def filter_empty_entries(
input_dict: Dict[Union[Any], Tensor]) -> Dict[Any, Tensor]:
r"""Removes empty tensors from a dictionary. This avoids unnecessary
computation when some node/edge types are non-reachable after trimming.
"""
out_dict = copy.copy(input_dict)
for key, value in input_dict.items():
if value.numel() == 0:
del out_dict[key]
return out_dict


@overload
def trim_to_layer(
layer: int,
Expand Down Expand Up @@ -96,7 +83,6 @@ def trim_to_layer(
k: trim_feat(v, layer, num_sampled_nodes_per_hop[k])
for k, v in x.items()
}
x = filter_empty_entries(x)

assert isinstance(edge_index, dict)
edge_index = {
Expand All @@ -110,15 +96,13 @@ def trim_to_layer(
)
for k, v in edge_index.items()
}
edge_index = filter_empty_entries(edge_index)

if edge_attr is not None:
assert isinstance(edge_attr, dict)
edge_attr = {
k: trim_feat(v, layer, num_sampled_edges_per_hop[k])
for k, v in edge_attr.items()
}
edge_attr = filter_empty_entries(edge_attr)

return x, edge_index, edge_attr

Expand Down

0 comments on commit af0f5f4

Please sign in to comment.