Skip to content

Commit

Permalink
Merge branch 'master' into multinode_multigpu_tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
puririshi98 authored Oct 9, 2023
2 parents 137f730 + 4d280f5 commit e392a65
Show file tree
Hide file tree
Showing 6 changed files with 102 additions and 54 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Fixed `HeteroConv` for layers that have a non-default argument order, *e.g.*, `GCN2Conv` ([#8166](https://github.com/pyg-team/pytorch_geometric/pull/8166))
- Handle reserved keywords as keys in `ModuleDict` and `ParameterDict` ([#8163](https://github.com/pyg-team/pytorch_geometric/pull/8163))
- Updated the examples and tutorials to account for `torch.compile(dynamic=True)` in PyTorch 2.1.0 ([#8145](https://github.com/pyg-team/pytorch_geometric/pull/8145))
- Enabled dense eigenvalue computation in `AddLaplacianEigenvectorPE` for small-scale graphs ([#8143](https://github.com/pyg-team/pytorch_geometric/pull/8143))
Expand Down
10 changes: 4 additions & 6 deletions test/explain/algorithm/test_explain_algorithm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,16 +53,14 @@ def test_set_clear_mask(hetero_data):
set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict)
for edge_type in hetero_data.edge_types:
# Check that masks are correctly set:
str_edge_type = '__'.join(edge_type)
assert torch.allclose(model.conv1.convs[str_edge_type]._edge_mask,
assert torch.allclose(model.conv1.convs[edge_type]._edge_mask,
edge_mask_dict[edge_type])
assert model.conv1.convs[str_edge_type].explain
assert model.conv1.convs[edge_type].explain

clear_masks(model)
for edge_type in hetero_data.edge_types:
str_edge_type = '__'.join(edge_type)
assert model.conv1.convs[str_edge_type]._edge_mask is None
assert not model.conv1.convs[str_edge_type].explain
assert model.conv1.convs[edge_type]._edge_mask is None
assert not model.conv1.convs[edge_type].explain

model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False)

Expand Down
79 changes: 59 additions & 20 deletions test/nn/conv/test_hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.data import HeteroData
from torch_geometric.nn import (
GATConv,
GCN2Conv,
GCNConv,
HeteroConv,
Linear,
Expand Down Expand Up @@ -35,24 +36,55 @@ def test_hetero_conv(aggr):
SAGEConv((-1, -1), 64),
('paper', 'to', 'author'):
GATConv((-1, -1), 64, edge_dim=3, add_self_loops=False),
}, aggr=aggr)
},
aggr=aggr,
)

assert len(list(conv.parameters())) > 0
assert str(conv) == 'HeteroConv(num_relations=3)'

out = conv(data.x_dict, data.edge_index_dict, data.edge_attr_dict,
edge_weight_dict=data.edge_weight_dict)
out_dict = conv(
data.x_dict,
data.edge_index_dict,
data.edge_attr_dict,
edge_weight_dict=data.edge_weight_dict,
)

assert len(out) == 2
assert len(out_dict) == 2
if aggr == 'cat':
assert out['paper'].size() == (50, 128)
assert out['author'].size() == (30, 64)
assert out_dict['paper'].size() == (50, 128)
assert out_dict['author'].size() == (30, 64)
elif aggr is not None:
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)
else:
assert out['paper'].size() == (50, 2, 64)
assert out['author'].size() == (30, 1, 64)
assert out_dict['paper'].size() == (50, 2, 64)
assert out_dict['author'].size() == (30, 1, 64)


def test_gcn2_hetero_conv():
data = HeteroData()
data['paper'].x = torch.randn(50, 32)
data['author'].x = torch.randn(30, 64)
data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200)
data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100)
data['paper', 'paper'].edge_weight = torch.rand(200)

conv = HeteroConv({
('paper', 'to', 'paper'): GCN2Conv(32, alpha=0.1),
('author', 'to', 'author'): GCN2Conv(64, alpha=0.2),
})

out_dict = conv(
data.x_dict,
data.x_dict,
data.edge_index_dict,
edge_weight_dict=data.edge_weight_dict,
)

assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 32)
assert out_dict['author'].size() == (30, 64)


class CustomConv(MessagePassing):
Expand Down Expand Up @@ -81,11 +113,15 @@ def test_hetero_conv_with_custom_conv():

