-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathFRCNN_Resnet_training.py
120 lines (89 loc) · 3.89 KB
/
FRCNN_Resnet_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
import numpy as np
import pandas as pd
from PIL import Image
import torch
from torch import nn
from torch.utils.data import DataLoader
print('torch-version = ',torch.__version__ )
import torchvision
from torchvision import transforms
print('torchvision-version = ',torchvision.__version__)
import models
import my_utils
from datasets import WheatDataset_training
# To select the gpu
#torch.cuda.set_device(1)
# Change it to the path to your repo
base_dir = "/raid/sahil_g_ma/wheatDetection"
# We need some helper functions for training
sys.path.append(os.path.join(base_dir, 'detection'))
from engine import train_one_epoch, evaluate
import utils
train_df = pd.read_csv(os.path.join(base_dir, 'train', 'train.csv'))
# images at index 7 and 72 are same , similarly at 16 and 85 are same
# Also they have inappropriate labels
# hence dropping them to avoid key error in dataloader and to ensure better training
train_df = train_df.drop([7, 72, 16, 85], axis=0)
# To avoid training on images with no_box
#train_df = train_df[train_df.BoxesString != 'no_box']
train_df = train_df.reset_index(drop=True)
# For training with Pseudo Labels
# pseudo_df = pd.read_csv(os.path.join(base_dir, 'submissions', 'final_sub_resnet152fpn3_igbox_pseudo2.csv'))
# pseudo_df = pseudo_df.rename(columns={'PredString':'BoxesString'})
# train_df = pd.concat([train_df, pseudo_df]).reset_index(drop=True)
# Checking number of GPUs available
#gpu_count = torch.cuda.device_count()
#print('GPU_count=', gpu_count)
# train on the GPU or on the CPU, if a GPU is not available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# get the model using our helper function
model = models.FRCNN_resnet50_fpn(pre_trained=False, pretrained_backbone=False)
# For using multiple GPUs
#model= nn.DataParallel(model)
# move model to the right device
model.to(device)
# training for pseudo labels
# model.load_state_dict(torch.load(os.path.join(base_dir, 'saved_models', 'frcnn_resnet152fpn_ignore_nobox5_pseudo2.pth'),
# map_location=device))
# use our dataset and defined transformations
dataset = WheatDataset_training(train_df, base_dir)
dataset_test = WheatDataset_training(train_df, base_dir)
# split the dataset in train and test set
indices = torch.randperm(len(dataset)).tolist()
dataset_train = torch.utils.data.Subset(dataset, indices[:-100])
dataset_validation = torch.utils.data.Subset(dataset_test, indices[-100:])
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=16, shuffle=True, num_workers=2,
collate_fn=utils.collate_fn)
#define training and validation data loaders
data_loader_train = torch.utils.data.DataLoader(
dataset_train, batch_size=8, shuffle=True, num_workers=2,
collate_fn=utils.collate_fn)
data_loader_validation = torch.utils.data.DataLoader(
dataset_validation, batch_size=16, shuffle=False, num_workers=2,
collate_fn=utils.collate_fn)
# construct an optimizer
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=0.01,
momentum=0.9, weight_decay=0.0005)
# and a learning rate scheduler
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
step_size=6,
gamma=0.1)
# let's train it for 15 epochs
num_epochs = 30
for epoch in range(num_epochs):
# train for one epoch, printing every 10 iterations
train_one_epoch(model, optimizer, data_loader_train, device, epoch, print_freq=100)
# update the learning rate
lr_scheduler.step()
# evaluate on the test dataset
evaluate(model, data_loader_validation, device=device)
print('done')
#save the model
torch.save(model.state_dict(), os.path.join(base_dir, "saved_models", "frcnn_resnet50fpn_scratch20.pth"))
# Debugging for index error
# for idx, (data, image) in enumerate(dataset):
# print(idx)