forked from ai-dawang/PlugNPlay-Modules
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path(ICLR 2023)ContraNorm(对比归一化层).py
58 lines (53 loc) · 2.23 KB
/
(ICLR 2023)ContraNorm(对比归一化层).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch
import torch.nn as nn
# 论文:ContraNorm: A Contrastive Learning Perspective on Oversmoothing and Beyond
# 论文地址:https://ar5iv.labs.arxiv.org/html/2303.06562
class ContraNorm(nn.Module):
def __init__(self, dim, scale=0.1, dual_norm=False, pre_norm=False, temp=1.0, learnable=False, positive=False, identity=False):
super().__init__()
if learnable and scale > 0:
import math
if positive:
scale_init = math.log(scale)
else:
scale_init = scale
self.scale_param = nn.Parameter(torch.empty(dim).fill_(scale_init))
self.dual_norm = dual_norm
self.scale = scale
self.pre_norm = pre_norm
self.temp = temp
self.learnable = learnable
self.positive = positive
self.identity = identity
self.layernorm = nn.LayerNorm(dim, eps=1e-6)
def forward(self, x):
if self.scale > 0.0:
xn = nn.functional.normalize(x, dim=2)
if self.pre_norm:
x = xn
sim = torch.bmm(xn, xn.transpose(1,2)) / self.temp
if self.dual_norm:
sim = nn.functional.softmax(sim, dim=2) + nn.functional.softmax(sim, dim=1)
else:
sim = nn.functional.softmax(sim, dim=2)
x_neg = torch.bmm(sim, x)
if not self.learnable:
if self.identity:
x = (1+self.scale) * x - self.scale * x_neg
else:
x = x - self.scale * x_neg
else:
scale = torch.exp(self.scale_param) if self.positive else self.scale_param
scale = scale.view(1, 1, -1)
if self.identity:
x = scale * x - scale * x_neg
else:
x = x - scale * x_neg
x = self.layernorm(x)
return x
if __name__ == '__main__':
block = ContraNorm(dim=128, scale=0.1, dual_norm=False, pre_norm=False, temp=1.0, learnable=False, positive=False, identity=False)
input = torch.rand(32, 784, 128)
output = block(input)
print("Input size:", input.size())
print("Output size:", output.size())