conv = HeteroConv({key: CustomConv(64) for key in data.edge_types})
# Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`:
out = conv(data.x_dict, data.edge_index_dict, data.y_dict,
z_dict=data.z_dict)
assert len(out) == 2
assert out['paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)
out_dict = conv(
data.x_dict,
data.edge_index_dict,
data.y_dict,
z_dict=data.z_dict,
)
assert len(out_dict) == 2
assert out_dict['paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)


class MessagePassingLoops(MessagePassing):
Expand Down Expand Up @@ -122,9 +158,12 @@ def test_hetero_conv_with_dot_syntax_node_types():
assert len(list(conv.parameters())) > 0
assert str(conv) == 'HeteroConv(num_relations=3)'

out = conv(data.x_dict, data.edge_index_dict,
edge_weight_dict=data.edge_weight_dict)
out_dict = conv(
data.x_dict,
data.edge_index_dict,
edge_weight_dict=data.edge_weight_dict,
)

assert len(out) == 2
assert out['src.paper'].size() == (50, 64)
assert out['author'].size() == (30, 64)
assert len(out_dict) == 2
assert out_dict['src.paper'].size() == (50, 64)
assert out_dict['author'].size() == (30, 64)
22 changes: 13 additions & 9 deletions torch_geometric/explain/algorithm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,19 @@ def set_hetero_masks(
for module in model.modules():
if isinstance(module, torch.nn.ModuleDict):
for edge_type in mask_dict.keys():
# TODO (jinu) Use common function get `str_edge_type`.
str_edge_type = '__'.join(edge_type)
if str_edge_type in module:
set_masks(
module[str_edge_type],
mask_dict[edge_type],
edge_index_dict[edge_type],
apply_sigmoid=apply_sigmoid,
)
if edge_type in module:
edge_level_module = module[edge_type]
elif '__'.join(edge_type) in module:
edge_level_module = module['__'.join(edge_type)]
else:
continue

set_masks(
edge_level_module,
mask_dict[edge_type],
edge_index_dict[edge_type],
apply_sigmoid=apply_sigmoid,
)


def clear_masks(model: torch.nn.Module):
Expand Down
42 changes: 23 additions & 19 deletions torch_geometric/nn/conv/hetero_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.module_dict import ModuleDict
from torch_geometric.typing import Adj, EdgeType, NodeType
from torch_geometric.typing import EdgeType, NodeType
from torch_geometric.utils.hetero import check_add_self_loops


Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(
f"passing as they do not occur as destination type in any "
f"edge type. This may lead to unexpected behavior.")

self.convs = ModuleDict({'__'.join(k): v for k, v in convs.items()})
self.convs = ModuleDict(convs)
self.aggr = aggr

def reset_parameters(self):
Expand All @@ -90,8 +90,6 @@ def reset_parameters(self):

def forward(
self,
x_dict: Dict[NodeType, Tensor],
edge_index_dict: Dict[EdgeType, Adj],
*args_dict,
**kwargs_dict,
) -> Dict[NodeType, Tensor]:
Expand All @@ -117,42 +115,48 @@ def forward(
:obj:`edge_attr_dict = { edge_type: edge_attr }`.
"""
out_dict = defaultdict(list)
for edge_type, edge_index in edge_index_dict.items():

for edge_type in self.convs.keys():
src, rel, dst = edge_type

str_edge_type = '__'.join(edge_type)
if str_edge_type not in self.convs:
continue
has_edge_level_arg = False

args = []
for value_dict in args_dict:
if edge_type in value_dict:
has_edge_level_arg = True
args.append(value_dict[edge_type])
elif src == dst and src in value_dict:
args.append(value_dict[src])
elif src in value_dict or dst in value_dict:
args.append(
(value_dict.get(src, None), value_dict.get(dst, None)))
args.append((
value_dict.get(src, None),
value_dict.get(dst, None),
))

kwargs = {}
for arg, value_dict in kwargs_dict.items():
if not arg.endswith('_dict'):
raise ValueError(
f"Keyword arguments in '{self.__class__.__name__}' "
f"need to end with '_dict' (got '{arg}')")

arg = arg[:-5] # `{*}_dict`
if edge_type in value_dict:
has_edge_level_arg = True
kwargs[arg] = value_dict[edge_type]
elif src == dst and src in value_dict:
kwargs[arg] = value_dict[src]
elif src in value_dict or dst in value_dict:
kwargs[arg] = (value_dict.get(src, None),
value_dict.get(dst, None))
kwargs[arg] = (
value_dict.get(src, None),
value_dict.get(dst, None),
)

conv = self.convs[str_edge_type]

if src == dst:
out = conv(x_dict[src], edge_index, *args, **kwargs)
else:
out = conv((x_dict[src], x_dict[dst]), edge_index, *args,
**kwargs)
if not has_edge_level_arg:
continue

out = self.convs[edge_type](*args, **kwargs)
out_dict[dst].append(out)

for key, value in out_dict.items():
Expand Down
2 changes: 2 additions & 0 deletions torch_geometric/nn/conv/utils/cheatsheet.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def supports_lazy_initialization(cls: str) -> bool:


def processes_heterogeneous_graphs(cls: str) -> bool:
if 'hetero' in cls.lower():
return True
cls = importlib.import_module('torch_geometric.nn.conv').__dict__[cls]
signature = inspect.signature(cls.forward)
return 'edge_index_dict' in str(signature) or 'edge_type' in str(signature)
Expand Down

0 comments on commit e392a65

Please sign in to comment.