-
Notifications
You must be signed in to change notification settings - Fork 256
/
Copy pathtrain.py
181 lines (142 loc) · 6.84 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import logging
from collections import OrderedDict
import higher # tested with higher v0.2
from torchmeta.datasets.helpers import omniglot
from torchmeta.utils.data import BatchMetaDataLoader
logger = logging.getLogger(__name__)
def conv3x3(in_channels, out_channels, **kwargs):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, **kwargs),
nn.BatchNorm2d(out_channels, momentum=1., track_running_stats=False),
nn.ReLU(),
nn.MaxPool2d(2)
)
class ConvolutionalNeuralNetwork(nn.Module):
def __init__(self, in_channels, out_features, hidden_size=64):
super(ConvolutionalNeuralNetwork, self).__init__()
self.in_channels = in_channels
self.out_features = out_features
self.hidden_size = hidden_size
self.features = nn.Sequential(
conv3x3(in_channels, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size),
conv3x3(hidden_size, hidden_size)
)
self.classifier = nn.Linear(hidden_size, out_features)
def forward(self, inputs, params=None):
features = self.features(inputs)
features = features.view((features.size(0), -1))
logits = self.classifier(features)
return logits
def get_accuracy(logits, targets):
"""Compute the accuracy (after adaptation) of MAML on the test/query points
Parameters
----------
logits : `torch.FloatTensor` instance
Outputs/logits of the model on the query points. This tensor has shape
`(num_examples, num_classes)`.
targets : `torch.LongTensor` instance
A tensor containing the targets of the query points. This tensor has
shape `(num_examples,)`.
Returns
-------
accuracy : `torch.FloatTensor` instance
Mean accuracy on the query points
"""
_, predictions = torch.max(logits, dim=-1)
return torch.mean(predictions.eq(targets).float())
def train(args):
logger.warning('This script is an example to showcase the data-loading '
'features of Torchmeta in conjunction with using higher to '
'make models "unrollable" and optimizers differentiable, '
'and as such has been very lightly tested.')
dataset = omniglot(args.folder,
shots=args.num_shots,
ways=args.num_ways,
shuffle=True,
test_shots=15,
meta_train=True,
download=args.download)
dataloader = BatchMetaDataLoader(dataset,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.num_workers)
model = ConvolutionalNeuralNetwork(1,
args.num_ways,
hidden_size=args.hidden_size)
model.to(device=args.device)
model.train()
inner_optimiser = torch.optim.SGD(model.parameters(), lr=args.step_size)
meta_optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
# Training loop
with tqdm(dataloader, total=args.num_batches) as pbar:
for batch_idx, batch in enumerate(pbar):
model.zero_grad()
train_inputs, train_targets = batch['train']
train_inputs = train_inputs.to(device=args.device)
train_targets = train_targets.to(device=args.device)
test_inputs, test_targets = batch['test']
test_inputs = test_inputs.to(device=args.device)
test_targets = test_targets.to(device=args.device)
outer_loss = torch.tensor(0., device=args.device)
accuracy = torch.tensor(0., device=args.device)
for task_idx, (train_input, train_target, test_input,
test_target) in enumerate(zip(train_inputs, train_targets,
test_inputs, test_targets)):
with higher.innerloop_ctx(model, inner_optimiser, copy_initial_weights=False) as (fmodel, diffopt):
train_logit = fmodel(train_input)
inner_loss = F.cross_entropy(train_logit, train_target)
diffopt.step(inner_loss)
test_logit = fmodel(test_input)
outer_loss += F.cross_entropy(test_logit, test_target)
with torch.no_grad():
accuracy += get_accuracy(test_logit, test_target)
outer_loss.div_(args.batch_size)
accuracy.div_(args.batch_size)
outer_loss.backward()
meta_optimizer.step()
pbar.set_postfix(accuracy='{0:.4f}'.format(accuracy.item()))
if batch_idx >= args.num_batches:
break
# Save model
if args.output_folder is not None:
filename = os.path.join(args.output_folder, 'maml_omniglot_'
'{0}shot_{1}way.th'.format(args.num_shots, args.num_ways))
with open(filename, 'wb') as f:
state_dict = model.state_dict()
torch.save(state_dict, f)
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser('Model-Agnostic Meta-Learning (MAML)')
parser.add_argument('folder', type=str,
help='Path to the folder the data is downloaded to.')
parser.add_argument('--num-shots', type=int, default=5,
help='Number of examples per class (k in "k-shot", default: 5).')
parser.add_argument('--num-ways', type=int, default=5,
help='Number of classes per task (N in "N-way", default: 5).')
parser.add_argument('--step-size', type=float, default=0.4,
help='Step-size for the gradient step for adaptation (default: 0.4).')
parser.add_argument('--hidden-size', type=int, default=64,
help='Number of channels for each convolutional layer (default: 64).')
parser.add_argument('--output-folder', type=str, default=None,
help='Path to the output folder for saving the model (optional).')
parser.add_argument('--batch-size', type=int, default=16,
help='Number of tasks in a mini-batch of tasks (default: 16).')
parser.add_argument('--num-batches', type=int, default=100,
help='Number of batches the model is trained over (default: 100).')
parser.add_argument('--num-workers', type=int, default=1,
help='Number of workers for data loading (default: 1).')
parser.add_argument('--download', action='store_true',
help='Download the Omniglot dataset in the data folder.')
parser.add_argument('--use-cuda', action='store_true',
help='Use CUDA if available.')
args = parser.parse_args()
args.device = torch.device('cuda' if args.use_cuda
and torch.cuda.is_available() else 'cpu')
train(args)