-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathsplit_dataset.py
57 lines (43 loc) · 1.47 KB
/
split_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import argparse
import glob
import os
import random
from pathlib import Path
from loguru import logger
random.seed(1234567)
def run(args):
data_root = os.path.expanduser(args.data_root)
ratio = args.train_ratio
data_path = os.path.join(data_root, "img", "*", "*.png")
img_list = glob.glob(data_path, recursive=True)
sorted(img_list)
random.shuffle(img_list)
train_size = int(len(img_list) * ratio)
train_text_path = os.path.join(data_root, "train.txt")
with open(train_text_path, "w") as file:
for item in img_list[:train_size]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
val_text_path = os.path.join(data_root, "val.txt")
with open(val_text_path, "w") as file:
for item in img_list[train_size:]:
parts = Path(item).parts
item = os.path.join(parts[-2], parts[-1])
file.write("%s\n" % item.split(".png")[0])
logger.info(f"TRAIN LABEL: {train_text_path}")
logger.info(f"VAL LABEL: {val_text_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--data_root",
type=str,
default="~/datasets/doc3d",
help="Data path to load data",
)
parser.add_argument(
"--train_ratio", type=float, default=0.8, help="Ratio of training data"
)
args = parser.parse_args()
logger.info(args)
run(args)