-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_LAPT.py
75 lines (49 loc) · 2.44 KB
/
test_LAPT.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
This file contains code from https://github.com/nv-tlabs/lift-splat-shoot
License available in https://github.com/nv-tlabs/lift-splat-shoot/blob/master/LICENSE
"""
import torch
from models.LAPTNet import compile_model
from dataloaders.data import compile_data
from tools import (get_val_info, compile_loss)
from tools import get_cfgs
def test(cfg=None, weight_path=None):
train_config, data_config = get_cfgs(cfg)
_, valloader = compile_data(data_config)
device = torch.device('cpu') if not torch.cuda.is_available() else torch.device(f'cuda:{train_config.gpuid}')
num_classes = len(data_config.train_label)
if data_config.add_map:
num_classes += 2 #Drivable area + Unknown
data_aug_conf = data_config.data_aug_conf.to_dict()
grid_conf = data_config.grid_conf.to_dict()
model = compile_model(grid_conf, data_aug_conf,
outC=num_classes, use_fpn=train_config.use_fpn)
print("Loading weights under:", weight_path)
weight_dict = torch.load(weight_path)
model.load_state_dict(weight_dict)
model.to(device)
if num_classes == 1 and not data_config.add_map:
loss_fn = compile_loss(num_classes, data_config.add_map, train_config.gpuid, task='semseg')
else:
loss_fn, text_labels = compile_loss(num_classes, data_config.add_map, train_config.gpuid, task='semseg', train_label=data_config.train_label[0])
print('Dataloader samples:', len(valloader))
model.eval()
print('Validation:')
val_info = get_val_info(model, valloader, loss_fn, device,
num_classes=num_classes, use_tqdm=True)
if num_classes == 1:
print('{} IoU: {:.3f}'.format(data_config.train_label[0], val_info['iou'][0]*100))
print('{} Loss: {:.3f}'.format(data_config.train_label[0], val_info['loss']))
else:
for i in range(len(text_labels)):
print('{} IoU: {:.3f}'.format(text_labels[i], val_info['iou'][i]*100))
print('Loss: {:.3f}'.format(val_info['loss']))
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser(description='Evaluate the LAPTNet model.')
parser.add_argument('--cfg', type=str, required=True,
help='Path to the config file.')
parser.add_argument('--weights', type=str, required=True,
help='Path to weights of the trained model.')
args = parser.parse_args()
test(cfg=args.cfg, weight_path=args.weights)