-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtripletgen.py
73 lines (63 loc) · 2.98 KB
/
tripletgen.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
class DataGenerator(tf.keras.utils.Sequence):
def __init__(self, dataset_path, batch_size=32, shuffle=True):
self.dataset = self.curate_dataset(dataset_path)
self.dataset_path = dataset_path
self.shuffle = shuffle
self.batch_size =batch_size
self.no_of_people = len(list(self.dataset.keys()))
self.on_epoch_end()
def __getitem__(self, index):
people = list(self.dataset.keys())[index * self.batch_size: (index + 1) * self.batch_size]
P = []
A = []
N = []
for person in people:
anchor_index = random.randint(0, len(self.dataset[person])-1)
a = self.get_image(person, anchor_index)
positive_index = random.randint(0, len(self.dataset[person])-1)
while positive_index == anchor_index:
positive_index = random.randint(0, len(self.dataset[person])-1)
p = self.get_image(person, positive_index)
negative_person_index = random.randint(0, self.no_of_people - 1)
negative_person = list(self.dataset.keys())[negative_person_index]
while negative_person == person:
negative_person_index = random.randint(0, self.no_of_people - 1)
negative_person = list(self.dataset.keys())[negative_person_index]
negative_index = random.randint(0, len(self.dataset[negative_person])-1)
n = self.get_image(negative_person, negative_index)
P.append(p)
A.append(a)
N.append(n)
A = np.asarray(A)
N = np.asarray(N)
P = np.asarray(P)
return [A, P, N]
def __len__(self):
return self.no_of_people // self.batch_size
def curate_dataset(self, dataset_path):
with open(os.path.join(dataset_path, 'list.txt'), 'r') as f:
dataset = {}
image_list = f.read().split()
for image in image_list:
folder_name, file_name = image.split('/')
if folder_name in dataset.keys():
dataset[folder_name].append(file_name)
else:
dataset[folder_name] = [file_name]
return dataset
def on_epoch_end(self):
if self.shuffle:
keys = list(self.dataset.keys())
random.shuffle(keys)
dataset_ = {}
for key in keys:
dataset_[key] = self.dataset[key]
self.dataset = dataset_
def get_image(self, person, index):
# print(os.path.join(self.dataset_path, os.path.join('images/' + person, self.dataset[person][index])))
img = cv2.imread(os.path.join(self.dataset_path, os.path.join('images/' + person, self.dataset[person][index])))
img = cv2.resize(img, (224, 224))
img = np.asarray(img, dtype=np.float64)
img = preprocess_input(img)
return img
data_generator = DataGenerator(dataset_path='./dataset/')