From 1e6dc8e2cdc589d3e152a6cb2215303d93966a07 Mon Sep 17 00:00:00 2001 From: ogawayuto Date: Sun, 22 Oct 2023 13:12:16 +0900 Subject: [PATCH] update inductive split --- .../utils/inductive_train_test_split.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/torch_geometric/utils/inductive_train_test_split.py b/torch_geometric/utils/inductive_train_test_split.py index faf059edde04..cac1754117ff 100644 --- a/torch_geometric/utils/inductive_train_test_split.py +++ b/torch_geometric/utils/inductive_train_test_split.py @@ -208,10 +208,20 @@ def inductive_train_test_split(data, train_node, test_node, device = data.edge_index.device if isinstance(train_node, (list, tuple)): - train_node = torch.tensor(train_node, dtype=torch.long, device=device) + if isinstance(train_node[0], bool): + train_node = torch.tensor(train_node, dtype=torch.bool, + device=device) + else: + train_node = torch.tensor(train_node, dtype=torch.long, + device=device) if isinstance(test_node, (list, tuple)): - test_node = torch.tensor(test_node, dtype=torch.long, device=device) + if isinstance(test_node[0], bool): + test_node = torch.tensor(test_node, dtype=torch.bool, + device=device) + else: + test_node = torch.tensor(test_node, dtype=torch.long, + device=device) if train_node.dtype != torch.bool: num_nodes = maybe_num_nodes(data.edge_index, data.num_nodes)