-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathlosses.py
69 lines (45 loc) · 1.58 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import numpy as np
import pandas as pd
import gzip
import matplotlib.pyplot as plt
import io
import torch
import torch.utils#.data.Dataset
import glob
import imgaug as ia
from imgaug import augmenters as iaa
def dice_loss_classes(inpu, target):
''' Dice Loss for two classes '''
ip=inpu[:,0,:,:,:].contiguous().view(-1)
tar=target[:,0,:,:,:].contiguous().view(-1)
intersection=(ip * tar).sum()
union= ip.sum() + tar.sum()
score1=1-2*(intersection/union)
ip=inpu[:,1,:,:,:].contiguous().view(-1)
tar=target[:,1,:,:,:].contiguous().view(-1)
intersection=(ip * tar).sum()
union= ip.sum() + tar.sum()
score2=1-2*(intersection/union)
return score1,score2
def tversky_loss(inpu, target,alpha,beta):
''' Tversky Loss for two classes '''
ip=inpu[:,0,:,:,:].contiguous().view(-1)
tar=target[:,0,:,:,:].contiguous().view(-1)
intersection=(ip * tar).sum()
fps = torch.sum(inpu * (1 - target))
fns = torch.sum((1 - inpu) * target)
num = intersection
denom = intersection + (alpha * fps) + (beta * fns)
score1=1-2*(intersection/denom)
ip=inpu[:,1,:,:,:].contiguous().view(-1)
tar=target[:,1,:,:,:].contiguous().view(-1)
intersection=(ip * tar).sum()
fps = torch.sum(inpu * (1 - target))
fns = torch.sum((1 - inpu) * target)
num = intersection
denom = intersection + (alpha * fps) + (beta * fns)
score2=1-2*(intersection/denom)
return score1,score2
def power_dice(x,alpha,n):
''' Power Loss'''
return(alpha*(x**n))