Skip to content

Commit

Permalink
Include all synthetic data referenced in GNNExplainer (#8736)
Browse files Browse the repository at this point in the history
A smaller PR for #8704
### Include grid motifs and tree-based datasets in GNNExplainer.
- add torch_geometric/datasets/graph_generator/tree_graph.py
- add torch_geometric/datasets/motif_generator/grid.py
- add node label in torch_geometric/datasets/motif_generator/cycle.py
(according to bug report
#8509, it is better
to directly add node label since we know what the label should be,
rather than be captured at `else` sentense in line 127 of
`./torch_geometric/datasets/explainer_dataset.py`)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: rusty1s <[email protected]>
  • Loading branch information
3 people authored Jan 7, 2024
1 parent 3fb6ff1 commit 5230418
Show file tree
Hide file tree
Showing 8 changed files with 171 additions and 1 deletion.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736))
- Added an example for edge-level temporal sampling on a heterogenous graph ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383))
- Added the `num_graphs` option to the `StochasticBlockModelDataset` ([#8648](https://github.com/pyg-team/pytorch_geometric/pull/8648))
- Added noise scheduler utility for diffusion based graph generative models ([#8347](https://github.com/pyg-team/pytorch_geometric/pull/8347))
Expand Down
25 changes: 25 additions & 0 deletions test/datasets/graph_generator/test_tree_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest

from torch_geometric.datasets.graph_generator import TreeGraph


@pytest.mark.parametrize('undirected', [False, True])
def test_tree_graph(undirected):
graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected)
assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, '
f'undirected={undirected})')

data = graph_generator()
assert len(data) == 3
assert data.num_nodes == 7
assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2]
if not undirected:
assert data.edge_index.tolist() == [
[0, 0, 1, 1, 2, 2],
[1, 2, 3, 4, 5, 6],
]
else:
assert data.edge_index.tolist() == [
[0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6],
[1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2],
]
17 changes: 17 additions & 0 deletions test/datasets/motif_generator/test_grid_motif.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from torch_geometric.datasets.motif_generator import GridMotif


def test_grid_motif():
motif_generator = GridMotif()
assert str(motif_generator) == 'GridMotif()'

motif = motif_generator()
assert len(motif) == 3
assert motif.num_nodes == 9
assert motif.num_edges == 24
assert motif.edge_index.size() == (2, 24)
assert motif.edge_index.min() == 0
assert motif.edge_index.max() == 8
assert motif.y.size() == (9, )
assert motif.y.min() == 0
assert motif.y.max() == 2
2 changes: 2 additions & 0 deletions torch_geometric/datasets/graph_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .ba_graph import BAGraph
from .er_graph import ERGraph
from .grid_graph import GridGraph
from .tree_graph import TreeGraph

__all__ = classes = [
'GraphGenerator',
'BAGraph',
'ERGraph',
'GridGraph',
'TreeGraph',
]
80 changes: 80 additions & 0 deletions torch_geometric/datasets/graph_generator/tree_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import List, Optional, Tuple

import torch
from torch import Tensor

from torch_geometric.data import Data
from torch_geometric.datasets.graph_generator import GraphGenerator
from torch_geometric.utils import to_undirected


def tree(
depth: int,
branch: int = 2,
undirected: bool = False,
device: Optional[torch.device] = None,
) -> Tuple[Tensor, Tensor]:
"""Generates a tree graph with the given depth and branch size, along with
node-level depth indicators.
Args:
depth (int): The depth of the tree.
branch (int, optional): The branch size of the tree.
(default: :obj:`2`)
undirected (bool, optional): If set to :obj:`True`, the tree graph will
be undirected. (default: :obj:`False`)
device (torch.device, optional): The desired device of the returned
tensors. (default: :obj:`None`)
"""
edges: List[Tuple[int, int]] = []
depths: List[int] = [0]

def add_edges(node: int, current_depth: int) -> None:
node_count = len(depths)

if current_depth < depth:
for i in range(branch):
edges.append((node, node_count + i))
depths.append(current_depth + 1)

for i in range(branch):
add_edges(node=node_count + i, current_depth=current_depth + 1)

add_edges(node=0, current_depth=0)

edge_index = torch.tensor(edges, device=device).t().contiguous()
if undirected:
edge_index = to_undirected(edge_index, num_nodes=len(depths))

return edge_index, torch.tensor(depths, device=device)


class TreeGraph(GraphGenerator):
r"""Generates tree graphs.
Args:
depth (int): The depth of the tree.
branch (int, optional): The branch size of the tree.
(default: :obj:`2`)
undirected (bool, optional): If set to :obj:`True`, the tree graph will
be undirected. (default: :obj:`False`)
"""
def __init__(
self,
depth: int,
branch: int = 2,
undirected: bool = False,
) -> None:
super().__init__()
self.depth = depth
self.branch = branch
self.undirected = undirected

def __call__(self) -> Data:
edge_index, depth = tree(self.depth, self.branch, self.undirected)
num_nodes = depth.numel()
return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}(depth={self.depth}, '
f'branch={self.branch}, undirected={self.undirected})')
2 changes: 2 additions & 0 deletions torch_geometric/datasets/motif_generator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
from .custom import CustomMotif
from .house import HouseMotif
from .cycle import CycleMotif
from .grid import GridMotif

__all__ = classes = [
'MotifGenerator',
'CustomMotif',
'HouseMotif',
'CycleMotif',
'GridMotif',
]
1 change: 0 additions & 1 deletion torch_geometric/datasets/motif_generator/cycle.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ def __init__(self, num_nodes: int):
num_nodes=num_nodes,
edge_index=torch.stack([row, col], dim=0),
)

super().__init__(structure)

def __repr__(self) -> str:
Expand Down
44 changes: 44 additions & 0 deletions torch_geometric/datasets/motif_generator/grid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch

from torch_geometric.data import Data
from torch_geometric.datasets.motif_generator import CustomMotif


class GridMotif(CustomMotif):
r"""Generates the grid-structured motif from the
`"GNNExplainer: Generating Explanations for Graph Neural Networks"
<https://arxiv.org/abs/1903.03894>`__ paper.
"""
def __init__(self) -> None:
edge_indices = [
[0, 1],
[0, 3],
[1, 4],
[3, 4],
[1, 2],
[2, 5],
[4, 5],
[3, 6],
[6, 7],
[4, 7],
[5, 8],
[7, 8],
[1, 0],
[3, 0],
[4, 1],
[4, 3],
[2, 1],
[5, 2],
[5, 4],
[6, 3],
[7, 6],
[7, 4],
[8, 5],
[8, 7],
]
structure = Data(
num_nodes=9,
edge_index=torch.tensor(edge_indices).t().contiguous(),
y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]),
)
super().__init__(structure)

0 comments on commit 5230418

Please sign in to comment.