Skip to content

Commit

Permalink
Merge pull request #1374 from leondgarse/master
Browse files Browse the repository at this point in the history
Fix 1N evaluation in IJB_evals.py and add pytorch model interface
  • Loading branch information
nttstar authored Jan 4, 2021
2 parents 253a5eb + 4a78741 commit f62db73
Showing 1 changed file with 43 additions and 14 deletions.
57 changes: 43 additions & 14 deletions evaluation/IJB/IJB_evals.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@


class Mxnet_model_interf:
import mxnet as mx

def __init__(self, model_file, layer="fc1", image_size=(112, 112)):
import mxnet as mx

self.mx = mx
cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
if len(cvd) > 0 and int(cvd) != -1:
ctx = [self.mx.gpu(ii) for ii in range(len(cvd.split(",")))]
Expand All @@ -39,6 +40,28 @@ def __call__(self, imgs):
return emb


class Torch_model_interf:
def __init__(self, model_file, image_size=(112, 112)):
import torch

self.torch = torch
cvd = os.environ.get("CUDA_VISIBLE_DEVICES", "").strip()
device_name = "cuda:0" if len(cvd) > 0 and int(cvd) != -1 else "cpu"
self.device = self.torch.device(device_name)
try:
self.model = self.torch.jit.load(model_file, map_location=device_name)
except:
print("Error: %s is weights only, please load and save the entire model by `torch.jit.save`" % model_file)
self.model = None

def __call__(self, imgs):
# print(imgs.shape, imgs[0])
imgs = imgs.transpose(0, 3, 1, 2).copy().astype("float32")
imgs = (imgs - 127.5) * 0.0078125
output = self.model(self.torch.from_numpy(imgs).to(self.device).float())
return output.cpu().detach().numpy()


def keras_model_interf(model_file):
import tensorflow as tf

Expand Down Expand Up @@ -225,6 +248,7 @@ def image2template_feature(img_feats=None, templates=None, medias=None, choose_t
unique_templates = np.unique(templates)
unique_subjectids = None

# template_feats = np.zeros((len(unique_templates), img_feats.shape[1]), dtype=img_feats.dtype)
template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
for count_template, uqt in tqdm(enumerate(unique_templates), "Extract template feature", total=len(unique_templates)):
(ind_t,) = np.where(templates == uqt)
Expand All @@ -246,15 +270,15 @@ def image2template_feature(img_feats=None, templates=None, medias=None, choose_t


def verification_11(template_norm_feats=None, unique_templates=None, p1=None, p2=None, batch_size=100000):
template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
template2id = np.zeros(max(unique_templates) + 1, dtype=int)
for count_template, uqt in enumerate(unique_templates):
template2id[uqt] = count_template

steps = int(np.ceil(len(p1) / batch_size))
score = []
for id in tqdm(range(steps), "Verification"):
feat1 = template_norm_feats[template2id[p1[id * batch_size : (id + 1) * batch_size]].flatten()]
feat2 = template_norm_feats[template2id[p2[id * batch_size : (id + 1) * batch_size]].flatten()]
feat1 = template_norm_feats[template2id[p1[id * batch_size : (id + 1) * batch_size]]]
feat2 = template_norm_feats[template2id[p2[id * batch_size : (id + 1) * batch_size]]]
score.extend(np.sum(feat1 * feat2, -1))
return np.array(score)

Expand Down Expand Up @@ -309,7 +333,7 @@ def evaluation_1N(query_feats, gallery_feats, query_ids, reg_ids):
neg_sims_sorted = heapq.nlargest(max(required_topk), neg_sims) # heap sort
print("pos_sims: %s, neg_sims: %s, neg_sims_sorted: %d" % (pos_sims.shape, neg_sims.shape, len(neg_sims_sorted)))
for far, pos in zip(Fars, required_topk):
th = neg_sims[pos - 1]
th = neg_sims_sorted[pos - 1]
recall = np.sum(pos_sims > th) / query_num
print("far = {:.10f} pr = {:.10f} th = {:.10f}".format(far, recall, th))

Expand All @@ -320,7 +344,12 @@ def __init__(self, model_file, data_path, subset, batch_size=64, force_reload=Fa
data_path, subset, force_reload=force_reload
)
if model_file != None:
interf_func = keras_model_interf(model_file) if model_file.endswith(".h5") else Mxnet_model_interf(model_file)
if model_file.endswith(".h5"):
interf_func = keras_model_interf(model_file)
elif model_file.endswith(".pth") or model_file.endswith(".pt"):
interf_func = Torch_model_interf(model_file)
else:
interf_func = Mxnet_model_interf(model_file)
self.embs, self.embs_f = get_embeddings(interf_func, img_names, landmarks, batch_size=batch_size)
elif restore_embs != None:
print(">>>> Reload embeddings from:", restore_embs)
Expand All @@ -330,9 +359,10 @@ def __init__(self, model_file, data_path, subset, batch_size=64, force_reload=Fa
else:
print("ERROR: %s NOT containing embs / embs_f" % restore_embs)
exit(1)
print(">>>> Done.")
self.data_path, self.subset, self.force_reload = data_path, subset, force_reload
self.templates, self.medias, self.p1, self.p2, self.face_scores = templates, medias, p1, p2, face_scores
self.label = label
self.templates, self.medias, self.p1, self.p2, self.label = templates, medias, p1, p2, label
self.face_scores = face_scores.astype(self.embs.dtype)

def run_model_test_single(self, use_flip_test=True, use_norm_score=False, use_detector_score=True):
img_input_feats = process_embeddings(
Expand Down Expand Up @@ -404,8 +434,7 @@ def plot_roc_and_calculate_tpr(scores, names=None, label=None):
score_dict[name] = np.load(score)
elif isinstance(score, str) and score.endswith(".txt"):
# IJB meta data like ijbb_template_pair_label.txt
pairs = np.loadtxt(score, dtype=str)
label = pairs[:, 2].astype(np.int)
label = pd.read_csv(score, sep=" ").values[:, 2]
else:
name = name if name is not None else str(id)
score_dict[name] = score
Expand Down Expand Up @@ -457,7 +486,7 @@ def parse_arguments(argv):

default_save_result_name = "IJB_result/{model_name}_{subset}.npz"
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model file path, could be keras / mxnet one")
parser.add_argument("-m", "--model_file", type=str, default=None, help="Saved model file, could be keras h5 / pytorch jit pth / mxnet")
parser.add_argument("-d", "--data_path", type=str, default="./", help="Dataset path")
parser.add_argument("-s", "--subset", type=str, default="IJBB", help="Subset test target, could be IJBB / IJBC")
parser.add_argument("-b", "--batch_size", type=int, default=64, help="Batch size for get_embeddings")
Expand All @@ -484,8 +513,8 @@ def parse_arguments(argv):
print("Please provide -m MODEL_FILE, see `--help` for usage.")
exit(1)
elif args.model_file != None:
if args.model_file.endswith(".h5"):
# Keras model file "model.h5"
if args.model_file.endswith(".h5") or args.model_file.endswith(".pth") or args.model_file.endswith(".pt"):
# Keras model file "model.h5", pytorch model ends with `.pth` or `.pt`
model_name = os.path.splitext(os.path.basename(args.model_file))[0]
else:
# MXNet model file "models/r50-arcface-emore/model,1"
Expand Down

0 comments on commit f62db73

Please sign in to comment.