-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathcyclicLR.py
131 lines (98 loc) · 5.32 KB
/
cyclicLR.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import math
from bisect import bisect_right,bisect_left
import torch
import numpy as np
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
class CyclicCosAnnealingLR(_LRScheduler):
r"""
Implements reset on milestones inspired from CosineAnnealingLR pytorch
Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
.. math::
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})(1 +
\cos(\frac{T_{cur}}{T_{max}}\pi))
When last_epoch > last set milestone, lr is automatically set to \eta_{min}
It has been proposed in
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this only
implements the cosine annealing part of SGDR, and not the restarts.
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list of ints): List of epoch indices. Must be increasing.
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points
gamma (float): factor by which to decay the max learning rate at each decay milestone
eta_min (float): Minimum learning rate. Default: 1e-6
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer,milestones,decay_milestones=None, gamma=0.5,eta_min=1e-6, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
self.eta_min = eta_min
self.milestones=milestones
self.milestones2=decay_milestones
self.gamma = gamma
super(CyclicCosAnnealingLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch >= self.milestones[-1]:
return [self.eta_min for base_lr in self.base_lrs]
idx = bisect_right(self.milestones,self.last_epoch)
left_barrier = 0 if idx==0 else self.milestones[idx-1]
right_barrier = self.milestones[idx]
width = right_barrier - left_barrier
curr_pos = self.last_epoch- left_barrier
if self.milestones2:
return [self.eta_min + ( base_lr* self.gamma ** bisect_right(self.milestones2,self.last_epoch)- self.eta_min) *
(1 + math.cos(math.pi * curr_pos/ width)) / 2
for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1 + math.cos(math.pi * curr_pos/ width)) / 2
for base_lr in self.base_lrs]
class CyclicLinearLR(_LRScheduler):
r"""
Implements reset on milestones inspired from Linear learning rate decay
Set the learning rate of each parameter group using a linear decay
schedule, where :math:`\eta_{max}` is set to the initial lr and
:math:`T_{cur}` is the number of epochs since the last restart:
.. math::
\eta_t = \eta_{min} + (\eta_{max} - \eta_{min})(1 -\frac{T_{cur}}{T_{max}})
When last_epoch > last set milestone, lr is automatically set to \eta_{min}
Args:
optimizer (Optimizer): Wrapped optimizer.
milestones (list of ints): List of epoch indices. Must be increasing.
decay_milestones(list of ints):List of increasing epoch indices. Ideally,decay values should overlap with milestone points
gamma (float): factor by which to decay the max learning rate at each decay milestone
eta_min (float): Minimum learning rate. Default: 1e-6
last_epoch (int): The index of last epoch. Default: -1.
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
https://arxiv.org/abs/1608.03983
"""
def __init__(self, optimizer,milestones, decay_milestones=None,gamma=0.5, eta_min=1e-6, last_epoch=-1):
if not list(milestones) == sorted(milestones):
raise ValueError('Milestones should be a list of'
' increasing integers. Got {}', milestones)
self.eta_min = eta_min
self.gamma = gamma
self.milestones=milestones
self.milestones2=decay_milestones
super(CyclicLinearLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch >= self.milestones[-1]:
return [self.eta_min for base_lr in self.base_lrs]
idx = bisect_right(self.milestones,self.last_epoch)
left_barrier = 0 if idx==0 else self.milestones[idx-1]
right_barrier = self.milestones[idx]
width = right_barrier - left_barrier
curr_pos = self.last_epoch- left_barrier
if self.milestones2:
return [self.eta_min + (base_lr* self.gamma ** bisect_right(self.milestones2,self.last_epoch) - self.eta_min) *
(1. - 1.0*curr_pos/ width)
for base_lr in self.base_lrs]
else:
return [self.eta_min + (base_lr - self.eta_min) *
(1. - 1.0*curr_pos/ width)
for base_lr in self.base_lrs]