We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
eg.
https://github.com/szagoruyko/pytorchviz
UPDATE
this maybe more suitable library The graphs are designed to communicate the high-level architecture. https://github.com/waleedka/hiddenlayer/tree/master
it's not hard to plugin. demonstration of plugin on existing code for this repo. EMSA Attention
import numpy as np import torch from torch import nn from torch.nn import init from torchviz import make_dot class EMSA(nn.Module): def __init__(self, d_model, d_k, d_v, h,dropout=.1,H=7,W=7,ratio=3,apply_transform=True): super(EMSA, self).__init__() self.H=H self.W=W self.fc_q = nn.Linear(d_model, h * d_k) self.fc_k = nn.Linear(d_model, h * d_k) self.fc_v = nn.Linear(d_model, h * d_v) self.fc_o = nn.Linear(h * d_v, d_model) self.dropout=nn.Dropout(dropout) self.ratio=ratio if(self.ratio>1): self.sr=nn.Sequential() self.sr_conv=nn.Conv2d(d_model,d_model,kernel_size=ratio+1,stride=ratio,padding=ratio//2,groups=d_model) self.sr_ln=nn.LayerNorm(d_model) self.apply_transform=apply_transform and h>1 if(self.apply_transform): self.transform=nn.Sequential() self.transform.add_module('conv',nn.Conv2d(h,h,kernel_size=1,stride=1)) self.transform.add_module('softmax',nn.Softmax(-1)) self.transform.add_module('in',nn.InstanceNorm2d(h)) self.d_model = d_model self.d_k = d_k self.d_v = d_v self.h = h self.init_weights() def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): init.kaiming_normal_(m.weight, mode='fan_out') if m.bias is not None: init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): init.constant_(m.weight, 1) init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): init.normal_(m.weight, std=0.001) if m.bias is not None: init.constant_(m.bias, 0) def forward(self, queries, keys, values, attention_mask=None, attention_weights=None): b_s, nq ,c = queries.shape nk = keys.shape[1] q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k) if(self.ratio>1): x=queries.permute(0,2,1).view(b_s,c,self.H,self.W) #bs,c,H,W x=self.sr_conv(x) #bs,c,h,w x=x.contiguous().view(b_s,c,-1).permute(0,2,1) #bs,n',c x=self.sr_ln(x) k = self.fc_k(x).view(b_s, -1, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, n') v = self.fc_v(x).view(b_s, -1, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, n', d_v) else: k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk) v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v) if(self.apply_transform): att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') att = self.transform(att) # (b_s, h, nq, n') else: att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, n') att = torch.softmax(att, -1) # (b_s, h, nq, n') if attention_weights is not None: att = att * attention_weights if attention_mask is not None: att = att.masked_fill(attention_mask, -np.inf) att=self.dropout(att) out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v) out = self.fc_o(out) # (b_s, nq, d_model) return out if __name__ == '__main__': input = torch.randn(50, 64, 512) emsa = EMSA(d_model=512, d_k=512, d_v=512, h=8, H=8, W=8, ratio=2, apply_transform=True) output = emsa(input, input, input) # Visualize the graph dot = make_dot(output, params=dict(emsa.named_parameters())) # Attempt to modify the canvas size dot.graph_attr['size'] = "120,120" # Set the size as width,height in inches dot.render('emsa_graph', format='svg') # Save the graph as PNG print(output.shape)
The text was updated successfully, but these errors were encountered:
No branches or pull requests
eg.
https://github.com/szagoruyko/pytorchviz
UPDATE
this maybe more suitable library
The graphs are designed to communicate the high-level architecture.
https://github.com/waleedka/hiddenlayer/tree/master
it's not hard to plugin. demonstration of plugin on existing code for this repo. EMSA Attention
The text was updated successfully, but these errors were encountered: