Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fix SelfAttn and lengh penalty
Browse files Browse the repository at this point in the history
  • Loading branch information
hfxunlp committed Jan 28, 2019
1 parent 53ef6c3 commit e245637
Show file tree
Hide file tree
Showing 15 changed files with 74 additions and 35 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ export rsf=trans.txt

## Exporting python files to C libraries

You can convert python classes into C libraries with `python mkcy.py build_ext --inplace`, and codes will be checked before compling, which can serve as a simple to way to find typo and bugs as well. This function is supported by [Cython](https://cython.org/). These files can be removed with `rm -fr *.c *.so parallel/*.c parallel/*.so transformer/*.c transformer/*.so build/`. Loading modules from compiled C libraries may also accelerate, but not significantly.
You can convert python classes into C libraries with `python mkcy.py build_ext --inplace`, and codes will be checked before compiling, which can serve as a simple to way to find typo and bugs as well. This function is supported by [Cython](https://cython.org/). These files can be removed with `rm -fr *.c *.so parallel/*.c parallel/*.so transformer/*.c transformer/*.so build/`. Loading modules from compiled C libraries may also accelerate, but not significantly.

## Ranking

Expand Down Expand Up @@ -414,7 +414,7 @@ Measured with `multi-bleu-detok.perl`:

## Acknowledgements

The project starts when Hongfei XU (the developer) was a postgraduate student at [Zhengzhou University](http://www5.zzu.edu.cn/nlp/), and continues when he is a PhD candidate at [Saarland University](https://www.uni-saarland.de/nc/en/home.html) and a Junior Researcher at [DFKI (German Research Center for Artificial Intelligence)](https://www.dfki.de/en/web/research/research-departments-and-groups/multilingual-technologies/). Hongfei XU enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project.
The project starts when Hongfei XU (the developer) was a postgraduate student at [Zhengzhou University](http://www5.zzu.edu.cn/nlp/), and continues when he is a PhD candidate at [Saarland University](https://www.uni-saarland.de/nc/en/home.html) supervised by [Prof. Dr. Josef van Genabith](https://www.dfki.de/en/web/about-us/employee/person/jova02/) and a Junior Researcher at [DFKI, MLT (German Research Center for Artificial Intelligence, Multilinguality and Language Technology)](https://www.dfki.de/en/web/research/research-departments-and-groups/multilinguality-and-language-technology/). Hongfei XU enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project.

## Contributor(s)

Expand Down
2 changes: 1 addition & 1 deletion mkcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_name(fname):

eccargs = ["-Ofast", "-march=native", "-pipe", "-fomit-frame-pointer"]

baselist = ["modules.py", "loss.py", "lrsch.py", "rnncell.py", "translator.py", "discriminator.py"]
baselist = ["modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"]

extlist = [Extension(get_name(pyf), [pyf], extra_compile_args=eccargs) for pyf in baselist]

Expand Down
19 changes: 1 addition & 18 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,7 @@ def forward(self, iQ, mask=None, iK=None):
real_iQ, real_iK, real_iV = _out.narrow(-1, 0, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2), _out.narrow(-1, self.hsize, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2), _out.narrow(-1, self.hsize + self.hsize, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2)
else:

real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), self.adaptor.bias.narrow(0, 0, self.hsize) if self.adaptor.bias else None), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize) if self.adaptor.bias else None)
real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), self.adaptor.bias.narrow(0, 0, self.hsize) if self.adaptor.bias else None).view(bsize, nquery, nheads, adim).transpose(1, 2), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize) if self.adaptor.bias else None)

seql = iK.size(1)

Expand Down Expand Up @@ -356,23 +356,6 @@ def forward(self, iQ, iK, mask=None):

return self.outer(oMA.view(bsize, nquery, self.hsize))

def freeze_module(module):

for p in module.parameters():
if p.requires_grad:
p.requires_grad_(False)

def unfreeze_module(module):

def unfreeze_fixing(mod):
if "fix_unfreeze" in dir(mod):
mod.fix_unfreeze()

for p in module.parameters():
p.requires_grad_(True)

module.apply(unfreeze_fixing)

# Aggregation from: Exploiting Deep Representations for Neural Machine Translation
class ResidueCombiner(nn.Module):

Expand Down
16 changes: 2 additions & 14 deletions parallel/parallelMT.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from parallel.parallel import DataParallelModel

from utils import pad_tensors

from threading import Lock, Thread

class DataParallelMT(DataParallelModel):
Expand Down Expand Up @@ -40,20 +42,6 @@ def train_decode(self, *inputs, **kwargs):
outputs = parallel_apply_train_decode(replicas, inputs, devices, kwargs)
return self.gather(pad_tensors(outputs), self.output_device) if self.gather_output else outputs

def pad_tensors(tensor_list):

def get_pad_size(tsize, stdlen):
nsize = list(tsize)
nsize[-1] = stdlen - tsize[-1]
return nsize

maxlen = 0
for tensor in tensor_list:
tlen = tensor.size(-1)
if tlen > maxlen:
maxlen = tlen
return [tensor if tensor.size(-1) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), -1) for tensor in tensor_list]

# update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py)

def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None):
Expand Down
6 changes: 6 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,12 @@ def reverse_dict(din):

return rs

def load_model_cpu_old(modf, base_model):

base_model.load_state_dict(torch.load(modf, map_location='cpu'))

return base_model

def load_model_cpu(modf, base_model):

mpg = torch.load(modf, map_location='cpu')
Expand Down
3 changes: 3 additions & 0 deletions transformer/AvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,6 +368,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/Decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/EnsembleAvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/EnsembleDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/HierAvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/HierDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/InceptAvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/InceptDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
3 changes: 3 additions & 0 deletions transformer/RNMTDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,6 +423,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt
# if length penalty is only applied in the last step, apply length penalty
if (not clip_beam) and (length_penalty > 0.0):
scores = scores / lpv.view(bsize, beam_size)
scores, _inds = torch.topk(scores, beam_size, dim=-1)
_inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize)
trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1)

if return_all:

Expand Down
35 changes: 35 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#encoding: utf-8

import torch

def pad_tensors(tensor_list):

def get_pad_size(tsize, stdlen):
nsize = list(tsize)
nsize[-1] = stdlen - tsize[-1]
return nsize

maxlen = 0
for tensor in tensor_list:
tlen = tensor.size(-1)
if tlen > maxlen:
maxlen = tlen
return [tensor if tensor.size(-1) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), -1) for tensor in tensor_list]

def freeze_module(module):

for p in module.parameters():
if p.requires_grad:
p.requires_grad_(False)

def unfreeze_module(module):

def unfreeze_fixing(mod):
if "fix_unfreeze" in dir(mod):
mod.fix_unfreeze()

for p in module.parameters():
p.requires_grad_(True)

module.apply(unfreeze_fixing)

0 comments on commit e245637

Please sign in to comment.