-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdatasets.py
104 lines (82 loc) · 3.02 KB
/
datasets.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import numpy as np
from PIL import Image
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import my_utils
# Training dataset
class WheatDataset_training(Dataset):
def __init__(self, df, base_dir):
self.df = df
self.base_dir = base_dir
self.image_name = df['image_name']
self.transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.image_name)
def __getitem__(self, index):
image_name = self.df.iloc[index].image_name
path = os.path.join(self.base_dir, 'train', 'train', f'{image_name}.png')
#for trainng with pseudo-labels
if not os.path.isfile(path):
path = os.path.join(self.base_dir, 'test', f'{image_name}.png')
image = Image.open(path)
bboxes_str = self.df.iloc[index].BoxesString
bboxes = bboxes_str.split(';')
n_objects = len(bboxes) # Number of wheat heads in the given image
boxes, areas = [], []
for bbox in bboxes:
if(bbox!='no_box'):
box = list(map(float,bbox.split()))
area = my_utils.area(box)
if(area > 200000):
continue
boxes.append(box)
areas.append(area)
else:
n_objects = 0
# Convert everything into a torch.Tensor
boxes = torch.as_tensor(boxes, dtype=torch.float32)
# Get the labels. We have only one class (wheat head)
labels = torch.ones((n_objects, ), dtype=torch.int64)
areas = torch.as_tensor(areas)
# suppose all instances are not crowd
iscrowd = torch.zeros((n_objects, ), dtype=torch.int64)
if(n_objects == 0):
boxes = torch.zeros((0,4), dtype=torch.float32)
labels = torch.zeros(0, dtype=torch.int64)
areas = torch.zeros(0, dtype=torch.float32)
iscrowd = torch.zeros((0,), dtype=torch.int64)
target = {
'boxes': boxes,
'labels': labels,
'image_id': torch.tensor([index], dtype=torch.int64),
'area': areas,
'iscrowd': iscrowd
}
image = self.transform(image)
target['boxes'] = target['boxes'].float()
return image, target
# Test dataset
class WheatDataset_test(Dataset):
def __init__(self, df, base_dir):
self.base_dir = base_dir
self.image_ids = df['image_name']
self.domains = df['domain']
self.transform = transforms.Compose(
[
transforms.ToTensor(),
]
)
def __len__(self):
return len(self.image_ids)
def __getitem__(self, index):
image_id = self.image_ids[index]
domain = self.domains[index]
image = Image.open(os.path.join(self.base_dir, 'test', f'{image_id}.png'))
image = self.transform(image)
return image_id, image, domain