Skip to content

Commit

Permalink
Merge branch 'master' into inductive_train_test_split
Browse files Browse the repository at this point in the history
  • Loading branch information
ogawayuto committed Oct 22, 2023
2 parents ab76792 + ca35214 commit 1b2cf57
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 24 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
## [2.5.0] - 2023-MM-DD

### Added
- Added the `utils.inductive_train_test_split()` and `utils.split_graph` to split graph according to given subsets of nodes([#8243](https://github.com/pyg-team/pytorch_geometric/pull/8243))

- Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247))
- Added the `utils.inductive_train_test_split()` and `utils.split_graph()` to split graph according to given subsets of nodes([#8243](https://github.com/pyg-team/pytorch_geometric/pull/8243))

- Added support for `to_hetero_with_bases` on static graphs ([#8247](https://github.com/pyg-team/pytorch_geometric/pull/8247))
- Added the `utils.inductive_train_test_split()` and `utils.split_graph()` to split graph according to given two subsets of nodes([#8243](https://github.com/pyg-team/pytorch_geometric/pull/8243))
Expand All @@ -23,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed `AttentionExplainer` usage within `AttentiveFP` ([#8244](https://github.com/pyg-team/pytorch_geometric/pull/8244))
- Fixed `load_from_state_dict` in lazy `Linear` modules ([#8242](https://github.com/pyg-team/pytorch_geometric/pull/8242))
- Fixed pre-trained `DimeNet++` performance on `QM9` ([#8239](https://github.com/pyg-team/pytorch_geometric/pull/8239))
- Fixed `GNNExplainer` usage within `AttentiveFP` ([#8216](https://github.com/pyg-team/pytorch_geometric/pull/8216))
Expand Down
23 changes: 22 additions & 1 deletion test/explain/algorithm/test_attention_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch_geometric.explain import AttentionExplainer, Explainer
from torch_geometric.explain.config import ExplanationType, MaskType
from torch_geometric.nn import GATConv, GATv2Conv, TransformerConv
from torch_geometric.nn import AttentiveFP, GATConv, GATv2Conv, TransformerConv


class AttentionGNN(torch.nn.Module):
Expand All @@ -25,6 +25,8 @@ def forward(self, x, edge_index):
[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7],
[1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6],
])
edge_attr = torch.randn(edge_index.size(1), 5)
batch = torch.tensor([0, 0, 0, 1, 1, 2, 2, 2])


@pytest.mark.parametrize('index', [None, 2, torch.arange(3)])
Expand Down Expand Up @@ -61,3 +63,22 @@ def test_attention_explainer_supports(explanation_type, node_mask_type):
return_type='raw',
),
)


def test_attention_explainer_attentive_fp(check_explanation):
model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2)

explainer = Explainer(
model=model,
algorithm=AttentionExplainer(),
explanation_type='model',
edge_mask_type='object',
model_config=dict(
mode='binary_classification',
task_level='node',
return_type='raw',
),
)

explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch)
check_explanation(explanation, None, explainer.edge_mask_type)
6 changes: 2 additions & 4 deletions test/nn/test_to_hetero_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,10 +520,8 @@ def test_to_hetero_on_static_graphs():
'author': torch.randn(4, 100, 16),
}
edge_index_dict = {
('paper', 'written_by', 'author'):
torch.randint(100, (2, 200), dtype=torch.long),
('author', 'writes', 'paper'):
torch.randint(100, (2, 200), dtype=torch.long),
('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),
('author', 'writes', 'paper'): torch.randint(100, (2, 200)),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())
Expand Down
21 changes: 21 additions & 0 deletions test/nn/test_to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,24 @@ def test_to_hetero_with_bases_validate():

with pytest.warns(UserWarning, match="letters, numbers and underscores"):
model = to_hetero_with_bases(model, metadata, num_bases=4, debug=False)


def test_to_hetero_with_bases_on_static_graphs():
x_dict = {
'paper': torch.randn(4, 100, 16),
'author': torch.randn(4, 100, 16),
}
edge_index_dict = {
('paper', 'written_by', 'author'): torch.randint(100, (2, 200)),
('author', 'writes', 'paper'): torch.randint(100, (2, 200)),
}

metadata = list(x_dict.keys()), list(edge_index_dict.keys())
model = to_hetero_with_bases(Net4(), metadata, num_bases=4,
in_channels={'x0': 16}, debug=False)

out_dict = model(x_dict, edge_index_dict)

assert len(out_dict) == 2
assert out_dict['paper'].size() == (4, 100, 32)
assert out_dict['author'].size() == (4, 100, 32)
3 changes: 2 additions & 1 deletion torch_geometric/explain/algorithm/attention_explainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ def hook(module, msg_kwargs, out):

hook_handles = []
for module in model.modules(): # Register message forward hooks:
if isinstance(module, MessagePassing):
if (isinstance(module, MessagePassing)
and module.explain is not False):
hook_handles.append(module.register_message_forward_hook(hook))

model(x, edge_index, **kwargs)
Expand Down
17 changes: 11 additions & 6 deletions torch_geometric/nn/models/attentive_fp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,22 +41,27 @@ def reset_parameters(self):
zeros(self.bias)

def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor) -> Tensor:
# propagate_type: (x: Tensor, edge_attr: Tensor)
out = self.propagate(edge_index, x=x, edge_attr=edge_attr, size=None)
# edge_updater_type: (x: Tensor, edge_attr: Tensor)
alpha = self.edge_updater(edge_index, x=x, edge_attr=edge_attr)

# propagate_type: (x: Tensor, alpha: Tensor)
out = self.propagate(edge_index, x=x, alpha=alpha, size=None)
out = out + self.bias
return out

def message(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:

def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Tensor,
index: Tensor, ptr: OptTensor,
size_i: Optional[int]) -> Tensor:
x_j = F.leaky_relu_(self.lin1(torch.cat([x_j, edge_attr], dim=-1)))
alpha_j = (x_j @ self.att_l.t()).squeeze(-1)
alpha_i = (x_i @ self.att_r.t()).squeeze(-1)
alpha = alpha_j + alpha_i
alpha = F.leaky_relu_(alpha)
alpha = softmax(alpha, index, ptr, size_i)
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
return alpha

def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
return self.lin2(x_j) * alpha.unsqueeze(-1)


Expand Down
18 changes: 9 additions & 9 deletions torch_geometric/nn/to_hetero_with_bases_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,15 +324,15 @@ def __init__(self, module: MessagePassing, num_relations: int,
# to a materialization of messages.
def hook(module, inputs, output):
assert isinstance(module._edge_type, Tensor)
if module._edge_type.size(0) != output.size(0):
if module._edge_type.size(0) != output.size(-2):
raise ValueError(
f"Number of messages ({output.size(0)}) does not match "
f"with the number of original edges "
f"({module._edge_type.size(0)}). Does your message "
f"passing layer create additional self-loops? Try to "
f"remove them via 'add_self_loops=False'")
weight = module.edge_type_weight.view(-1)[module._edge_type]
weight = weight.view([-1] + [1] * (output.dim() - 1))
weight = weight.view([1] * (output.dim() - 2) + [-1, 1])
return weight * output

params = list(module.parameters())
Expand Down Expand Up @@ -415,7 +415,7 @@ def get_node_offset_dict(
out: Dict[NodeType, int] = {}
for key in type2id.keys():
out[key] = cumsum
cumsum += input_dict[key].size(0)
cumsum += input_dict[key].size(-2)
return out


Expand All @@ -433,7 +433,7 @@ def get_edge_offset_dict(
elif value.dtype == torch.long and value.size(0) == 2:
cumsum += value.size(-1)
else:
cumsum += value.size(0)
cumsum += value.size(-2)
return out


Expand All @@ -458,7 +458,7 @@ def get_edge_type(
out = torch.full((value.nnz(), ), i, dtype=torch.long,
device=value.device())
else:
out = value.new_full((value.size(0), ), i, dtype=torch.long)
out = value.new_full((value.size(-2), ), i, dtype=torch.long)
outs.append(out)

return outs[0] if len(outs) == 1 else torch.cat(outs, dim=0)
Expand All @@ -474,7 +474,7 @@ def group_node_placeholder(input_dict: Dict[NodeType, Tensor],
type2id: Dict[NodeType, int]) -> Tensor:

inputs = [input_dict[key] for key in type2id.keys()]
return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=0)
return inputs[0] if len(inputs) == 1 else torch.cat(inputs, dim=-2)


def group_edge_placeholder(
Expand Down Expand Up @@ -528,7 +528,7 @@ def group_edge_placeholder(
return torch.stack([row, col], dim=0)

else:
return torch.cat(inputs, dim=0)
return torch.cat(inputs, dim=-2)


###############################################################################
Expand All @@ -542,9 +542,9 @@ def split_output(
offset_dict: Union[Dict[NodeType, int], Dict[EdgeType, int]],
) -> Union[Dict[NodeType, Tensor], Dict[EdgeType, Tensor]]:

cumsums = list(offset_dict.values()) + [output.size(0)]
cumsums = list(offset_dict.values()) + [output.size(-2)]
sizes = [cumsums[i + 1] - cumsums[i] for i in range(len(offset_dict))]
outputs = output.split(sizes)
outputs = output.split(sizes, dim=-2)
return {key: output for key, output in zip(offset_dict, outputs)}


Expand Down
2 changes: 0 additions & 2 deletions torch_geometric/transforms/add_positional_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,6 @@ def forward(self, data: Data) -> Data:
pe_list = [get_self_loop_attr(*to_edge_index(out), num_nodes=N)]
for _ in range(self.walk_length - 1):
out = out @ adj
if out.layout == torch.sparse_coo:
out = out._coalesced_(True)
pe_list.append(get_self_loop_attr(*to_edge_index(out), N))
pe = torch.stack(pe_list, dim=-1)

Expand Down

0 comments on commit 1b2cf57

Please sign in to comment.