diff --git a/CHANGELOG.md b/CHANGELOG.md index a3f4fe46baca..4e71ac0e5c2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added `TreeGraph` and `GridMotif` generators ([#8736](https://github.com/pyg-team/pytorch_geometric/pull/8736)) - Added an example for edge-level temporal sampling on a heterogenous graph ([#8383](https://github.com/pyg-team/pytorch_geometric/pull/8383)) - Added the `num_graphs` option to the `StochasticBlockModelDataset` ([#8648](https://github.com/pyg-team/pytorch_geometric/pull/8648)) - Added noise scheduler utility for diffusion based graph generative models ([#8347](https://github.com/pyg-team/pytorch_geometric/pull/8347)) diff --git a/test/datasets/graph_generator/test_tree_graph.py b/test/datasets/graph_generator/test_tree_graph.py new file mode 100644 index 000000000000..d32ee0adb35a --- /dev/null +++ b/test/datasets/graph_generator/test_tree_graph.py @@ -0,0 +1,25 @@ +import pytest + +from torch_geometric.datasets.graph_generator import TreeGraph + + +@pytest.mark.parametrize('undirected', [False, True]) +def test_tree_graph(undirected): + graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected) + assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, ' + f'undirected={undirected})') + + data = graph_generator() + assert len(data) == 3 + assert data.num_nodes == 7 + assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2] + if not undirected: + assert data.edge_index.tolist() == [ + [0, 0, 1, 1, 2, 2], + [1, 2, 3, 4, 5, 6], + ] + else: + assert data.edge_index.tolist() == [ + [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6], + [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2], + ] diff --git a/test/datasets/motif_generator/test_grid_motif.py b/test/datasets/motif_generator/test_grid_motif.py new file mode 100644 index 000000000000..84271a094812 --- /dev/null +++ b/test/datasets/motif_generator/test_grid_motif.py @@ -0,0 +1,17 @@ +from torch_geometric.datasets.motif_generator import GridMotif + + +def test_grid_motif(): + motif_generator = GridMotif() + assert str(motif_generator) == 'GridMotif()' + + motif = motif_generator() + assert len(motif) == 3 + assert motif.num_nodes == 9 + assert motif.num_edges == 24 + assert motif.edge_index.size() == (2, 24) + assert motif.edge_index.min() == 0 + assert motif.edge_index.max() == 8 + assert motif.y.size() == (9, ) + assert motif.y.min() == 0 + assert motif.y.max() == 2 diff --git a/torch_geometric/datasets/graph_generator/__init__.py b/torch_geometric/datasets/graph_generator/__init__.py index a6e64e8ac89c..65298bb92fd0 100644 --- a/torch_geometric/datasets/graph_generator/__init__.py +++ b/torch_geometric/datasets/graph_generator/__init__.py @@ -2,10 +2,12 @@ from .ba_graph import BAGraph from .er_graph import ERGraph from .grid_graph import GridGraph +from .tree_graph import TreeGraph __all__ = classes = [ 'GraphGenerator', 'BAGraph', 'ERGraph', 'GridGraph', + 'TreeGraph', ] diff --git a/torch_geometric/datasets/graph_generator/tree_graph.py b/torch_geometric/datasets/graph_generator/tree_graph.py new file mode 100644 index 000000000000..6af40858ed7f --- /dev/null +++ b/torch_geometric/datasets/graph_generator/tree_graph.py @@ -0,0 +1,80 @@ +from typing import List, Optional, Tuple + +import torch +from torch import Tensor + +from torch_geometric.data import Data +from torch_geometric.datasets.graph_generator import GraphGenerator +from torch_geometric.utils import to_undirected + + +def tree( + depth: int, + branch: int = 2, + undirected: bool = False, + device: Optional[torch.device] = None, +) -> Tuple[Tensor, Tensor]: + """Generates a tree graph with the given depth and branch size, along with + node-level depth indicators. + + Args: + depth (int): The depth of the tree. + branch (int, optional): The branch size of the tree. + (default: :obj:`2`) + undirected (bool, optional): If set to :obj:`True`, the tree graph will + be undirected. (default: :obj:`False`) + device (torch.device, optional): The desired device of the returned + tensors. (default: :obj:`None`) + """ + edges: List[Tuple[int, int]] = [] + depths: List[int] = [0] + + def add_edges(node: int, current_depth: int) -> None: + node_count = len(depths) + + if current_depth < depth: + for i in range(branch): + edges.append((node, node_count + i)) + depths.append(current_depth + 1) + + for i in range(branch): + add_edges(node=node_count + i, current_depth=current_depth + 1) + + add_edges(node=0, current_depth=0) + + edge_index = torch.tensor(edges, device=device).t().contiguous() + if undirected: + edge_index = to_undirected(edge_index, num_nodes=len(depths)) + + return edge_index, torch.tensor(depths, device=device) + + +class TreeGraph(GraphGenerator): + r"""Generates tree graphs. + + Args: + depth (int): The depth of the tree. + branch (int, optional): The branch size of the tree. + (default: :obj:`2`) + undirected (bool, optional): If set to :obj:`True`, the tree graph will + be undirected. (default: :obj:`False`) + """ + def __init__( + self, + depth: int, + branch: int = 2, + undirected: bool = False, + ) -> None: + super().__init__() + self.depth = depth + self.branch = branch + self.undirected = undirected + + def __call__(self) -> Data: + edge_index, depth = tree(self.depth, self.branch, self.undirected) + num_nodes = depth.numel() + return Data(edge_index=edge_index, depth=depth, num_nodes=num_nodes) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(depth={self.depth}, ' + f'branch={self.branch}, undirected={self.undirected})') diff --git a/torch_geometric/datasets/motif_generator/__init__.py b/torch_geometric/datasets/motif_generator/__init__.py index b97b639468a1..ca42aaa5431e 100644 --- a/torch_geometric/datasets/motif_generator/__init__.py +++ b/torch_geometric/datasets/motif_generator/__init__.py @@ -2,10 +2,12 @@ from .custom import CustomMotif from .house import HouseMotif from .cycle import CycleMotif +from .grid import GridMotif __all__ = classes = [ 'MotifGenerator', 'CustomMotif', 'HouseMotif', 'CycleMotif', + 'GridMotif', ] diff --git a/torch_geometric/datasets/motif_generator/cycle.py b/torch_geometric/datasets/motif_generator/cycle.py index 1bb3c6f17ae6..78eaf5542b7f 100644 --- a/torch_geometric/datasets/motif_generator/cycle.py +++ b/torch_geometric/datasets/motif_generator/cycle.py @@ -24,7 +24,6 @@ def __init__(self, num_nodes: int): num_nodes=num_nodes, edge_index=torch.stack([row, col], dim=0), ) - super().__init__(structure) def __repr__(self) -> str: diff --git a/torch_geometric/datasets/motif_generator/grid.py b/torch_geometric/datasets/motif_generator/grid.py new file mode 100644 index 000000000000..997627e30728 --- /dev/null +++ b/torch_geometric/datasets/motif_generator/grid.py @@ -0,0 +1,44 @@ +import torch + +from torch_geometric.data import Data +from torch_geometric.datasets.motif_generator import CustomMotif + + +class GridMotif(CustomMotif): + r"""Generates the grid-structured motif from the + `"GNNExplainer: Generating Explanations for Graph Neural Networks" + `__ paper. + """ + def __init__(self) -> None: + edge_indices = [ + [0, 1], + [0, 3], + [1, 4], + [3, 4], + [1, 2], + [2, 5], + [4, 5], + [3, 6], + [6, 7], + [4, 7], + [5, 8], + [7, 8], + [1, 0], + [3, 0], + [4, 1], + [4, 3], + [2, 1], + [5, 2], + [5, 4], + [6, 3], + [7, 6], + [7, 4], + [8, 5], + [8, 7], + ] + structure = Data( + num_nodes=9, + edge_index=torch.tensor(edge_indices).t().contiguous(), + y=torch.tensor([0, 1, 0, 1, 2, 1, 0, 1, 0]), + ) + super().__init__(structure)