Skip to content
This repository has been archived by the owner on Aug 10, 2023. It is now read-only.

Commit

Permalink
fix the training script.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuqiuhui2015 committed Jul 10, 2019
1 parent 6809ade commit d7cb516
Show file tree
Hide file tree
Showing 10 changed files with 83 additions and 12 deletions.
1 change: 1 addition & 0 deletions modules/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
#encoding: utf-8
File renamed without changes.
8 changes: 6 additions & 2 deletions modules/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=Fal
self.drop = nn.Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None

# iQ: query (bsize, num_query, vsize)
# mask (bsize, num_query, seql)
# mask: (bsize, num_query, seql)
# iK: key/value (bsize, seql, vsize), in case key != query, for efficient decoding

def forward(self, iQ, mask=None, iK=None):
Expand Down Expand Up @@ -389,7 +389,11 @@ def __init__(self, isize, ncomb=2, hsize=None, use_GeLU=False):

def forward(self, *xl):

out = torch.stack([self.net(torch.cat(xl, -1))] + list(xl), -2).sum(-2)
# faster only when len(xl) is very large
#out = torch.stack([self.net(torch.cat(xl, -1))] + list(xl), -2).sum(-2)
out = self.net(torch.cat(xl, -1))
for inputu in xl:
out = out + inputu

return self.out_normer(out)

Expand Down
24 changes: 24 additions & 0 deletions tools/__no_significance__/add_tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#encoding: utf-8

''' usage:
python tools/add_tag.py $src_file.t7 $rs_file.t7 $token
'''

import sys

def handle(srcf, rsf, token):

_et = token.encode("utf-8") if token.endswith(" ") else (token + " ").encode("utf-8")
_ens = "\n".encode("utf-8")

with open(srcf, "rb") as frd, open(rsf, "wb") as fwrt:
for line in frd:
tmp = line.strip()
if tmp:
tmp = tmp.decode("utf-8")
fwrt.write(_et)
fwrt.write(tmp.encode("utf-8"))
fwrt.write(_ens)

if __name__ == "__main__":
handle(sys.argv[1], sys.argv[2], sys.argv[3])
36 changes: 36 additions & 0 deletions tools/__no_significance__/remove_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#encoding: utf-8

''' usage:
python remove_layers.py $src.t7 $rs.t7 enc/dec layers...
'''

import sys

import torch
from torch.nn import ModuleList

from transformer.NMT import NMT

from utils import *

import h5py
import cnfg

def handle(srcf, rsf, typ, rlist):

td = h5py.File(cnfg.dev_data, "r")
nwordi = int(td["nwordi"][:][0])
nwordt = int(td["nwordt"][:][0])
td.close()

_tmpm = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cnfg.cache_len, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes)
_tmpm = load_model_cpu(srcf, _tmpm)
if typ == "enc":
_tmpm.enc.nets = ModuleList(remove_layers(list(_tmpm.enc.nets), rlist))
elif typ == "dec":
_tmpm.dec.nets = ModuleList(remove_layers(list(_tmpm.dec.nets), rlist))

save_model(_tmpm, rsf, False)

if __name__ == "__main__":
handle(sys.argv[1], sys.argv[2], sys.argv[3], [int(_t) for _t in sys.argv[4:]])
9 changes: 4 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
seq_batch = seq_batch.to(mv_device)
seq_o = seq_o.to(mv_device)

if _done_tokens >= tokens_optm:
optm.zero_grad()
_done_tokens = 0

oi = seq_o.narrow(1, 0, lo)
ot = seq_o.narrow(1, 1, lo).contiguous()
output = model(seq_batch, oi)
Expand Down Expand Up @@ -88,9 +84,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok
if multi_gpu:
model.collect_gradients()
optm.step()
optm.zero_grad()
model.update_replicas()
else:
optm.step()
optm.zero_grad()
_done_tokens = 0
if _cur_rstep is not None:
if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0):
if num_checkpoint > 1:
Expand Down Expand Up @@ -208,7 +207,7 @@ def init_fixing(module):

tokens_optm = cnfg.tokens_optm

done_tokens = tokens_optm
done_tokens = 0

batch_report = cnfg.batch_report
report_eva = cnfg.report_eva
Expand Down
4 changes: 2 additions & 2 deletions transformer/NMT.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import torch
from torch import nn

# import Encoder and Decoder from transformer.AGG.InceptEncoder and transformer.AGG.InceptDecoder/transformer.AGG.InceptAvgDecoder to learn complex representation with incepted transformer, TA/TAEncoder for Transparent Encoder
# import Encoder and Decoder from transformer.AGG.InceptEncoder and transformer.AGG.InceptDecoder/transformer.AGG.InceptAvgDecoder to learn complex representation with incepted transformer, transformer.TA.Encoder for Transparent Encoder.
from transformer.Encoder import Encoder

# switch the comment between the following two lines to choose standard decoder or average decoder. Using TA/TADecoder for Transparent Decoder.
# switch the comment between the following two lines to choose standard decoder or average decoder. Using transformer.TA.Decoder for Transparent Decoder.
from transformer.Decoder import Decoder
#from transformer.AvgDecoder import Decoder

Expand Down
File renamed without changes.
4 changes: 1 addition & 3 deletions transformer/TA/TAEncoder.py → transformer/TA/Encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ class EncoderLayer(nn.Module):
# attn_drop: dropout for MultiHeadAttention
# num_head: number of heads in MultiHeadAttention
# ahsize: hidden size of MultiHeadAttention
# norm_residue: residue with layer normalized representation

def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None):

Expand Down Expand Up @@ -94,8 +93,7 @@ def forward(self, inputs, mask=None):
out = self.drop(out)

out = self.out_normer(out)
outs = []
outs.append(out)
outs = [out]

for net in self.nets:
out = net(out, mask)
Expand Down
9 changes: 9 additions & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,12 @@ def expand_bsize_for_beam(*inputs, beam_size=1):
outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size))

return outputs[0] if len(inputs) == 1 else tuple(outputs)

def remove_layers(all_layers, ltr):

rs = []
for i, _l in enumerate(all_layers):
if i not in ltr:
rs.append(_l)

return rs

0 comments on commit d7cb516

Please sign in to comment.