diff --git a/torch_geometric/utils/inductive_train_test_split.py b/torch_geometric/utils/inductive_train_test_split.py index faf059edde04..cac1754117ff 100644 --- a/torch_geometric/utils/inductive_train_test_split.py +++ b/torch_geometric/utils/inductive_train_test_split.py @@ -208,10 +208,20 @@ def inductive_train_test_split(data, train_node, test_node, device = data.edge_index.device if isinstance(train_node, (list, tuple)): - train_node = torch.tensor(train_node, dtype=torch.long, device=device) + if isinstance(train_node[0], bool): + train_node = torch.tensor(train_node, dtype=torch.bool, + device=device) + else: + train_node = torch.tensor(train_node, dtype=torch.long, + device=device) if isinstance(test_node, (list, tuple)): - test_node = torch.tensor(test_node, dtype=torch.long, device=device) + if isinstance(test_node[0], bool): + test_node = torch.tensor(test_node, dtype=torch.bool, + device=device) + else: + test_node = torch.tensor(test_node, dtype=torch.long, + device=device) if train_node.dtype != torch.bool: num_nodes = maybe_num_nodes(data.edge_index, data.num_nodes)