From 5ca6c8034220de022961641cd47ad6e1a6e6f2e0 Mon Sep 17 00:00:00 2001 From: ano Date: Mon, 11 Mar 2019 11:24:56 +0100 Subject: [PATCH] fix and acc decode --- TAmodules.py | 29 ++++ mkcy.py | 2 +- modules.py | 7 +- scripts/mktest.sh | 6 +- tools/check/charatio.py | 2 +- tools/restore.py | 36 +++++ tools/sort.py | 4 +- tools/sorti.py | 38 +++++ transformer/AvgDecoder.py | 19 +++ transformer/Decoder.py | 20 ++- transformer/Encoder.py | 14 ++ transformer/RNMTDecoder.py | 6 +- transformer/TADecoder.py | 293 +++++++++++++++++++++++++++++++++++++ transformer/TAEncoder.py | 111 ++++++++++++++ translator.py | 114 ++++++++++++--- 15 files changed, 668 insertions(+), 33 deletions(-) create mode 100644 TAmodules.py create mode 100644 tools/restore.py create mode 100644 tools/sorti.py create mode 100644 transformer/TADecoder.py create mode 100644 transformer/TAEncoder.py diff --git a/TAmodules.py b/TAmodules.py new file mode 100644 index 0000000..e861a88 --- /dev/null +++ b/TAmodules.py @@ -0,0 +1,29 @@ +#encoding: utf-8 + +from math import sqrt, log, exp, pi +import torch +from torch import nn +from torch.nn import functional as nnFunc +from torch.autograd import Function + +from modules import GeLU_BERT +from modules import PositionwiseFF as PositionwiseFFBase + +class PositionwiseFF(PositionwiseFFBase): + + # isize: input dimension + # hsize: hidden dimension + + def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=False): + + super(PositionwiseFF, self).__init__(isize, hsize, dropout, False, use_GeLU) + + def forward(self, x): + + out = x + for net in self.nets: + out = net(out) + + out = self.normer(out + x) + + return out diff --git a/mkcy.py b/mkcy.py index ea64d81..90979bc 100644 --- a/mkcy.py +++ b/mkcy.py @@ -17,7 +17,7 @@ def get_name(fname): eccargs = ["-Ofast", "-march=native", "-pipe", "-fomit-frame-pointer"] - baselist = ["modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"] + baselist = ["TAmodules.py", "modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"] extlist = [Extension(get_name(pyf), [pyf], extra_compile_args=eccargs) for pyf in baselist] diff --git a/modules.py b/modules.py index bcc7c88..71ff3ac 100644 --- a/modules.py +++ b/modules.py @@ -26,7 +26,6 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residue=False, use_GeLU= self.norm_residue = norm_residue - def forward(self, x): _out = self.normer(x) @@ -647,7 +646,7 @@ def __init__(self, isize, bias=True): super(Scorer, self).__init__() - self.w = nn.Parameter(torch.Tensor(isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize))) + self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize))) self.bias = nn.Parameter(torch.zeros(1)) if bias else None def forward(self, x): @@ -667,7 +666,7 @@ def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0): super(MHAttnSummer, self).__init__() - self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize))) + self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize))) self.attn = CrossAttn(isize, isize if ahsize is None else ahsize, isize, num_head, dropout=attn_drop) # x: (bsize, seql, isize) @@ -700,7 +699,7 @@ def __init__(self, isize, minv = 0.125): super(Temperature, self).__init__() - self.w = nn.Parameter(torch.Tensor(isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize))) + self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize))) self.bias = nn.Parameter(torch.zeros(1)) self.act = nn.Tanh() self.k = nn.Parameter(torch.ones(1)) diff --git a/scripts/mktest.sh b/scripts/mktest.sh index 890da90..c93c8b0 100644 --- a/scripts/mktest.sh +++ b/scripts/mktest.sh @@ -13,6 +13,8 @@ export tgtd=$cachedir/$dataid export bpef=out.bpe -python tools/mktest.py $srcd/$srctf $tgtd/src.vcb $tgtd/test.h5 $ngpu -python predict.py $tgtd/$bpef $tgtd/tgt.vcb $modelf +python tools/sorti.py $srcd/$srctf $tgtd/$srctf.srt +python tools/mktest.py $tgtd/$srctf.srt $tgtd/src.vcb $tgtd/test.h5 $ngpu +python predict.py $tgtd/$bpef.srt $tgtd/tgt.vcb $modelf +python tools/restore.py $srcd/$srctf $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf diff --git a/tools/check/charatio.py b/tools/check/charatio.py index ff5083c..f8d88a2 100644 --- a/tools/check/charatio.py +++ b/tools/check/charatio.py @@ -61,7 +61,7 @@ def get_ratio(strin): mrsc, _rsc, mrsb, _rsb, mrss, _rss = getfratio(srcfs) mrtc, _rtc, mrtb, _rtb, mrts, _rts = getfratio(srcft) - print("Max/mean/adv char ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv char ratio of target data: %.3f / %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of target data: %.3f / %.3f / %.3f" % (mrsc, _rsc, min(mrsc, _rsc * 2.5) + 0.001, mrtc, _rtc, min(mrtc, _rtc * 2.5) + 0.001, mrsb, _rsb, min(mrsb, _rsb * 2.5) + 0.001, mrtb, _rtb, min(mrtb, _rtb * 2.5) + 0.001, mrss, _rss, min(mrss, _rss * 2.5) + 0.001, mrts, _rts, min(mrts, _rts * 2.5) + 0.001)) + print("Max/mean/adv char ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv char ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of target data: %.3f / %.3f / %.3f" % (mrsc, _rsc, min(mrsc, _rsc * 2.5) + 0.001, mrtc, _rtc, min(mrtc, _rtc * 2.5) + 0.001, mrsb, _rsb, min(mrsb, _rsb * 2.5) + 0.001, mrtb, _rtb, min(mrtb, _rtb * 2.5) + 0.001, mrss, _rss, min(mrss, _rss * 2.5) + 0.001, mrts, _rts, min(mrts, _rts * 2.5) + 0.001)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2]) diff --git a/tools/restore.py b/tools/restore.py new file mode 100644 index 0000000..9bb99fe --- /dev/null +++ b/tools/restore.py @@ -0,0 +1,36 @@ +#encoding: utf-8 + +import sys + +def handle(srcfs, srtsf, srttf, tgtf): + + def clean(lin): + rs = [] + for lu in lin.split(): + if lu: + rs.append(lu) + return " ".join(rs), len(rs) + + data = {} + + with open(srtsf, "rb") as fs, open(srttf, "rb") as ft: + for sl, tl in zip(fs, ft): + _sl, _tl = sl.strip(), tl.strip() + if _sl and _tl: + _sl, _ls = clean(_sl.decode("utf-8")) + _tl, _lt = clean(_tl.decode("utf-8")) + data[_sl] = _tl + + ens = "\n".encode("utf-8") + + with open(srcfs, "rb") as fs, open(tgtf, "wb") as ft: + for line in fs: + tmp = line.strip() + if tmp: + tmp, _ = clean(tmp.decode("utf-8")) + tmp = data.get(tmp, "") + ft.write(tmp.encode("utf-8")) + ft.write(ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/sort.py b/tools/sort.py index c0b5caa..d662fb9 100644 --- a/tools/sort.py +++ b/tools/sort.py @@ -67,9 +67,7 @@ def shuffle_pair(ls, lt): if (slen <= _max_len) and (tlen <= _max_len): lgth = slen + tlen if lgth not in data: - tmp = {} - tmp[tlen] = [(ls, lt)] - data[lgth] = tmp + data[lgth] = {tlen: [(ls, lt)]} else: if tlen in data[lgth]: data[lgth][tlen].append((ls, lt)) diff --git a/tools/sorti.py b/tools/sorti.py new file mode 100644 index 0000000..83d8787 --- /dev/null +++ b/tools/sorti.py @@ -0,0 +1,38 @@ +#encoding: utf-8 + +import sys + +def handle(srcfs, tgtfs): + + def clean(lin): + rs = [] + for lu in lin: + if lu: + rs.append(lu) + return " ".join(rs), len(rs) + + data = {} + + with open(srcfs, "rb") as fs: + for ls in fs: + ls = ls.strip() + if ls: + ls, lgth = clean(ls.decode("utf-8").split()) + if lgth not in data: + data[lgth] = set([ls]) + else: + if ls not in data[lgth]: + data[lgth].add(ls) + + length = list(data.keys()) + length.sort() + + ens = "\n".encode("utf-8") + + with open(tgtfs, "wb") as fs: + for lgth in length: + fs.write("\n".join(data[lgth]).encode("utf-8")) + fs.write(ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2]) diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index ae96fcd..04a01ff 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -145,6 +145,25 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out + def load_base(self, base_decoder): + + self.drop = base_decoder.drop + + self.wemb = base_decoder.wemb + + self.pemb = base_decoder.pemb + + _nets = list(base_decoder.nets) + + self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):])) + + self.classifier = base_decoder.classifier + + self.lsm = base_decoder.lsm + + self.out_normer = None if self.out_normer is None else base_decoder.out_normer + + # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: # src_pad_mask = input.eq(0).unsqueeze(1) diff --git a/transformer/Decoder.py b/transformer/Decoder.py index 09c6880..b652d78 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -171,6 +171,24 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out + def load_base(self, base_decoder): + + self.drop = base_decoder.drop + + self.wemb = base_decoder.wemb + + self.pemb = base_decoder.pemb + + _nets = list(base_decoder.nets) + + self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):])) + + self.classifier = base_decoder.classifier + + self.lsm = base_decoder.lsm + + self.out_normer = None if self.out_normer is None else base_decoder.out_normer + def _get_subsequent_mask(self, length): return self.mask.narrow(1, 0, length).narrow(2, 0, length) if length > self.xseql else torch.triu(self.mask.new_ones(length, length), 1).unsqueeze(0) @@ -425,7 +443,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt def get_sos_emb(self, inpute): - bsize, _, __ = inpute.size() + bsize = inpute.size(0) return self.wemb.weight[1].reshape(1, 1, -1).expand(bsize, 1, -1) diff --git a/transformer/Encoder.py b/transformer/Encoder.py index 34dbca1..6a12b43 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -100,3 +100,17 @@ def forward(self, inputs, mask=None): out = net(out, mask) return out if self.out_normer is None else self.out_normer(out) + + def load_base(self, base_encoder): + + self.drop = base_encoder.drop + + self.wemb = base_encoder.wemb + + self.pemb = base_encoder.pemb + + _nets = list(base_encoder.nets) + + self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):])) + + self.out_normer = None if self.out_normer is None else base_encoder.out_normer diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index c14713a..e16091a 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -126,10 +126,10 @@ def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None self.projector = nn.Linear(isize, isize, bias=False) if projector else None - self.classifier = nn.Sequential(nn.Linear(isize * 2, isize, bias=False), nn.Tanh(), nn.Linear(isize, nwd)) + self.classifier = nn.Linear(isize * 2, nwd)#nn.Sequential(nn.Linear(isize * 2, isize, bias=False), nn.Tanh(), nn.Linear(isize, nwd)) # be careful since this line of code is trying to share the weight of the wemb and the classifier, which may cause problems if torch.nn updates - if bindemb: - list(self.classifier.modules())[-1].weight = self.wemb.weight + #if bindemb: + #list(self.classifier.modules())[-1].weight = self.wemb.weight self.lsm = nn.LogSoftmax(-1) diff --git a/transformer/TADecoder.py b/transformer/TADecoder.py new file mode 100644 index 0000000..76a8956 --- /dev/null +++ b/transformer/TADecoder.py @@ -0,0 +1,293 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules import * +from math import sqrt + +from transformer.Decoder import Decoder as DecoderBase + +class Decoder(DecoderBase): + + # isize: size of word embedding + # nwd: number of words + # num_layer: number of encoder layers + # fhsize: number of hidden units for PositionwiseFeedForward + # attn_drop: dropout for MultiHeadAttention + # emb_w: weight for embedding. Use only when the encoder and decoder share a same dictionary + # num_head: number of heads in MultiHeadAttention + # xseql: maxmimum length of sequence + # ahsize: number of hidden units for MultiHeadAttention + # bindemb: bind embedding and classifier weight + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=512, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None): + + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindemb, forbidden_index) + + # inpute: encoded representation from encoder (bsize, seql, isize, num_layer) + # inputo: decoded translation (bsize, nquery) + # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: + # src_pad_mask = input.eq(0).unsqueeze(1) + + def forward(self, inpute, inputo, src_pad_mask=None): + + bsize, nquery = inputo.size() + + out = self.wemb(inputo) + + out = out * sqrt(out.size(-1)) + self.pemb(inputo, expand=False) + + if self.drop is not None: + out = self.drop(out) + + _mask = self._get_subsequent_mask(inputo.size(1)) + + # the following line of code is to mask for the decoder, + # which I think is useless, since only may pay attention to previous tokens, whos loss will be omitted by the loss function. + #_mask = torch.gt(_mask + inputo.eq(0).unsqueeze(1), 0) + + for net, inputu in zip(self.nets, inpute.unbind(dim=-1)): + out = net(inputu, out, src_pad_mask, _mask) + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)) + + return out + + # inpute: encoded representation from encoder (bsize, seql, isize) + # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: + # src_pad_mask = input.eq(0).unsqueeze(1) + # max_len: maximum length to generate + + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512): + + bsize, seql= inpute.size()[:2] + + sos_emb = self.get_sos_emb(inpute) + + sqrt_isize = sqrt(sos_emb.size(-1)) + + # out: input to the decoder for the first step (bsize, 1, isize) + + out = sos_emb * sqrt_isize + self.pemb.get_pos(0).view(1, 1, -1).expand(bsize, 1, -1) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): + out, _state = net(inputu, None, src_pad_mask, None, out, True) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + # out: (bsize, 1, nwd) + + out = self.lsm(self.classifier(out)) + + # wds: (bsize, 1) + + wds = torch.argmax(out, dim=-1) + + trans = [wds] + + # done_trans: (bsize) + + done_trans = wds.squeeze(1).eq(2) + + for i in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + self.pemb.get_pos(i).view(1, 1, -1).expand(bsize, 1, -1) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): + out, _state = net(inputu, states[_tmp], src_pad_mask, None, out, True) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + # out: (bsize, 1, nwd) + out = self.lsm(self.classifier(out)) + wds = torch.argmax(out, dim=-1) + + trans.append(wds) + + done_trans = torch.gt(done_trans + wds.squeeze(1).eq(2), 0) + if done_trans.sum().item() == bsize: + break + + return torch.cat(trans, 1) + + # inpute: encoded representation from encoder (bsize, seql, isize) + # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: + # src_pad_mask = input.eq(0).unsqueeze(1) + # beam_size: beam size + # max_len: maximum length to generate + + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=False): + + bsize, seql = inpute.size()[:2] + + beam_size2 = beam_size * beam_size + bsizeb2 = bsize * beam_size2 + real_bsize = bsize * beam_size + + sos_emb = self.get_sos_emb(inpute) + isize = sos_emb.size(-1) + sqrt_isize = sqrt(isize) + + if length_penalty > 0.0: + # lpv: length penalty vector for each beam (bsize * beam_size, 1) + lpv = sos_emb.new_ones(real_bsize, 1) + lpv_base = 6.0 ** length_penalty + + out = sos_emb * sqrt_isize + self.pemb.get_pos(0).view(1, 1, isize).expand(bsize, 1, isize) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): + out, _state = net(inputu, None, src_pad_mask, None, out, True) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + # out: (bsize, 1, nwd) + + out = self.lsm(self.classifier(out)) + + # scores: (bsize, 1, beam_size) => (bsize, beam_size) + # wds: (bsize * beam_size, 1) + # trans: (bsize * beam_size, 1) + + scores, wds = torch.topk(out, beam_size, dim=-1) + scores = scores.squeeze(1) + sum_scores = scores + wds = wds.view(real_bsize, 1) + trans = wds + + # done_trans: (bsize, beam_size) + + done_trans = wds.view(bsize, beam_size).eq(2) + + # inpute: (bsize, seql, isize, num_layer_enc) => (bsize * beam_size, seql, isize, num_layer_enc) + + inpute = inpute.repeat(1, beam_size, 1, 1).view(real_bsize, seql, isize, -1) + + # _src_pad_mask: (bsize, 1, seql) => (bsize * beam_size, 1, seql) + + _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) + + # states[i]: (bsize, 1, isize) => (bsize * beam_size, 1, isize) + + for key, value in states.items(): + states[key] = value.repeat(1, beam_size, 1).view(real_bsize, 1, isize) + + for step in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + self.pemb.get_pos(step).view(1, 1, isize).expand(real_bsize, 1, isize) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): + out, _state = net(inputu, states[_tmp], _src_pad_mask, None, out, True) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + # out: (bsize, beam_size, nwd) + + out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) + + # find the top k ** 2 candidates and calculate route scores for them + # _scores: (bsize, beam_size, beam_size) + # done_trans: (bsize, beam_size) + # scores: (bsize, beam_size) + # _wds: (bsize, beam_size, beam_size) + # mask_from_done_trans: (bsize, beam_size) => (bsize, beam_size * beam_size) + # added_scores: (bsize, 1, beam_size) => (bsize, beam_size, beam_size) + + _scores, _wds = torch.topk(out, beam_size, dim=-1) + _scores = (_scores.masked_fill(done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size)) + + if length_penalty > 0.0: + lpv = lpv.masked_fill(1 - done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base) + + # clip from k ** 2 candidate and remain the top-k for each path + # scores: (bsize, beam_size * beam_size) => (bsize, beam_size) + # _inds: indexes for the top-k candidate (bsize, beam_size) + + if clip_beam and (length_penalty > 0.0): + scores, _inds = torch.topk((_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2), beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size) + else: + scores, _inds = torch.topk(_scores.view(bsize, beam_size2), beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = scores + + # select the top-k candidate with higher route score and update translation record + # wds: (bsize, beam_size, beam_size) => (bsize * beam_size, 1) + + wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1) + + # reduces indexes in _inds from (beam_size ** 2) to beam_size + # thus the fore path of the top-k candidate is pointed out + # _inds: indexes for the top-k candidate (bsize, beam_size) + + _inds = (_inds / beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + + # select the corresponding translation history for the top-k candidate and update translation records + # trans: (bsize * beam_size, nquery) => (bsize * beam_size, nquery + 1) + + trans = torch.cat((trans.index_select(0, _inds), wds), 1) + + done_trans = torch.gt(done_trans.view(real_bsize).index_select(0, _inds) + wds.eq(2).squeeze(1), 0).view(bsize, beam_size) + + # check early stop for beam search + # done_trans: (bsize, beam_size) + # scores: (bsize, beam_size) + + _done = False + if length_penalty > 0.0: + lpv = lpv.index_select(0, _inds) + elif (not return_all) and done_trans.select(1, 0).sum().item() == bsize: + _done = True + + # check beam states(done or not) + + if _done or (done_trans.sum().item() == real_bsize): + break + + # update the corresponding hidden states + # states[i]: (bsize * beam_size, nquery, isize) + # _inds: (bsize, beam_size) => (bsize * beam_size) + + for key, value in states.items(): + states[key] = value.index_select(0, _inds) + + # if length penalty is only applied in the last step, apply length penalty + if (not clip_beam) and (length_penalty > 0.0): + scores = scores / lpv.view(bsize, beam_size) + scores, _inds = torch.topk(scores, beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) + + if return_all: + + return trans, scores + else: + + return trans.view(bsize, beam_size, -1).select(1, 0) diff --git a/transformer/TAEncoder.py b/transformer/TAEncoder.py new file mode 100644 index 0000000..54a1e77 --- /dev/null +++ b/transformer/TAEncoder.py @@ -0,0 +1,111 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules import SelfAttn, PositionalEmb +from TAmodules import PositionwiseFF +from math import sqrt + +from transformer.Encoder import Encoder as EncoderBase + +# vocabulary: +# :0 +# :1 +# :2 +# :3 +# ... +# for the classier of the decoder, is omitted + +class EncoderLayer(nn.Module): + + # isize: input size + # fhsize: hidden size of PositionwiseFeedForward + # attn_drop: dropout for MultiHeadAttention + # num_head: number of heads in MultiHeadAttention + # ahsize: hidden size of MultiHeadAttention + # norm_residue: residue with layer normalized representation + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None): + + super(EncoderLayer, self).__init__() + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + self.attn = SelfAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) + + self.ff = PositionwiseFF(isize, _fhsize, dropout) + + self.layer_normer = nn.LayerNorm(isize, eps=1e-06) + + self.drop = nn.Dropout(dropout, inplace=True) if dropout > 0.0 else None + + # inputs: input of this layer (bsize, seql, isize) + + def forward(self, inputs, mask=None): + + context = self.attn(inputs, mask=mask) + + if self.drop is not None: + context = self.drop(context) + + context = self.layer_normer(context + inputs) + + context = self.ff(context) + + return context + +class Encoder(EncoderBase): + + # isize: size of word embedding + # nwd: number of words + # num_layer: number of encoder layers + # fhsize: number of hidden units for PositionwiseFeedForward + # attn_drop: dropout for MultiHeadAttention + # num_head: number of heads in MultiHeadAttention + # xseql: maxmimum length of sequence + # ahsize: number of hidden units for MultiHeadAttention + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=512, ahsize=None, norm_output=True, num_layer_dec=6): + + _ahsize = isize if ahsize is None else ahsize + + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output) + + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(6.0 / num_layer + num_layer_dec + 1), sqrt(6.0 / num_layer + num_layer_dec + 1))) + self.tattn_drop = nn.Dropout(dropout) if dropout > 0.0 else None + + # inputs: (bsize, seql) + # mask: (bsize, 1, seql), generated with: + # mask = inputs.eq(0).unsqueeze(1) + + def forward(self, inputs, mask=None): + + bsize, seql = inputs.size() + out = self.wemb(inputs) + out = out * sqrt(out.size(-1)) + self.pemb(inputs, expand=False) + + if self.drop is not None: + out = self.drop(out) + + out = self.out_normer(out) + outs = [] + outs.append(out) + + for net in self.nets: + out = net(out, mask) + outs.append(out) + + out = torch.stack(outs, -1) + + osize = out.size() + out = torch.mm(out.view(-1, osize[-1]), self.tattn_w.softmax(dim=0) if self.tattn_drop is None else self.tattn_drop(self.tattn_w).softmax(dim=0)) + osize = list(osize) + osize[-1] = -1 + + return out.view(osize) diff --git a/translator.py b/translator.py index 6f3a805..0336c9a 100644 --- a/translator.py +++ b/translator.py @@ -6,15 +6,29 @@ from transformer.EnsembleNMT import NMT as Ensemble from parallel.parallelMT import DataParallelMT +from utils import * + has_unk = True +def clear_list(lin): + rs = [] + for tmpu in lin: + if tmpu: + rs.append(tmpu) + return rs + +def clean_len(line): + rs = clear_list(line.split()) + return " ".join(rs), len(rs) + +def clean_list(lin): + rs = [] + for lu in lin: + rs.append(" ".join(clear_list(lu.split()))) + return rs + def list_reader(fname): - def clear_list(lin): - rs = [] - for tmpu in lin: - if tmpu: - rs.append(tmpu) - return rs + with open(fname, "rb") as frd: for line in frd: tmp = line.strip() @@ -133,18 +147,63 @@ def data_loader(sentences_iter, vcbi, minbsize=1, bsize=64, maxpad=16, maxpart=4 for i_d in batch_padder(sentences_iter, vcbi, bsize, maxpad, maxpart, maxtoken, minbsize): yield torch.tensor(i_d, dtype=torch.long) -def load_model_cpu(modf, base_model): +def load_fixing(module): + if "fix_load" in dir(module): + module.fix_load() + +def sorti(lin): - mpg = torch.load(modf, map_location='cpu') + data = {} - for para, mp in zip(base_model.parameters(), mpg): - para.data = mp.data + for ls in lin: + ls = ls.strip() + if ls: + ls, lgth = clean_len(ls) + if lgth not in data: + data[lgth] = set([ls]) + elif ls not in data[lgth]: + data[lgth].add(ls) - return base_model + length = list(data.keys()) + length.sort() -def load_fixing(module): - if "fix_load" in dir(module): - module.fix_load() + rs = [] + + for lgth in length: + rs.extend(data[lgth]) + + return rs + +def restore(src, tsrc, trs): + + data = {} + + for sl, tl in zip(tsrc, trs): + _sl, _tl = sl.strip(), tl.strip() + if _sl and _tl: + data[_sl] = " ".join(clear_list(_tl.split())) + + rs = [] + _tl = [] + for line in src: + tmp = line.strip() + if tmp: + tmp = " ".join(clear_list(tmp.split())) + tmp = data.get(tmp, "").strip() + if tmp: + _tl.append(tmp) + elif _tl: + rs.append(" ".join(_tl)) + _tl = [] + elif _tl: + rs.append(" ".join(_tl)) + _tl = [] + else: + rs.append("") + if _tl: + rs.append(" ".join(_tl)) + + return rs class TranslatorCore: @@ -247,9 +306,9 @@ class Translator: def __init__(self, trans=None, sent_split=None, tok=None, detok=None, bpe=None, debpe=None, punc_norm=None, truecaser=None, detruecaser=None): + self.sent_split = sent_split + self.flow = [] - if sent_split is not None: - self.flow.append(sent_split) if punc_norm is not None: self.flow.append(punc_norm) if tok is not None: @@ -269,10 +328,29 @@ def __init__(self, trans=None, sent_split=None, tok=None, detok=None, bpe=None, def __call__(self, paragraph): - _tmp = paragraph + _tmp = [tmpu.strip() for tmpu in paragraph.strip().split("\n")] + _rs = [] + _tmpi = None + if self.sent_split is not None: + np = len(_tmp) - 1 + if np > 0: + for _i, _tmpu in enumerate(_tmp): + if _tmpu: + _rs.extend(self.sent_split(_tmpu)) + if _i < np: + _rs.append("") + _tmpi = sorti(_rs) + _tmp = _tmpi + else: + _tmp = [" ".join(clear_list(_tmp[0].split()))] + else: + _tmp = clean_list(_tmp) for pu in self.flow: _tmp = pu(_tmp) - return " ".join(_tmp) + if len(_rs) > 1: + _tmp = restore(_rs, _tmpi, _tmp) + return "\n".join(_tmp) + return " ".join(_tmp)