-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_and_test.py
47 lines (43 loc) · 1.25 KB
/
train_and_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from feature_extractor import *
from classifier import *
from dataset import *
import numpy as np
import pickle
import h5py
print('loading datasets and models...')
train_ds = DirDataset('demo1-dataset/train', 256, 256)
test_ds = DirDataset('demo1-dataset/test1', 256, 256)
seed = 42
print('grid searching best combination...')
grid = [
c.format(e)
for c in [
'KNNClassifier({}, k=1)',
# 'SklearnClassifier({}, LogisticRegression(max_iter=1000))',
# 'SklearnClassifier({}, GaussianNB())',
# 'GMMClassifier({}, n_components=5, covariance_type="full")',
# 'RandomForest({})'
]
for e in [
'Resnet("resnet18")',
'Resnet("resnet18")',
# 'Resnet("resnet34")',
# 'Resnet("resnet50")',
]
]
#grid = []
#grid.append('''\
#VoteEnsembleClassifier(
# KNNClassifier(Resnet("resnet50"), k=1),
# SklearnClassifier(Resnet("resnet18"), LogisticRegression(max_iter=1000)),
# SklearnClassifier(Resnet("resnet50"), GaussianNB()),
#)''')
cache = h5py.File('cache.h5')
for i, g in enumerate(grid):
np.random.seed(seed)
print(i, g, flush=True)
cfer = eval(g)
cfer.train(train_ds, cache=cache)
cfer.test(test_ds, cache=cache)
pickle.dump(cfer, open(f'cfer{i}.pkl', 'wb'))
print()