diff --git a/README.md b/README.md index a666237..fa3b5f1 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Neutron -Neutron: A pytorch based implementation of [Transformer](https://arxiv.org/abs/1706.03762) and its variants. +Neutron: A pytorch based implementation of the [Transformer](https://arxiv.org/abs/1706.03762) and its variants. This project is developed with python 3.8. @@ -96,11 +96,11 @@ Tokenized case-sensitive BLEU measured with [multi-bleu.perl](https://github.com | | BLEU | Training Speed | Decoding Speed | | :------| ------: | ------: | ------: | | Attention is all you need | 27.3 | | | -| Neutron | 28.07 | 21562.98 | 68.25 | +| Neutron | 28.07 | 22424.63 | 150.15 | ## Acknowledgments -The project starts when Hongfei XU (the developer) was a postgraduate student at [Zhengzhou University](http://www5.zzu.edu.cn/nlp/), and continues when he is a PhD candidate at [Saarland University](https://www.uni-saarland.de/nc/en/home.html) supervised by [Prof. Dr. Josef van Genabith](https://www.dfki.de/en/web/about-us/employee/person/jova02/) and [Prof. Dr. Deyi Xiong](http://cic.tju.edu.cn/faculty/xiongdeyi/), and a Junior Researcher at [DFKI, MLT (German Research Center for Artificial Intelligence, Multilinguality and Language Technology)](https://www.dfki.de/en/web/research/research-departments-and-groups/multilinguality-and-language-technology/). Hongfei XU enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project. +Hongfei Xu enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project. Details of this project can be found [here](https://arxiv.org/abs/1903.07402), and please cite it if you enjoy the implementation :) diff --git a/adv/predict/doc/para/predict_doc_para.py b/adv/predict/doc/para/predict_doc_para.py index 38f89d6..eee8409 100644 --- a/adv/predict/doc/para/predict_doc_para.py +++ b/adv/predict/doc/para/predict_doc_para.py @@ -57,7 +57,7 @@ def load_fixing(module): # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -75,9 +75,10 @@ def load_fixing(module): with open(sys.argv[1], "wb") as f: with torch.no_grad(): for nsent, i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]).long() - if use_cuda: + seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]) + if cuda_device: seq_batch = seq_batch.to(cuda_device) + seq_batch = seq_batch.long() bsize, _nsent, seql = seq_batch.size() _nsent_use = _nsent - 1 with autocast(enabled=use_amp): diff --git a/adv/predict/predict_ape.py b/adv/predict/predict_ape.py index 5c969d7..1038c55 100644 --- a/adv/predict/predict_ape.py +++ b/adv/predict/predict_ape.py @@ -57,7 +57,7 @@ def load_fixing(module): set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -73,11 +73,12 @@ def load_fixing(module): with open(sys.argv[1], "wb") as f: with torch.no_grad(): for i in tqdm(range(ntest)): - seq_batch = torch.from_numpy(src_grp[str(i)][:]).long() - seq_mt = torch.from_numpy(mt_grp[str(i)][:]).long() - if use_cuda: + seq_batch = torch.from_numpy(src_grp[str(i)][:]) + seq_mt = torch.from_numpy(mt_grp[str(i)][:]) + if cuda_device: seq_batch = seq_batch.to(cuda_device) seq_mt = seq_mt.to(cuda_device) + seq_batch, seq_mt = seq_batch.long(), seq_mt.long() with autocast(enabled=use_amp): output = mymodel.decode(seq_batch, seq_mt, beam_size, None, length_penalty) if multi_gpu: diff --git a/adv/rank/doc/para/rank_loss_para.py b/adv/rank/doc/para/rank_loss_para.py index a3bd653..6f9e406 100644 --- a/adv/rank/doc/para/rank_loss_para.py +++ b/adv/rank/doc/para/rank_loss_para.py @@ -65,7 +65,7 @@ def load_fixing(module): # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) if multi_gpu: @@ -81,12 +81,13 @@ def load_fixing(module): with torch.no_grad(): for i in tqdm(range(ntest)): _curid = str(i) - seq_batch = torch.from_numpy(src_grp[_curid][:]).long() - seq_o = torch.from_numpy(tgt_grp[_curid][:]).long() + seq_batch = torch.from_numpy(src_grp[_curid][:]) + seq_o = torch.from_numpy(tgt_grp[_curid][:]) lo = seq_o.size(-1) - 1 - if use_cuda: + if cuda_device: seq_batch = seq_batch.to(cuda_device) seq_o = seq_o.to(cuda_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() bsize, _nsent = seq_batch.size()[:2] _nsent_use = _nsent - 1 seq_o = seq_o.narrow(1, 1, _nsent_use) diff --git a/adv/rank/doc/rank_loss_sent.py b/adv/rank/doc/rank_loss_sent.py index a126bf7..61076e7 100644 --- a/adv/rank/doc/rank_loss_sent.py +++ b/adv/rank/doc/rank_loss_sent.py @@ -65,7 +65,7 @@ def load_fixing(module): # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) if multi_gpu: @@ -79,13 +79,14 @@ def load_fixing(module): with torch.no_grad(): for i in tqdm(range(ntest)): _curid = str(i) - seq_batch = torch.from_numpy(src_grp[_curid][:]).long() - seq_o = torch.from_numpy(tgt_grp[_curid][:]).long() + seq_batch = torch.from_numpy(src_grp[_curid][:]) + seq_o = torch.from_numpy(tgt_grp[_curid][:]) bsize, nsent = seq_batch.size()[:2] ebsize = bsize * nsent - if use_cuda: + if cuda_device: seq_batch = seq_batch.to(cuda_device) seq_o = seq_o.to(cuda_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() lo = seq_o.size(-1) - 1 ot = seq_o.narrow(-1, 1, lo).contiguous() with autocast(enabled=use_amp): diff --git a/adv/train/doc/para/train_doc_para.py b/adv/train/doc/para/train_doc_para.py index 3e2b21d..50672ea 100644 --- a/adv/train/doc/para/train_doc_para.py +++ b/adv/train/doc/para/train_doc_para.py @@ -42,12 +42,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok src_grp, tgt_grp = td["src"], td["tgt"] for nsent, i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[nsent][i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]) + seq_o = torch.from_numpy(tgt_grp[nsent][i_d][:]) lo = seq_o.size(-1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() _nsent = seq_batch.size(1) _nsent_use = _nsent - 1 @@ -145,12 +146,13 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): src_grp, tgt_grp = ed["src"], ed["tgt"] with torch.no_grad(): for nsent, i_d in tqdm(nd): - seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[nsent][i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[nsent][i_d][:]) + seq_o = torch.from_numpy(tgt_grp[nsent][i_d][:]) lo = seq_o.size(-1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() _nsent = seq_batch.size(1) _nsent_use = _nsent - 1 @@ -261,7 +263,7 @@ def init_fixing(module): logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) diff --git a/adv/train/train_ape.py b/adv/train/train_ape.py index 0d35118..84d799c 100644 --- a/adv/train/train_ape.py +++ b/adv/train/train_ape.py @@ -40,14 +40,15 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok cur_b, _ls = 1, {} if save_loss else None src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"] for i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[i_d][:]).long() - seq_mt = torch.from_numpy(mt_grp[i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[i_d][:]) + seq_mt = torch.from_numpy(mt_grp[i_d][:]) + seq_o = torch.from_numpy(tgt_grp[i_d][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_mt = seq_mt.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_mt, seq_o = seq_batch.long(), seq_mt.long(), seq_o.long() oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() @@ -142,14 +143,15 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): with torch.no_grad(): for i in tqdm(range(nd)): bid = str(i) - seq_batch = torch.from_numpy(src_grp[bid][:]).long() - seq_mt = torch.from_numpy(mt_grp[bid][:]).long() - seq_o = torch.from_numpy(tgt_grp[bid][:]).long() + seq_batch = torch.from_numpy(src_grp[bid][:]) + seq_mt = torch.from_numpy(mt_grp[bid][:]) + seq_o = torch.from_numpy(tgt_grp[bid][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_mt = seq_mt.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_mt, seq_o = seq_batch.long(), seq_mt.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=use_amp): output = model(seq_batch, seq_mt, seq_o.narrow(1, 0, lo)) @@ -251,7 +253,7 @@ def init_fixing(module): logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index d800e8c..cde714a 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -56,12 +56,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[i_d][:]) + seq_o = torch.from_numpy(tgt_grp[i_d][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() @@ -169,12 +170,13 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): with torch.no_grad(): for i in tqdm(range(nd)): bid = str(i) - seq_batch = torch.from_numpy(src_grp[bid][:]).long() - seq_o = torch.from_numpy(tgt_grp[bid][:]).long() + seq_batch = torch.from_numpy(src_grp[bid][:]) + seq_o = torch.from_numpy(tgt_grp[bid][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) @@ -272,7 +274,7 @@ def init_fixing(module): logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) diff --git a/cnfg/README.md b/cnfg/README.md index a8736ba..52ee3f2 100644 --- a/cnfg/README.md +++ b/cnfg/README.md @@ -152,10 +152,13 @@ cache_len_default = 256 use_k_relative_position = 0 disable_std_pemb = False +# using fast implementation of label smoothing loss, but it cannot exclude the negative impact of special tokens, like , on training. `forbidden_indexes` in `cnfg/base.py` shall be set to None to enable. +use_fast_loss = False + # configure maximum batch size w.r.t GPU memory -max_sentences_gpu = 768 -max_tokens_gpu = 4608 -max_pad_tokens_sentence = 16 +max_sentences_gpu = 2048 +max_tokens_gpu = 6144 +max_pad_tokens_sentence = 32 normal_tokens_vs_pad_tokens = 4 # trade CPU for IO and disk space, see [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. @@ -168,11 +171,14 @@ hdf5_model_compression_level = 0 # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True + +# prune with length penalty in each beam decoding step +clip_beam_with_lp = True ``` ## `ihyp.py` -To interpret configurations in hyp.py. +To interpret configurations in `hyp.py`. ## `dynb.py` diff --git a/cnfg/hyp.py b/cnfg/hyp.py index 8d38b6e..ea07a8d 100644 --- a/cnfg/hyp.py +++ b/cnfg/hyp.py @@ -13,7 +13,7 @@ # default cached sequence length (for positional embedding, etc.) cache_len_default = 256 -# window size (one side) of relative positional embeddings, 0 to disable. 16 and 8 are used in [Self-Attention with Relative Position Representations](https://www.aclweb.org/anthology/N18-2074/) for Transformer Base and Big respectively. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. +# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://www.aclweb.org/anthology/N18-2074/) for Transformer Base and Big respectively. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. use_k_relative_position = 0 disable_std_pemb = False @@ -21,9 +21,9 @@ use_fast_loss = False # configure maximum batch size w.r.t GPU memory -max_sentences_gpu = 768 -max_tokens_gpu = 4608 -max_pad_tokens_sentence = 16 +max_sentences_gpu = 2048 +max_tokens_gpu = 6144 +max_pad_tokens_sentence = 32 normal_tokens_vs_pad_tokens = 4 # trade CPU for IO and disk space, see [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. diff --git a/cnfg/ihyp.py b/cnfg/ihyp.py index e26d27c..9f55949 100644 --- a/cnfg/ihyp.py +++ b/cnfg/ihyp.py @@ -45,6 +45,7 @@ use_k_relative_position_encoder, use_k_relative_position_decoder = parse_double_value_tuple(use_k_relative_position) rel_pos_enabled = (max(use_k_relative_position_encoder, use_k_relative_position_decoder) > 0) disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) +relpos_reduction_with_zeros = True h5datawargs = {} if hdf5_data_compression is None else {"compression": hdf5_data_compression, "compression_opts": hdf5_data_compression_level, "shuffle":True} h5modelwargs = {} if hdf5_model_compression is None else {"compression": hdf5_model_compression, "compression_opts": hdf5_model_compression_level, "shuffle":True} diff --git a/modules/attn/res.py b/modules/attn/res.py index e4c3067..cb61ebd 100644 --- a/modules/attn/res.py +++ b/modules/attn/res.py @@ -1,5 +1,6 @@ #encoding: utf-8 +import torch from torch.nn import functional as nnFunc from math import sqrt @@ -10,29 +11,27 @@ class SelfAttn(SelfAttnBase): - def forward(self, iQ, mask=None, iK=None, resin=None): + def forward(self, iQ, mask=None, states=None, resin=None): bsize, nquery = iQ.size()[:2] nheads = self.num_head adim = self.attn_dim - if iK is None: - - real_iQ, real_iK, real_iV = self.adaptor(iQ).view(bsize, nquery, 3, nheads, adim).unbind(2) - - else: - - seql = iK.size(1) - - real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), None if self.adaptor.bias is None else self.adaptor.bias.narrow(0, 0, self.hsize)).view(bsize, nquery, nheads, adim), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), None if self.adaptor.bias is None else self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize)).view(bsize, seql, 2, nheads, adim) - real_iK, real_iV = _out.unbind(2) - + real_iQ, real_iK, real_iV = self.adaptor(iQ).view(bsize, nquery, 3, nheads, adim).unbind(2) real_iQ, real_iK, real_iV = real_iQ.transpose(1, 2), real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + if states is not None: + _h_real_iK, _h_real_iV = states + if _h_real_iK is None: + seql = nquery + else: + seql = nquery + _h_real_iK.size(-1) + real_iK, real_iV = torch.cat((_h_real_iK, real_iK,), dim=-1), torch.cat((_h_real_iV, real_iV,), dim=2) + scores = real_iQ.matmul(real_iK) if self.rel_pemb is not None: - if iK is None: + if states is None: self.rel_pos_cache = self.get_rel_pos(nquery).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache scores += real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.rel_pos_cache).transpose(1, 2)).view(nquery, bsize, nheads, nquery).permute(1, 2, 0, 3) else: @@ -54,9 +53,12 @@ def forward(self, iQ, mask=None, iK=None, resin=None): if self.drop is not None: scores = self.drop(scores) - oMA = scores.matmul(real_iV).transpose(1, 2).contiguous() + out = self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)) - return self.outer(oMA.view(bsize, nquery, self.hsize)), resout + if states is None: + return out, resout + else: + return out, (real_iK, real_iV,), resout class CrossAttn(CrossAttnBase): @@ -67,10 +69,14 @@ def forward(self, iQ, iK, mask=None, resin=None): nheads = self.num_head adim = self.attn_dim - real_iQ, _out = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim), self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim) - real_iK, real_iV = _out.unbind(2) - - real_iQ, real_iK, real_iV = real_iQ.transpose(1, 2), real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) + if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + real_iK, real_iV = self.real_iK, self.real_iV + else: + real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) + real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + if not self.training: + self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) / sqrt(adim) @@ -87,6 +93,4 @@ def forward(self, iQ, iK, mask=None, resin=None): if self.drop is not None: scores = self.drop(scores) - oMA = scores.matmul(real_iV).transpose(1, 2).contiguous() - - return self.outer(oMA.view(bsize, nquery, self.hsize)), resout + return self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)), resout diff --git a/modules/base.py b/modules/base.py index 0e75e04..ef19fad 100644 --- a/modules/base.py +++ b/modules/base.py @@ -6,7 +6,7 @@ from torch.nn import functional as nnFunc from torch.autograd import Function -from utils.base import reduce_model_list +from utils.base import reduce_model_list, repeat_bsize_for_beam_tensor from modules.act import Custom_Act from modules.act import reduce_model as reduce_model_act from modules.dropout import Dropout @@ -114,10 +114,14 @@ class MultiHeadAttn(nn.Module): # osize: output size of this layer # num_head: number of heads # dropout: dropout probability + # k_rel_pos: uni-directional window size of relative positional encoding + # uni_direction_reduction: performing resource reduction for uni-directional self-attention + # is_left_to_right_reduction: only for uni_direction_reduction, indicating left-to-right self-attention or right-to-left + # zero_reduction: only for uni_direction_reduction, using zeros for padding positions in the relative positional matrix # sparsenorm: using sparse normer or standard softmax # bind_qk: query and key can share a same linear transformation for the Reformer: The Efficient Transformer (https://arxiv.org/abs/2001.04451) paper. - def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=0, sparsenorm=False, bind_qk=False, xseql=cache_len_default): + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=0, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, sparsenorm=False, bind_qk=False, xseql=cache_len_default): super(MultiHeadAttn, self).__init__() @@ -138,22 +142,47 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v self.drop = Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None if k_rel_pos > 0: - self.k_rel_pos = k_rel_pos - self.rel_pemb = nn.Embedding(k_rel_pos * 2 + 1, self.attn_dim) + self.rel_shift = k_rel_pos + padding_idx = None + if uni_direction_reduction: + _n_pemb = k_rel_pos + 1 + if is_left_to_right_reduction: + self.clamp_min, self.clamp_max = -k_rel_pos, 0, + else: + self.clamp_min, self.clamp_max, self.rel_shift = 0, k_rel_pos, 0 + if zero_reduction: + _n_pemb += 1 + if is_left_to_right_reduction: + self.clamp_max += 1 + padding_idx = self.clamp_max + else: + self.clamp_min -= 1 + self.rel_shift += 1 + padding_idx = 0 + else: + _n_pemb = k_rel_pos + k_rel_pos + 1 + self.clamp_min, self.clamp_max = -k_rel_pos, k_rel_pos + self.rel_pemb = nn.Embedding(_n_pemb, self.attn_dim, padding_idx=padding_idx) _rpm = torch.arange(-xseql + 1, 1, dtype=torch.long).unsqueeze(0) - self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=-k_rel_pos, max=k_rel_pos) + k_rel_pos) + self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos self.ref_rel_posm = None + self.register_buffer("rel_pos_cache", None) else: self.rel_pemb = None + self.register_buffer('real_iK', None) + self.register_buffer('real_iV', None) + self.register_buffer('iK', None) + self.register_buffer('iV', None) + # iQ: query (bsize, num_query, vsize) # iK: keys (bsize, seql, vsize) # iV: values (bsize, seql, vsize) # mask (bsize, num_query, seql) - def forward(self, iQ, iK, iV, mask=None): + def forward(self, iQ, iK, iV, mask=None, states=None): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -164,7 +193,26 @@ def forward(self, iQ, iK, iV, mask=None): # real_iK: MultiHead iK (bsize, seql, vsize) => (bsize, nheads, adim, seql) # real_iV: MultiHead iV (bsize, seql, vsize) => (bsize, nheads, seql, adim) - real_iQ, real_iK, real_iV = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2), self.key_adaptor(iK).view(bsize, seql, nheads, adim).permute(0, 2, 3, 1), self.value_adaptor(iV).view(bsize, seql, nheads, adim).transpose(1, 2) + real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) + + if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + real_iK = self.real_iK + else: + real_iK = self.key_adaptor(iK).view(bsize, seql, nheads, adim).permute(0, 2, 3, 1) + if not self.training: + self.iK, self.real_iK = iK, real_iK + if (self.real_iV is not None) and self.iV.is_set_to(iV) and (not self.training): + real_iV = self.real_iV + else: + real_iV = self.value_adaptor(iV).view(bsize, seql, nheads, adim).transpose(1, 2) + if not self.training: + self.iV, self.real_iV = iV, real_iV + + if states is not None: + _h_real_iK, _h_real_iV = states + if _h_real_iK is not None: + seql += _h_real_iK.size(-1) + real_iK, real_iV = torch.cat((_h_real_iK, real_iK,), dim=-1), torch.cat((_h_real_iV, real_iV,), dim=2) # scores (bsize, nheads, nquery, adim) * (bsize, nheads, adim, seql) => (bsize, nheads, nquery, seql) @@ -184,13 +232,23 @@ def forward(self, iQ, iK, iV, mask=None): if self.drop is not None: scores = self.drop(scores) - # oMA: output of MultiHeadAttention T((bsize, nheads, nquery, seql) * (bsize, nheads, seql, adim)) => (bsize, nquery, nheads, adim) + # output of this layer T((bsize, nheads, nquery, seql) * (bsize, nheads, seql, adim)) => (bsize, nquery, nheads, adim) => (bsize, nquery, osize) + + out = self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)) + + if states is None: + return out + else: + return out, (real_iK, real_iV,) + + def train(self, mode=True): - oMA = scores.matmul(real_iV).transpose(1, 2).contiguous() + super(MultiHeadAttn, self).train(mode) - # output of this layer (bsize, nquery, nheads, adim) => (bsize, nquery, osize) + if mode: + self.reset_buffer() - return self.outer(oMA.view(bsize, nquery, self.hsize)) + return self def get_rel_pos(self, length): @@ -198,7 +256,25 @@ def get_rel_pos(self, length): return self.rel_pos.narrow(0, 0, length).narrow(1, 0, length) else: _rpm = torch.arange(-length + 1, 1, dtype=self.rel_pos.dtype, device=self.rel_pos.device).unsqueeze(0) - return ((_rpm - _rpm.t()).clamp(min=-self.k_rel_pos, max=self.k_rel_pos) + self.k_rel_pos) + return ((_rpm - _rpm.t()).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) + + def reset_buffer(self, value=None): + + self.iK = self.iV = self.real_iK = self.real_iV = self.rel_pos_cache = value + + def repeat_buffer(self, beam_size): + + if self.real_iK is not None: + self.real_iK = repeat_bsize_for_beam_tensor(self.real_iK, beam_size) + if self.real_iV is not None: + self.real_iV = repeat_bsize_for_beam_tensor(self.real_iV, beam_size) + + def index_buffer(self, indices, dim=0): + + if self.real_iK is not None: + self.real_iK = self.real_iK.index_select(dim, indices) + if self.real_iV is not None: + self.real_iV = self.real_iV.index_select(dim, indices) # Average Attention is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) class AverageAttn(nn.Module): @@ -254,7 +330,7 @@ def get_ext(self, npos): # Accelerated MultiHeadAttn for self attention, use when Q == K == V class SelfAttn(nn.Module): - def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, sparsenorm=False, xseql=cache_len_default): + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, sparsenorm=False, xseql=cache_len_default): super(SelfAttn, self).__init__() @@ -272,10 +348,29 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena self.drop = Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None if k_rel_pos > 0: - self.k_rel_pos = k_rel_pos - self.rel_pemb = nn.Embedding(k_rel_pos * 2 + 1, self.attn_dim) + self.rel_shift = k_rel_pos + padding_idx = None + if uni_direction_reduction: + _n_pemb = k_rel_pos + 1 + if is_left_to_right_reduction: + self.clamp_min, self.clamp_max = -k_rel_pos, 0, + else: + self.clamp_min, self.clamp_max, self.rel_shift = 0, k_rel_pos, 0 + if zero_reduction: + _n_pemb += 1 + if is_left_to_right_reduction: + self.clamp_max += 1 + padding_idx = self.clamp_max + else: + self.clamp_min -= 1 + self.rel_shift += 1 + padding_idx = 0 + else: + _n_pemb = k_rel_pos + k_rel_pos + 1 + self.clamp_min, self.clamp_max = -k_rel_pos, k_rel_pos + self.rel_pemb = nn.Embedding(_n_pemb, self.attn_dim, padding_idx=padding_idx) _rpm = torch.arange(-xseql + 1, 1, dtype=torch.long).unsqueeze(0) - self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=-k_rel_pos, max=k_rel_pos) + k_rel_pos) + self.register_buffer("rel_pos", (_rpm - _rpm.t()).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos self.ref_rel_posm = None @@ -283,29 +378,27 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena else: self.rel_pemb = None - def forward(self, iQ, mask=None, iK=None): + def forward(self, iQ, mask=None, states=None): bsize, nquery = iQ.size()[:2] nheads = self.num_head adim = self.attn_dim - if iK is None: - - real_iQ, real_iK, real_iV = self.adaptor(iQ).view(bsize, nquery, 3, nheads, adim).unbind(2) - - else: - - seql = iK.size(1) - - real_iQ, _out = nnFunc.linear(iQ, self.adaptor.weight.narrow(0, 0, self.hsize), None if self.adaptor.bias is None else self.adaptor.bias.narrow(0, 0, self.hsize)).view(bsize, nquery, nheads, adim), nnFunc.linear(iK, self.adaptor.weight.narrow(0, self.hsize, self.hsize + self.hsize), None if self.adaptor.bias is None else self.adaptor.bias.narrow(0, self.hsize, self.hsize + self.hsize)).view(bsize, seql, 2, nheads, adim) - real_iK, real_iV = _out.unbind(2) - + real_iQ, real_iK, real_iV = self.adaptor(iQ).view(bsize, nquery, 3, nheads, adim).unbind(2) real_iQ, real_iK, real_iV = real_iQ.transpose(1, 2), real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + if states is not None: + _h_real_iK, _h_real_iV = states + if _h_real_iK is None: + seql = nquery + else: + seql = nquery + _h_real_iK.size(-1) + real_iK, real_iV = torch.cat((_h_real_iK, real_iK,), dim=-1), torch.cat((_h_real_iV, real_iV,), dim=2) + scores = real_iQ.matmul(real_iK) if self.rel_pemb is not None: - if iK is None: + if states is None: self.rel_pos_cache = self.get_rel_pos(nquery).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache scores += real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.rel_pos_cache).transpose(1, 2)).view(nquery, bsize, nheads, nquery).permute(1, 2, 0, 3) else: @@ -322,9 +415,12 @@ def forward(self, iQ, mask=None, iK=None): if self.drop is not None: scores = self.drop(scores) - oMA = scores.matmul(real_iV).transpose(1, 2).contiguous() + out = self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)) - return self.outer(oMA.view(bsize, nquery, self.hsize)) + if states is None: + return out + else: + return out, (real_iK, real_iV,) def get_rel_pos(self, length): @@ -332,7 +428,11 @@ def get_rel_pos(self, length): return self.rel_pos.narrow(0, 0, length).narrow(1, 0, length) else: _rpm = torch.arange(-length + 1, 1, dtype=self.rel_pos.dtype, device=self.rel_pos.device).unsqueeze(0) - return ((_rpm - _rpm.t()).clamp(min=-self.k_rel_pos, max=self.k_rel_pos) + self.k_rel_pos) + return ((_rpm - _rpm.t()).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) + + def reset_buffer(self, value=None): + + self.rel_pos_cache = value # Accelerated MultiHeadAttn for cross attention, use when K == V class CrossAttn(nn.Module): @@ -356,6 +456,10 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, e self.drop = Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None + self.register_buffer('real_iK', None) + self.register_buffer('real_iV', None) + self.register_buffer('iK', None) + def forward(self, iQ, iK, mask=None): bsize, nquery = iQ.size()[:2] @@ -363,10 +467,14 @@ def forward(self, iQ, iK, mask=None): nheads = self.num_head adim = self.attn_dim - real_iQ, _out = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim), self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim) - real_iK, real_iV = _out.unbind(2) - - real_iQ, real_iK, real_iV = real_iQ.transpose(1, 2), real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) + if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + real_iK, real_iV = self.real_iK, self.real_iV + else: + real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) + real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) + if not self.training: + self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) / sqrt(adim) @@ -378,9 +486,30 @@ def forward(self, iQ, iK, mask=None): if self.drop is not None: scores = self.drop(scores) - oMA = scores.matmul(real_iV).transpose(1, 2).contiguous() + return self.outer(scores.matmul(real_iV).transpose(1, 2).contiguous().view(bsize, nquery, self.hsize)) + + def train(self, mode=True): + + super(CrossAttn, self).train(mode) + + if mode: + self.reset_buffer() + + return self + + def reset_buffer(self, value=None): + + self.iK = self.real_iK = self.real_iV = value + + def repeat_buffer(self, beam_size): + + if self.real_iK is not None: + self.real_iK, self.real_iV = repeat_bsize_for_beam_tensor(self.real_iK, beam_size), repeat_bsize_for_beam_tensor(self.real_iV, beam_size) + + def index_buffer(self, indices, dim=0): - return self.outer(oMA.view(bsize, nquery, self.hsize)) + if self.real_iK is not None: + self.real_iK, self.real_iV = self.real_iK.index_select(dim, indices), self.real_iV.index_select(dim, indices) # Aggregation from: Exploiting Deep Representations for Neural Machine Translation class ResidueCombiner(nn.Module): diff --git a/predict.py b/predict.py index 9231e54..6de4914 100644 --- a/predict.py +++ b/predict.py @@ -57,7 +57,7 @@ def load_fixing(module): # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -72,9 +72,10 @@ def load_fixing(module): with open(sys.argv[1], "wb") as f: with torch.no_grad(): for i in tqdm(range(ntest)): - seq_batch = torch.from_numpy(src_grp[str(i)][:]).long() - if use_cuda: + seq_batch = torch.from_numpy(src_grp[str(i)][:]) + if cuda_device: seq_batch = seq_batch.to(cuda_device) + seq_batch = seq_batch.long() with autocast(enabled=use_amp): output = mymodel.decode(seq_batch, beam_size, None, length_penalty) #output = mymodel.train_decode(seq_batch, beam_size, None, length_penalty) diff --git a/rank_loss.py b/rank_loss.py index ac4374d..66b5934 100644 --- a/rank_loss.py +++ b/rank_loss.py @@ -65,7 +65,7 @@ def load_fixing(module): # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) if multi_gpu: @@ -79,11 +79,12 @@ def load_fixing(module): with torch.no_grad(): for i in tqdm(range(ntest)): _curid = str(i) - seq_batch = torch.from_numpy(src_grp[_curid][:]).long() - seq_o = torch.from_numpy(tgt_grp[_curid][:]).long() - if use_cuda: + seq_batch = torch.from_numpy(src_grp[_curid][:]) + seq_o = torch.from_numpy(tgt_grp[_curid][:]) + if cuda_device: seq_batch = seq_batch.to(cuda_device) seq_o = seq_o.to(cuda_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() lo = seq_o.size(1) - 1 ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=use_amp): diff --git a/requirements.txt b/requirements.txt index fff5fc0..cbaf7f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tqdm>=4.55.1 +tqdm>=4.56.0 torch>=1.7.1 h5py>=2.10.0 diff --git a/tools/README.md b/tools/README.md index 8a8c6df..fdd8847 100644 --- a/tools/README.md +++ b/tools/README.md @@ -35,6 +35,10 @@ Tools to check the implementation and the data. When you using a shared vocabulary for source side and target side, there are still some words which only appear at the source side even joint BPE is applied. Those words take up probabilities in the label smoothing classifier, and this tool can prevent this through generating a larger and well covered forbidden indexes list which can be concatnated to `forbidden_indexes` in `cnfg/base.py`. +### `prune_model_vocab.py` + +Pruning source and target vocabularies of the trained model, useful for reducing the vocabulary sizes in case a shared vocabulary is used during training. + ## `clean/` Tools to filter the datasets. diff --git a/tools/average_model.py b/tools/average_model.py index 20f944a..adc1727 100644 --- a/tools/average_model.py +++ b/tools/average_model.py @@ -1,7 +1,7 @@ #encoding: utf-8 ''' usage: - python tools/average_model.py $averaged_model_file.h5 $model1.h5 $ model2.h5 ... + python tools/average_model.py $averaged_model_file.h5 $model1.h5 $model2.h5 ... ''' import sys diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index 8cdb7a7..3ca08b9 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -61,12 +61,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[i_d][:]) + seq_o = torch.from_numpy(tgt_grp[i_d][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() @@ -191,12 +192,13 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): with torch.no_grad(): for i in tqdm(range(nd)): bid = str(i) - seq_batch = torch.from_numpy(src_grp[bid][:]).long() - seq_o = torch.from_numpy(tgt_grp[bid][:]).long() + seq_batch = torch.from_numpy(src_grp[bid][:]) + seq_o = torch.from_numpy(tgt_grp[bid][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) @@ -296,7 +298,7 @@ def init_fixing(module): logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) diff --git a/tools/check/tspeed.py b/tools/check/tspeed.py index dc57da0..60c4d3a 100644 --- a/tools/check/tspeed.py +++ b/tools/check/tspeed.py @@ -71,7 +71,7 @@ def load_fixing(module): multi_gpu = False cuda_devices = None -if use_cuda: +if cuda_device: mymodel.to(cuda_device) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -83,9 +83,10 @@ def load_fixing(module): src_grp = td["src"] with torch.no_grad(): for i in tqdm(range(ntest)): - seq_batch = torch.from_numpy(src_grp[str(i)][:]).long() - if use_cuda: + seq_batch = torch.from_numpy(src_grp[str(i)][:]) + if cuda_device: seq_batch = seq_batch.to(cuda_device) + seq_batch = seq_batch.long() output = mymodel.decode(seq_batch, beam_size, None, length_penalty) td.close() diff --git a/tools/prune_model_vocab.py b/tools/prune_model_vocab.py new file mode 100644 index 0000000..117105e --- /dev/null +++ b/tools/prune_model_vocab.py @@ -0,0 +1,40 @@ +#encoding: utf-8 + +''' this file aims at pruning source/target vocabulary of the trained model using a shared vocabulary. It depends on the model implementation, and has to be executed at the root path of the project. Usage: + python prune_model_vocab.py path/to/common.vcb path/to/src.vcb path/to/tgt.vcb path/to/model.h5 path/to/pruned_model.h5 +''' + +import sys + +import torch +from utils.base import load_model_cpu, save_model +from utils.fmt.base import ldvocab, reverse_dict +from transformer.NMT import NMT + +import cnfg.base as cnfg +from cnfg.ihyp import * + +def handle(common, src, tgt, srcm, rsm, minfreq=False, vsize=False): + + vcbc, nwordf = ldvocab(common, minf=minfreq, omit_vsize=vsize, vanilla=False) + + if src == common: + src_indices = None + else: + vcbw, nword = ldvocab(src, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbw = reverse_dict(vcbw) + src_indices = torch.tensor([vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long) + if tgt == common: + tgt_indices = None + else: + vcbw, nword = ldvocab(tgt, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbw = reverse_dict(vcbw) + tgt_indices = torch.tensor([vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long) + + mymodel = NMT(cnfg.isize, nwordf, nwordf, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = load_model_cpu(srcm, mymodel) + mymodel.update_vocab(src_indices=src_indices, tgt_indices=tgt_indices) + save_model(mymodel, rsm, sub_module=False, logger=None, h5args=h5zipargs) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]) diff --git a/train.py b/train.py index 7e873c8..0fb635f 100644 --- a/train.py +++ b/train.py @@ -41,12 +41,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): - seq_batch = torch.from_numpy(src_grp[i_d][:]).long() - seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() + seq_batch = torch.from_numpy(src_grp[i_d][:]) + seq_o = torch.from_numpy(tgt_grp[i_d][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() @@ -144,12 +145,13 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): with torch.no_grad(): for i in tqdm(range(nd)): bid = str(i) - seq_batch = torch.from_numpy(src_grp[bid][:]).long() - seq_o = torch.from_numpy(tgt_grp[bid][:]).long() + seq_batch = torch.from_numpy(src_grp[bid][:]) + seq_o = torch.from_numpy(tgt_grp[bid][:]) lo = seq_o.size(1) - 1 if mv_device: seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) + seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() with autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) @@ -254,7 +256,7 @@ def init_fixing(module): logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) -if use_cuda: +if cuda_device: mymodel.to(cuda_device) lossf.to(cuda_device) diff --git a/transformer/AGG/HierDecoder.py b/transformer/AGG/HierDecoder.py index 6e41e61..0187426 100644 --- a/transformer/AGG/HierDecoder.py +++ b/transformer/AGG/HierDecoder.py @@ -27,7 +27,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.comb_input = comb_input - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, concat_query=False): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): outs = [] if query_unit is None: @@ -44,12 +44,10 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un outs.append(out) states_return = [] for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None if inputo is None else inputo.select(-2, _tmp), src_pad_mask, tgt_pad_mask, out, concat_query) + out, _state = net(inpute, None if inputo is None else inputo[_tmp], src_pad_mask, tgt_pad_mask, out) outs.append(out) states_return.append(_state) - states_return = torch.stack(states_return, -2) - out = self.combiner(*outs) if query_unit is None: diff --git a/transformer/AGG/InceptDecoder.py b/transformer/AGG/InceptDecoder.py index e3a7fc1..429c23b 100644 --- a/transformer/AGG/InceptDecoder.py +++ b/transformer/AGG/InceptDecoder.py @@ -23,7 +23,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.combiner = ResidueCombiner(isize, num_sub, _fhsize) - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, concat_query=False): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): outs = [] if query_unit is None: @@ -36,12 +36,10 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un out = query_unit states_return = [] for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None if inputo is None else inputo.select(-2, _tmp), src_pad_mask, tgt_pad_mask, out, concat_query) + out, _state = net(inpute, None if inputo is None else inputo[_tmp], src_pad_mask, tgt_pad_mask, out) outs.append(out) states_return.append(_state) - states_return = torch.stack(states_return, -2) - out = self.combiner(*outs) if query_unit is None: diff --git a/transformer/APE/Decoder.py b/transformer/APE/Decoder.py index bebc9a8..b156944 100644 --- a/transformer/APE/Decoder.py +++ b/transformer/APE/Decoder.py @@ -7,7 +7,7 @@ from transformer.Decoder import DecoderLayer as DecoderLayerBase from transformer.Decoder import Decoder as DecoderBase -from utils.base import all_done, repeat_bsize_for_beam_tensor, mask_tensor_type +from utils.base import all_done, index_tensors, expand_bsize_for_beam, mask_tensor_type from math import sqrt from utils.fmt.base import pad_id @@ -40,11 +40,7 @@ def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, t else: _query_unit = self.layer_normer1(query_unit) - _inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) - - states_return = _inputo - - context = self.self_attn(_query_unit, iK=_inputo) + context, states_return = self.self_attn(_query_unit, states=inputo) if self.drop is not None: context = self.drop(context) @@ -219,14 +215,13 @@ def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_ done_trans = wds.view(bsize, beam_size).eq(2) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) - inputm = inputm.repeat(1, beam_size, 1).view(real_bsize, mtl, isize) + #inputm = inputm.repeat(1, beam_size, 1).view(real_bsize, mtl, isize) + self.repeat_cross_attn_buffer(beam_size) _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) _mt_pad_mask = None if mt_pad_mask is None else mt_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, mtl) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -278,8 +273,7 @@ def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_ if _done or all_done(done_trans, real_bsize): break - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) diff --git a/transformer/APE/Encoder.py b/transformer/APE/Encoder.py index 129713c..e816949 100644 --- a/transformer/APE/Encoder.py +++ b/transformer/APE/Encoder.py @@ -6,7 +6,7 @@ from transformer.Encoder import Encoder as EncoderBase from transformer.Decoder import DecoderLayer as MSEncoderLayerBase -from utils.fmt.base import parse_double_value_tuple +from utils.fmt.base import pad_id, parse_double_value_tuple from math import sqrt @@ -48,7 +48,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.wemb = nn.Embedding(nwd, isize, padding_idx=0) + self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) if emb_w is not None: self.wemb.weight = emb_w @@ -101,3 +101,17 @@ def forward(self, inpute, inputo, src_mask=None, tgt_mask=None): enc_src = self.src_enc(inpute, src_mask) return enc_src, self.tgt_enc(enc_src, inputo, src_mask, tgt_mask) + + def update_vocab(self, indices): + + _bind_emb = self.src_enc.wemb.weight.is_set_to(self.tgt_enc.wemb.weight) + _swemb = nn.Embedding(len(indices), self.src_enc.wemb.weight.size(-1), padding_idx=pad_id) + _twemb = nn.Embedding(len(indices), self.tgt_enc.wemb.weight.size(-1), padding_idx=pad_id) + with torch.no_grad(): + _swemb.weight.copy_(self.src_enc.wemb.weight.index_select(0, indices)) + if _bind_emb: + _twemb.weight = _swemb.weight + else: + with torch.no_grad(): + _twemb.weight.copy_(self.tgt_enc.wemb.weight.index_select(0, indices)) + self.src_enc.wemb, self.tgt_enc.wemb = _swemb, _twemb diff --git a/transformer/APE/NMT.py b/transformer/APE/NMT.py index 31797d5..e0e7e37 100644 --- a/transformer/APE/NMT.py +++ b/transformer/APE/NMT.py @@ -1,23 +1,22 @@ #encoding: utf-8 -from torch import nn - from utils.relpos import share_rel_pos_cache from utils.fmt.base import parse_double_value_tuple from transformer.APE.Encoder import Encoder from transformer.APE.Decoder import Decoder +from transformer.NMT import NMT as NMTBase from cnfg.ihyp import * -class NMT(nn.Module): +class NMT(NMTBase): def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): - super(NMT, self).__init__() - enc_layer, dec_layer = parse_double_value_tuple(num_layer) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + self.enc = Encoder(isize, (snwd, tnwd,), enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, global_emb=global_emb) emb_w = self.enc.tgt_enc.wemb.weight if global_emb else None diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index bded082..b530123 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -4,7 +4,7 @@ from torch import nn from modules.base import * from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from utils.aan import share_aan_cache from math import sqrt @@ -99,14 +99,20 @@ class Decoder(DecoderBase): # ahsize: number of hidden units for MultiHeadAttention # bindemb: bind embedding and classifier weight - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): _ahsize = isize if ahsize is None else ahsize _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, _fhsize, dropout, attn_drop, emb_w, num_head, xseql, _ahsize, norm_output, bindemb, forbidden_index) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + if share_layer: + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + self.mask = None share_aan_cache(self) @@ -286,7 +292,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) @@ -294,8 +300,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(2, max_len + 1): @@ -381,8 +386,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): diff --git a/transformer/Decoder.py b/transformer/Decoder.py index 978ba44..e7b953e 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -4,7 +4,7 @@ from torch import nn from modules.base import * from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor, mask_tensor_type +from utils.base import all_done, index_tensors, expand_bsize_for_beam, mask_tensor_type from math import sqrt from utils.fmt.base import pad_id @@ -28,7 +28,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a _ahsize = isize if ahsize is None else ahsize _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) + self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True) self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) @@ -62,11 +62,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un else: _query_unit = self.layer_normer1(query_unit) - _inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) - - states_return = _inputo - - context = self.self_attn(_query_unit, iK=_inputo) + context, states_return = self.self_attn(_query_unit, states=inputo) if self.drop is not None: context = self.drop(context) @@ -115,7 +111,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.xseql = xseql self.register_buffer('mask', torch.ones(xseql, xseql, dtype=mask_tensor_type).triu(1).unsqueeze(0)) - self.wemb = nn.Embedding(nwd, isize, padding_idx=0) + self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) if emb_w is not None: self.wemb.weight = emb_w @@ -193,6 +189,14 @@ def _get_subsequent_mask(self, length): return self.mask.narrow(1, 0, length).narrow(2, 0, length) if length <= self.xseql else self.mask.new_ones(length, length).triu(1).unsqueeze(0) + # this function repeats buffers of all cross-attention keys/values, corresponding inputs do not need to be repeated in beam search. + + def repeat_cross_attn_buffer(self, beam_size): + + for _m in self.modules(): + if isinstance(_m, (CrossAttn, MultiHeadAttn,)): + _m.repeat_buffer(beam_size) + # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: # src_pad_mask = input.eq(0).unsqueeze(1) @@ -229,7 +233,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, src_pad_mask, None, out) + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -307,7 +311,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, src_pad_mask, None, out) + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -331,9 +335,10 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt done_trans = wds.view(bsize, beam_size).eq(2) - # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) + # instead of update inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) with the following line, we only update cross-attention buffers. + #inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) @@ -341,8 +346,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -428,8 +432,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): @@ -457,7 +460,7 @@ def fix_init(self): self.fix_load() with torch.no_grad(): - self.wemb.weight[pad_id].zero_() + #self.wemb.weight[pad_id].zero_() self.classifier.weight[pad_id].zero_() def fix_load(self): @@ -475,6 +478,28 @@ def unbind_classifier_weight(self): _new_w.data.copy_(_tmp.data) self.classifier.weight = _new_w + # this function will untie the decoder embedding from the encoder + + def update_vocab(self, indices): + + _nwd = len(indices) + _wemb = nn.Embedding(_nwd, self.wemb.weight.size(-1), padding_idx=pad_id) + _classifier = Linear(self.classifier.weight.size(-1), _nwd) + with torch.no_grad(): + _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + if self.classifier.weight.is_set_to(self.wemb.weight): + _classifier.weight = _wemb.weight + else: + _classifier.weight.copy_(self.classifier.weight.index_select(0, indices)) + _classifier.bias.copy_(self.classifier.bias.index_select(0, indices)) + self.wemb, self.classifier = _wemb, _classifier + + def index_cross_attn_buffer(self, indices, dim=0): + + for _m in self.modules(): + if isinstance(_m, (CrossAttn, MultiHeadAttn,)): + _m.index_buffer(indices, dim=dim) + # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: # src_pad_mask = input.eq(0).unsqueeze(1) @@ -510,7 +535,7 @@ def greedy_decode_clip(self, inpute, src_pad_mask=None, max_len=512, return_mat= states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, src_pad_mask, None, out) + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -569,11 +594,11 @@ def greedy_decode_clip(self, inpute, src_pad_mask=None, max_len=512, return_mat= _ndid = (~done_trans).nonzero().squeeze(1) bsize = _ndid.size(0) wds = wds.index_select(0, _ndid) - inpute = inpute.index_select(0, _ndid) + #inpute = inpute.index_select(0, _ndid) + self.index_cross_attn_buffer(_ndid) if src_pad_mask is not None: src_pad_mask = src_pad_mask.index_select(0, _ndid) - for k, value in states.items(): - states[k] = value.index_select(0, _ndid) + states = index_tensors(states, indices=_ndid, dim=0) trans = list(_trans.index_select(0, _ndid).unbind(1)) # update mapper @@ -615,7 +640,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, src_pad_mask, None, out) + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -640,8 +665,9 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, done_trans = wds.view(bsize, beam_size).eq(2) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) + #inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) @@ -649,8 +675,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, # states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - states[key] = value.repeat(1, beam_size, 1).view(real_bsize, 1, isize) + states = expand_bsize_for_beam(states, beam_size=beam_size) mapper = list(range(bsize)) rs = [None for i in range(bsize)] @@ -758,8 +783,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, # states[i]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if _ndone > 0: _dind = _done_trans_u.nonzero().squeeze(1) @@ -790,11 +814,14 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, _real_bsize = _bsize * beam_size wds = wds.view(bsize, beam_size).index_select(0, _ndid).view(_real_bsize, 1) - inpute = inpute.view(bsize, beam_size, seql, isize).index_select(0, _ndid).view(_real_bsize, seql, isize) + #inpute = inpute.view(bsize, beam_size, seql, isize).index_select(0, _ndid).view(_real_bsize, seql, isize) + for _m in self.modules(): + if isinstance(layer, (CrossAttn, MultiHeadAttn,)) and layer.real_iK is not None: + layer.real_iK, layer.real_iV = tuple(_vu.view(bsize, beam_size, *list(_vu.size()[1:])).index_select(0, _ndid).view(_real_bsize, *list(_vu.size()[1:])) for _vu in (layer.real_iK, layer.real_iV,)) if _src_pad_mask is not None: _src_pad_mask = _src_pad_mask.view(bsize, beam_size, 1, seql).index_select(0, _ndid).view(_real_bsize, 1, seql) for k, value in states.items(): - states[k] = value.view(bsize, beam_size, -1, isize).index_select(0, _ndid).view(_real_bsize, -1, isize) + states[k] = [_vu.view(bsize, beam_size, *list(_vu.size()[1:])).index_select(0, _ndid).view(_real_bsize, *list(_vu.size()[1:])) for _vu in value] sum_scores = sum_scores.index_select(0, _ndid) trans = _trans.index_select(0, _ndid).view(_real_bsize, -1) if length_penalty > 0.0: diff --git a/transformer/Doc/Para/Base/Decoder.py b/transformer/Doc/Para/Base/Decoder.py index 694cfb5..bf41cc3 100644 --- a/transformer/Doc/Para/Base/Decoder.py +++ b/transformer/Doc/Para/Base/Decoder.py @@ -5,7 +5,7 @@ from modules.base import * from utils.sampler import SampleMax from modules.paradoc import GateResidual -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from math import sqrt from utils.fmt.base import pad_id @@ -42,11 +42,7 @@ def forward(self, inpute, inputo, inputc, src_pad_mask=None, tgt_pad_mask=None, else: _query_unit = self.layer_normer1(query_unit) - _inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) - - states_return = _inputo - - context = self.self_attn(_query_unit, iK=_inputo) + context, states_return = self.self_attn(_query_unit, states=inputo) if self.drop is not None: context = self.drop(context) @@ -162,7 +158,7 @@ def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, ma states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, inputc, src_pad_mask, context_mask, None, out) + out, _state = net(inpute, (None, None,), inputc, src_pad_mask, context_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -228,7 +224,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam states = {} for _tmp, net in enumerate(self.nets): - out, _state = net(inpute, None, inputc, src_pad_mask, context_mask, None, out) + out, _state = net(inpute, (None, None,), inputc, src_pad_mask, context_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -244,7 +240,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam done_trans = wds.view(bsize, beam_size).eq(2) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) _cbsize, _cseql = inputc[0].size()[:2] @@ -253,8 +249,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam _inputc = [inputu.repeat(1, beam_size, 1).view(_creal_bsize, _cseql, isize) for inputu in inputc] - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -306,8 +301,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam if _done or all_done(done_trans, real_bsize): break - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) diff --git a/transformer/Doc/Para/Base/Encoder.py b/transformer/Doc/Para/Base/Encoder.py index 27e6ff8..68c8b2f 100644 --- a/transformer/Doc/Para/Base/Encoder.py +++ b/transformer/Doc/Para/Base/Encoder.py @@ -150,3 +150,7 @@ def get_pad(self, seql): def get_padmask(self, seql): return self.pad_mask.narrow(-1, 0, seql) if seql <= self.xseql else torch.cat((self.pad_mask, self.pad_mask.new_ones(1, 1, self.nprev_context - 1, seql - self.xseql),), dim=-1) + + def update_vocab(self, indices): + + self.context_enc.update_vocab(indices) diff --git a/transformer/Encoder.py b/transformer/Encoder.py index ffc1162..c1e415f 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -82,7 +82,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.wemb = nn.Embedding(nwd, isize, padding_idx=0) + self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) self.pemb = None if disable_pemb else PositionalEmb(isize, xseql, 0, 0) if share_layer: @@ -126,9 +126,16 @@ def load_base(self, base_encoder): self.out_normer = None if self.out_normer is None else base_encoder.out_normer + def update_vocab(self, indices): + + _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=pad_id) + with torch.no_grad(): + _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + self.wemb = _wemb + def fix_init(self): if hasattr(self, "fix_load"): self.fix_load() - with torch.no_grad(): - self.wemb.weight[pad_id].zero_() + #with torch.no_grad(): + # self.wemb.weight[pad_id].zero_() diff --git a/transformer/EnsembleAvgDecoder.py b/transformer/EnsembleAvgDecoder.py index 0a6e76a..b3e7494 100644 --- a/transformer/EnsembleAvgDecoder.py +++ b/transformer/EnsembleAvgDecoder.py @@ -2,7 +2,7 @@ import torch from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from math import sqrt from utils.fmt.base import pad_id @@ -202,9 +202,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i][j]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - for _key, _value in value.items(): - value[_key] = repeat_bsize_for_beam_tensor(_value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(2, max_len + 1): @@ -296,9 +294,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i][j]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - for _key, _value in value.items(): - value[_key] = _value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): diff --git a/transformer/EnsembleDecoder.py b/transformer/EnsembleDecoder.py index 2e5a22a..10fd9c7 100644 --- a/transformer/EnsembleDecoder.py +++ b/transformer/EnsembleDecoder.py @@ -3,7 +3,7 @@ import torch from torch import nn from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from math import sqrt from utils.fmt.base import pad_id @@ -221,9 +221,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i][j]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - for _key, _value in value.items(): - value[_key] = repeat_bsize_for_beam_tensor(_value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -315,9 +313,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i][j]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - for _key, _value in value.items(): - value[_key] = _value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): diff --git a/transformer/LD/Decoder.py b/transformer/LD/Decoder.py index d29d3a5..e54db20 100644 --- a/transformer/LD/Decoder.py +++ b/transformer/LD/Decoder.py @@ -7,7 +7,7 @@ from utils.sampler import SampleMax from modules.TA import PositionwiseFF -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam, repeat_bsize_for_beam_tensor from math import sqrt from utils.fmt.base import pad_id @@ -32,7 +32,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.ff = PositionwiseFF(isize, _fhsize, dropout) self.scff = ResidueCombiner(isize, 2, _fhsize, dropout) - def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None, concat_query=False): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None): if query_unit is None: @@ -45,13 +45,7 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, else: - if concat_query: - - inputo = query_unit if inputo is None else torch.cat((inputo, query_unit,), 1) - - states_return = inputo - - context = self.self_attn(query_unit, iK=inputo) + context, states_return = self.self_attn(query_unit, states=inputo) if self.drop is not None: context = self.drop(context) @@ -209,14 +203,13 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam done_trans = wds.view(bsize, beam_size).eq(2) - inpute = inpute.repeat(1, beam_size, 1, 1).view(real_bsize, seql, isize, -1) - inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) + #inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) + self.repeat_cross_attn_buffer(beam_size) _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) _chk_pad_mask = None if chk_pad_mask is None else repeat_bsize_for_beam_tensor(chk_pad_mask, beam_size) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -267,8 +260,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam if _done or all_done(done_trans, real_bsize): break - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) diff --git a/transformer/LD/NMT.py b/transformer/LD/NMT.py index 38f3817..a42417b 100644 --- a/transformer/LD/NMT.py +++ b/transformer/LD/NMT.py @@ -1,23 +1,22 @@ #encoding: utf-8 -from torch import nn - from utils.relpos import share_rel_pos_cache from utils.fmt.base import parse_double_value_tuple from transformer.LD.Encoder import Encoder from transformer.LD.Decoder import Decoder +from transformer.NMT import NMT as NMTBase from cnfg.ihyp import * -class NMT(nn.Module): +class NMT(NMTBase): def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): - super(NMT, self).__init__() - enc_layer, dec_layer = parse_double_value_tuple(num_layer) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, dec_layer) emb_w = self.enc.wemb.weight if global_emb else None diff --git a/transformer/NMT.py b/transformer/NMT.py index 7d8255e..f8e3af0 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -66,6 +66,13 @@ def load_base(self, base_nmt): else: self.dec = base_nmt.dec + def update_vocab(self, src_indices=None, tgt_indices=None): + + if (src_indices is not None) and hasattr(self.enc, "update_vocab"): + self.enc.update_vocab(src_indices) + if (tgt_indices is not None) and hasattr(self.dec, "update_vocab"): + self.dec.update_vocab(tgt_indices) + # inpute: source sentences from encoder (bsize, seql) # beam_size: the beam size for beam search # max_len: maximum length to generate diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index f459caf..fa58e43 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -12,6 +12,8 @@ from utils.fmt.base import pad_id +from transformer.Decoder import Decoder as DecoderBase + from cnfg.ihyp import * class FirstLayer(nn.Module): @@ -103,7 +105,7 @@ def forward(self, inputo, attn, state=None, first_step=False): return out + inputo if self.residual else out, (hx, cx) -class Decoder(nn.Module): +class Decoder(DecoderBase): # isize: size of word embedding # nwd: number of words @@ -115,19 +117,11 @@ class Decoder(nn.Module): # ahsize: number of hidden units for MultiHeadAttention # bindemb: bind embedding and classifier weight - def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, projector=True): - - super(Decoder, self).__init__() + def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, projector=True, **kwargs): _ahsize = isize if ahsize is None else ahsize - self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - - self.xseql = xseql - - self.wemb = nn.Embedding(nwd, isize, padding_idx=0) - if emb_w is not None: - self.wemb.weight = emb_w + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=isize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) self.flayer = FirstLayer(isize, osize=isize, dropout=dropout) @@ -142,11 +136,7 @@ def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None #if bindemb: #list(self.classifier.modules())[-1].weight = self.wemb.weight - self.lsm = nn.LogSoftmax(-1) - - self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None - - self.fbl = None if forbidden_index is None else tuple(set(forbidden_index)) + self.mask = None # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: decoded translation (bsize, nquery) @@ -181,16 +171,6 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out - # inpute: encoded representation from encoder (bsize, seql, isize) - # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: - # src_pad_mask = input.eq(0).unsqueeze(1) - # beam_size: the beam size for beam search - # max_len: maximum length to generate - - def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): - - return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad) - # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: # src_pad_mask = input.eq(0).unsqueeze(1) @@ -324,7 +304,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) @@ -333,8 +313,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize, 2, isize) => (bsize * beam_size, 2, isize) statefl = statefl.repeat(1, beam_size, 1).view(real_bsize, 2, isize) - for key, value in states.items(): - states[key] = value.repeat(1, beam_size, 1).view(real_bsize, 2, isize) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -424,8 +403,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # _inds: (bsize, beam_size) => (bsize * beam_size) statefl = statefl.index_select(0, _inds) - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): @@ -441,24 +419,8 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt return trans.view(bsize, beam_size, -1).select(1, 0) - # inpute: encoded representation from encoder (bsize, seql, isize) - - def get_sos_emb(self, inpute, bsize=None): - - bsize = inpute.size(0) if bsize is None else bsize - - return self.wemb.weight[1].view(1, -1).expand(bsize, -1) - - def fix_init(self): - - self.fix_load() - with torch.no_grad(): - self.wemb.weight[pad_id].zero_() - self.classifier.weight[pad_id].zero_() - - def fix_load(self): + '''def fix_load(self): if self.fbl is not None: with torch.no_grad(): - #list(self.classifier.modules())[-1].bias.index_fill_(0, torch.tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default) - self.classifier.bias.index_fill_(0, torch.tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default) + list(self.classifier.modules())[-1].bias.index_fill_(0, torch.tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default)''' diff --git a/transformer/RealFormer/Decoder.py b/transformer/RealFormer/Decoder.py index d9cd5e4..f3e12a5 100644 --- a/transformer/RealFormer/Decoder.py +++ b/transformer/RealFormer/Decoder.py @@ -9,7 +9,7 @@ from transformer.Decoder import Decoder as DecoderBase from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from math import sqrt from utils.fmt.base import pad_id @@ -25,7 +25,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos) - self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos) + self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True) self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, resin=None): @@ -48,11 +48,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un else: _query_unit = self.layer_normer1(query_unit) - _inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) - - states_return = _inputo - - context, sresout = self.self_attn(_query_unit, iK=_inputo, resin=sresin) + context, states_return, sresout = self.self_attn(_query_unit, states=inputo, resin=sresin) if self.drop is not None: context = self.drop(context) @@ -132,7 +128,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, states = {} attnres = None for _tmp, net in enumerate(self.nets): - out, _state, attnres = net(inpute, None, src_pad_mask, None, out, resin=attnres) + out, _state, attnres = net(inpute, (None, None,), src_pad_mask, None, out, resin=attnres) states[_tmp] = _state if self.out_normer is not None: @@ -199,7 +195,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt states = {} attnres = None for _tmp, net in enumerate(self.nets): - out, _state, attnres = net(inpute, None, src_pad_mask, None, out, resin=attnres) + out, _state, attnres = net(inpute, (None, None,), src_pad_mask, None, out, resin=attnres) states[_tmp] = _state if self.out_normer is not None: @@ -215,12 +211,11 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt done_trans = wds.view(bsize, beam_size).eq(2) - inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + self.repeat_cross_attn_buffer(beam_size) _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -273,8 +268,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt if _done or all_done(done_trans, real_bsize): break - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) diff --git a/transformer/SC/Decoder.py b/transformer/SC/Decoder.py index 89a0e0d..06d113d 100644 --- a/transformer/SC/Decoder.py +++ b/transformer/SC/Decoder.py @@ -7,7 +7,7 @@ from utils.sampler import SampleMax from modules.TA import PositionwiseFF -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam, repeat_bsize_for_beam_tensor from math import sqrt from utils.fmt.base import pad_id @@ -29,7 +29,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.ff = PositionwiseFF(isize, _fhsize, dropout) self.scff = ResidueCombiner(isize, 2, _fhsize) - def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, concat_query=False): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): if query_unit is None: @@ -46,13 +46,7 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, _query_unit = self.scff(query_unit, inputh) - if concat_query: - - inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) - - states_return = inputo - - context = self.self_attn(_query_unit, iK=inputo) + context, states_return = self.self_attn(_query_unit, states=inputo) if self.drop is not None: context = self.drop(context) @@ -208,13 +202,12 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 done_trans = wds.view(bsize, beam_size).eq(2) - inpute = inpute.repeat(1, beam_size, 1, 1).view(real_bsize, seql, isize, -1) inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) + self.repeat_cross_attn_buffer(beam_size) _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -265,8 +258,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 if _done or all_done(done_trans, real_bsize): break - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) if (not clip_beam) and (length_penalty > 0.0): scores = scores / lpv.view(bsize, beam_size) diff --git a/transformer/SC/NMT.py b/transformer/SC/NMT.py index 949cd78..af9cbf2 100644 --- a/transformer/SC/NMT.py +++ b/transformer/SC/NMT.py @@ -1,25 +1,24 @@ #encoding: utf-8 -from torch import nn - from utils.relpos import share_rel_pos_cache from utils.fmt.base import parse_double_value_tuple from transformer.SC.Encoder import Encoder from transformer.SC.Decoder import Decoder +from transformer.NMT import NMT as NMTBase from math import sqrt from cnfg.ihyp import * -class NMT(nn.Module): +class NMT(NMTBase): def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): - super(NMT, self).__init__() - enc_layer, dec_layer = parse_double_value_tuple(num_layer) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, num_layer) emb_w = self.enc.wemb.weight if global_emb else None diff --git a/transformer/TA/Decoder.py b/transformer/TA/Decoder.py index b74327e..2217980 100644 --- a/transformer/TA/Decoder.py +++ b/transformer/TA/Decoder.py @@ -3,7 +3,7 @@ import torch from modules.base import * from utils.sampler import SampleMax -from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.base import all_done, index_tensors, expand_bsize_for_beam from math import sqrt from utils.fmt.base import pad_id @@ -190,7 +190,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # inpute: (bsize, seql, isize, num_layer_enc) => (bsize * beam_size, seql, isize, num_layer_enc) - inpute = inpute.repeat(1, beam_size, 1, 1).view(real_bsize, seql, isize, -1) + self.repeat_cross_attn_buffer(beam_size) # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) @@ -198,8 +198,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) - for key, value in states.items(): - states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + states = expand_bsize_for_beam(states, beam_size=beam_size) for step in range(1, max_len): @@ -285,8 +284,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # states[i]: (bsize * beam_size, nquery, isize) # _inds: (bsize, beam_size) => (bsize * beam_size) - for key, value in states.items(): - states[key] = value.index_select(0, _inds) + states = index_tensors(states, indices=_inds, dim=0) # if length penalty is only applied in the last step, apply length penalty if (not clip_beam) and (length_penalty > 0.0): diff --git a/transformer/UniEncoder.py b/transformer/UniEncoder.py index 77eec64..d5ceed6 100644 --- a/transformer/UniEncoder.py +++ b/transformer/UniEncoder.py @@ -1,30 +1,17 @@ #encoding: utf-8 +import torch from torch import nn from modules.base import * -from transformer.Encoder import EncoderLayer -# vocabulary: -# :0 -# :1 -# :2 -# :3 -# ... -# for the classier of the decoder, is omitted +from utils.fmt.base import pad_id + +from transformer.Encoder import EncoderLayer from cnfg.ihyp import * class Encoder(nn.Module): - # isize: size of word embedding - # nwd: number of words - # num_layer: number of encoder layers - # fhsize: number of hidden units for PositionwiseFeedForward - # attn_drop: dropout for MultiHeadAttention - # num_head: number of heads in MultiHeadAttention - # xseql: maxmimum length of sequence - # ahsize: number of hidden units for MultiHeadAttention - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True): super(Encoder, self).__init__() @@ -36,7 +23,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.wemb = nn.Embedding(nwd, isize, padding_idx=0) + self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) self.pemb = CoordinateEmb(isize, xseql, num_layer, 0, 0) self.net = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) @@ -101,3 +88,10 @@ def forward(self, inputs, mask=None): return out else: return out, loss_act + + def update_vocab(self, indices): + + _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=pad_id) + with torch.no_grad(): + _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + self.wemb = _wemb diff --git a/utils/base.py b/utils/base.py index d5c2596..a57f6e8 100644 --- a/utils/base.py +++ b/utils/base.py @@ -1,6 +1,7 @@ #encoding: utf-8 import torch +from torch import Tensor from torch.nn import ModuleDict from threading import Thread @@ -267,19 +268,33 @@ def expand_bsize_for_beam(*inputs, beam_size=1): outputs = [] for inputu in inputs: - if inputu is None: - outputs.append(None) - elif isinstance(inputu, list): - outputs.append(list(expand_bsize_for_beam(*inputu, beam_size=beam_size))) + if isinstance(inputu, Tensor): + outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size)) + elif isinstance(inputu, dict): + outputs.append({k: expand_bsize_for_beam(v, beam_size=beam_size) for k, v in inputu.items()}) elif isinstance(inputu, tuple): - outputs.append(tuple(expand_bsize_for_beam(*inputu, beam_size=beam_size))) + outputs.append(tuple(expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu)) + elif isinstance(inputu, list): + outputs.append([expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu]) + else: + outputs.append(inputu) + + return outputs[0] if len(inputs) == 1 else tuple(outputs) + +def index_tensors(*inputs, indices=None, dim=0): + + outputs = [] + for inputu in inputs: + if isinstance(inputu, Tensor): + outputs.append(inputu.index_select(dim, indices)) elif isinstance(inputu, dict): - _tmp = {} - for _k, _v in inputu.items(): - _tmp[_k] = expand_bsize_for_beam(_v, beam_size=beam_size) - outputs.append(_tmp) + outputs.append({k: index_tensors(v, indices=indices, dim=dim) for k, v in inputu.items()}) + elif isinstance(inputu, tuple): + outputs.append(tuple(index_tensors(tmpu, indices=indices, dim=dim) for tmpu in inputu)) + elif isinstance(inputu, list): + outputs.append([index_tensors(tmpu, indices=indices, dim=dim) for tmpu in inputu]) else: - outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size)) + outputs.append(inputu) return outputs[0] if len(inputs) == 1 else tuple(outputs) diff --git a/utils/relpos.py b/utils/relpos.py index ef1be5d..bf7d94b 100644 --- a/utils/relpos.py +++ b/utils/relpos.py @@ -8,18 +8,19 @@ def share_rel_pos_cache(netin): rel_cache_d = {} for net in netin.modules(): if isinstance(net, ModuleList): - _base_net = None + base_nets = {} for layer in net.modules(): if isinstance(layer, (SelfAttn, MultiHeadAttn,)): if layer.rel_pemb is not None: - if _base_net is None: - _base_net = layer + _key = (layer.clamp_min, layer.clamp_max, layer.rel_shift,) + if _key in base_nets: + layer.ref_rel_posm = base_nets[_key] else: - layer.ref_rel_posm = _base_net - _rel_c_size = layer.rel_pos.size() - if _rel_c_size in rel_cache_d: - layer.rel_pos = rel_cache_d[_rel_c_size] + base_nets[_key] = layer + _key = (layer.clamp_min, layer.clamp_max, layer.rel_shift, layer.rel_pos.size(),) + if _key in rel_cache_d: + layer.rel_pos = rel_cache_d[_key] else: - rel_cache_d[_rel_c_size] = layer.rel_pos + rel_cache_d[_key] = layer.rel_pos return netin diff --git a/utils/sampler.py b/utils/sampler.py index 5ff4c11..459254e 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -14,4 +14,4 @@ def SampleMax(input, dim=-1, keepdim=False): _msize[dim] = 1 _ms.logical_xor_(torch.cat((_ms.new_zeros(_msize, dtype=_ms.dtype, device=_ms.device), _ms.narrow(dim, 0, _nkeep),), dim=dim)) - return _ms.long().argmax(dim=dim, keepdim=keepdim) + return _ms.byte().argmax(dim=dim, keepdim=keepdim)