Skip to content

Commit

Permalink
update inductive split
Browse files Browse the repository at this point in the history
  • Loading branch information
ogawayuto committed Oct 22, 2023
1 parent d6f19e0 commit 1e6dc8e
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torch_geometric/utils/inductive_train_test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,20 @@ def inductive_train_test_split(data, train_node, test_node,
device = data.edge_index.device

if isinstance(train_node, (list, tuple)):
train_node = torch.tensor(train_node, dtype=torch.long, device=device)
if isinstance(train_node[0], bool):
train_node = torch.tensor(train_node, dtype=torch.bool,
device=device)
else:
train_node = torch.tensor(train_node, dtype=torch.long,
device=device)

if isinstance(test_node, (list, tuple)):
test_node = torch.tensor(test_node, dtype=torch.long, device=device)
if isinstance(test_node[0], bool):
test_node = torch.tensor(test_node, dtype=torch.bool,
device=device)
else:
test_node = torch.tensor(test_node, dtype=torch.long,
device=device)

if train_node.dtype != torch.bool:
num_nodes = maybe_num_nodes(data.edge_index, data.num_nodes)
Expand Down

0 comments on commit 1e6dc8e

Please sign in to comment.