-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Include all synthetic data referenced in
GNNExplainer
(#8736)
A smaller PR for #8704 ### Include grid motifs and tree-based datasets in GNNExplainer. - add torch_geometric/datasets/graph_generator/tree_graph.py - add torch_geometric/datasets/motif_generator/grid.py - add node label in torch_geometric/datasets/motif_generator/cycle.py (according to bug report #8509, it is better to directly add node label since we know what the label should be, rather than be captured at `else` sentense in line 127 of `./torch_geometric/datasets/explainer_dataset.py`) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: rusty1s <[email protected]>
- Loading branch information
1 parent
3fb6ff1
commit 5230418
Showing
8 changed files
with
171 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
import pytest | ||
|
||
from torch_geometric.datasets.graph_generator import TreeGraph | ||
|
||
|
||
@pytest.mark.parametrize('undirected', [False, True]) | ||
def test_tree_graph(undirected): | ||
graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected) | ||
assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, ' | ||
f'undirected={undirected})') | ||
|
||
data = graph_generator() | ||
assert len(data) == 3 | ||
assert data.num_nodes == 7 | ||
assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2] | ||
if not undirected: | ||
assert data.edge_index.tolist() == [ | ||
[0, 0, 1, 1, 2, 2], | ||
[1, 2, 3, 4, 5, 6], | ||
] | ||
else: | ||
assert data.edge_index.tolist() == [ | ||
[0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6], | ||
[1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2], | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
from torch_geometric.datasets.motif_generator import GridMotif | ||
|
||
|
||
def test_grid_motif(): | ||
motif_generator = GridMotif() | ||
assert str(motif_generator) == 'GridMotif()' | ||
|
||
motif = motif_generator() | ||
assert len(motif) == 3 | ||
assert motif.num_nodes == 9 | ||
assert motif.num_edges == 24 | ||
assert motif.edge_index.size() == (2, 24) | ||
assert motif.edge_index.min() == 0 | ||
assert motif.edge_index.max() == 8 | ||
assert motif.y.size() == (9, ) | ||
assert motif.y.min() == 0 | ||
assert motif.y.max() == 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from typing import List, Optional, Tuple | ||
|
||
import torch | ||
from torch import Tensor | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.datasets.graph_generator import GraphGenerator | ||
from torch_geometric.utils import to_undirected | ||
|
||
|
||
def tree( | ||
depth: int, | ||
branch: int = 2, | ||
undirected: bool = False, | ||
device: Optional[torch.device] = None, | ||
) -> Tuple[Tensor, Tensor]: | ||
"""Generates a tree graph with the given depth and branch size, along with | ||
node-level depth indicators. | ||
Args: | ||
depth (int): The depth of the tree. | ||
branch (int, optional): The branch size of the tree. | ||
(default: :obj:`2`) | ||
undirected (bool, optional): If set to :obj:`True`, the tree graph will | ||
be undirected. (default: :obj:`False`) | ||
device (torch.device, optional): The desired device of the returned | ||
tensors. (default: :obj:`None`) | ||
""" | ||
edges: List[Tuple[int, int]] = [] | ||
depths: List[int] = [0] | ||
|
||
def add_edges(node: int, current_depth: int) -> None: | ||
node_count = len(depths) | ||
|
||
if current_depth < depth: | ||
for i in range(branch): | ||
edges.append((node, node_count + i)) | ||
depths.append(current_depth + 1) | ||
|
||
for i in range(branch): | ||
add_edges(node=node_count + i, current_depth=current_depth + 1) | ||
|
||
add_edges(node=0, current_depth=0) | ||
|
||
edge_index = torch.tensor(edges, device=device).t().contiguous() | ||
if undirected: | ||
edge_index = to_undirected(edge_index, num_nodes=len(depths)) | ||
|
||
return edge_index, torch.tensor(depths, device=device) | ||
|
||
|
||
class TreeGraph(GraphGenerator): | ||
r"""Generates tree graphs. | ||
Args: | ||
depth (int): The depth of the tree. | ||
branch (int, optional): The branch size of the tree. | ||
(default: :obj:`2`) | ||
undirected (bool, optional): If set to :obj:`True`, the tree graph will | ||
be undirected. (default: :obj:`False`) | ||
""" | ||
def __init__( | ||
self, | ||
depth: int, | ||
branch: int = 2, | ||
undirected: bool = False, | ||
) -> None: | ||
super().__init__() | ||
self.depth = depth | ||
self.branch = branch | ||
self.undirected = undirected | ||
|
||
def __call__(self) -> Data: | ||
edge_index, depth = tree(self.depth, self.branch, self.undirected) | ||
num_nodes = depth.numel() | ||
return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes) | ||
|
||
def __repr__(self) -> str: | ||
return (f'{self.__class__.__name__}(depth={self.depth}, ' | ||
f'branch={self.branch}, undirected={self.undirected})') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
|
||
from torch_geometric.data import Data | ||
from torch_geometric.datasets.motif_generator import CustomMotif | ||
|
||
|
||
class GridMotif(CustomMotif): | ||
r"""Generates the grid-structured motif from the | ||
`"GNNExplainer: Generating Explanations for Graph Neural Networks" | ||
<https://arxiv.org/abs/1903.03894>`__ paper. | ||
""" | ||
def __init__(self) -> None: | ||
edge_indices = [ | ||
[0, 1], | ||
[0, 3], | ||
[1, 4], | ||
[3, 4], | ||
[1, 2], | ||
[2, 5], | ||
[4, 5], | ||
[3, 6], | ||
[6, 7], | ||
[4, 7], | ||
[5, 8], | ||
[7, 8], | ||
[1, 0], | ||
[3, 0], | ||
[4, 1], | ||
[4, 3], | ||
[2, 1], | ||
[5, 2], | ||
[5, 4], | ||
[6, 3], | ||
[7, 6], | ||
[7, 4], | ||
[8, 5], | ||
[8, 7], | ||
] | ||
structure = Data( | ||
num_nodes=9, | ||
edge_index=torch.tensor(edge_indices).t().contiguous(), | ||
y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]), | ||
) | ||
super().__init__(structure) |