-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcocob.py
64 lines (49 loc) · 2.59 KB
/
cocob.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.optim as optim
import torch
###########################################################################
# Training Deep Networks without Learning Rates Through Coin Betting
# Paper: https://arxiv.org/abs/1705.07795
#
# NOTE: This optimizer is hardcoded to run on GPU, needs to be parametrized
###########################################################################
class COCOBBackprop(optim.Optimizer):
def __init__(self, params, alpha=100, epsilon=1e-8):
self._alpha = alpha
self.epsilon = epsilon
defaults = dict(alpha=alpha, epsilon=epsilon)
super(COCOBBackprop, self).__init__(params, defaults)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
state = self.state[p]
if len(state) == 0:
state['gradients_sum'] = torch.zeros_like(p.data).cuda().float()
state['grad_norm_sum'] = torch.zeros_like(p.data).cuda().float()
state['L'] = self.epsilon * torch.ones_like(p.data).cuda().float()
state['tilde_w'] = torch.zeros_like(p.data).cuda().float()
state['reward'] = torch.zeros_like(p.data).cuda().float()
gradients_sum = state['gradients_sum']
grad_norm_sum = state['grad_norm_sum']
tilde_w = state['tilde_w']
L = state['L']
reward = state['reward']
zero = torch.cuda.FloatTensor([0.])
L_update = torch.max(L, torch.abs(grad))
gradients_sum_update = gradients_sum + grad
grad_norm_sum_update = grad_norm_sum + torch.abs(grad)
reward_update = torch.max(reward - grad * tilde_w, zero)
new_w = -gradients_sum_update/(L_update * (torch.max(grad_norm_sum_update + L_update, self._alpha * L_update)))*(reward_update + L_update)
p.data = p.data - tilde_w + new_w
tilde_w_update = new_w
state['gradients_sum'] = gradients_sum_update
state['grad_norm_sum'] = grad_norm_sum_update
state['L'] = L_update
state['tilde_w'] = tilde_w_update
state['reward'] = reward_update
return loss