-
Notifications
You must be signed in to change notification settings - Fork 15
/
Copy pathdata.py
131 lines (114 loc) · 5.82 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import config
import utils
import random
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from tqdm import tqdm
from torchvision.datasets.voc import VOCDetection
from torch.utils.data import Dataset
class YoloPascalVocDataset(Dataset):
def __init__(self, set_type, normalize=False, augment=False):
assert set_type in {'train', 'test'}
self.dataset = VOCDetection(
root=config.DATA_PATH,
year='2007',
image_set=('train' if set_type == 'train' else 'val'),
download=True,
transform=T.Compose([
T.ToTensor(),
T.Resize(config.IMAGE_SIZE)
])
)
self.normalize = normalize
self.augment = augment
self.classes = utils.load_class_dict()
# Generate class index if needed
index = 0
if len(self.classes) == 0:
for i, data_pair in enumerate(tqdm(self.dataset, desc=f'Generating class dict')):
data, label = data_pair
for j, bbox_pair in enumerate(utils.get_bounding_boxes(label)):
name, coords = bbox_pair
if name not in self.classes:
self.classes[name] = index
index += 1
utils.save_class_dict(self.classes)
def __getitem__(self, i):
data, label = self.dataset[i]
original_data = data
x_shift = int((0.2 * random.random() - 0.1) * config.IMAGE_SIZE[0])
y_shift = int((0.2 * random.random() - 0.1) * config.IMAGE_SIZE[1])
scale = 1 + 0.2 * random.random()
# Augment images
if self.augment:
data = TF.affine(data, angle=0.0, scale=scale, translate=(x_shift, y_shift), shear=0.0)
data = TF.adjust_hue(data, 0.2 * random.random() - 0.1)
data = TF.adjust_saturation(data, 0.2 * random.random() + 0.9)
if self.normalize:
data = TF.normalize(data, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
grid_size_x = data.size(dim=2) / config.S # Images in PyTorch have size (channels, height, width)
grid_size_y = data.size(dim=1) / config.S
# Process bounding boxes into the SxSx(5*B+C) ground truth tensor
boxes = {}
class_names = {} # Track what class each grid cell has been assigned to
depth = 5 * config.B + config.C # 5 numbers per bbox, then one-hot encoding of label
ground_truth = torch.zeros((config.S, config.S, depth))
for j, bbox_pair in enumerate(utils.get_bounding_boxes(label)):
name, coords = bbox_pair
assert name in self.classes, f"Unrecognized class '{name}'"
class_index = self.classes[name]
x_min, x_max, y_min, y_max = coords
# Augment labels
if self.augment:
half_width = config.IMAGE_SIZE[0] / 2
half_height = config.IMAGE_SIZE[1] / 2
x_min = utils.scale_bbox_coord(x_min, half_width, scale) + x_shift
x_max = utils.scale_bbox_coord(x_max, half_width, scale) + x_shift
y_min = utils.scale_bbox_coord(y_min, half_height, scale) + y_shift
y_max = utils.scale_bbox_coord(y_max, half_height, scale) + y_shift
# Calculate the position of center of bounding box
mid_x = (x_max + x_min) / 2
mid_y = (y_max + y_min) / 2
col = int(mid_x // grid_size_x)
row = int(mid_y // grid_size_y)
if 0 <= col < config.S and 0 <= row < config.S:
cell = (row, col)
if cell not in class_names or name == class_names[cell]:
# Insert class one-hot encoding into ground truth
one_hot = torch.zeros(config.C)
one_hot[class_index] = 1.0
ground_truth[row, col, :config.C] = one_hot
class_names[cell] = name
# Insert bounding box into ground truth tensor
bbox_index = boxes.get(cell, 0)
if bbox_index < config.B:
bbox_truth = (
(mid_x - col * grid_size_x) / config.IMAGE_SIZE[0], # X coord relative to grid square
(mid_y - row * grid_size_y) / config.IMAGE_SIZE[1], # Y coord relative to grid square
(x_max - x_min) / config.IMAGE_SIZE[0], # Width
(y_max - y_min) / config.IMAGE_SIZE[1], # Height
1.0 # Confidence
)
# Fill all bbox slots with current bbox (starting from current bbox slot, avoid overriding prev)
# This prevents having "dead" boxes (zeros) at the end, which messes up IOU loss calculations
bbox_start = 5 * bbox_index + config.C
ground_truth[row, col, bbox_start:] = torch.tensor(bbox_truth).repeat(config.B - bbox_index)
boxes[cell] = bbox_index + 1
return data, ground_truth, original_data
def __len__(self):
return len(self.dataset)
if __name__ == '__main__':
# Display data
obj_classes = utils.load_class_array()
train_set = YoloPascalVocDataset('train', normalize=True, augment=True)
negative_labels = 0
smallest = 0
largest = 0
for data, label, _ in train_set:
negative_labels += torch.sum(label < 0).item()
smallest = min(smallest, torch.min(data).item())
largest = max(largest, torch.max(data).item())
utils.plot_boxes(data, label, obj_classes, max_overlap=float('inf'))
# print('num_negatives', negative_labels)
# print('dist', smallest, largest)