Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fix and acc decode
Browse files Browse the repository at this point in the history
  • Loading branch information
hfxunlp committed Mar 11, 2019
1 parent a20a38e commit 5ca6c80
Show file tree
Hide file tree
Showing 15 changed files with 668 additions and 33 deletions.
29 changes: 29 additions & 0 deletions TAmodules.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#encoding: utf-8

from math import sqrt, log, exp, pi
import torch
from torch import nn
from torch.nn import functional as nnFunc
from torch.autograd import Function

from modules import GeLU_BERT
from modules import PositionwiseFF as PositionwiseFFBase

class PositionwiseFF(PositionwiseFFBase):

# isize: input dimension
# hsize: hidden dimension

def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=False):

super(PositionwiseFF, self).__init__(isize, hsize, dropout, False, use_GeLU)

def forward(self, x):

out = x
for net in self.nets:
out = net(out)

out = self.normer(out + x)

return out
2 changes: 1 addition & 1 deletion mkcy.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def get_name(fname):

eccargs = ["-Ofast", "-march=native", "-pipe", "-fomit-frame-pointer"]

baselist = ["modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"]
baselist = ["TAmodules.py", "modules.py", "loss.py", "lrsch.py", "utils.py","rnncell.py", "translator.py", "discriminator.py"]

extlist = [Extension(get_name(pyf), [pyf], extra_compile_args=eccargs) for pyf in baselist]

Expand Down
7 changes: 3 additions & 4 deletions modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residue=False, use_GeLU=

self.norm_residue = norm_residue


def forward(self, x):

_out = self.normer(x)
Expand Down Expand Up @@ -647,7 +646,7 @@ def __init__(self, isize, bias=True):

super(Scorer, self).__init__()

self.w = nn.Parameter(torch.Tensor(isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.bias = nn.Parameter(torch.zeros(1)) if bias else None

def forward(self, x):
Expand All @@ -667,7 +666,7 @@ def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0):

super(MHAttnSummer, self).__init__()

self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(1, 1, isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.attn = CrossAttn(isize, isize if ahsize is None else ahsize, isize, num_head, dropout=attn_drop)

# x: (bsize, seql, isize)
Expand Down Expand Up @@ -700,7 +699,7 @@ def __init__(self, isize, minv = 0.125):

super(Temperature, self).__init__()

self.w = nn.Parameter(torch.Tensor(isize).uniform_(sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(6.0 / isize), sqrt(6.0 / isize)))
self.bias = nn.Parameter(torch.zeros(1))
self.act = nn.Tanh()
self.k = nn.Parameter(torch.ones(1))
Expand Down
6 changes: 4 additions & 2 deletions scripts/mktest.sh
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ export tgtd=$cachedir/$dataid

export bpef=out.bpe

python tools/mktest.py $srcd/$srctf $tgtd/src.vcb $tgtd/test.h5 $ngpu
python predict.py $tgtd/$bpef $tgtd/tgt.vcb $modelf
python tools/sorti.py $srcd/$srctf $tgtd/$srctf.srt
python tools/mktest.py $tgtd/$srctf.srt $tgtd/src.vcb $tgtd/test.h5 $ngpu
python predict.py $tgtd/$bpef.srt $tgtd/tgt.vcb $modelf
python tools/restore.py $srcd/$srctf $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef
sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf
2 changes: 1 addition & 1 deletion tools/check/charatio.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_ratio(strin):
mrsc, _rsc, mrsb, _rsb, mrss, _rss = getfratio(srcfs)
mrtc, _rtc, mrtb, _rtb, mrts, _rts = getfratio(srcft)

print("Max/mean/adv char ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv char ratio of target data: %.3f / %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of target data: %.3f / %.3f / %.3f" % (mrsc, _rsc, min(mrsc, _rsc * 2.5) + 0.001, mrtc, _rtc, min(mrtc, _rtc * 2.5) + 0.001, mrsb, _rsb, min(mrsb, _rsb * 2.5) + 0.001, mrtb, _rtb, min(mrtb, _rtb * 2.5) + 0.001, mrss, _rss, min(mrss, _rss * 2.5) + 0.001, mrts, _rts, min(mrts, _rts * 2.5) + 0.001))
print("Max/mean/adv char ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv char ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv bpe ratio of target data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of source data: %.3f / %.3f / %.3f\nMax/mean/adv seperated ratio of target data: %.3f / %.3f / %.3f" % (mrsc, _rsc, min(mrsc, _rsc * 2.5) + 0.001, mrtc, _rtc, min(mrtc, _rtc * 2.5) + 0.001, mrsb, _rsb, min(mrsb, _rsb * 2.5) + 0.001, mrtb, _rtb, min(mrtb, _rtb * 2.5) + 0.001, mrss, _rss, min(mrss, _rss * 2.5) + 0.001, mrts, _rts, min(mrts, _rts * 2.5) + 0.001))

if __name__ == "__main__":
handle(sys.argv[1], sys.argv[2])
36 changes: 36 additions & 0 deletions tools/restore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#encoding: utf-8

import sys

def handle(srcfs, srtsf, srttf, tgtf):

def clean(lin):
rs = []
for lu in lin.split():
if lu:
rs.append(lu)
return " ".join(rs), len(rs)

data = {}

with open(srtsf, "rb") as fs, open(srttf, "rb") as ft:
for sl, tl in zip(fs, ft):
_sl, _tl = sl.strip(), tl.strip()
if _sl and _tl:
_sl, _ls = clean(_sl.decode("utf-8"))
_tl, _lt = clean(_tl.decode("utf-8"))
data[_sl] = _tl

ens = "\n".encode("utf-8")

with open(srcfs, "rb") as fs, open(tgtf, "wb") as ft:
for line in fs:
tmp = line.strip()
if tmp:
tmp, _ = clean(tmp.decode("utf-8"))
tmp = data.get(tmp, "")
ft.write(tmp.encode("utf-8"))
ft.write(ens)

if __name__ == "__main__":
handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4])
4 changes: 1 addition & 3 deletions tools/sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ def shuffle_pair(ls, lt):
if (slen <= _max_len) and (tlen <= _max_len):
lgth = slen + tlen
if lgth not in data:
tmp = {}
tmp[tlen] = [(ls, lt)]
data[lgth] = tmp
data[lgth] = {tlen: [(ls, lt)]}
else:
if tlen in data[lgth]:
data[lgth][tlen].append((ls, lt))
Expand Down
38 changes: 38 additions & 0 deletions tools/sorti.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#encoding: utf-8

