Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fully take over parameter initialization.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Jun 12, 2019
1 parent 1cf7628 commit 6809ade
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 10 deletions.
6 changes: 3 additions & 3 deletions modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ def __init__(self, isize, bias=True):

super(Scorer, self).__init__()

self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(1.0 / isize), sqrt(1.0 / isize)))
self.bias = nn.Parameter(torch.zeros(1)) if bias else None

def forward(self, x):
Expand All @@ -671,7 +671,7 @@ def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0):

super(MHAttnSummer, self).__init__()

self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(- sqrt(1.0 / isize), sqrt(1.0 / isize)))
self.attn = CrossAttn(isize, isize if ahsize is None else ahsize, isize, num_head, dropout=attn_drop)

# x: (bsize, seql, isize)
Expand Down Expand Up @@ -704,7 +704,7 @@ def __init__(self, isize, minv = 0.125):

super(Temperature, self).__init__()

self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(1.0 / isize), sqrt(1.0 / isize)))
self.bias = nn.Parameter(torch.zeros(1))
self.act = nn.Tanh()
self.k = nn.Parameter(torch.ones(1))
Expand Down
74 changes: 74 additions & 0 deletions modules/group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#encoding: utf-8

from math import sqrt
import torch
from torch import nn
from modules.base import GeLU_BERT
from modules.base import PositionwiseFF as PositionwiseFFBase

class GroupLinearCore(nn.Module):

# isize: input dimension
# osize: output dimension
# ngroup: number of group
# bias: enable bias or not

def __init__(self, isize, osize, ngroup, bias=True):

super(GroupLinearCore, self).__init__()

self.ngroup = ngroup
self.isize = isize // ngroup
_osize = osize // ngroup

self.weight = nn.Parameter(torch.Tensor(ngroup, self.isize, _osize).uniform_(- sqrt(2.0 / (self.isize + _osize)), sqrt(2.0 / (self.isize + _osize))))
self.bias = nn.Parameter(torch.zeros(osize)) if bias else None

# inputu: (bsize, isize)

def forward(self, inputu):

_bsize = inputu.size(0)
out = inputu.view(_bsize, self.ngroup, self.isize).transpose(0, 1).bmm(self.weight).transpose(0, 1).contiguous().view(_bsize, -1)

return out if self.bias is None else out + self.bias

class GroupLinear(nn.Module):

# isize: input dimension
# osize: output dimension
# ngroup: number of group
# bias: enable bias or not

def __init__(self, isize, osize, ngroup, bias=True):

super(GroupLinear, self).__init__()

self.net = GroupLinearCore(isize, osize, ngroup, bias)

# inputu: (..., isize)

def forward(self, inputu):

_size = list(inputu.size())
_isize = _size[-1]
_size[-1] = -1

return self.net(inputu.view(-1, _isize)).view(_size)

class PositionwiseFF(PositionwiseFFBase):

# isize: input dimension
# hsize: hidden dimension

def __init__(self, isize, hsize=None, dropout=0.0, norm_residue=False, use_GeLU=False, ngroup=16):

_hsize = isize * 4 if hsize is None else hsize

super(PositionwiseFF, self).__init__(isize, _hsize, dropout, norm_residue, use_GeLU)

if dropout > 0.0:
#self.nets = nn.ModuleList([nn.Linear(isize, _hsize), nn.Dropout(dropout, inplace=True), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), GroupLinear(_hsize, _hsize, ngroup), nn.Dropout(dropout, inplace=True), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), nn.Linear(_hsize, isize), nn.Dropout(dropout, inplace=True)])
self.nets = nn.ModuleList([nn.Linear(isize, _hsize), nn.Dropout(dropout, inplace=True), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), nn.Linear(_hsize, _hsize), nn.Dropout(dropout, inplace=True), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), nn.Linear(_hsize, isize), nn.Dropout(dropout, inplace=True)])
else:
self.nets = nn.ModuleList([nn.Linear(isize, _hsize), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), GroupLinear(_hsize, _hsize, ngroup), GeLU_BERT() if use_GeLU else nn.ReLU(inplace=True), nn.Linear(_hsize, isize)])
2 changes: 1 addition & 1 deletion modules/rnncells.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,6 @@ def forward(self, x, cell):

p, q = self.t1(x), self.t2(cell)

