-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathevaluation_utils.py
37 lines (27 loc) · 1.09 KB
/
evaluation_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from train_utils import output_to_class
from sklearn.metrics import f1_score
import torch.nn.functional as F
def get_mean_F1(model, validation_loader):
"""
returns the mean F1 score for a given dataloader
"""
model.eval()
mean_f1 = 0
for (data, target) in validation_loader:
output = model(data)
mean_f1 += f1_score(target.detach().cpu().numpy(), output_to_class(output), average='micro') / len(validation_loader)
return mean_f1
def get_loss(model, validation_loader, device):
"""
returns loss for a given dataloader
"""
model.eval()
loss_epoch = 0
mean_f1 = 0
for (data, target) in validation_loader:
data, target = data.to(device), target.to(device)
output = model(data)
loss = F.nll_loss(output, target)
loss_epoch += loss.item() / len(validation_loader)
mean_f1 += f1_score(target.detach().cpu().numpy(), output_to_class(output), average='micro') / len(validation_loader)
return loss_epoch, mean_f1