-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_export_new_labels.py
106 lines (89 loc) · 3.72 KB
/
main_export_new_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import sys
import os
import h5py
import numpy as np
from main import getModel
from neural_wrappers.readers import CitySimReader
from argparse import ArgumentParser
from neural_wrappers.callbacks import Callback
from loss import l2_loss, classification_loss
from main import SchedulerCallback
def getArgs():
parser = ArgumentParser()
parser.add_argument("task", help="regression for Depth / classification for HVN")
parser.add_argument("dataset_path", help="Path to dataset")
parser.add_argument("--batch_size", default=10, type=int)
# Model stuff
parser.add_argument("--model", type=str)
parser.add_argument("--weights_file")
parser.add_argument("--label_dims")
args = parser.parse_args()
assert not args.weights_file is None
args.weights_file = os.path.abspath(args.weights_file)
assert args.task in ("classification", "regression")
assert args.model in ("unet_big_concatenate", "unet_tiny_sum")
args.data_dims = ["rgb"]
if args.task == "classification":
args.label_dims = ["hvn_gt_p1"]
else:
args.label_dims = ["depth"]
return args
class ExportCallback(Callback):
def __init__(self, reader, originalLabelsName, newLabelsName):
self.file = h5py.File("%s.h5" % (newLabelsName), "w")
dType = reader.dataset["train"][originalLabelsName].dtype
trainShape = list(reader.dataset["train"][originalLabelsName].shape)
valShape = list(reader.dataset["validation"][originalLabelsName].shape)
trainShape[1 : 3] = 240, 320
valShape[1 : 3] = 240, 320
self.newLabelsName = newLabelsName
self.file.create_group("train")
self.file.create_group("validation")
self.file["train"].create_dataset(newLabelsName, dtype=dType, shape=trainShape)
self.file["validation"].create_dataset(newLabelsName, dtype=dType, shape=valShape)
self.group = None
self.index = None
def setGroup(self, group):
assert group in ("train", "validation")
self.group = group
self.index = 0
def onIterationEnd(self, **kwargs):
results = kwargs["results"]
dataset = self.file[self.group][self.newLabelsName]
for i in range(len(results)):
result = results[i]
if "hvn" in self.newLabelsName:
result = np.argmax(result, axis=-1)
else:
result = result[..., 0]
dataset[self.index] = result
self.index += 1
if self.index % 10 == 0:
print("Done %d" % (self.index))
def main():
args = getArgs()
hvnTransform = "hvn_two_dims" if args.task == "regression" else "identity_long"
reader = CitySimReader(args.dataset_path, dataDims=args.data_dims, labelDims=args.label_dims, \
resizer=(240, 320), hvnTransform=hvnTransform, dataGroup="all")
trainGenerator = reader.iterate_once("train", args.batch_size)
trainSteps = reader.getNumIterations("train", args.batch_size)
valGenerator = reader.iterate_once("validation", args.batch_size)
valSteps = reader.getNumIterations("validation", args.batch_size)
print(reader.summary())
dIn = CitySimReader.getNumDimensions(args.data_dims, hvnTransform)
# For classification, we need to output probabilities for all 3 classes.
dOut = 3 if args.task == "classification" else CitySimReader.getNumDimensions(args.label_dims, hvnTransform)
model = getModel(args, dIn=dIn, dOut=dOut)
model.loadWeights(args.weights_file)
criterion = l2_loss if args.task == "regression" else classification_loss
model.setCriterion(criterion)
print(model.summary())
modelName = "tiny" if "tiny" in args.model else "big"
newLabelDim = "depth_%s_it1" % (modelName) if args.label_dims[0] == "depth" else "hvn_%s_it1_p1" % (modelName)
callback = ExportCallback(reader, args.label_dims[0], newLabelDim)
callback.setGroup("train")
model.test_generator(trainGenerator, trainSteps, callbacks=[callback])
callback.setGroup("validation")
model.test_generator(valGenerator, valSteps, callbacks=[callback])
if __name__ == "__main__":
main()