-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathpredict_img.py
executable file
·59 lines (47 loc) · 2.1 KB
/
predict_img.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import argparse
import cv2
import torch
import torch.nn.functional as F
from network import ResnetUnetHybrid
import image_utils
def predict_img(img_path, focal_len):
"""Given an image, create a 3D model of the environment, based depth estimation and semantic segmentation."""
# switch to GPU if possible
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('Use GPU: {}'.format(str(device) != 'cpu'))
# load models
print('Loading models...')
model_de = ResnetUnetHybrid.load_pretrained(output_type='depth', device=device)
model_seg = ResnetUnetHybrid.load_pretrained(output_type='seg', device=device)
model_de.eval()
model_seg.eval()
# load image
img = cv2.imread(img_path)[..., ::-1]
img = image_utils.scale_image(img)
img = image_utils.center_crop(img)
inp = image_utils.img_transform(img)
inp = inp[None, :, :, :].to(device)
print('Plotting...')
output_de = model_de(inp)
output_seg = model_seg(inp)
# up-sample outputs
output_de = F.interpolate(output_de, size=(320, 320), mode='bilinear', align_corners=True)
output_seg = F.interpolate(output_seg, size=(320, 320), mode='bilinear', align_corners=True)
# use softmax on the segmentation output
output_seg = F.softmax(output_seg, dim=1)
# plot the results
output_de = output_de.cpu()[0].data.numpy()
output_seg = output_seg.cpu()[0].data.numpy()
image_utils.create_plots(img, output_de, output_seg, focal_len, uncertainty_threshold=0.9, apply_depth_mask=True)
def get_arguments():
"""Get command line arguments."""
parser = argparse.ArgumentParser()
parser.add_argument('-i', '--img_path', type=str, required=True, help='Path to input image.')
parser.add_argument('-f', '--focal_len', type=float, required=False, default=2264.0,
help='The focal length of the camera. '
'Default: 2264 (this value should work for the example images).')
args = parser.parse_args()
return args
if __name__ == '__main__':
args = get_arguments()
predict_img(args.img_path, args.focal_len)