-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathutils.py
63 lines (55 loc) · 2.24 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import scipy
import numpy as np
import lasagne as nn
import cv2
import config as cfg;
def cross_entropy(pred, label, cap = 1e-5):
x = np.clip(pred,cap,1-cap);
x = x/np.sum(x,axis=1)[:,np.newaxis];
y = np.log(x);
return np.mean([y[i,int(label[i])] for i in range(len(label))]);
def binary_entropy(pred, label, cap = 1e-5):
x = np.clip(pred,cap,1-cap);
return np.mean(np.log(x)*label + np.log(1-x)*(1-label));
def accuracy(pred, label):
m = [int(label[i]) == np.argmax(pred[i]) for i in range(len(label))];
return 1.0*np.sum(m)/len(m);
def report(pred,label):
print("mean prediction");
for i in range(10):
mp = np.mean(pred[label==i],axis=0);
print("{} {}".format(i,' '.join(['{:.2f}'.format(x) for x in mp])));
def report_driver(pred, label, driver):
m = np.asarray([int(label[i]) == np.argmax(pred[i]) for i in range(len(label))]);
print("accuracy by driver : ");
for i in np.unique(driver):
idx = (driver == i);
print("driver {}: {}".format(i,1.0*np.sum(m[idx])/np.sum(idx)));
print("mean prediction by driver");
for i in np.unique(driver):
idx = (driver == i);
mp = np.mean(pred[idx],axis=0);
print("{} {}".format(i,' '.join(['{:.2f}'.format(x) for x in mp])));
def make_submission(filename, name_idx, res, cap = 1e-5):
submit_csv = open(filename,'w');
submit_csv.write("img,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9\n");
res = np.clip(res,cap,1-cap);
for i in range(len(name_idx)):
submit_csv.write("{},".format(name_idx[i]));
submit_csv.write(','.join([str(x) for x in res[i]]));
submit_csv.write('\n');
submit_csv.close();
def save_cv(filename,res,label,ids):
submit_csv = open(filename,'w');
submit_csv.write("id,label,c0,c1,c2,c3,c4,c5,c6,c7,c8,c9\n");
for i in range(len(label)):
submit_csv.write("{},{},".format(ids[i],label[i]));
submit_csv.write(','.join([str(x) for x in res[i]]));
submit_csv.write('\n');
submit_csv.close();
def save_cv_cat(filename,res,label,ids,cat):
submit_csv = open(filename,'w');
submit_csv.write("id,label,c{}\n".format(cat));
for i in range(len(label)):
submit_csv.write("{},{},{}\n".format(ids[i],label[i],res[i]));
submit_csv.close();