-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathscore_losses.py
171 lines (134 loc) · 6.56 KB
/
score_losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import torch
import torch.nn
import numpy as np
import math
from utils.util import add_dimensions
def replace(y, p, n_classes, device):
boolean_ = torch.bernoulli(p * torch.ones_like(y, device=device)).bool()
no_class_label = n_classes * torch.ones_like(y, device=device)
y = torch.where(boolean_, no_class_label, y)
return y
def dropout_label_for_cfg_training(y, n_noise_samples, n_classes, p, device):
if y is not None:
if n_classes is None:
raise ValueError
else:
with torch.no_grad():
boolean_ = torch.bernoulli(
p * torch.ones_like(y, device=device)).bool()
no_class_label = n_classes * torch.ones_like(y, device=device)
y = torch.where(boolean_, no_class_label, y)
y = y.repeat_interleave(n_noise_samples)
return y
else:
return None
class VPSDELoss:
def __init__(self, beta_min, beta_d, eps_t, n_noise_samples=1, label_unconditioning_prob=.1, n_classes=None, **kwargs):
self.beta_min = beta_min
self.beta_d = beta_d
self.eps_t = eps_t
self.n_noise_samples = n_noise_samples
self.label_unconditioning_prob = label_unconditioning_prob
self.n_classes = n_classes
def _sigma(self, t):
return ((.5 * self.beta_d * t ** 2. + self.beta_min * t).exp() - 1.).sqrt()
def get_loss(self, model, x, y):
y = dropout_label_for_cfg_training(
y, self.n_noise_samples, self.n_classes, self.label_unconditioning_prob, x.device)
t = (1. - self.eps_t) * \
torch.rand((x.shape[0], self.n_noise_samples),
device=x.device) + self.eps_t
sigma = self._sigma(t)
sigma = add_dimensions(sigma, len(x.shape) - 1)
x_repeated = x.unsqueeze(1).repeat_interleave(
self.n_noise_samples, dim=1)
x_noisy = x_repeated + sigma * \
torch.randn_like(x_repeated, device=x.device)
w = 1. / sigma ** 2.
pred = model(x_noisy.reshape(-1, *x.shape[1:]), sigma.reshape(-1, *sigma.shape[2:]), y).reshape(
x.shape[0], self.n_noise_samples, *x.shape[1:])
loss = w * (pred - x_repeated) ** 2.
loss = torch.mean(loss.reshape(loss.shape[0], -1), dim=-1)
return loss
class VESDELoss:
def __init__(self, sigma_min, sigma_max, n_noise_samples=1, label_unconditioning_prob=.1, n_classes=None, **kwargs):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.n_noise_samples = n_noise_samples
self.label_unconditioning_prob = label_unconditioning_prob
self.n_classes = n_classes
def get_loss(self, model, x, y):
y = dropout_label_for_cfg_training(
y, self.n_noise_samples, self.n_classes, self.label_unconditioning_prob, x.device)
log_sigma = (np.log(self.sigma_max) - np.log(self.sigma_min)) * torch.rand(
(x.shape[0], self.n_noise_samples), device=x.device) + np.log(self.sigma_min)
sigma = log_sigma.exp()
sigma = add_dimensions(sigma, len(x.shape) - 1)
x_repeated = x.unsqueeze(1).repeat_interleave(
self.n_noise_samples, dim=1)
x_noisy = x_repeated + sigma * \
torch.randn_like(x_repeated, device=x.device)
w = 1. / sigma ** 2.
pred = model(x_noisy.reshape(-1, *x.shape[1:]), sigma.reshape(-1, *sigma.shape[2:]), y).reshape(
x.shape[0], self.n_noise_samples, *x.shape[1:])
loss = w * (pred - x_repeated) ** 2.
loss = torch.mean(loss.reshape(loss.shape[0], -1), dim=-1)
return loss
class VLoss:
def __init__(self, logsnr_min, logsnr_max, n_noise_samples=1, label_unconditioning_prob=.1, n_classes=None, **kwargs):
self.logsnr_min = logsnr_min
self.logsnr_max = logsnr_max
self.eps_min = self._t_given_logsnr(logsnr_max)
self.eps_max = self._t_given_logsnr(logsnr_min)
self.n_noise_samples = n_noise_samples
self.label_unconditioning_prob = label_unconditioning_prob
self.n_classes = n_classes
def _t_given_logsnr(self, logsnr):
return 2. * np.arccos(1. / np.sqrt(1. + np.exp(-logsnr))) / np.pi
def _sigma(self, t):
return (torch.cos(np.pi * t / 2.) ** (-2.) - 1.).sqrt()
def get_loss(self, model, x, y):
y = dropout_label_for_cfg_training(
y, self.n_noise_samples, self.n_classes, self.label_unconditioning_prob, x.device)
t = (self.eps_max - self.eps_min) * \
torch.rand((x.shape[0], self.n_noise_samples),
device=x.device) + self.eps_min
sigma = self._sigma(t)
sigma = add_dimensions(sigma, len(x.shape) - 1)
x_repeated = x.unsqueeze(1).repeat_interleave(
self.n_noise_samples, dim=1)
x_noisy = x_repeated + sigma * \
torch.randn_like(x_repeated, device=x.device)
w = (sigma ** 2. + 1.) / sigma ** 2.
pred = model(x_noisy.reshape(-1, *x.shape[1:]), sigma.reshape(-1, *sigma.shape[2:]), y).reshape(
x.shape[0], self.n_noise_samples, *x.shape[1:])
loss = w * (pred - x_repeated) ** 2.
loss = torch.mean(loss.reshape(loss.shape[0], -1), dim=-1)
return loss
class EDMLoss:
def __init__(self, p_mean, p_std, sigma_data=math.sqrt(1. / 3), n_noise_samples=1, label_unconditioning_prob=.1, n_classes=None, **kwargs):
self.p_mean = p_mean
self.p_std = p_std
self.sigma_data = sigma_data
self.n_noise_samples = n_noise_samples
self.label_unconditioning_prob = label_unconditioning_prob
self.n_classes = n_classes
def get_loss(self, model, x, y):
y = dropout_label_for_cfg_training(
y, self.n_noise_samples, self.n_classes, self.label_unconditioning_prob, x.device)
log_sigma = self.p_mean + self.p_std * \
torch.randn(
(x.shape[0], self.n_noise_samples), device=x.device)
sigma = log_sigma.exp()
sigma = add_dimensions(sigma, len(x.shape) - 1)
x_repeated = x.unsqueeze(1).repeat_interleave(
self.n_noise_samples, dim=1)
x_noisy = x_repeated + sigma * \
torch.randn_like(x_repeated, device=x.device)
w = (sigma ** 2. + self.sigma_data ** 2.) / \
(sigma * self.sigma_data) ** 2.
pred = model(x_noisy.reshape(-1, *x.shape[1:]), sigma.reshape(-1, *sigma.shape[2:]), y).reshape(
x.shape[0], self.n_noise_samples, *x.shape[1:])
loss = w * (pred - x_repeated) ** 2.
loss = torch.mean(loss.reshape(loss.shape[0], -1), dim=-1)
return loss