-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathtrainer.py
121 lines (97 loc) · 3.87 KB
/
trainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import time
import torch
from tabulate import tabulate
def poison_train(model, loader, criterion, optimizer):
loss_meter = AverageMeter("loss")
acc_meter = AverageMeter("acc")
meter_list = [
loss_meter,
acc_meter,
]
model.train()
gpu = next(model.parameters()).device
start_time = time.time()
for batch_idx, batch in enumerate(loader):
data = batch["img"].cuda(gpu, non_blocking=True)
target = batch["target"].cuda(gpu, non_blocking=True)
output = model(data)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_meter.update(loss.item())
pred = output.argmax(dim=1, keepdim=True)
truth = pred.view_as(target).eq(target)
acc_meter.update((1.0 * torch.sum(truth) / len(truth)).item())
tabulate_step_meter(batch_idx, len(loader), 3, meter_list)
print("Training summary:")
tabulate_epoch_meter(time.time() - start_time, meter_list)
def test(model, loader, criterion):
loss_meter = AverageMeter("loss")
acc_meter = AverageMeter("acc")
meter_list = [loss_meter, acc_meter]
model.eval()
gpu = next(model.parameters()).device
start_time = time.time()
for batch_idx, batch in enumerate(loader):
data = batch["img"].cuda(gpu, non_blocking=True)
target = batch["target"].cuda(gpu, non_blocking=True)
with torch.no_grad():
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
loss = criterion(output, target)
loss_meter.update(loss.item())
pred = output.argmax(dim=1, keepdim=True)
truth = pred.view_as(target).eq(target)
acc_meter.update((torch.sum(truth).float() / len(truth)).item())
tabulate_step_meter(batch_idx, len(loader), 3, meter_list)
tabulate_epoch_meter(time.time() - start_time, meter_list)
def tabulate_step_meter(batch_idx, num_batches, num_intervals, meter_list):
""" Tabulate current average value of meters every `step_interval`.
Args:
batch_idx (int): The batch index in an epoch.
num_batches (int): The number of batch in an epoch.
num_intervals (int): The number of interval to tabulate.
meter_list (list or tuple of AverageMeter): A list of meters.
"""
step_interval = int(num_batches / num_intervals)
if batch_idx % step_interval == 0:
step_meter = {"Iteration": ["{}/{}".format(batch_idx, num_batches)]}
for m in meter_list:
step_meter[m.name] = [m.batch_avg]
table = tabulate(step_meter, headers="keys", tablefmt="github", floatfmt=".5f")
if batch_idx == 0:
table = table.split("\n")
table = "\n".join([table[1]] + table)
else:
table = table.split("\n")[2]
print(table)
def tabulate_epoch_meter(elapsed_time, meter_list):
""" Tabulate total average value of meters every epoch.
Args:
eplased_time (float): The elapsed time of a epoch.
meter_list (list or tuple of AverageMeter): A list of meters.
"""
epoch_meter = {m.name: [m.total_avg] for m in meter_list}
epoch_meter["time"] = [elapsed_time]
table = tabulate(epoch_meter, headers="keys", tablefmt="github", floatfmt=".5f")
table = table.split("\n")
table = "\n".join([table[1]] + table)
print(table)
class AverageMeter(object):
"""Computes and stores the average and current value.
Modified from https://github.com/pytorch/examples/blob/master/imagenet/main.py
"""
def __init__(self, name, fmt=None):
self.name = name
self.reset()
def reset(self):
self.batch_avg = 0
self.total_avg = 0
self.sum = 0
self.count = 0
def update(self, avg, n=1):
self.batch_avg = avg
self.sum += avg * n
self.count += n
self.total_avg = self.sum / self.count