From d7cb516baaf5ec4af5b708a52f5c4dc4af34ebe2 Mon Sep 17 00:00:00 2001 From: Qiuhui Liu Date: Wed, 10 Jul 2019 12:14:43 +0200 Subject: [PATCH] fix the training script. --- modules/__init__.py | 1 + modules/{ => __no_significance__}/group.py | 0 modules/base.py | 8 +++-- tools/__no_significance__/add_tag.py | 24 ++++++++++++++ tools/__no_significance__/remove_layers.py | 36 +++++++++++++++++++++ train.py | 9 +++--- transformer/NMT.py | 4 +-- transformer/TA/{TADecoder.py => Decoder.py} | 0 transformer/TA/{TAEncoder.py => Encoder.py} | 4 +-- utils.py | 9 ++++++ 10 files changed, 83 insertions(+), 12 deletions(-) create mode 100644 modules/__init__.py rename modules/{ => __no_significance__}/group.py (100%) create mode 100644 tools/__no_significance__/add_tag.py create mode 100644 tools/__no_significance__/remove_layers.py rename transformer/TA/{TADecoder.py => Decoder.py} (100%) rename transformer/TA/{TAEncoder.py => Encoder.py} (97%) diff --git a/modules/__init__.py b/modules/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/modules/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/modules/group.py b/modules/__no_significance__/group.py similarity index 100% rename from modules/group.py rename to modules/__no_significance__/group.py diff --git a/modules/base.py b/modules/base.py index 849a4ed..e3e4046 100644 --- a/modules/base.py +++ b/modules/base.py @@ -256,7 +256,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=Fal self.drop = nn.Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None # iQ: query (bsize, num_query, vsize) - # mask (bsize, num_query, seql) + # mask: (bsize, num_query, seql) # iK: key/value (bsize, seql, vsize), in case key != query, for efficient decoding def forward(self, iQ, mask=None, iK=None): @@ -389,7 +389,11 @@ def __init__(self, isize, ncomb=2, hsize=None, use_GeLU=False): def forward(self, *xl): - out = torch.stack([self.net(torch.cat(xl, -1))] + list(xl), -2).sum(-2) + # faster only when len(xl) is very large + #out = torch.stack([self.net(torch.cat(xl, -1))] + list(xl), -2).sum(-2) + out = self.net(torch.cat(xl, -1)) + for inputu in xl: + out = out + inputu return self.out_normer(out) diff --git a/tools/__no_significance__/add_tag.py b/tools/__no_significance__/add_tag.py new file mode 100644 index 0000000..8ff2322 --- /dev/null +++ b/tools/__no_significance__/add_tag.py @@ -0,0 +1,24 @@ +#encoding: utf-8 + +''' usage: + python tools/add_tag.py $src_file.t7 $rs_file.t7 $token +''' + +import sys + +def handle(srcf, rsf, token): + + _et = token.encode("utf-8") if token.endswith(" ") else (token + " ").encode("utf-8") + _ens = "\n".encode("utf-8") + + with open(srcf, "rb") as frd, open(rsf, "wb") as fwrt: + for line in frd: + tmp = line.strip() + if tmp: + tmp = tmp.decode("utf-8") + fwrt.write(_et) + fwrt.write(tmp.encode("utf-8")) + fwrt.write(_ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/__no_significance__/remove_layers.py b/tools/__no_significance__/remove_layers.py new file mode 100644 index 0000000..5121c54 --- /dev/null +++ b/tools/__no_significance__/remove_layers.py @@ -0,0 +1,36 @@ +#encoding: utf-8 + +''' usage: + python remove_layers.py $src.t7 $rs.t7 enc/dec layers... +''' + +import sys + +import torch +from torch.nn import ModuleList + +from transformer.NMT import NMT + +from utils import * + +import h5py +import cnfg + +def handle(srcf, rsf, typ, rlist): + + td = h5py.File(cnfg.dev_data, "r") + nwordi = int(td["nwordi"][:][0]) + nwordt = int(td["nwordt"][:][0]) + td.close() + + _tmpm = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cnfg.cache_len, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + _tmpm = load_model_cpu(srcf, _tmpm) + if typ == "enc": + _tmpm.enc.nets = ModuleList(remove_layers(list(_tmpm.enc.nets), rlist)) + elif typ == "dec": + _tmpm.dec.nets = ModuleList(remove_layers(list(_tmpm.dec.nets), rlist)) + + save_model(_tmpm, rsf, False) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], [int(_t) for _t in sys.argv[4:]]) diff --git a/train.py b/train.py index 3508d77..346bc66 100644 --- a/train.py +++ b/train.py @@ -51,10 +51,6 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) - if _done_tokens >= tokens_optm: - optm.zero_grad() - _done_tokens = 0 - oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() output = model(seq_batch, oi) @@ -88,9 +84,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if multi_gpu: model.collect_gradients() optm.step() + optm.zero_grad() model.update_replicas() else: optm.step() + optm.zero_grad() + _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): if num_checkpoint > 1: @@ -208,7 +207,7 @@ def init_fixing(module): tokens_optm = cnfg.tokens_optm -done_tokens = tokens_optm +done_tokens = 0 batch_report = cnfg.batch_report report_eva = cnfg.report_eva diff --git a/transformer/NMT.py b/transformer/NMT.py index 9c71003..4f9e507 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -3,10 +3,10 @@ import torch from torch import nn -# import Encoder and Decoder from transformer.AGG.InceptEncoder and transformer.AGG.InceptDecoder/transformer.AGG.InceptAvgDecoder to learn complex representation with incepted transformer, TA/TAEncoder for Transparent Encoder +# import Encoder and Decoder from transformer.AGG.InceptEncoder and transformer.AGG.InceptDecoder/transformer.AGG.InceptAvgDecoder to learn complex representation with incepted transformer, transformer.TA.Encoder for Transparent Encoder. from transformer.Encoder import Encoder -# switch the comment between the following two lines to choose standard decoder or average decoder. Using TA/TADecoder for Transparent Decoder. +# switch the comment between the following two lines to choose standard decoder or average decoder. Using transformer.TA.Decoder for Transparent Decoder. from transformer.Decoder import Decoder #from transformer.AvgDecoder import Decoder diff --git a/transformer/TA/TADecoder.py b/transformer/TA/Decoder.py similarity index 100% rename from transformer/TA/TADecoder.py rename to transformer/TA/Decoder.py diff --git a/transformer/TA/TAEncoder.py b/transformer/TA/Encoder.py similarity index 97% rename from transformer/TA/TAEncoder.py rename to transformer/TA/Encoder.py index 8fcc7bf..5af0478 100644 --- a/transformer/TA/TAEncoder.py +++ b/transformer/TA/Encoder.py @@ -23,7 +23,6 @@ class EncoderLayer(nn.Module): # attn_drop: dropout for MultiHeadAttention # num_head: number of heads in MultiHeadAttention # ahsize: hidden size of MultiHeadAttention - # norm_residue: residue with layer normalized representation def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None): @@ -94,8 +93,7 @@ def forward(self, inputs, mask=None): out = self.drop(out) out = self.out_normer(out) - outs = [] - outs.append(out) + outs = [out] for net in self.nets: out = net(out, mask) diff --git a/utils.py b/utils.py index a951693..14d609e 100644 --- a/utils.py +++ b/utils.py @@ -195,3 +195,12 @@ def expand_bsize_for_beam(*inputs, beam_size=1): outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size)) return outputs[0] if len(inputs) == 1 else tuple(outputs) + +def remove_layers(all_layers, ltr): + + rs = [] + for i, _l in enumerate(all_layers): + if i not in ltr: + rs.append(_l) + + return rs