-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
323 lines (251 loc) · 12.6 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import xml.etree.ElementTree as ET
import numpy as np
import os
import matplotlib.pyplot as plt
import matplotlib.patches as patches
from tqdm import tqdm
import torch
from torchvision import ops
import torch.nn.functional as F
import torch.optim as optim
# -------------- Data Untils -------------------
device = 'cuda:2'
def parse_annotation(annotation_path, image_dir, img_size):
'''
Traverse the xml tree, get the annotations, and resize them to the scaled image size
'''
img_h, img_w = img_size
with open(annotation_path, "r") as f:
tree = ET.parse(f)
root = tree.getroot()
img_paths = []
gt_boxes_all = []
gt_classes_all = []
# get image paths
for object_ in root.findall('image'):
img_path = os.path.join(image_dir, object_.get("name"))
img_paths.append(img_path)
# get raw image size
orig_w = int(object_.get("width"))
orig_h = int(object_.get("height"))
# get bboxes and their labels
groundtruth_boxes = []
groundtruth_classes = []
for box_ in object_.findall('box'):
xmin = float(box_.get("xtl"))
ymin = float(box_.get("ytl"))
xmax = float(box_.get("xbr"))
ymax = float(box_.get("ybr"))
# rescale bboxes
bbox = torch.Tensor([xmin, ymin, xmax, ymax])
bbox[[0, 2]] = bbox[[0, 2]] * img_w/orig_w
bbox[[1, 3]] = bbox[[1, 3]] * img_h/orig_h
groundtruth_boxes.append(bbox.tolist())
# get labels
label = box_.get("label")
groundtruth_classes.append(label)
gt_boxes_all.append(torch.Tensor(groundtruth_boxes))
gt_classes_all.append(groundtruth_classes)
return gt_boxes_all, gt_classes_all, img_paths
# -------------- Prepocessing utils ----------------
def calc_gt_offsets(pos_anc_coords, gt_bbox_mapping):
pos_anc_coords = ops.box_convert(pos_anc_coords, in_fmt='xyxy', out_fmt='cxcywh').cuda(device)
gt_bbox_mapping = ops.box_convert(gt_bbox_mapping, in_fmt='xyxy', out_fmt='cxcywh').cuda(device)
gt_cx, gt_cy, gt_w, gt_h = gt_bbox_mapping[:, 0], gt_bbox_mapping[:, 1], gt_bbox_mapping[:, 2], gt_bbox_mapping[:, 3]
anc_cx, anc_cy, anc_w, anc_h = pos_anc_coords[:, 0], pos_anc_coords[:, 1], pos_anc_coords[:, 2], pos_anc_coords[:, 3]
tx_ = (gt_cx - anc_cx)/anc_w
ty_ = (gt_cy - anc_cy)/anc_h
tw_ = torch.log(gt_w / anc_w)
th_ = torch.log(gt_h / anc_h)
return torch.stack([tx_, ty_, tw_, th_], dim=-1)
def gen_anc_centers(out_size):
out_h, out_w = out_size
anc_pts_x = torch.arange(0, out_w) + 0.5
anc_pts_y = torch.arange(0, out_h) + 0.5
return anc_pts_x, anc_pts_y
def project_bboxes(bboxes, width_scale_factor, height_scale_factor, mode='a2p'):
assert mode in ['a2p', 'p2a']
batch_size = bboxes.size(dim=0)
proj_bboxes = bboxes.clone().reshape(batch_size, -1, 4)
invalid_bbox_mask = (proj_bboxes == -1) # indicating padded bboxes
if mode == 'a2p':
# activation map to pixel image
proj_bboxes[:, :, [0, 2]] *= width_scale_factor
proj_bboxes[:, :, [1, 3]] *= height_scale_factor
else:
# pixel image to activation map
proj_bboxes[:, :, [0, 2]] /= width_scale_factor
proj_bboxes[:, :, [1, 3]] /= height_scale_factor
proj_bboxes.masked_fill_(invalid_bbox_mask, -1) # fill padded bboxes back with -1
proj_bboxes.resize_as_(bboxes)
return proj_bboxes
def generate_proposals(anchors, offsets):
# change format of the anchor boxes from 'xyxy' to 'cxcywh'
anchors = ops.box_convert(anchors, in_fmt='xyxy', out_fmt='cxcywh').cuda(device)
# apply offsets to anchors to create proposals
offsets = offsets.cuda(device)
proposals_ = torch.zeros_like(anchors).cuda(device)
proposals_[:,0] = anchors[:,0] + offsets[:,0]*anchors[:,2]
proposals_[:,1] = anchors[:,1] + offsets[:,1]*anchors[:,3]
proposals_[:,2] = anchors[:,2] * torch.exp(offsets[:,2])
proposals_[:,3] = anchors[:,3] * torch.exp(offsets[:,3])
# change format of proposals back from 'cxcywh' to 'xyxy'
proposals = ops.box_convert(proposals_, in_fmt='cxcywh', out_fmt='xyxy')
return proposals
def gen_anc_base(anc_pts_x, anc_pts_y, anc_scales, anc_ratios, out_size):
n_anc_boxes = len(anc_scales) * len(anc_ratios)
anc_base = torch.zeros(1, anc_pts_x.size(dim=0) \
, anc_pts_y.size(dim=0), n_anc_boxes, 4) # shape - [1, Hmap, Wmap, n_anchor_boxes, 4]
for ix, xc in enumerate(anc_pts_x):
for jx, yc in enumerate(anc_pts_y):
anc_boxes = torch.zeros((n_anc_boxes, 4))
c = 0
for i, scale in enumerate(anc_scales):
for j, ratio in enumerate(anc_ratios):
w = scale * ratio
h = scale
xmin = xc - w / 2
ymin = yc - h / 2
xmax = xc + w / 2
ymax = yc + h / 2
anc_boxes[c, :] = torch.Tensor([xmin, ymin, xmax, ymax])
c += 1
anc_base[:, ix, jx, :] = ops.clip_boxes_to_image(anc_boxes, size=out_size)
return anc_base
def get_iou_mat(batch_size, anc_boxes_all, gt_bboxes_all):
# flatten anchor boxes
anc_boxes_flat = anc_boxes_all.reshape(batch_size, -1, 4)
# get total anchor boxes for a single image
tot_anc_boxes = anc_boxes_flat.size(dim=1)
# create a placeholder to compute IoUs amongst the boxes
ious_mat = torch.zeros((batch_size, tot_anc_boxes, gt_bboxes_all.size(dim=1)))
# compute IoU of the anc boxes with the gt boxes for all the images
for i in range(batch_size):
gt_bboxes = gt_bboxes_all[i]
anc_boxes = anc_boxes_flat[i]
# print(gt_bboxes)
# print(anc_boxes)
ious_mat[i, :] = ops.box_iou(anc_boxes.cuda(device), gt_bboxes)
return ious_mat
def get_req_anchors(anc_boxes_all, gt_bboxes_all, gt_classes_all, pos_thresh=0.7, neg_thresh=0.2):
'''
Prepare necessary data required for training
Input
------
anc_boxes_all - torch.Tensor of shape (B, w_amap, h_amap, n_anchor_boxes, 4)
all anchor boxes for a batch of images
gt_bboxes_all - torch.Tensor of shape (B, max_objects, 4)
padded ground truth boxes for a batch of images
gt_classes_all - torch.Tensor of shape (B, max_objects)
padded ground truth classes for a batch of images
Returns
---------
positive_anc_ind - torch.Tensor of shape (n_pos,)
flattened positive indices for all the images in the batch
negative_anc_ind - torch.Tensor of shape (n_pos,)
flattened positive indices for all the images in the batch
GT_conf_scores - torch.Tensor of shape (n_pos,), IoU scores of +ve anchors
GT_offsets - torch.Tensor of shape (n_pos, 4),
offsets between +ve anchors and their corresponding ground truth boxes
GT_class_pos - torch.Tensor of shape (n_pos,)
mapped classes of +ve anchors
positive_anc_coords - (n_pos, 4) coords of +ve anchors (for visualization)
negative_anc_coords - (n_pos, 4) coords of -ve anchors (for visualization)
positive_anc_ind_sep - list of indices to keep track of +ve anchors
'''
# get the size and shape parameters
B, w_amap, h_amap, A, _ = anc_boxes_all.shape
N = gt_bboxes_all.shape[1] # max number of groundtruth bboxes in a batch
# get total number of anchor boxes in a single image
tot_anc_boxes = A * w_amap * h_amap
# get the iou matrix which contains iou of every anchor box
# against all the groundtruth bboxes in an image
iou_mat = get_iou_mat(B, anc_boxes_all, gt_bboxes_all)
# for every groundtruth bbox in an image, find the iou
# with the anchor box which it overlaps the most
max_iou_per_gt_box, _ = iou_mat.max(dim=1, keepdim=True)
# get positive anchor boxes
# condition 1: the anchor box with the max iou for every gt bbox
positive_anc_mask = torch.logical_and(iou_mat == max_iou_per_gt_box, max_iou_per_gt_box > 0)
# condition 2: anchor boxes with iou above a threshold with any of the gt bboxes
positive_anc_mask = torch.logical_or(positive_anc_mask, iou_mat > pos_thresh)
# print(positive_anc_mask, positive_anc_mask.shape, positive_anc_mask == False)
positive_anc_ind_sep = torch.where(positive_anc_mask)[0] # get separate indices in the batch
# print(positive_anc_ind_sep)
# combine all the batches and get the idxs of the +ve anchor boxes
positive_anc_mask = positive_anc_mask.flatten(start_dim=0, end_dim=1)
positive_anc_ind = torch.where(positive_anc_mask)[0]
# for every anchor box, get the iou and the idx of the
# gt bbox it overlaps with the most
max_iou_per_anc, max_iou_per_anc_ind = iou_mat.max(dim=-1)
max_iou_per_anc = max_iou_per_anc.flatten(start_dim=0, end_dim=1)
# get iou scores of the +ve anchor boxes
GT_conf_scores = max_iou_per_anc[positive_anc_ind]
# get gt classes of the +ve anchor boxes
# expand gt classes to map against every anchor box
# print(gt_classes_all.shape)
gt_classes_expand = gt_classes_all.view(B, 1, N).expand(B, tot_anc_boxes, N)
# for every anchor box, consider only the class of the gt bbox it overlaps with the most
GT_class = torch.gather(gt_classes_expand.cuda(device), -1, max_iou_per_anc_ind.cuda(device).unsqueeze(-1)).squeeze(-1)
# combine all the batches and get the mapped classes of the +ve anchor boxes
GT_class = GT_class.flatten(start_dim=0, end_dim=1)
GT_class_pos = GT_class[positive_anc_ind]
# get gt bbox coordinates of the +ve anchor boxes
# expand all the gt bboxes to map against every anchor box
gt_bboxes_expand = gt_bboxes_all.view(B, 1, N, 4).expand(B, tot_anc_boxes, N, 4)
# for every anchor box, consider only the coordinates of the gt bbox it overlaps with the most
GT_bboxes = torch.gather(gt_bboxes_expand.cuda(device), -2, max_iou_per_anc_ind.cuda(device).reshape(B, tot_anc_boxes, 1, 1).repeat(1, 1, 1, 4))
# combine all the batches and get the mapped gt bbox coordinates of the +ve anchor boxes
GT_bboxes = GT_bboxes.flatten(start_dim=0, end_dim=2)
GT_bboxes_pos = GT_bboxes[positive_anc_ind]
# get coordinates of +ve anc boxes
anc_boxes_flat = anc_boxes_all.flatten(start_dim=0, end_dim=-2) # flatten all the anchor boxes
positive_anc_coords = anc_boxes_flat[positive_anc_ind]
# calculate gt offsets
GT_offsets = calc_gt_offsets(positive_anc_coords, GT_bboxes_pos)
# get -ve anchors
# condition: select the anchor boxes with max iou less than the threshold
negative_anc_mask = (max_iou_per_anc < neg_thresh)
negative_anc_ind = torch.where(negative_anc_mask)[0]
# sample -ve samples to match the +ve samples
negative_anc_ind = negative_anc_ind[torch.randint(0, negative_anc_ind.shape[0], (positive_anc_ind.shape[0],))]
negative_anc_coords = anc_boxes_flat[negative_anc_ind]
return positive_anc_ind, negative_anc_ind, GT_conf_scores, GT_offsets, GT_class_pos, \
positive_anc_coords, negative_anc_coords, positive_anc_ind_sep
# # -------------- Visualization utils ----------------
def display_img(img_data, fig, axes):
for i, img in enumerate(img_data):
if type(img) == torch.Tensor:
img = img.permute(1, 2, 0).numpy()
axes[i].imshow(img)
return fig, axes
def display_bbox(bboxes, fig, ax, classes=None, in_format='xyxy', color='y', line_width=3):
if type(bboxes) == np.ndarray:
bboxes = torch.from_numpy(bboxes)
if classes:
assert len(bboxes) == len(classes)
# convert boxes to xywh format
bboxes = ops.box_convert(bboxes, in_fmt=in_format, out_fmt='xywh')
c = 0
for box in bboxes:
x, y, w, h = box.numpy()
# display bounding box
rect = patches.Rectangle((x, y), w, h, linewidth=line_width, edgecolor=color, facecolor='none')
ax.add_patch(rect)
# display category
if classes:
if classes[c] == 'pad':
continue
ax.text(x + 5, y + 20, classes[c], bbox=dict(facecolor='yellow', alpha=0.5))
c += 1
return fig, ax
def display_grid(x_points, y_points, fig, ax, special_point=None):
# plot grid
for x in x_points:
for y in y_points:
ax.scatter(x, y, color="w", marker='+')
# plot a special point we want to emphasize on the grid
if special_point:
x, y = special_point
ax.scatter(x, y, color="red", marker='+')
return fig, ax