-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathhw4_eval.py
59 lines (42 loc) · 1.8 KB
/
hw4_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import argparse
from modules.reader import getVideoList
if __name__ == "__main__":
parser = argparse.ArgumentParser("Evaluation of HW4.")
parser.add_argument("problem", type=str, help="Problem 1, 2 or 3.")
parser.add_argument("gt", type=str, help="Ground-truth file.")
parser.add_argument("pred", type=str, help="Predicted file.")
args = parser.parse_args()
if args.problem != "3":
gt = getVideoList(args.gt)["Action_labels"]
with open(args.pred, "r") as fin:
pred = fin.readlines()
assert len(gt) == len(pred), "Number of ground-truth and predicts not same!"
acc = 0
for _g, _p in zip(gt, pred):
if _g == _p.strip():
acc += 1
print(acc / len(gt))
else:
categories = sorted(os.listdir(args.gt))
acc = {}
for category in categories:
with open(os.path.join(args.gt, category), "r") as fin:
targets = fin.readlines()
with open(os.path.join(args.pred, category), "r") as fin:
preds = fin.readlines()
if category not in acc:
acc[category] = {"n": 0, "n_correct": 0}
assert len(targets) == len(preds), "Number of ground-truth and predicts not same!"
acc[category]["n"] = len(targets)
for _g, _p in zip(targets, preds):
if _g.strip() == _p.strip():
acc[category]["n_correct"] += 1
_acc = []
for k, v in acc.items():
_acc.append(v["n_correct"] / v["n"])
print(k, str(v["n_correct"] / v["n"]))
print(sum(_acc) / len(_acc))
n_total = sum([v["n"] for k, v in acc.items()])
n_correct = sum([v["n_correct"] for k, v in acc.items()])
print(n_correct / n_total)