-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathdata_utils.py
executable file
·71 lines (63 loc) · 3.65 KB
/
data_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import hickle as hkl
import numpy as np
from keras import backend as K
from keras.preprocessing.image import Iterator
#Code taken from code related to PredNet - Lotter et al. 2016 (https://arxiv.org/abs/1605.08104 https://github.com/coxlab/prednet).
# Data generator that creates sequences for input into PreCNet/PredNet.
class SequenceGenerator(Iterator):
def __init__(self, data_file, source_file, nt,
batch_size=8, shuffle=False, seed=None,
output_mode='error', sequence_start_mode='all', N_seq=None,
data_format=K.image_data_format()):
self.X = hkl.load(data_file) # X will be like (n_images, nb_cols, nb_rows, nb_channels)
self.sources = hkl.load(source_file) # source for each image so when creating sequences can assure that consecutive frames are from same video
self.nt = nt
self.batch_size = batch_size
self.data_format = data_format
assert sequence_start_mode in {'all', 'unique'}, 'sequence_start_mode must be in {all, unique}'
self.sequence_start_mode = sequence_start_mode
assert output_mode in {'error', 'prediction'}, 'output_mode must be in {error, prediction}'
self.output_mode = output_mode
if self.data_format == 'channels_first':
self.X = np.transpose(self.X, (0, 3, 1, 2))
self.im_shape = self.X[0].shape
if self.sequence_start_mode == 'all': # allow for any possible sequence, starting from any frame
self.possible_starts = np.array([i for i in range(self.X.shape[0] - self.nt) if self.sources[i] == self.sources[i + self.nt - 1]])
elif self.sequence_start_mode == 'unique': #create sequences where each unique frame is in at most one sequence
curr_location = 0
possible_starts = []
while curr_location < self.X.shape[0] - self.nt + 1:
if self.sources[curr_location] == self.sources[curr_location + self.nt - 1]:
possible_starts.append(curr_location)
curr_location += self.nt
else:
curr_location += 1
self.possible_starts = possible_starts
if shuffle:
self.possible_starts = np.random.permutation(self.possible_starts)
if N_seq is not None and len(self.possible_starts) > N_seq: # select a subset of sequences if want to
self.possible_starts = self.possible_starts[:N_seq]
self.N_sequences = len(self.possible_starts)
super(SequenceGenerator, self).__init__(len(self.possible_starts), batch_size, shuffle, seed)
def __getitem__(self, null):
return self.next()
def next(self):
with self.lock:
current_index = (self.batch_index * self.batch_size) % self.n
index_array, current_batch_size = next(self.index_generator), self.batch_size
batch_x = np.zeros((current_batch_size, self.nt) + self.im_shape, np.float32)
for i, idx in enumerate(index_array):
idx = self.possible_starts[idx]
batch_x[i] = self.preprocess(self.X[idx:idx+self.nt])
if self.output_mode == 'error': # model outputs errors, so y should be zeros
batch_y = np.zeros(current_batch_size, np.float32)
elif self.output_mode == 'prediction': # output actual pixels
batch_y = batch_x
return batch_x, batch_y
def preprocess(self, X):
return X.astype(np.float32) / 255
def create_all(self):
X_all = np.zeros((self.N_sequences, self.nt) + self.im_shape, np.float32)
for i, idx in enumerate(self.possible_starts):
X_all[i] = self.preprocess(self.X[idx:idx+self.nt])
return X_all