-
Notifications
You must be signed in to change notification settings - Fork 41
/
Copy pathbra_nchw.py
97 lines (77 loc) · 4.04 KB
/
bra_nchw.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Refactored Bi-level Routing Attention that takes NCHW input.
author: ZHU Lei
github: https://github.com/rayleizhu
email: [email protected]
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
from typing import List, Optional
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch import LongTensor, Tensor
from ops.torch.rrsda import regional_routing_attention_torch
class nchwBRA(nn.Module):
"""Bi-Level Routing Attention that takes nchw input
Compared to legacy version, this implementation:
* removes unused args and components
* uses nchw input format to avoid frequent permutation
When the size of inputs is not divisible by the region size, there is also a numerical difference
than legacy implementation, due to:
* different way to pad the input feature map (padding after linear projection)
* different pooling behavior (count_include_pad=False)
Current implementation is more reasonable, hence we do not keep backward numerical compatiability
"""
def __init__(self, dim, num_heads=8, n_win=7, qk_scale=None, topk=4, side_dwconv=3, auto_pad=False, attn_backend='torch'):
super().__init__()
# local attention setting
self.dim = dim
self.num_heads = num_heads
assert self.dim % num_heads == 0, 'dim must be divisible by num_heads!'
self.head_dim = self.dim // self.num_heads
self.scale = qk_scale or self.dim ** -0.5 # NOTE: to be consistent with old models.
################side_dwconv (i.e. LCE in Shunted Transformer)###########
self.lepe = nn.Conv2d(dim, dim, kernel_size=side_dwconv, stride=1, padding=side_dwconv//2, groups=dim) if side_dwconv > 0 else \
lambda x: torch.zeros_like(x)
################ regional routing setting #################
self.topk = topk
self.n_win = n_win # number of windows per row/col
##########################################
self.qkv_linear = nn.Conv2d(self.dim, 3*self.dim, kernel_size=1)
self.output_linear = nn.Conv2d(self.dim, self.dim, kernel_size=1)
if attn_backend == 'torch':
self.attn_fn = regional_routing_attention_torch
else:
raise ValueError('CUDA implementation is not available yet. Please stay tuned.')
def forward(self, x:Tensor, ret_attn_mask=False):
"""
Args:
x: NCHW tensor, better to be channel_last (https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html)
Return:
NCHW tensor
"""
N, C, H, W = x.size()
region_size = (H//self.n_win, W//self.n_win)
# STEP 1: linear projection
qkv = self.qkv_linear.forward(x) # ncHW
q, k, v = qkv.chunk(3, dim=1) # ncHW
# STEP 2: region-to-region routing
# NOTE: ceil_mode=True, count_include_pad=False = auto padding
# NOTE: gradients backward through token-to-token attention. See Appendix A for the intuition.
q_r = F.avg_pool2d(q.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False)
k_r = F.avg_pool2d(k.detach(), kernel_size=region_size, ceil_mode=True, count_include_pad=False) # nchw
q_r:Tensor = q_r.permute(0, 2, 3, 1).flatten(1, 2) # n(hw)c
k_r:Tensor = k_r.flatten(2, 3) # nc(hw)
a_r = q_r @ k_r # n(hw)(hw), adj matrix of regional graph
_, idx_r = torch.topk(a_r, k=self.topk, dim=-1) # n(hw)k long tensor
idx_r:LongTensor = idx_r.unsqueeze_(1).expand(-1, self.num_heads, -1, -1)
# STEP 3: token to token attention (non-parametric function)
output, attn_mat = self.attn_fn(query=q, key=k, value=v, scale=self.scale,
region_graph=idx_r, region_size=region_size
)
output = output + self.lepe(v) # ncHW
output = self.output_linear(output) # ncHW
if ret_attn_mask:
return output, attn_mat
return output