-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathtrain.lua
117 lines (87 loc) · 3.37 KB
/
train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
require 'cunn'
local optim = require 'optim'
local lr_policy = {
{0,50,2.5e-4},
{50,70,1e-4},
{70,90,5e-5},
{90,100,1e-5},
{100,110,5e-6}
}
local M = {}
local Trainer = torch.class('Trainer', M)
function Trainer:__init(model,criterion,opt,optimState)
self.model = model
self.criterion = criterion
self.optimState = optimState or {
learningRate = opt.LR,
learningRateDecay = 0.0,
momentum = opt.momentum,
epsilon = 1e-8,
weightDecay = opt.weightDecay,
}
self.opt = opt
self.params, self.gradParams = model:getParameters()
end
function Trainer:train(epoch, dataloader)
local avgLoss, avgAcc = 0.0, 0.0
self.optimState.learningRate = self:learningRate(epoch)
local timer = torch.Timer()
local dataTimer = torch.Timer()
local function feval()
return self.criterion.output, self.gradParams
end
local trainSize = dataloader:size()
local N = 0
print('=> Training epoch # '..epoch)
self.model:training()
for n, sample in dataloader:run() do
local dataTime = dataTimer:time().real
self:copyInputs(sample)
self.model:zeroGradParameters()
local output = self.model:forward(self.input)
local loss = self.criterion:forward(output, self.label)
self.criterion:backward(self.model.output, self.label)
self.model:backward(self.input,self.criterion.gradInput)
optim.rmsprop(feval, self.params, self.optimState)
avgLoss = avgLoss + loss
N = N + 1
print((' | Epoch: [%d][%d/%d] Time %.3f Data %.3f Err %1.4f'):format(
epoch, n, trainSize, timer:time().real, dataTime, loss))
-- check that the storage didn't get changed do to an unfortunate getParameters call
assert(self.params:storage() == self.model:parameters()[1]:storage())
collectgarbage()
timer:reset()
dataTimer:reset()
end
return avgLoss / N, avgAcc / N
end
function Trainer:learningRate(epoch)
local decay = 0
for i=1, #lr_policy do
if (epoch>lr_policy[i][1]) and (lr_policy[i][2]>=epoch) then
print(string.format('Using lr_rate: %f',lr_policy[i][3]))
return lr_policy[i][3]
end
end
end
function Trainer:copyInputs(sample)
-- Copies the input to a CUDA tensor, if using 1 GPU, or to pinned memory,
-- if using DataParallelTable. The target is always copied to a CUDA tensor
self.input = self.input or (self.opt.nGPU == 1
and torch.CudaTensor()
or cutorch.createCudaHostTensor())
label = label or torch.CudaTensor()
self.input:resize(sample.input[{{},{},{},{}}]:size()):copy(sample.input[{{},{},{},{}}])
label:resize(sample.label:size()):copy(sample.label)
-- Adjust the input accordingly to the network arhitecture
if self.opt.nStacks>1 then
local tempLabel = {}
for i=1,self.opt.nStacks do
table.insert(tempLabel, label)
end
self.label = tempLabel
else
self.label = label
end
end
return M.Trainer