Skip to content

Commit

Permalink
added test case and modify inductive split
Browse files Browse the repository at this point in the history
  • Loading branch information
ogawayuto committed Oct 22, 2023
1 parent 9a6d549 commit d6f19e0
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
14 changes: 14 additions & 0 deletions test/utils/test_inductive_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,12 @@ def test_inductive_train_test_split():
[3, 4], bridge=False)
output3, output4 = inductive_train_test_split(input_graph, [0, 1, 2],
[3, 4])
output5, output6 = inductive_train_test_split(
input_graph,
[True, True, True, False, False],
[False, False, False, True, True],
bridge=False,
)

assert torch.equal(output1.edge_index, expected_output1.edge_index)
assert torch.equal(output1.edge_attr, expected_output1.edge_attr)
Expand All @@ -102,6 +108,14 @@ def test_inductive_train_test_split():
expected_output3.bridge_edge_index)
assert torch.equal(output4.bridge_edge_attr,
expected_output3.bridge_edge_attr)
assert torch.equal(output5.edge_index, expected_output1.edge_index)
assert torch.equal(output5.edge_attr, expected_output1.edge_attr)
assert torch.equal(output5.x, expected_output1.x)
assert torch.equal(output5.y, expected_output1.y)
assert torch.equal(output6.edge_index, expected_output2.edge_index)
assert torch.equal(output6.edge_attr, expected_output2.edge_attr)
assert torch.equal(output6.x, expected_output2.x)
assert torch.equal(output6.y, expected_output2.y)

with pytest.raises(Exception) as e:
inductive_train_test_split(input_graph, [0, 1, 2], [1, 3, 4])
Expand Down
19 changes: 18 additions & 1 deletion torch_geometric/utils/inductive_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,24 @@ def inductive_train_test_split(data, train_node, test_node,
bridge_edge_attr=torch.tensor([[4], [5], [8], [9], [10], [11]]),
))
>>> inductive_train_test_split(input_graph, [0, 1, 2],
[3, 4], bridge=False)
... [3, 4], bridge=False)
(Data(
edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),
edge_attr=torch.tensor([[0], [1], [2], [3]]),
x=torch.tensor([[0], [1], [2]]),
y=torch.tensor([0, 1, 2]),
),
Data(
edge_index=torch.tensor([[3, 4], [4, 3]]),
edge_attr=torch.tensor([[6], [7]]),
x=torch.tensor([[3], [4]]),
y=torch.tensor([3, 4]),
))
>>> inductive_train_test_split( input_graph,
... [True, True, True, False, False],
... [False, False, False, True, True],
... bridge=False,
... )
(Data(
edge_index=torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]]),
edge_attr=torch.tensor([[0], [1], [2], [3]]),
Expand Down

0 comments on commit d6f19e0

Please sign in to comment.