-
Notifications
You must be signed in to change notification settings - Fork 119
/
Copy pathrun.py
55 lines (41 loc) · 1.47 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#!/usr/bin/env python
from __future__ import division
import tensorflow as tf
import model
import cv2
import subprocess as sp
import itertools
import params
import sys
import os
import preprocess
import visualize
import time
import local_common as cm
sess = tf.InteractiveSession()
saver = tf.train.Saver()
model_name = 'model.ckpt'
model_path = cm.jn(params.save_dir, model_name)
saver.restore(sess, model_path)
epoch_ids = sorted(list(set(itertools.chain(*params.epochs.values()))))
for epoch_id in epoch_ids:
print '---------- processing video for epoch {} ----------'.format(epoch_id)
vid_path = cm.jn(params.data_dir, 'epoch{:0>2}_front.mkv'.format(epoch_id))
assert os.path.isfile(vid_path)
frame_count = cm.frame_count(vid_path)
cap = cv2.VideoCapture(vid_path)
machine_steering = []
print 'performing inference...'
time_start = time.time()
for frame_id in xrange(frame_count):
ret, img = cap.read()
assert ret
img = preprocess.preprocess(img)
deg = model.y.eval(feed_dict={model.x: [img], model.keep_prob: 1.0})[0][0]
machine_steering.append(deg)
cap.release()
fps = frame_count / (time.time() - time_start)
print 'completed inference, total frames: {}, average fps: {} Hz'.format(frame_count, round(fps, 1))
print 'performing visualization...'
visualize.visualize(epoch_id, machine_steering, params.out_dir,
verbose=True, frame_count_limit=None)