-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathmain_train.lua
40 lines (36 loc) · 1.36 KB
/
main_train.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
-- Require the detection package
require 'detection'
-- Paths
local dataset_name = config.dataset
local image_set = config.train_img_set
local dataset_dir = paths.concat(config.dataset_path,dataset_name)
local ss_dir = './data/datasets/selective_search_data/'
local ss_file = paths.concat(ss_dir, dataset_name .. '_' .. image_set .. '.mat')
local param_path = config.pre_trained_file
local model_path = config.model_def
-- Loading the dataset
local dataset
model_opt = {}
if config.dataset == 'MSCOCO' then
print('MSCOCO '.. image_set)
dataset = detection.DataSetCoco({image_set = image_set, datadir = dataset_dir, test_mode = false})
model_opt.nclass = 80
else
print('VOC '.. image_set)
local year = 2007
if config.dataset:find(2012) then
year = 2012
end
dataset = detection.DataSetPascal({image_set = image_set, datadir = dataset_dir, roidbdir = ss_dir , roidbfile = ss_file, year = year})
model_opt.nclass = 20
end
-- Creating the detection network
model_opt.test = false
model_opt.nclass = dataset:nclass()
model_opt.fine_tunning = not config.resume_training
network = detection.Net(model_path,param_path,model_opt)
-- Creating the network wrapper
local network_wrapper = detection.NetworkWrapper() -- This adds train and test functionality to the global network
-- Train the network on the dataset
print('Training the network...')
network_wrapper:trainNetwork(dataset)