import sys

def handle(srcfs, tgtfs):

def clean(lin):
rs = []
for lu in lin:
if lu:
rs.append(lu)
return " ".join(rs), len(rs)

data = {}

with open(srcfs, "rb") as fs:
for ls in fs:
ls = ls.strip()
if ls:
ls, lgth = clean(ls.decode("utf-8").split())
if lgth not in data:
data[lgth] = set([ls])
else:
if ls not in data[lgth]:
data[lgth].add(ls)

length = list(data.keys())
length.sort()

ens = "\n".encode("utf-8")

with open(tgtfs, "wb") as fs:
for lgth in length:
fs.write("\n".join(data[lgth]).encode("utf-8"))
fs.write(ens)

if __name__ == "__main__":
handle(sys.argv[1], sys.argv[2])
19 changes: 19 additions & 0 deletions transformer/AvgDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,25 @@ def forward(self, inpute, inputo, src_pad_mask=None):

return out

def load_base(self, base_decoder):

self.drop = base_decoder.drop

self.wemb = base_decoder.wemb

self.pemb = base_decoder.pemb

_nets = list(base_decoder.nets)

self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):]))

self.classifier = base_decoder.classifier

self.lsm = base_decoder.lsm

self.out_normer = None if self.out_normer is None else base_decoder.out_normer


# inpute: encoded representation from encoder (bsize, seql, isize)
# src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with:
# src_pad_mask = input.eq(0).unsqueeze(1)
Expand Down
20 changes: 19 additions & 1 deletion transformer/Decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,24 @@ def forward(self, inpute, inputo, src_pad_mask=None):

return out

def load_base(self, base_decoder):

self.drop = base_decoder.drop

self.wemb = base_decoder.wemb

self.pemb = base_decoder.pemb

_nets = list(base_decoder.nets)

self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):]))

self.classifier = base_decoder.classifier

self.lsm = base_decoder.lsm

self.out_normer = None if self.out_normer is None else base_decoder.out_normer

def _get_subsequent_mask(self, length):

return self.mask.narrow(1, 0, length).narrow(2, 0, length) if length > self.xseql else torch.triu(self.mask.new_ones(length, length), 1).unsqueeze(0)
Expand Down Expand Up @@ -425,7 +443,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt

def get_sos_emb(self, inpute):

bsize, _, __ = inpute.size()
bsize = inpute.size(0)

return self.wemb.weight[1].reshape(1, 1, -1).expand(bsize, 1, -1)

Expand Down
14 changes: 14 additions & 0 deletions transformer/Encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,17 @@ def forward(self, inputs, mask=None):
out = net(out, mask)

return out if self.out_normer is None else self.out_normer(out)

def load_base(self, base_encoder):

self.drop = base_encoder.drop

self.wemb = base_encoder.wemb

self.pemb = base_encoder.pemb

_nets = list(base_encoder.nets)

self.nets = nn.ModuleList(_nets + list(self.nets[len(_nets):]))

self.out_normer = None if self.out_normer is None else base_encoder.out_normer
6 changes: 3 additions & 3 deletions transformer/RNMTDecoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None

self.projector = nn.Linear(isize, isize, bias=False) if projector else None

self.classifier = nn.Sequential(nn.Linear(isize * 2, isize, bias=False), nn.Tanh(), nn.Linear(isize, nwd))
self.classifier = nn.Linear(isize * 2, nwd)#nn.Sequential(nn.Linear(isize * 2, isize, bias=False), nn.Tanh(), nn.Linear(isize, nwd))
# be careful since this line of code is trying to share the weight of the wemb and the classifier, which may cause problems if torch.nn updates
if bindemb:
list(self.classifier.modules())[-1].weight = self.wemb.weight
#if bindemb:
#list(self.classifier.modules())[-1].weight = self.wemb.weight

self.lsm = nn.LogSoftmax(-1)

Expand Down
Loading

0 comments on commit 5ca6c80

Please sign in to comment.