igate, fgate = torch.sigmoid(p + q), torch.sigmoid(p - q)
igate, fgate = (p + q).sigmoid(), (p - q).sigmoid()

return igate * p + fgate * q
2 changes: 1 addition & 1 deletion transformer/TA/TAEncoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.

self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)])

self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(6.0 / (num_layer + num_layer_dec + 1)), sqrt(6.0 / (num_layer + num_layer_dec + 1))))
self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(2.0 / (num_layer + num_layer_dec + 1)), sqrt(2.0 / (num_layer + num_layer_dec + 1))))
self.tattn_drop = nn.Dropout(dropout) if dropout > 0.0 else None

# inputs: (bsize, seql)
Expand Down
1 change: 1 addition & 0 deletions transformer/TA/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#encoding: utf-8
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.

self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)])

self.tattn_w = nn.Parameter(torch.Tensor(num_layer * num_head).uniform_(- sqrt(6.0 / (num_layer * num_head + 1)), sqrt(6.0 / (num_layer * num_head + 1))))
self.tattn_w = nn.Parameter(torch.Tensor(num_layer * num_head).uniform_(- sqrt(1.0 / (num_layer * num_head)), sqrt(1.0 / (num_layer * num_head))))
self.tattn_drop = nn.Dropout(dropout) if dropout > 0.0 else None

self.classifier = nn.Sequential(nn.Linear(isize * 2, isize, bias=False), nn.Linear(isize, nwd))
Expand Down
50 changes: 46 additions & 4 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
#encoding: utf-8

import torch
from torch.nn.init import xavier_uniform_
from torch.nn.init import xavier_uniform_, kaiming_uniform_
from torch.nn import Embedding, Linear, LayerNorm

from math import sqrt

from random import sample
from random import seed as rpyseed
Expand All @@ -11,15 +14,18 @@
def pad_tensors(tensor_list):

def get_pad_size(tsize, stdlen):

nsize = list(tsize)
nsize[-1] = stdlen - tsize[-1]

return nsize

maxlen = 0
for tensor in tensor_list:
tlen = tensor.size(-1)
if tlen > maxlen:
maxlen = tlen

return [tensor if tensor.size(-1) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), -1) for tensor in tensor_list]

def freeze_module(module):
Expand All @@ -31,6 +37,7 @@ def freeze_module(module):
def unfreeze_module(module):

def unfreeze_fixing(mod):

if "fix_unfreeze" in dir(mod):
mod.fix_unfreeze()

Expand All @@ -39,19 +46,22 @@ def unfreeze_fixing(mod):

module.apply(unfreeze_fixing)


def getlr(optm):

lr = []
for i, param_group in enumerate(optm.param_groups):
lr.append(float(param_group['lr']))

return lr

def updated_lr(oldlr, newlr):

rs = False
for olr, nlr in zip(oldlr, newlr):
if olr != nlr:
rs = True
break

return rs

def dynamic_sample(incd, dss_ws, dss_rm):
Expand Down Expand Up @@ -106,15 +116,47 @@ def get_logger(fname):

logger.addHandler(handler)
logger.addHandler(console)

return logger

def init_model_params(modin):
def init_model_params_glorot(modin, hyp=None):

_scale = sqrt(1.0 / 3.0) if hyp is None else hyp

for p in modin.parameters():
if p.requires_grad and (p.dim() > 1):
xavier_uniform_(p)
xavier_uniform_(p, gain=_scale)

return modin

def init_model_params_kaiming(modin, hyp=None):

_scale = sqrt(5.0) if hyp is None else hyp

for p in modin.parameters():
if p.requires_grad and (p.dim() > 1):
kaiming_uniform_(p, a=_scale)

return modin

def init_model_params(modin, scale_glorot=None, scale_kaiming=None):

_tmpm = init_model_params_kaiming(modin, scale_kaiming)

for _m in _tmpm.modules():
if isinstance(_m, Embedding):
init_model_params_glorot(_m, scale_glorot)
elif isinstance(_m, Linear):
if _m.bias is not None:
with torch.no_grad():
_m.bias.zero_()
elif isinstance(_m, LayerNorm):
with torch.no_grad():
_m.weight.fill_(1.0)
_m.bias.zero_()

return _tmpm

def set_random_seed(seed, set_cuda=False):

_rseed = torch.initial_seed() if seed is None else seed
Expand Down

0 comments on commit 6809ade

Please sign in to comment.