Skip to content

Commit

Permalink
Allow to_hetero_with_bases to be applied on static graphs (#8247)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Oct 22, 2023
1 parent 405ef2c commit b17555e
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 13 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247))
- Added the `RCDD` dataset ([#8196](https://github.com/pyg-team/pytorch_geometric/pull/8196))
- Added distributed `GAT + ogbn-products` example targeting XPU device ([#8032](https://github.com/pyg-team/pytorch_geometric/pull/8032))
- Added the option to skip explanations of certain message passing layers via `conv.explain = False` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
Expand Down
6 changes: 2 additions & 4 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,8 @@ def test_to_hetero_on_static_graphs():
'author': torch.randn(4, 100, 16),
}
edge_index_dict = {
('paper', 'written_by', 'author'):
torch.randint(100, (2, 200), dtype=torch.long),
('author', 'writes', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),
('author', 'writes', 'paper'): torch.randint(100, (2, 200)),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())
Expand Down
21 changes: 21 additions & 0 deletions test/nn/test_to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,24 @@ def test_to_hetero_with_bases_validate():

with pytest.warns(UserWarning, match="letters, numbers and underscores"):
model = to_hetero_with_bases(model, metadata, num_bases=4, debug=False)


def test_to_hetero_with_bases_on_static_graphs():
x_dict = {
'paper': torch.randn(4, 100, 16),
'author': torch.randn(4, 100, 16),
}
edge_index_dict = {
('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),
('author', 'writes', 'paper'): torch.randint(100, (2, 200)),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())
model = to_hetero_with_bases(Net4(), metadata, num_bases=4,
in_channels={'x0': 16}, debug=False)

out_dict = model(x_dict, edge_index_dict)

assert len(out_dict) == 2
assert out_dict['paper'].size() == (4, 100, 32)
assert out_dict['author'].size() == (4, 100, 32)
18 changes: 9 additions & 9 deletions torch_geometric/nn/to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ def __init__(self, module: MessagePassing, num_relations: int,
# to a materialization of messages.
def hook(module, inputs, output):
assert isinstance(module._edge_type, Tensor)
if module._edge_type.size(0) != output.size(0):
if module._edge_type.size(0) != output.size(-2):
raise ValueError(
f"Number of messages ({output.size(0)}) does not match "
f"with the number of original edges "
f"({module._edge_type.size(0)}). Does your message "
f"passing layer create additional self-loops? Try to "
f"remove them via 'add_self_loops=False'")
weight = module.edge_type_weight.view(-1)[module._edge_type]
weight = weight.view([-1] + [1] * (output.dim() - 1))
weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
return weight * output

params = list(module.parameters())
Expand Down Expand Up @@ -415,7 +415,7 @@ def get_node_offset_dict(
out: Dict[NodeType, int] = {}
for key in type2id.keys():
out[key] = cumsum
cumsum += input_dict[key].size(0)
cumsum += input_dict[key].size(-2)
return out


Expand All @@ -433,7 +433,7 @@ def get_edge_offset_dict(
elif value.dtype == torch.long and value.size(0) == 2:
cumsum += value.size(-1)
else:
cumsum += value.size(0)
cumsum += value.size(-2)
return out


Expand All @@ -458,7 +458,7 @@ def get_edge_type(
out = torch.full((value.nnz(), ), i, dtype=torch.long,
device=value.device())
else:
out = value.new_full((value.size(0), ), i, dtype=torch.long)
out = value.new_full((value.size(-2), ), i, dtype=torch.long)
outs.append(out)

return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0)
Expand All @@ -474,7 +474,7 @@ def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
type2id: Dict[NodeType, int]) -> Tensor:

inputs = [input_dict[key] for key in type2id.keys()]
return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=0)
return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=-2)


def group_edge_placeholder(
Expand Down Expand Up @@ -528,7 +528,7 @@ def group_edge_placeholder(
return torch.stack([row, col], dim=0)

else:
return torch.cat(inputs, dim=0)
return torch.cat(inputs, dim=-2)


###############################################################################
Expand All @@ -542,9 +542,9 @@ def split_output(
offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]],
) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]:

cumsums = list(offset_dict.values()) + [output.size(0)]
cumsums = list(offset_dict.values()) + [output.size(-2)]
sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))]
outputs = output.split(sizes)
outputs = output.split(sizes, dim=-2)
return {key: output for key, output in zip(offset_dict, outputs)}


Expand Down

0 comments on commit b17555e

Please sign in to comment.