From e24563715010e13b82ee2c49a06e43574732a26d Mon Sep 17 00:00:00 2001 From: ano Date: Mon, 28 Jan 2019 13:49:14 +0100 Subject: [PATCH] fix SelfAttn and lengh penalty --- README.md | 4 ++-- mkcy.py | 2 +- modules.py | 19 +---------------- parallel/parallelMT.py | 16 ++------------ predict.py | 6 ++++++ transformer/AvgDecoder.py | 3 +++ transformer/Decoder.py | 3 +++ transformer/EnsembleAvgDecoder.py | 3 +++ transformer/EnsembleDecoder.py | 3 +++ transformer/HierAvgDecoder.py | 3 +++ transformer/HierDecoder.py | 3 +++ transformer/InceptAvgDecoder.py | 3 +++ transformer/InceptDecoder.py | 3 +++ transformer/RNMTDecoder.py | 3 +++ utils.py | 35 +++++++++++++++++++++++++++++++ 15 files changed, 74 insertions(+), 35 deletions(-) create mode 100644 utils.py diff --git a/README.md b/README.md index cbeebae..f4bd0dd 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ export rsf=trans.txt ## Exporting python files to C libraries -You can convert python classes into C libraries with `python mkcy.py build_ext --inplace`, and codes will be checked before compling, which can serve as a simple to way to find typo and bugs as well. This function is supported by [Cython](https://cython.org/). These files can be removed with `rm -fr *.c *.so parallel/*.c parallel/*.so transformer/*.c transformer/*.so build/`. Loading modules from compiled C libraries may also accelerate, but not significantly. +You can convert python classes into C libraries with `python mkcy.py build_ext --inplace`, and codes will be checked before compiling, which can serve as a simple to way to find typo and bugs as well. This function is supported by [Cython](https://cython.org/). These files can be removed with `rm -fr *.c *.so parallel/*.c parallel/*.so transformer/*.c transformer/*.so build/`. Loading modules from compiled C libraries may also accelerate, but not significantly. ## Ranking @@ -414,7 +414,7 @@ Measured with `multi-bleu-detok.perl`: ## Acknowledgements -The project starts when Hongfei XU (the developer) was a postgraduate student at [Zhengzhou University](http://www5.zzu.edu.cn/nlp/), and continues when he is a PhD candidate at [Saarland University](https://www.uni-saarland.de/nc/en/home.html) and a Junior Researcher at [DFKI (German Research Center for Artificial Intelligence)](https://www.dfki.de/en/web/research/research-departments-and-groups/multilingual-technologies/). Hongfei XU enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project. +The project starts when Hongfei XU (the developer) was a postgraduate student at [Zhengzhou University](http://www5.zzu.edu.cn/nlp/), and continues when he is a PhD candidate at [Saarland University](https://www.uni-saarland.de/nc/en/home.html) supervised by [Prof. Dr. Josef van Genabith](https://www.dfki.de/en/web/about-us/employee/person/jova02/) and a Junior Researcher at [DFKI, MLT (German Research Center for Artificial Intelligence, Multilinguality and Language Technology)](https://www.dfki.de/en/web/research/research-departments-and-groups/multilinguality-and-language-technology/). Hongfei XU enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project. ## Contributor(s) diff --git a/mkcy.py b/mkcy.py index 0aaa4b2..ea64d81 100644 --- a/mkcy.py +++ b/mkcy.py @@ -17,7 +17,7 @@ def get_name(fname): eccargs = ["-Ofast", "-march=native", "-pipe", "-fomit-frame-pointer"] - baselist = ["modules.py", "loss.py", "lrsch.py", "rnncell.py", "translator.py", "discriminator.py"] + baselist = ["modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"] extlist = [Extension(get_name(pyf), [pyf], extra_compile_args=eccargs) for pyf in baselist] diff --git a/modules.py b/modules.py index 24c0eae..e7ef1a9 100644 --- a/modules.py +++ b/modules.py @@ -263,7 +263,7 @@ def forward(self, iQ, mask=None, iK=None): real_iQ, real_iK, real_iV = _out.narrow(-1, 0, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2), _out.narrow(-1, self.hsize, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2), _out.narrow(-1, self.hsize + self.hsize, self.hsize).contiguous().view(bsize, nquery, nheads, adim).transpose(1, 2) else: - real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), self.adaptor.bias.narrow(0, 0, self.hsize) if self.adaptor.bias else None), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize) if self.adaptor.bias else None) + real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), self.adaptor.bias.narrow(0, 0, self.hsize) if self.adaptor.bias else None).view(bsize, nquery, nheads, adim).transpose(1, 2), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize) if self.adaptor.bias else None) seql = iK.size(1) @@ -356,23 +356,6 @@ def forward(self, iQ, iK, mask=None): return self.outer(oMA.view(bsize, nquery, self.hsize)) -def freeze_module(module): - - for p in module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - -def unfreeze_module(module): - - def unfreeze_fixing(mod): - if "fix_unfreeze" in dir(mod): - mod.fix_unfreeze() - - for p in module.parameters(): - p.requires_grad_(True) - - module.apply(unfreeze_fixing) - # Aggregation from: Exploiting Deep Representations for Neural Machine Translation class ResidueCombiner(nn.Module): diff --git a/parallel/parallelMT.py b/parallel/parallelMT.py index cbb09e1..ab3093e 100644 --- a/parallel/parallelMT.py +++ b/parallel/parallelMT.py @@ -4,6 +4,8 @@ from parallel.parallel import DataParallelModel +from utils import pad_tensors + from threading import Lock, Thread class DataParallelMT(DataParallelModel): @@ -40,20 +42,6 @@ def train_decode(self, *inputs, **kwargs): outputs = parallel_apply_train_decode(replicas, inputs, devices, kwargs) return self.gather(pad_tensors(outputs), self.output_device) if self.gather_output else outputs -def pad_tensors(tensor_list): - - def get_pad_size(tsize, stdlen): - nsize = list(tsize) - nsize[-1] = stdlen - tsize[-1] - return nsize - - maxlen = 0 - for tensor in tensor_list: - tlen = tensor.size(-1) - if tlen > maxlen: - maxlen = tlen - return [tensor if tensor.size(-1) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), -1) for tensor in tensor_list] - # update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None): diff --git a/predict.py b/predict.py index 1db4445..af69fe9 100644 --- a/predict.py +++ b/predict.py @@ -82,6 +82,12 @@ def reverse_dict(din): return rs +def load_model_cpu_old(modf, base_model): + + base_model.load_state_dict(torch.load(modf, map_location='cpu')) + + return base_model + def load_model_cpu(modf, base_model): mpg = torch.load(modf, map_location='cpu') diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index 9ff1dcc..85e9a66 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -368,6 +368,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/Decoder.py b/transformer/Decoder.py index 147ceb8..d3a55f8 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -407,6 +407,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/EnsembleAvgDecoder.py b/transformer/EnsembleAvgDecoder.py index 556e3e7..d119f82 100644 --- a/transformer/EnsembleAvgDecoder.py +++ b/transformer/EnsembleAvgDecoder.py @@ -292,6 +292,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/EnsembleDecoder.py b/transformer/EnsembleDecoder.py index 27e0e62..dac4c4c 100644 --- a/transformer/EnsembleDecoder.py +++ b/transformer/EnsembleDecoder.py @@ -312,6 +312,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/HierAvgDecoder.py b/transformer/HierAvgDecoder.py index 63b34dd..19116d5 100644 --- a/transformer/HierAvgDecoder.py +++ b/transformer/HierAvgDecoder.py @@ -250,6 +250,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/HierDecoder.py b/transformer/HierDecoder.py index 6b0f39d..2ffc5ef 100644 --- a/transformer/HierDecoder.py +++ b/transformer/HierDecoder.py @@ -246,6 +246,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/InceptAvgDecoder.py b/transformer/InceptAvgDecoder.py index 52911f4..9b29e2c 100644 --- a/transformer/InceptAvgDecoder.py +++ b/transformer/InceptAvgDecoder.py @@ -244,6 +244,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/InceptDecoder.py b/transformer/InceptDecoder.py index 6cceba7..2801715 100644 --- a/transformer/InceptDecoder.py +++ b/transformer/InceptDecoder.py @@ -240,6 +240,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index 043a7d3..c14713a 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -423,6 +423,9 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) if return_all: diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..7619598 --- /dev/null +++ b/utils.py @@ -0,0 +1,35 @@ +#encoding: utf-8 + +import torch + +def pad_tensors(tensor_list): + + def get_pad_size(tsize, stdlen): + nsize = list(tsize) + nsize[-1] = stdlen - tsize[-1] + return nsize + + maxlen = 0 + for tensor in tensor_list: + tlen = tensor.size(-1) + if tlen > maxlen: + maxlen = tlen + return [tensor if tensor.size(-1) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), -1) for tensor in tensor_list] + +def freeze_module(module): + + for p in module.parameters(): + if p.requires_grad: + p.requires_grad_(False) + +def unfreeze_module(module): + + def unfreeze_fixing(mod): + if "fix_unfreeze" in dir(mod): + mod.fix_unfreeze() + + for p in module.parameters(): + p.requires_grad_(True) + + module.apply(unfreeze_fixing) +