diff --git a/adv/predict/doc/para/predict_doc_para.py b/adv/predict/doc/para/predict_doc_para.py index d8b2244..9b92749 100644 --- a/adv/predict/doc/para/predict_doc_para.py +++ b/adv/predict/doc/para/predict_doc_para.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast from tqdm import tqdm @@ -50,10 +51,8 @@ def load_fixing(module): mymodel.eval() -use_cuda = cnfg.use_cuda -gpuid = cnfg.gpuid - use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding) +use_amp = cnfg.use_amp and use_cuda # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) @@ -81,7 +80,8 @@ def load_fixing(module): seq_batch = seq_batch.to(cuda_device) bsize, _nsent, seql = seq_batch.size() _nsent_use = _nsent - 1 - output = mymodel.decode(seq_batch.narrow(1, 1, _nsent_use).contiguous(), seq_batch.narrow(1, 0, _nsent_use).contiguous(), beam_size, None, length_penalty).view(bsize, _nsent_use, -1) + with autocast(enabled=use_amp): + output = mymodel.decode(seq_batch.narrow(1, 1, _nsent_use).contiguous(), seq_batch.narrow(1, 0, _nsent_use).contiguous(), beam_size, None, length_penalty).view(bsize, _nsent_use, -1) if multi_gpu: tmp = [] for ou in output: diff --git a/adv/predict/predict_ape.py b/adv/predict/predict_ape.py new file mode 100644 index 0000000..d296392 --- /dev/null +++ b/adv/predict/predict_ape.py @@ -0,0 +1,100 @@ +#encoding: utf-8 + +import sys + +import torch +from torch.cuda.amp import autocast + +from tqdm import tqdm + +import h5py + +import cnfg.base as cnfg +from cnfg.ihyp import * + +from transformer.APE.NMT import NMT +from transformer.EnsembleNMT import NMT as Ensemble +from parallel.parallelMT import DataParallelMT + +from utils.base import * +from utils.fmt.base import ldvocab, reverse_dict, eos_id +from utils.fmt.base4torch import parse_cuda_decode + +def load_fixing(module): + + if "fix_load" in dir(module): + module.fix_load() + +td = h5py.File(cnfg.test_data, "r") + +ntest = td["ndata"][:].item() +nwordi = td["nword"][:].tolist()[0] +vcbt, nwordt = ldvocab(sys.argv[2]) +vcbt = reverse_dict(vcbt) + +if len(sys.argv) == 4: + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + + mymodel = load_model_cpu(sys.argv[3], mymodel) + mymodel.apply(load_fixing) + +else: + models = [] + for modelf in sys.argv[3:]: + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + + tmp = load_model_cpu(modelf, tmp) + tmp.apply(load_fixing) + + models.append(tmp) + mymodel = Ensemble(models) + +mymodel.eval() + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding) + +use_amp = cnfg.use_amp and use_cuda + +set_random_seed(cnfg.seed, use_cuda) + +if use_cuda: + mymodel.to(cuda_device) + if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + +beam_size = cnfg.beam_size + +length_penalty = cnfg.length_penalty + +ens = "\n".encode("utf-8") + +# using tgt instead of mt since data are processed by tools/mkiodata.py for the mt task +src_grp, mt_grp = td["src"], td["tgt"] +with open(sys.argv[1], "wb") as f: + with torch.no_grad(): + for i in tqdm(range(ntest)): + seq_batch = torch.from_numpy(src_grp[str(i)][:]).long() + seq_mt = torch.from_numpy(mt_grp[str(i)][:]).long() + if use_cuda: + seq_batch = seq_batch.to(cuda_device) + seq_mt = seq_mt.to(cuda_device) + with autocast(enabled=use_amp): + output = mymodel.decode(seq_batch, seq_mt, beam_size, None, length_penalty) + if multi_gpu: + tmp = [] + for ou in output: + tmp.extend(ou.tolist()) + output = tmp + else: + output = output.tolist() + for tran in output: + tmp = [] + for tmpu in tran: + if tmpu == eos_id: + break + else: + tmp.append(vcbt[tmpu]) + f.write(" ".join(tmp).encode("utf-8")) + f.write(ens) + +td.close() diff --git a/adv/rank/doc/para/rank_loss_doc_para.py b/adv/rank/doc/para/rank_loss_para.py similarity index 85% rename from adv/rank/doc/para/rank_loss_doc_para.py rename to adv/rank/doc/para/rank_loss_para.py index 5a135ae..ee59f74 100644 --- a/adv/rank/doc/para/rank_loss_doc_para.py +++ b/adv/rank/doc/para/rank_loss_para.py @@ -7,6 +7,7 @@ import sys import torch +from torch.cuda.amp import autocast from tqdm import tqdm @@ -23,6 +24,7 @@ from loss.base import LabelSmoothingLoss from utils.base import * +from utils.fmt.base import pad_id from utils.fmt.base4torch import parse_cuda def load_fixing(module): @@ -57,9 +59,10 @@ def load_fixing(module): mymodel.eval() -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='none', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes) use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +use_amp = cnfg.use_amp and use_cuda # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) @@ -91,10 +94,11 @@ def load_fixing(module): seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - output = mymodel(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()).view(bsize, _nsent_use, lo, -1) - loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1) + with autocast(enabled=use_amp): + output = mymodel(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()).view(bsize, _nsent_use, lo, -1) + loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1) if norm_token: - lenv = ot.ne(0).int().view(bsize, -1).sum(-1).to(loss) + lenv = ot.ne(pad_id).int().view(bsize, -1).sum(-1).to(loss) loss = loss / lenv f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8")) loss = output = ot = seq_batch = seq_o = None diff --git a/adv/rank/doc/rank_loss_sent.py b/adv/rank/doc/rank_loss_sent.py index eb86dda..33128df 100644 --- a/adv/rank/doc/rank_loss_sent.py +++ b/adv/rank/doc/rank_loss_sent.py @@ -7,6 +7,7 @@ import sys import torch +from torch.cuda.amp import autocast from tqdm import tqdm @@ -23,6 +24,7 @@ from loss.base import LabelSmoothingLoss from utils.base import * +from utils.fmt.base import pad_id from utils.fmt.base4torch import parse_cuda def load_fixing(module): @@ -57,9 +59,10 @@ def load_fixing(module): mymodel.eval() -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='none', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes) use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +use_amp = cnfg.use_amp and use_cuda # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) @@ -87,10 +90,11 @@ def load_fixing(module): seq_o = seq_o.to(cuda_device) lo = seq_o.size(-1) - 1 ot = seq_o.narrow(-1, 1, lo).contiguous() - output = mymodel(seq_batch.view(ebsize, -1), seq_o.narrow(-1, 0, lo).contiguous().view(ebsize, -1)).view(bsize, nsent, lo, -1) - loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1) + with autocast(enabled=use_amp): + output = mymodel(seq_batch.view(ebsize, -1), seq_o.narrow(-1, 0, lo).contiguous().view(ebsize, -1)).view(bsize, nsent, lo, -1) + loss = lossf(output, ot).sum(-1).view(bsize, -1).sum(-1) if norm_token: - lenv = ot.ne(0).int().view(bsize, -1).sum(-1).to(loss) + lenv = ot.ne(pad_id).int().view(bsize, -1).sum(-1).to(loss) loss = loss / lenv f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8")) loss = output = ot = seq_batch = seq_o = None diff --git a/adv/train/doc/para/train_doc_para.py b/adv/train/doc/para/train_doc_para.py index 5267983..33a1c8d 100644 --- a/adv/train/doc/para/train_doc_para.py +++ b/adv/train/doc/para/train_doc_para.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast, GradScaler from torch import optim @@ -12,7 +13,7 @@ from utils.base import * from utils.init import init_model_params from utils.h5serial import h5save, h5load -from utils.fmt.base import tostr, save_states, load_states +from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb from lrsch import GoogleLR @@ -33,19 +34,13 @@ from transformer.Doc.Para.Base.NMT import NMT from transformer.NMT import NMT as BaseNMT -def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): - sum_loss = 0.0 - sum_wd = 0 - part_loss = 0.0 - part_wd = 0 - _done_tokens = done_tokens + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(tl) model.train() - cur_b = 1 - ndata = len(tl) - _cur_checkid = cur_checkid - _cur_rstep = remain_steps - _ls = {} if save_loss else None + cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] for nsent, i_d in tqdm(tl): @@ -61,18 +56,19 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() + with autocast(enabled=_use_amp): + output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() loss_add = loss.data.item() - if use_amp: - with amp.scale_loss(loss, optm) as scaled_loss: - scaled_loss.backward() - else: + + if scaler is None: loss.backward() + else: + scaler.scale(loss).backward() - wd_add = ot.ne(0).int().sum().item() + wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: @@ -83,12 +79,10 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() - optm.step() - optm.zero_grad() + optm_step(optm, scaler) + optm.zero_grad() + if multi_gpu: model.update_replicas() - else: - optm.step() - optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -116,7 +110,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_wd += wd_add if cur_b % nreport == 0: if report_eva: - _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() @@ -145,9 +139,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls -def eva(ed, nd, model, lossf, mv_device, multi_gpu): - r = 0 - w = 0 +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 sum_loss = 0.0 model.eval() @@ -166,15 +159,16 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu): seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() - trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) - else: - trans = output.argmax(-1) + with autocast(enabled=use_amp): + output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) sum_loss += loss.data.item() - data_mask = ot.ne(0) + data_mask = ot.ne(pad_id) correct = (trans.eq(ot) & data_mask).int() w += data_mask.int().sum().item() r += correct.sum().item() @@ -231,16 +225,6 @@ def init_fixing(module): use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) -if use_cuda and cnfg.amp_opt: - try: - from apex import amp - use_amp = True - except Exception as e: - logger.info(str(e)) - use_amp = False -else: - use_amp = False - set_random_seed(cnfg.seed, use_cuda) td = h5py.File(cnfg.train_data, "r") @@ -259,6 +243,7 @@ def init_fixing(module): vl = [(str(nsent), str(_curd),) for nsent, ndata in zip(vd["nsent"][:].tolist(), vd["ndata"][:].tolist()) for _curd in range(ndata)] mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) if fine_tune_m is not None: logger.info("Load pre-trained model from: " + fine_tune_m) _tmpm = BaseNMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) @@ -271,9 +256,8 @@ def init_fixing(module): _tmpm.dec.classifier.bias.requires_grad_(True) mymodel.load_base(_tmpm) _tmpm = None -mymodel.apply(init_fixing) -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='sum', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) if cnfg.src_emb is not None: logger.info("Load source embedding from: " + cnfg.src_emb) load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) @@ -288,8 +272,8 @@ def init_fixing(module): optimizer = optim.Adam(filter_para_grad(mymodel.parameters()), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) optimizer.zero_grad() -if use_amp: - mymodel, optimizer = amp.initialize(mymodel, optimizer, opt_level=cnfg.amp_opt) +use_amp = cnfg.use_amp and use_cuda +scaler = GradScaler() if use_amp else None if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -307,7 +291,7 @@ def init_fixing(module): tminerr = inf_default -minloss, minerr = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu) +minloss, minerr = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: @@ -317,8 +301,8 @@ def init_fixing(module): cnt_states = cnfg.train_statesf if (cnt_states is not None) and p_check(cnt_states): logger.info("Continue last epoch") - tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, vl, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, use_amp) - vloss, vprec = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu) + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, vl, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) if save_optm_state: @@ -344,8 +328,8 @@ def init_fixing(module): for i in range(1, maxrun + 1): shuffle(tl) free_cache(use_cuda) - terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, vl, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, use_amp) - vloss, vprec = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, vl, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): @@ -375,8 +359,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() - #lrsch.step() + optm_step(optimizer, scaler) done_tokens = 0 logger.info("early stop") break @@ -397,7 +380,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) if save_optm_state: diff --git a/adv/train/train_ape.py b/adv/train/train_ape.py new file mode 100644 index 0000000..acf4996 --- /dev/null +++ b/adv/train/train_ape.py @@ -0,0 +1,381 @@ +#encoding: utf-8 + +import sys + +import torch +from torch.cuda.amp import autocast, GradScaler + +from torch import optim + +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT + +from utils.base import * +from utils.init import init_model_params +from utils.h5serial import h5save, h5load +from utils.fmt.base import tostr, save_states, load_states, pad_id +from utils.fmt.base4torch import parse_cuda, load_emb + +from lrsch import GoogleLR +from loss.base import LabelSmoothingLoss + +from random import shuffle + +from tqdm import tqdm + +from os import makedirs +from os.path import exists as p_check + +import h5py + +import cnfg.base as cnfg +from cnfg.ihyp import * + +from transformer.APE.NMT import NMT + +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): + + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(tl) + model.train() + cur_b, _ls = 1, {} if save_loss else None + src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"] + for i_d in tqdm(tl): + seq_batch = torch.from_numpy(src_grp[i_d][:]).long() + seq_mt = torch.from_numpy(mt_grp[i_d][:]).long() + seq_o = torch.from_numpy(tgt_grp[i_d][:]).long() + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_mt = seq_mt.to(mv_device) + seq_o = seq_o.to(mv_device) + + oi = seq_o.narrow(1, 0, lo) + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=_use_amp): + output = model(seq_batch, seq_mt, oi) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + loss_add = loss.data.item() + + if scaler is None: + loss.backward() + else: + scaler.scale(loss).backward() + + wd_add = ot.ne(pad_id).int().sum().item() + loss = output = oi = ot = seq_batch = seq_o = None + sum_loss += loss_add + if save_loss: + _ls[(i_d, t_d)] = loss_add / wd_add + sum_wd += wd_add + _done_tokens += wd_add + + if _done_tokens >= tokens_optm: + if multi_gpu: + model.collect_gradients() + optm_step(optm, scaler) + optm.zero_grad() + if multi_gpu: + model.update_replicas() + _done_tokens = 0 + if _cur_rstep is not None: + if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, logger) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + _cur_rstep -= 1 + if _cur_rstep <= 0: + break + lrsch.step() + + if nreport is not None: + part_loss += loss_add + part_wd += wd_add + if cur_b % nreport == 0: + if report_eva: + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) + logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) + free_cache(mv_device) + model.train() + else: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) + part_loss = 0.0 + part_wd = 0 + + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + if chkpof is not None: + _chkpof = chkpof[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + _chkpof = chkpof + save_model(model, _chkpf, multi_gpu, logger) + if chkpof is not None: + h5save(optm.state_dict(), _chkpof) + if statesf is not None: + save_states(statesf, tl[cur_b - 1:]) + cur_b += 1 + if part_wd != 0.0: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) + return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls + +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 + sum_loss = 0.0 + model.eval() + src_grp, mt_grp, tgt_grp = ed["src"], ed["mt"], ed["tgt"] + with torch.no_grad(): + for i in tqdm(range(nd)): + bid = str(i) + seq_batch = torch.from_numpy(src_grp[bid][:]).long() + seq_mt = torch.from_numpy(mt_grp[bid][:]).long() + seq_o = torch.from_numpy(tgt_grp[bid][:]).long() + lo = seq_o.size(1) - 1 + if mv_device: + seq_batch = seq_batch.to(mv_device) + seq_mt = seq_mt.to(mv_device) + seq_o = seq_o.to(mv_device) + ot = seq_o.narrow(1, 1, lo).contiguous() + with autocast(enabled=use_amp): + output = model(seq_batch, seq_mt, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) + sum_loss += loss.data.item() + data_mask = ot.ne(pad_id) + correct = (trans.eq(ot) & data_mask).int() + w += data_mask.int().sum().item() + r += correct.sum().item() + correct = data_mask = trans = loss = output = ot = seq_batch = seq_o = None + w = float(w) + return sum_loss / w, (w - r) / w * 100.0 + +def hook_lr_update(optm, flags=None): + + reset_Adam(optm, flags) + +def init_fixing(module): + + if "fix_init" in dir(module): + module.fix_init() + +rid = cnfg.run_id +if len(sys.argv) > 1: + rid = sys.argv[1] + +earlystop = cnfg.earlystop + +maxrun = cnfg.maxrun + +tokens_optm = cnfg.tokens_optm + +done_tokens = 0 + +batch_report = cnfg.batch_report +report_eva = cnfg.report_eva + +use_ams = cnfg.use_ams + +save_optm_state = cnfg.save_optm_state + +save_every = cnfg.save_every +start_chkp_save = cnfg.epoch_start_checkpoint_save + +epoch_save = cnfg.epoch_save + +remain_steps = cnfg.training_steps + +wkdir = "".join(("expm/", cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) +if not p_check(wkdir): + makedirs(wkdir) + +chkpf = None +chkpof = None +statesf = None +if save_every is not None: + chkpf = wkdir + "checkpoint.h5" + if save_optm_state: + chkpof = wkdir + "checkpoint.optm.h5" + if cnfg.save_train_state: + statesf = wkdir + "checkpoint.states" + +logger = get_logger(wkdir + "train.log") + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) + +set_random_seed(cnfg.seed, use_cuda) + +td = h5py.File(cnfg.train_data, "r") +vd = h5py.File(cnfg.dev_data, "r") + +ntrain = td["ndata"][:].item() +nvalid = vd["ndata"][:].item() +nword = td["nword"][:].tolist() +nwordi, nwordt = nword[0], nword[-1] + +logger.info("Design models with seed: %d" % torch.initial_seed()) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + +fine_tune_m = cnfg.fine_tune_m + +tl = [str(i) for i in range(ntrain)] + +mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) +if fine_tune_m is not None: + logger.info("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) + +if cnfg.src_emb is not None: + logger.info("Load source embedding from: " + cnfg.src_emb) + load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) +if cnfg.tgt_emb is not None: + logger.info("Load target embedding from: " + cnfg.tgt_emb) + load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) + +if use_cuda: + mymodel.to(cuda_device) + lossf.to(cuda_device) + +optimizer = optim.Adam(mymodel.parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad() + +use_amp = cnfg.use_amp and use_cuda +scaler = GradScaler() if use_amp else None + +if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) + +fine_tune_state = cnfg.fine_tune_state +if fine_tune_state is not None: + logger.info("Load optimizer state from: " + fine_tune_state) + optimizer.load_state_dict(h5load(fine_tune_state)) + +lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) + +num_checkpoint = cnfg.num_checkpoint +cur_checkid = 0 + +tminerr = inf_default + +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) +logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) + +if fine_tune_m is None: + save_model(mymodel, wkdir + "init.h5", multi_gpu, logger) + logger.info("Initial model saved") +else: + cnt_states = cnfg.train_statesf + if (cnt_states is not None) and p_check(cnt_states): + logger.info("Continue last epoch") + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_0_%.3f_%.3f_%.2f.optm.h5" % (tminerr, vloss, vprec)) + logger.info("New best model saved") + +if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: + dss_ws = int(cnfg.dss_ws * ntrain) + _Dws = {} + _prev_Dws = {} + _crit_inc = {} + if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: + dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws)) + else: + dss_rm = 0 +else: + dss_ws = 0 + dss_rm = 0 + _Dws = None + +namin = 0 + +for i in range(1, maxrun + 1): + shuffle(tl) + free_cache(use_cuda) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) + + if (vprec <= minerr) or (vloss <= minloss): + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "eva_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) + logger.info("New best model saved") + + namin = 0 + + if vprec < minerr: + minerr = vprec + if vloss < minloss: + minloss = vloss + + else: + if terr < tminerr: + tminerr = terr + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "train_%d_%.3f_%.3f_%.2f.optm.h5" % (i, terr, vloss, vprec)) + elif epoch_save: + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec), multi_gpu, logger) + + namin += 1 + if namin >= earlystop: + if done_tokens > 0: + if multi_gpu: + mymodel.collect_gradients() + optm_step(optimizer, scaler) + done_tokens = 0 + logger.info("early stop") + break + + if remain_steps is not None and remain_steps <= 0: + logger.info("Last training step reached") + break + + if dss_ws > 0: + if _prev_Dws: + for _key, _value in _Dws.items(): + if _key in _prev_Dws: + _ploss = _prev_Dws[_key] + _crit_inc[_key] = (_ploss - _value) / _ploss + tl = dynamic_sample(_crit_inc, dss_ws, dss_rm) + _prev_Dws = _Dws + +if done_tokens > 0: + if multi_gpu: + mymodel.collect_gradients() + optm_step(optimizer, scaler) + +save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) +if save_optm_state: + h5save(optimizer.state_dict(), wkdir + "last.optm.h5") +logger.info("model saved") + +td.close() +vd.close() diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index d5a7240..9f4db07 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast, GradScaler from torch import optim @@ -13,7 +14,7 @@ from utils.init import init_model_params from utils.dynbatch import GradientMonitor from utils.h5serial import h5save, h5load -from utils.fmt.base import tostr, save_states, load_states +from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb from lrsch import GoogleLR @@ -43,21 +44,15 @@ def select_function(modin, select_index): return _sel_m.parameters() -grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=1) +grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_record=cnfg.num_dynb_his, num_his_gm=1) -def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): - sum_loss = 0.0 - sum_wd = 0 - part_loss = 0.0 - part_wd = 0 - _done_tokens = done_tokens + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(tl) model.train() - cur_b = 1 - ndata = len(tl) - _cur_checkid = cur_checkid - _cur_rstep = remain_steps - _ls = {} if save_loss else None + cur_b, _ls = 1, {} if save_loss else None global grad_mon, update_angle @@ -72,19 +67,19 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, oi) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() loss_add = loss.data.item() - if use_amp: - with amp.scale_loss(loss, optm) as scaled_loss: - scaled_loss.backward() - else: + if scaler is None: loss.backward() + else: + scaler.scale(loss).backward() - wd_add = ot.ne(0).int().sum().item() + wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: @@ -101,16 +96,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if _do_optm_step: if multi_gpu: model.collect_gradients() - optm.step() - optm.zero_grad() + optm_step(optm, scaler) + optm.zero_grad() + if multi_gpu: model.update_replicas() - else: - optm.step() - optm.zero_grad() lrsch.step() else: if multi_gpu: - #optm.zero_grad() model.reset_grad() else: optm.zero_grad() @@ -141,7 +133,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_wd += wd_add if cur_b % nreport == 0: if report_eva: - _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() @@ -171,9 +163,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls -def eva(ed, nd, model, lossf, mv_device, multi_gpu): - r = 0 - w = 0 +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] @@ -187,15 +178,16 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu): seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, seq_o.narrow(1, 0, lo)) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() - trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) - else: - trans = output.argmax(-1) + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) sum_loss += loss.data.item() - data_mask = ot.ne(0) + data_mask = ot.ne(pad_id) correct = (trans.eq(ot) & data_mask).int() w += data_mask.int().sum().item() r += correct.sum().item() @@ -252,16 +244,6 @@ def init_fixing(module): use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) -if use_cuda and cnfg.amp_opt: - try: - from apex import amp - use_amp = True - except Exception as e: - logger.info(str(e)) - use_amp = False -else: - use_amp = False - set_random_seed(cnfg.seed, use_cuda) td = h5py.File(cnfg.train_data, "r") @@ -285,7 +267,7 @@ def init_fixing(module): logger.info("Load pre-trained model from: " + fine_tune_m) mymodel = load_model_cpu(fine_tune_m, mymodel) -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='sum', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) if cnfg.src_emb is not None: logger.info("Load source embedding from: " + cnfg.src_emb) @@ -301,8 +283,8 @@ def init_fixing(module): optimizer = optim.Adam(mymodel.parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) optimizer.zero_grad() -if use_amp: - mymodel, optimizer = amp.initialize(mymodel, optimizer, opt_level=cnfg.amp_opt) +use_amp = cnfg.use_amp and use_cuda +scaler = GradScaler() if use_amp else None if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -320,7 +302,7 @@ def init_fixing(module): tminerr = inf_default -minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: @@ -330,8 +312,8 @@ def init_fixing(module): cnt_states = cnfg.train_statesf if (cnt_states is not None) and p_check(cnt_states): logger.info("Continue last epoch") - tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) if save_optm_state: @@ -357,8 +339,8 @@ def init_fixing(module): for i in range(1, maxrun + 1): shuffle(tl) free_cache(use_cuda) - terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): @@ -388,7 +370,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) done_tokens = 0 logger.info("early stop") break @@ -409,7 +391,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) if save_optm_state: diff --git a/cnfg/README.md b/cnfg/README.md index b728e1e..a8736ba 100644 --- a/cnfg/README.md +++ b/cnfg/README.md @@ -51,8 +51,8 @@ report_eva = False # run on GPU or not, and GPU device(s) to use. Data Parallel depended multi-gpu support can be enabled with values like: 'cuda:0, 1, 3'. use_cuda = True gpuid = 'cuda:0' -# [EXP] enable mixed precision (FP16) with "O1" -amp_opt = None +# use mixed precision (FP16) +use_amp = False # bind the embedding matrix with the classifer weight in decoder bindDecoderEmb = True @@ -170,3 +170,48 @@ hdf5_model_compression_level = 0 use_unk = True ``` +## `ihyp.py` + +To interpret configurations in hyp.py. + +## `dynb.py` + +Additional configurations for dynamic batch sizes. + +``` +# If the angle change is greater than or equal to the minimum value in the history * dyn_tol_alpha, perform an optimization step. +dyn_tol_alpha = 1.1 +# If fails to obtain a smaller angle change after this number of steps, perform an optimization step. +dyn_tol_amin = 3 + +# override the maximum tokens per batch configuration in `cnfg/base.py`. If there are no less than this number of tokens in a batch, an optimization step will be performed. +tokens_optm = tokens_optm * 10 + +# perform optimization step only in case the angle change is smaller than update_angle. +update_angle = 90.0 / dyn_tol_alpha + +# number of records of the angle change reduction. +num_dynb_his = 50 + +# hyper parameter for parameter sampling. Ignored in case using softmax over normalized angle change reduction (default). Uncomment corresponding lines in `utils/dynbatch.py` to enable. +select_alpha = 3.0 +``` + +## `docpara.py` + +Additional configurations for context-aware models. + +``` +# number of previous context sentences utilized +num_prev_sent = 2 + +# freeze the loaded sentence-level model +freeze_load_model = True +# unfreeze the bias and the weight matrix of the classifier of the sentence-level model +unfreeze_bias = True +unfreeze_weight = False + +# number of layers for context encoding +num_layer_context = 1 +``` + diff --git a/cnfg/base.py b/cnfg/base.py index 81e20c5..5b2251c 100644 --- a/cnfg/base.py +++ b/cnfg/base.py @@ -33,8 +33,7 @@ use_cuda = True # enable Data Parallel multi-gpu support with values like: 'cuda:0, 1, 3'. gpuid = 'cuda:0, 1' -# [EXP] enable mixed precision (FP16) with "O1" -amp_opt = None +use_amp = False bindDecoderEmb = True share_emb = False diff --git a/cnfg/hyp.py b/cnfg/hyp.py index 64b21aa..d23f27d 100644 --- a/cnfg/hyp.py +++ b/cnfg/hyp.py @@ -4,7 +4,7 @@ lipschitz_initialization = True -# choices: None, "GeLU", "Swish", "Sigmoid" +# choices: None, "GeLU", "Swish", "Sigmoid", "Mish", "NormSwish" advance_activation_function = None # choices: "v1", "v2" diff --git a/cnfg/ihyp.py b/cnfg/ihyp.py index f4033fc..e26d27c 100644 --- a/cnfg/ihyp.py +++ b/cnfg/ihyp.py @@ -12,17 +12,19 @@ enable_ln_parameters = True -use_adv_act_default = False -override_GeLU_Swish = False -override_GeLU_Sigmoid = False +use_adv_act_default = custom_act_Sigmoid = custom_act_Swish = custom_act_Mish = use_norm_Swish = False if advance_activation_function is not None: use_adv_act_default = True _adv_act = advance_activation_function.lower() + use_norm_Swish = (_adv_act == "normswish") if _adv_act == "sigmoid": - override_GeLU_Sigmoid = True + custom_act_Sigmoid = True elif _adv_act == "swish": - override_GeLU_Swish = True -inplace_after_GeLU = use_adv_act_default and (not override_GeLU_Sigmoid) + custom_act_Swish = True + elif _adv_act == "mish": + custom_act_Mish = True + +inplace_after_Custom_Act = use_adv_act_default and (not custom_act_Sigmoid) norm_residual_default = not (computation_order.lower() == "v2") diff --git a/modules/LD.py b/modules/LD.py index b218e07..bf15bb3 100644 --- a/modules/LD.py +++ b/modules/LD.py @@ -4,19 +4,19 @@ from torch import nn from modules.base import Scorer, Linear, Dropout -from modules.act import GeLU +from modules.act import Custom_Act from cnfg.ihyp import * class ATTNCombiner(nn.Module): - def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default): + def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default): super(ATTNCombiner, self).__init__() _hsize = isize * 4 if hsize is None else hsize - self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) + self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) def forward(self, input1, input2, mask=None): @@ -37,13 +37,13 @@ def forward(self, input1, input2, mask=None): class DATTNCombiner(nn.Module): - def __init__(self, isize, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default): + def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default): super(DATTNCombiner, self).__init__() _hsize = isize * 4 if hsize is None else hsize - self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize, bias=False)) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Scorer(_hsize, bias=False)) + self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize, bias=False)) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize, bias=False)) # input1: (bsize, 1, isize) # input2: (bsize, seql, isize) diff --git a/modules/act.py b/modules/act.py index 9a718a6..bf34992 100644 --- a/modules/act.py +++ b/modules/act.py @@ -2,6 +2,10 @@ import torch from torch import nn +from torch.autograd import Function +from torch.nn import functional as nnFunc + +from utils.base import reduce_model_list from math import sqrt @@ -33,11 +37,16 @@ def forward(self, x): return 0.5 * x * (1.0 + (x / self.k).erf()) +try: + GELU = nn.GELU +except Exception as e: + GELU = GeLU_BERT + # Swish approximates GeLU when beta=1.702 (https://mp.weixin.qq.com/s/LEPalstOc15CX6fuqMRJ8Q). # GELU is nonmonotonic function that has a shape similar to Swish with beta = 1.4 (https://arxiv.org/abs/1710.05941). class Swish(nn.Module): - def __init__(self, beta=1.0, freeze_beta=True, isize=None): + def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if use_norm_Swish else None, eps=ieps_default): super(Swish, self).__init__() @@ -47,10 +56,17 @@ def __init__(self, beta=1.0, freeze_beta=True, isize=None): else: self.reset_beta = beta self.beta = nn.Parameter(torch.tensor([beta])) if isize is None else nn.Parameter(torch.tensor([beta]).repeat(isize)) + self.dim, self.eps = dim, eps def forward(self, x): - return (x.sigmoid() * x) if self.beta is None else (x * (self.beta * x).sigmoid()) + if self.dim is None: + _norm_x = x + else: + _dx = x.detach() + _norm_x = (x - _dx.mean(dim=self.dim, keepdim=True)) / (_dx.std(dim=self.dim, keepdim=True) + self.eps) + + return (x.sigmoid() * _norm_x) if self.beta is None else (_norm_x * (self.beta * x).sigmoid()) def fix_init(self): @@ -58,9 +74,85 @@ def fix_init(self): if self.reset_beta is not None: self.beta.fill_(self.reset_beta) -if override_GeLU_Swish: - GeLU = Swish -elif override_GeLU_Sigmoid: - GeLU = nn.Sigmoid +class Mish(nn.Module): + + def forward(self, x): + + return x * nnFunc.softplus(x).tanh() + +if custom_act_Swish: + Custom_Act = Swish +elif custom_act_Sigmoid: + Custom_Act = nn.Sigmoid +elif custom_act_Mish: + Custom_Act = Mish else: - GeLU = GeLU_BERT + Custom_Act = GELU + +# SparseMax (https://arxiv.org/pdf/1602.02068) borrowed form OpenNMT-py( https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py) +class SparsemaxFunction(Function): + + @staticmethod + def forward(ctx, input, dim=0): + + def _threshold_and_support(input, dim=0): + + def _make_ix_like(input, dim=0): + + d = input.size(dim) + rho = torch.arange(1, d + 1, dtype=input.dtype, device=input.device) + view = [1] * input.dim() + view[0] = -1 + + return rho.view(view).transpose(0, dim) + + input_srt, _ = input.sort(descending=True, dim=dim) + input_cumsum = input_srt.cumsum(dim) - 1 + rhos = _make_ix_like(input, dim) + support = rhos * input_srt > input_cumsum + + support_size = support.sum(dim=dim).unsqueeze(dim) + tau = input_cumsum.gather(dim, support_size - 1) + tau /= support_size.to(input.dtype) + + return tau, support_size + + ctx.dim = dim + max_val, _ = input.max(dim=dim, keepdim=True) + input -= max_val + tau, supp_size = _threshold_and_support(input, dim=dim) + output = (input - tau).clamp(min=0) + ctx.save_for_backward(supp_size, output) + + return output + + @staticmethod + def backward(ctx, grad_output): + + supp_size, output = ctx.saved_tensors + dim = ctx.dim + grad_input = grad_output.clone() + grad_input[output == 0] = 0 + + v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() + v_hat = v_hat.unsqueeze(dim) + grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) + + return grad_input, None + +class Sparsemax(nn.Module): + + def __init__(self, dim=-1): + + super(Sparsemax, self).__init__() + self.dim = dim + + def forward(self, input): + + return SparsemaxFunction.apply(input, self.dim) + +def reduce_model(modin): + + rsm = reduce_model_list(modin, [nn.ReLU, nn.Softmax, Sparsemax, Swish], [lambda m: (m.inplace,), lambda m: (m.dim,), lambda m: (m.dim,), lambda m: (m.reset_beta, m.beta, m.dim, m.eps)]) + + return reduce_model_list(rsm, [GELU, GeLU_GPT, GeLU_BERT, Mish, nn.Tanh, nn.Sigmoid]) diff --git a/modules/base.py b/modules/base.py index 034232d..ccddf85 100644 --- a/modules/base.py +++ b/modules/base.py @@ -7,8 +7,10 @@ from torch.autograd import Function from utils.base import reduce_model_list -from modules.act import GeLU_GPT, GeLU_BERT, GeLU, Swish -from modules.dropout import Dropout, TokenDropout +from modules.act import Custom_Act +from modules.act import reduce_model as reduce_model_act +from modules.dropout import Dropout, TokenDropout, InfDropout +from modules.dropout import reduce_model as reduce_model_drop from cnfg.ihyp import * @@ -19,13 +21,13 @@ class PositionwiseFF(nn.Module): # isize: input dimension # hsize: hidden dimension - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): super(PositionwiseFF, self).__init__() _hsize = isize * 4 if hsize is None else hsize - self.net = nn.Sequential(Linear(isize, _hsize), GeLU() if use_GeLU else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_GeLU), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), GeLU() if use_GeLU else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_bias)) + self.net = nn.Sequential(Linear(isize, _hsize), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_bias)) self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) @@ -206,16 +208,16 @@ class AverageAttn(nn.Module): # dropout: dropout rate for Feed-forward NN # num_pos: maximum length of sentence cached, extended length will be generated while needed and droped immediately after that - def __init__(self, isize, hsize=None, dropout=0.0, num_pos=cache_len_default, use_GeLU=use_adv_act_default): + def __init__(self, isize, hsize=None, dropout=0.0, num_pos=cache_len_default, custom_act=use_adv_act_default): super(AverageAttn, self).__init__() _hsize = isize if hsize is None else hsize self.num_pos = num_pos - self.register_buffer('w', torch.Tensor(num_pos, num_pos)) + self.register_buffer('w', torch.Tensor(num_pos, 1)) - self.ffn = nn.Sequential(Linear(isize, _hsize), Dropout(dropout, inplace=True), GeLU() if use_GeLU else nn.ReLU(inplace=True), Linear(_hsize, isize), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), GeLU() if use_GeLU else nn.ReLU(inplace=True), Linear(_hsize, isize)) + self.ffn = nn.Sequential(Linear(isize, _hsize), Dropout(dropout, inplace=True), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize)) self.gw = Linear(isize * 2, isize * 2) @@ -230,13 +232,10 @@ def forward(self, iQ, iV, decoding=False): if decoding: avg = iV else: - bsize, seql = iV.size()[:2] - - # attn: (seql, seql) - attn = self.get_ext(seql) if seql > self.num_pos else self.w.narrow(0, 0, seql).narrow(1, 0, seql) + seql = iV.size(1) # avg: (bsize, seql, vsize) - avg = attn.unsqueeze(0).expand(bsize, seql, seql).matmul(iV) + avg = iv.cumsum(dim=1) * (self.get_ext(seql) if seql > self.num_pos else self.w.narrow(0, 0, seql)) avg = self.ffn(avg) @@ -250,7 +249,7 @@ def reset_parameters(self): def get_ext(self, npos): - return (1.0 / torch.arange(1, npos + 1, dtype=self.w.dtype, device=self.w.device)).unsqueeze(1).expand(-1, npos).tril(0.0) + return (torch.arange(1, npos + 1, dtype=self.w.dtype, device=self.w.device).reciprocal_()).unsqueeze(-1) # Accelerated MultiHeadAttn for self attention, use when Q == K == V class SelfAttn(nn.Module): @@ -390,14 +389,14 @@ class ResidueCombiner(nn.Module): # isize: input size of Feed-forward NN - def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): super(ResidueCombiner, self).__init__() _hsize = isize * 2 * ncomb if hsize is None else hsize # should dropout be in front of sigmoid or not? - self.net = nn.Sequential(Linear(isize * ncomb, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Dropout(dropout, inplace=inplace_after_GeLU), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize * ncomb, _hsize), GeLU() if use_GeLU else nn.Sigmoid(), Linear(_hsize, isize, bias=enable_bias)) + self.net = nn.Sequential(Linear(isize * ncomb, _hsize), Custom_Act() if custom_act else nn.Sigmoid(), Dropout(dropout, inplace=inplace_after_Custom_Act), Linear(_hsize, isize, bias=enable_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize * ncomb, _hsize), Custom_Act() if custom_act else nn.Sigmoid(), Linear(_hsize, isize, bias=enable_bias)) self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) @@ -492,68 +491,6 @@ def forward(self, weight, weight_loss, remain_value): return ACTLossFunction.apply(weight, weight_loss, remain_value) -# SparseMax (https://arxiv.org/pdf/1602.02068) borrowed form OpenNMT-py( https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/modules/sparse_activations.py) -class SparsemaxFunction(Function): - - @staticmethod - def forward(ctx, input, dim=0): - - def _threshold_and_support(input, dim=0): - - def _make_ix_like(input, dim=0): - - d = input.size(dim) - rho = torch.arange(1, d + 1, dtype=input.dtype, device=input.device) - view = [1] * input.dim() - view[0] = -1 - - return rho.view(view).transpose(0, dim) - - input_srt, _ = input.sort(descending=True, dim=dim) - input_cumsum = input_srt.cumsum(dim) - 1 - rhos = _make_ix_like(input, dim) - support = rhos * input_srt > input_cumsum - - support_size = support.sum(dim=dim).unsqueeze(dim) - tau = input_cumsum.gather(dim, support_size - 1) - tau /= support_size.to(input.dtype) - - return tau, support_size - - ctx.dim = dim - max_val, _ = input.max(dim=dim, keepdim=True) - input -= max_val - tau, supp_size = _threshold_and_support(input, dim=dim) - output = (input - tau).clamp(min=0) - ctx.save_for_backward(supp_size, output) - - return output - - @staticmethod - def backward(ctx, grad_output): - - supp_size, output = ctx.saved_tensors - dim = ctx.dim - grad_input = grad_output.clone() - grad_input[output == 0] = 0 - - v_hat = grad_input.sum(dim=dim) / supp_size.to(output.dtype).squeeze() - v_hat = v_hat.unsqueeze(dim) - grad_input = torch.where(output != 0, grad_input - v_hat, grad_input) - - return grad_input, None - -class Sparsemax(nn.Module): - - def __init__(self, dim=-1): - - super(Sparsemax, self).__init__() - self.dim = dim - - def forward(self, input): - - return SparsemaxFunction.apply(input, self.dim) - class ApproximateEmb(nn.Module): def __init__(self, weight): @@ -760,5 +697,6 @@ def fix_init(self): def reduce_model(modin): - rsm = reduce_model_list(modin, [Dropout, nn.ReLU, nn.Softmax, PositionalEmb, TokenDropout, Sparsemax, CoordinateEmb, Swish], [lambda m: (m.p, m.inplace,), lambda m: (m.inplace,), lambda m: (m.dim,), lambda m: (m.num_pos, m.num_dim, m.poff, m.doff, m.alpha,), lambda m: (m.p, m.keep_magnitude,), lambda m: (m.dim,), lambda m: (m.num_pos, m.num_dim, m.poff, m.doff, m.alpha, m.num_steps,), lambda m: (m.reset_beta, m.beta,)]) - return reduce_model_list(rsm, [GeLU_GPT, GeLU_BERT, nn.Tanh, nn.Sigmoid]) + rsm = reduce_model_list(modin, [PositionalEmb, CoordinateEmb], [lambda m: (m.num_pos, m.num_dim, m.poff, m.doff, m.alpha,), lambda m: (m.num_pos, m.num_dim, m.poff, m.doff, m.alpha, m.num_steps,),]) + + return reduce_model_drop(reduce_model_act(rsm)) diff --git a/modules/dropout.py b/modules/dropout.py index fc4655e..e3be025 100644 --- a/modules/dropout.py +++ b/modules/dropout.py @@ -31,7 +31,7 @@ def forward(self, inpute): def norm(lin): _t = sum(lin) - return [lu / _t for lu in lin] + return tuple([lu / _t for lu in lin]) def sample(lin): @@ -77,3 +77,7 @@ def forward(self, inpute): return out else: return inpute + +def reduce_model(modin): + + return reduce_model_list(modin, [Dropout, TokenDropout, NGramDropout], [lambda m: (m.p, m.inplace,), lambda m: (m.p, m.inplace, m.keep_magnitude,), lambda m: (m.p, m.inplace, m.seqdim, m.keep_magnitude, m.sample_p, m.max_n,)]) diff --git a/modules/rnncells.py b/modules/rnncells.py index 7e4e7ec..3a66655 100644 --- a/modules/rnncells.py +++ b/modules/rnncells.py @@ -3,7 +3,7 @@ import torch from torch import nn from modules.base import * -from modules.act import GeLU +from modules.act import Custom_Act from cnfg.ihyp import * @@ -19,7 +19,7 @@ class LSTMCell4RNMT(nn.Module): # isize: input size of Feed-forward NN # dropout: dropout over hidden units, disabling it and applying dropout to outputs (_out) in most cases - def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): super(LSTMCell4RNMT, self).__init__() @@ -29,8 +29,8 @@ def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, self.trans = Linear(isize + _osize, _osize * 4, bias=enable_bias) self.normer = nn.LayerNorm((4, _osize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.act = GeLU() if use_GeLU else nn.Tanh() - self.drop = Dropout(dropout, inplace=inplace_after_GeLU) if dropout > 0.0 else None + self.act = Custom_Act() if custom_act else nn.Tanh() + self.drop = Dropout(dropout, inplace=inplace_after_Custom_Act) if dropout > 0.0 else None self.osize = _osize @@ -57,7 +57,7 @@ class GRUCell4RNMT(nn.Module): # isize: input size of Feed-forward NN - def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): super(GRUCell4RNMT, self).__init__() @@ -69,8 +69,8 @@ def __init__(self, isize, osize=None, dropout=0.0, use_GeLU=use_adv_act_default, self.normer = nn.LayerNorm((2, _osize), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - self.act = GeLU() if use_GeLU else nn.Tanh() - self.drop = Dropout(dropout, inplace=inplace_after_GeLU) if dropout > 0.0 else None + self.act = Custom_Act() if custom_act else nn.Tanh() + self.drop = Dropout(dropout, inplace=inplace_after_Custom_Act) if dropout > 0.0 else None self.osize = _osize diff --git a/parallel/base.py b/parallel/base.py index 4aadea2..b15fc3f 100644 --- a/parallel/base.py +++ b/parallel/base.py @@ -2,6 +2,7 @@ import torch import torch.cuda.comm as comm +from torch.cuda.amp import autocast from utils.comm import secure_broadcast_coalesced from torch.jit import ScriptModule @@ -291,16 +292,15 @@ def parallel_apply(modules, inputs, devices, kwargs_tup=None): lock = Lock() results = {} - grad_enabled = torch.is_grad_enabled() + grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() def _worker(i, module, input, kwargs, device=None): - with torch.set_grad_enabled(grad_enabled): - with torch.cuda.device(device): - # this also avoids accidental slicing of `input` if it is a Tensor - if not isinstance(input, (list, tuple)): - input = (input,) - output = module(*input, **kwargs) + # this also avoids accidental slicing of `input` if it is a Tensor + if not isinstance(input, (list, tuple)): + input = (input,) + with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + output = module(*input, **kwargs) with lock: results[i] = output @@ -324,17 +324,16 @@ def criterion_parallel_apply(modules, inputs, targets, devices, kwargs_tup=None) lock = Lock() results = {} - grad_enabled = torch.is_grad_enabled() + grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() def _worker(i, module, input, target, kwargs, device): - with torch.set_grad_enabled(grad_enabled): - with torch.cuda.device(device): - if not isinstance(input, (list, tuple)): - input = (input,) - if not isinstance(target, (list, tuple)): - target = (target,) - output = module(*(input + target), **kwargs) + if not isinstance(input, (list, tuple)): + input = (input,) + if not isinstance(target, (list, tuple)): + target = (target,) + with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + output = module(*(input + target), **kwargs) with lock: results[i] = output diff --git a/parallel/parallelMT.py b/parallel/parallelMT.py index 69f6342..402ec94 100644 --- a/parallel/parallelMT.py +++ b/parallel/parallelMT.py @@ -1,6 +1,7 @@ #encoding: utf-8 import torch +from torch.cuda.amp import autocast from parallel.base import DataParallelModel @@ -54,15 +55,14 @@ def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None): lock = Lock() results = {} - grad_enabled = torch.is_grad_enabled() + grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() def _worker(i, module, input, kwargs, device=None): - with torch.set_grad_enabled(grad_enabled): - with torch.cuda.device(device): - if not isinstance(input, (list, tuple)): - input = (input,) - output = module.decode(*input, **kwargs) + if not isinstance(input, (list, tuple)): + input = (input,) + with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + output = module.decode(*input, **kwargs) with lock: results[i] = output @@ -86,15 +86,14 @@ def parallel_apply_train_decode(modules, inputs, devices, kwargs_tup=None): lock = Lock() results = {} - grad_enabled = torch.is_grad_enabled() + grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() def _worker(i, module, input, kwargs, device=None): - with torch.set_grad_enabled(grad_enabled): - with torch.cuda.device(device): - if not isinstance(input, (list, tuple)): - input = (input,) - output = module.train_decode(*input, **kwargs) + if not isinstance(input, (list, tuple)): + input = (input,) + with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + output = module.train_decode(*input, **kwargs) with lock: results[i] = output diff --git a/predict.py b/predict.py index 18fc6c5..8fa6117 100644 --- a/predict.py +++ b/predict.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast from tqdm import tqdm @@ -50,10 +51,8 @@ def load_fixing(module): mymodel.eval() -use_cuda = cnfg.use_cuda -gpuid = cnfg.gpuid - use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding) +use_amp = cnfg.use_amp and use_cuda # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) @@ -76,7 +75,8 @@ def load_fixing(module): seq_batch = torch.from_numpy(src_grp[str(i)][:]).long() if use_cuda: seq_batch = seq_batch.to(cuda_device) - output = mymodel.decode(seq_batch, beam_size, None, length_penalty) + with autocast(enabled=use_amp): + output = mymodel.decode(seq_batch, beam_size, None, length_penalty) #output = mymodel.train_decode(seq_batch, beam_size, None, length_penalty) if multi_gpu: tmp = [] diff --git a/rank_loss.py b/rank_loss.py index 8a36c11..1962afa 100644 --- a/rank_loss.py +++ b/rank_loss.py @@ -7,6 +7,7 @@ import sys import torch +from torch.cuda.amp import autocast from tqdm import tqdm @@ -23,6 +24,7 @@ from loss.base import LabelSmoothingLoss from utils.base import * +from utils.fmt.base import pad_id from utils.fmt.base4torch import parse_cuda def load_fixing(module): @@ -57,9 +59,10 @@ def load_fixing(module): mymodel.eval() -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='none', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='none', forbidden_index=cnfg.forbidden_indexes) use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +use_amp = cnfg.use_amp and use_cuda # Important to make cudnn methods deterministic set_random_seed(cnfg.seed, use_cuda) @@ -85,10 +88,11 @@ def load_fixing(module): seq_o = seq_o.to(cuda_device) lo = seq_o.size(1) - 1 ot = seq_o.narrow(1, 1, lo).contiguous() - output = mymodel(seq_batch, seq_o.narrow(1, 0, lo)) - loss = lossf(output, ot).sum(-1).view(-1, lo).sum(-1) + with autocast(enabled=use_amp): + output = mymodel(seq_batch, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot).sum(-1).view(-1, lo).sum(-1) if norm_token: - lenv = ot.ne(0).int().sum(-1).to(loss) + lenv = ot.ne(pad_id).int().sum(-1).to(loss) loss = loss / lenv f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8")) loss = output = ot = seq_batch = seq_o = None diff --git a/requirements.opt.txt b/requirements.opt.txt index fba2d18..5fe6659 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -1,4 +1,4 @@ -Cython>=0.29.20 +Cython>=0.29.21 subword-nmt>=0.3.7 sacremoses>=0.0.43 Flask>=1.1.2 diff --git a/requirements.txt b/requirements.txt index b84cc3d..6363b51 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tqdm>=4.46.1 -torch>=1.5.0 +tqdm>=4.48.2 +torch>=1.6.0 h5py>=2.10.0 diff --git a/scripts/README.md b/scripts/README.md index d0ab1d9..01b7e5f 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -64,12 +64,20 @@ export dataid=w14ende # number of GPU(s) plan to use for decoding. export ngpu=1 + +# merge sub-words +export debpe=true ``` ## `bpe/` Scripts to perform sub-word segmentation. + ## `doc/` Corresponding scripts for document-level data processing. + +## `ape/` + +Scripts for data processing of APE. diff --git a/scripts/ape/bpe/clean.sh b/scripts/ape/bpe/clean.sh new file mode 100644 index 0000000..f7cbfcb --- /dev/null +++ b/scripts/ape/bpe/clean.sh @@ -0,0 +1,65 @@ +#!/bin/bash + +export cachedir=cache + +export dataid=w19ape + +export srcd=w19ape +export srcvf=dev/dev.src.tc +export mtvf=dev/dev.mt.tc +export tgtvf=dev/dev.pe.tc + +export maxtokens=256 + +export bpeops=32000 +export minfreq=8 +export share_bpe=false + +export tgtd=$cachedir/$dataid + +# options for cleaning the data processed by bpe, +# advised values except numrules can be calculated by: +# python tools/check/charatio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe, and +# python tools/check/biratio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe +# with development set. +# As for numrules, choose from [1, 6], fewer data will be droped with larger value, none data would be droped if it was set to 6, details are described in: +# tools/check/chars.py +export charatio=0.493 +export bperatio=2.401 +export seperatio=0.391 +export bibperatio=2.251 +export bioratio=2.501 +export numrules=1 + +# cleaning bpe results and bpe again +python tools/clean/ape/chars.py $tgtd/src.train.bpe $tgtd/mt.train.bpe $tgtd/tgt.train.bpe $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp $charatio $bperatio $seperatio $bibperatio $bioratio $numrules + +sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/src.clean.tmp > $tgtd/src.train.tok.clean +sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/mt.clean.tmp > $tgtd/mt.train.tok.clean +sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/tgt.clean.tmp > $tgtd/tgt.train.tok.clean +rm -fr $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp + +if $share_bpe; then +# to learn joint bpe + export src_cdsf=$tgtd/bpe.cds + export tgt_cdsf=$tgtd/bpe.cds + subword-nmt learn-joint-bpe-and-vocab --input $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean -s $bpeops -o $src_cdsf --write-vocabulary $tgtd/src.vcb.bpe $tgtd/tgt.vcb.bpe +else +# to learn independent bpe: + export src_cdsf=$tgtd/src.cds + export tgt_cdsf=$tgtd/tgt.cds + subword-nmt learn-bpe -s $bpeops < $tgtd/src.train.tok.clean > $src_cdsf + subword-nmt learn-bpe -s $bpeops < $tgtd/tgt.train.tok.clean > $tgt_cdsf + subword-nmt apply-bpe -c $src_cdsf < $tgtd/src.train.tok.clean | subword-nmt get-vocab > $tgtd/src.vcb.bpe + subword-nmt apply-bpe -c $tgt_cdsf < $tgtd/tgt.train.tok.clean | subword-nmt get-vocab > $tgtd/tgt.vcb.bpe +fi + +subword-nmt apply-bpe -c $src_cdsf --vocabulary $tgtd/src.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/src.train.tok.clean > $tgtd/src.train.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/mt.train.tok.clean > $tgtd/mt.train.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/tgt.train.tok.clean > $tgtd/tgt.train.bpe + +subword-nmt apply-bpe -c $src_cdsf --vocabulary $tgtd/src.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$srcvf > $tgtd/src.dev.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$mtvf > $tgtd/mt.dev.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$tgtvf > $tgtd/tgt.dev.bpe + +# then execute scripts/mktrain.sh to generate training and development data. diff --git a/scripts/ape/bpe/mk.sh b/scripts/ape/bpe/mk.sh new file mode 100644 index 0000000..e6138aa --- /dev/null +++ b/scripts/ape/bpe/mk.sh @@ -0,0 +1,59 @@ +#!/bin/bash + +export cachedir=cache + +export dataid=w19ape + +export srcd=w19ape +export srctf=train/all.src.tc +export mttf=train/all.mt.tc +export tgttf=train/all.pe.tc +export srcvf=dev/dev.src.tc +export mtvf=dev/dev.mt.tc +export tgtvf=dev/dev.pe.tc + +export vratio=0.2 +export maxtokens=256 + +export bpeops=32000 +export minfreq=8 +export share_bpe=false + +export tgtd=$cachedir/$dataid + +mkdir -p $tgtd + +# clean the data first by removing different translations with lower frequency of same sentences +python tools/clean/ape/maxkeeper.py $srcd/$srctf $srcd/$mttf $srcd/$tgttf $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp $maxtokens + +python tools/vocab.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 +python tools/vocab.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 +python tools/clean/ape/vocab.py $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.train.tok.clean $tgtd/mt.train.tok.clean $tgtd/tgt.train.tok.clean $tgtd/src.full.vcb $tgtd/tgt.full.vcb $vratio +rm -fr $tgtd/src.full.vcb $tgtd/tgt.full.vcb $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp + +if $share_bpe; then +# to learn joint bpe + export src_cdsf=$tgtd/bpe.cds + export tgt_cdsf=$tgtd/bpe.cds + subword-nmt learn-joint-bpe-and-vocab --input $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean -s $bpeops -o $src_cdsf --write-vocabulary $tgtd/src.vcb.bpe $tgtd/tgt.vcb.bpe +else +# to learn independent bpe: + export src_cdsf=$tgtd/src.cds + export tgt_cdsf=$tgtd/tgt.cds + subword-nmt learn-bpe -s $bpeops < $tgtd/src.train.tok.clean > $src_cdsf + subword-nmt learn-bpe -s $bpeops < $tgtd/tgt.train.tok.clean > $tgt_cdsf + subword-nmt apply-bpe -c $src_cdsf < $tgtd/src.train.tok.clean | subword-nmt get-vocab > $tgtd/src.vcb.bpe + subword-nmt apply-bpe -c $tgt_cdsf < $tgtd/tgt.train.tok.clean | subword-nmt get-vocab > $tgtd/tgt.vcb.bpe +fi + +subword-nmt apply-bpe -c $src_cdsf --vocabulary $tgtd/src.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/src.train.tok.clean > $tgtd/src.train.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/mt.train.tok.clean > $tgtd/mt.train.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $tgtd/tgt.train.tok.clean > $tgtd/tgt.train.bpe + +subword-nmt apply-bpe -c $src_cdsf --vocabulary $tgtd/src.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$srcvf > $tgtd/src.dev.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$mtvf > $tgtd/mt.dev.bpe +subword-nmt apply-bpe -c $tgt_cdsf --vocabulary $tgtd/tgt.vcb.bpe --vocabulary-threshold $minfreq < $srcd/$tgtvf > $tgtd/tgt.dev.bpe + +# report devlopment set features for cleaning +python tools/check/charatio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe +python tools/check/biratio.py $tgtd/src.dev.bpe $tgtd/tgt.dev.bpe diff --git a/scripts/ape/mktest.sh b/scripts/ape/mktest.sh new file mode 100644 index 0000000..8bb2997 --- /dev/null +++ b/scripts/ape/mktest.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +export srcd=w19ape/test +export srctf=test.src.tc.w19ape +export srcmf=test.mt.tc.w19ape +export modelf="expm/w19ape/std/base/avg.h5" +export rsd=w19apetrs/std +export rsf=$rsd/base_avg.txt + +export share_vcb=false + +export cachedir=cache +export dataid=w19ape + +export ngpu=1 + +export debpe=true + +export tgtd=$cachedir/$dataid + +export bpef=out.bpe + +if $share_vcb; then + export src_vcb=$tgtd/common.vcb + export tgt_vcb=$src_vcb +else + export src_vcb=$tgtd/src.vcb + export tgt_vcb=$tgtd/tgt.vcb +fi + +mkdir -p $rsd + +python tools/sort.py $srcd/$srctf $srcd/$srcmf $tgtd/$srctf.srt $tgtd/$srcmf.srt 1048576 +python tools/mkiodata.py $tgtd/$srctf.srt $tgtd/$srcmf.srt $src_vcb $tgt_vcb $tgtd/test.h5 $ngpu +python predict_ape.py $tgtd/$bpef.srt $tgt_vcb $modelf +python tools/ape/restore.py $srcd/$srctf $srcd/$srcmf $tgtd/$srctf.srt $tgtd/$srcmf.srt $tgtd/$bpef.srt $tgtd/$bpef +if $debpe; then + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + rm $tgtd/$bpef +else + mv $tgtd/$bpef $rsf +fi +rm $tgtd/$srctf.srt $tgtd/$srcmf.srt $tgtd/$bpef.srt diff --git a/scripts/ape/mktrain.sh b/scripts/ape/mktrain.sh new file mode 100644 index 0000000..cfa77a3 --- /dev/null +++ b/scripts/ape/mktrain.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +# take the processed data from scripts/mkbpe.sh and convert to tensor representation. + +export cachedir=cache +export dataid=w19ape + +export srcd=$cachedir/$dataid +export srctf=src.train.bpe +export mttf=mt.train.bpe +export tgttf=tgt.train.bpe +export srcvf=src.dev.bpe +export mtvf=tgt.dev.bpe +export tgtvf=tgt.dev.bpe + +export rsf_train=train.h5 +export rsf_dev=dev.h5 + +export share_vcb=false +export vsize=65536 + +export maxtokens=256 + +export ngpu=1 + +export wkd=$cachedir/$dataid + +python tools/ape/sort.py $srcd/$srctf $srcd/$mttf $srcd/$tgttf $wkd/src.train.srt $wkd/mt.train.srt $wkd/tgt.train.srt $maxtokens +python tools/ape/sort.py $srcd/$srcvf $srcd/$mtvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/mt.dev.srt $wkd/tgt.dev.srt 1048576 + +if $share_vcb; then + export src_vcb=$wkd/common.vcb + export tgt_vcb=$src_vcb + python tools/share_vocab.py $wkd/src.train.srt $wkd/tgt.train.srt $wkd/mt.train.srt $src_vcb $vsize + python tools/check/fbindexes.py $tgt_vcb $wkd/tgt.train.srt $wkd/tgt.dev.srt $wkd/fbind.py +else + export src_vcb=$wkd/src.vcb + export tgt_vcb=$wkd/tgt.vcb + python tools/vocab.py $wkd/src.train.srt $src_vcb $vsize + python tools/share_vocab.py $wkd/tgt.train.srt $wkd/mt.train.srt $tgt_vcb $vsize +fi + +python tools/ape/mkiodata.py $wkd/src.train.srt $wkd/mt.train.srt $wkd/tgt.train.srt $src_vcb $tgt_vcb $wkd/$rsf_train $ngpu +python tools/ape/mkiodata.py $wkd/src.dev.srt $wkd/mt.dev.srt $wkd/tgt.dev.srt $src_vcb $tgt_vcb $wkd/$rsf_dev $ngpu diff --git a/scripts/doc/para/mktest.sh b/scripts/doc/para/mktest.sh index ac3208d..5d0c331 100644 --- a/scripts/doc/para/mktest.sh +++ b/scripts/doc/para/mktest.sh @@ -13,6 +13,8 @@ export dataid=w19edoc export ngpu=1 +export debpe=true + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -31,4 +33,10 @@ python tools/doc/mono/sort.py $srcd/$srctf $tgtd/$srctf.srt python tools/doc/para/mktest.py $tgtd/$srctf.srt $src_vcb $tgtd/test.h5 $ngpu python predict_doc_para.py $tgtd/$bpef.srt $tgt_vcb $modelf python tools/doc/para/restore.py $srcd/$srctf w19ed/test.en.w19ed w19edtrs/base_avg.tbrs $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef -sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf +if $debpe; then + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + rm $tgtd/$bpef +else + mv $tgtd/$bpef $rsf +fi +rm $tgtd/$srctf.srt $tgtd/$bpef.srt diff --git a/scripts/mktest.sh b/scripts/mktest.sh index 8ae05e3..0913708 100644 --- a/scripts/mktest.sh +++ b/scripts/mktest.sh @@ -13,6 +13,8 @@ export dataid=w14ed32 export ngpu=1 +export debpe=true + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -31,4 +33,10 @@ python tools/sorti.py $srcd/$srctf $tgtd/$srctf.srt python tools/mktest.py $tgtd/$srctf.srt $src_vcb $tgtd/test.h5 $ngpu python predict.py $tgtd/$bpef.srt $tgt_vcb $modelf python tools/restore.py $srcd/$srctf $tgtd/$srctf.srt $tgtd/$bpef.srt $tgtd/$bpef -sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf +if $debpe; then + sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf + rm $tgtd/$bpef +else + mv $tgtd/$bpef $rsf +fi +rm $tgtd/$srctf.srt $tgtd/$bpef.srt diff --git a/scripts/mktrain.sh b/scripts/mktrain.sh index 6d3337b..f1111d6 100644 --- a/scripts/mktrain.sh +++ b/scripts/mktrain.sh @@ -32,7 +32,7 @@ if $share_vcb; then export src_vcb=$wkd/common.vcb export tgt_vcb=$src_vcb python tools/share_vocab.py $wkd/src.train.srt $wkd/tgt.train.srt $src_vcb $vsize - python tools/check/fbindexes.py $src_vcb $wkd/tgt.train.srt $wkd/fbind.py + python tools/check/fbindexes.py $tgt_vcb $wkd/tgt.train.srt $wkd/tgt.dev.srt $wkd/fbind.py else export src_vcb=$wkd/src.vcb export tgt_vcb=$wkd/tgt.vcb diff --git a/tools/README.md b/tools/README.md index fee7b28..8a8c6df 100644 --- a/tools/README.md +++ b/tools/README.md @@ -39,3 +39,6 @@ When you using a shared vocabulary for source side and target side, there are st Tools to filter the datasets. +## `ape/` + +Tools for APE. diff --git a/tools/ape/cnfg b/tools/ape/cnfg new file mode 120000 index 0000000..bcd9a88 --- /dev/null +++ b/tools/ape/cnfg @@ -0,0 +1 @@ +../../cnfg/ \ No newline at end of file diff --git a/tools/ape/mkiodata.py b/tools/ape/mkiodata.py new file mode 100644 index 0000000..aad2cd7 --- /dev/null +++ b/tools/ape/mkiodata.py @@ -0,0 +1,42 @@ +#encoding: utf-8 + +import sys + +import numpy +import h5py + +from utils.fmt.base import ldvocab +from utils.fmt.ape.triple import batch_padder + +from cnfg.ihyp import * + +def handle(finput, fmt, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + vcbi, nwordi = ldvocab(fvocab_i, minfreq, vsize) + vcbt, nwordt = ldvocab(fvocab_t, minfreq, vsize) + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + rsf = h5py.File(frs, 'w') + src_grp = rsf.create_group("src") + mt_grp = rsf.create_group("mt") + tgt_grp = rsf.create_group("tgt") + curd = 0 + for i_d, md, td in batch_padder(finput, fmt, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = numpy.array(i_d, dtype = numpy.int32) + rmd = numpy.array(md, dtype = numpy.int32) + rtd = numpy.array(td, dtype = numpy.int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + mt_grp.create_dataset(wid, data=rmd, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + curd += 1 + rsf["ndata"] = numpy.array([curd], dtype = numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordt], dtype = numpy.int32) + rsf.close() + print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (curd, nwordi, nwordt)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], int(sys.argv[7])) diff --git a/tools/ape/restore.py b/tools/ape/restore.py new file mode 100644 index 0000000..fd06398 --- /dev/null +++ b/tools/ape/restore.py @@ -0,0 +1,33 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_str + +def handle(srcfs, srcfm, srtsf, srtmf, srttf, tgtf): + + data = {} + + with open(srtsf, "rb") as fs, open(srtmf, "rb") as fm, open(srttf, "rb") as ft: + for sl, ml, tl in zip(fs, fm, ft): + _sl, _ml, _tl = sl.strip(), ml.strip(), tl.strip() + if _sl and _tl: + _sl = clean_str(_sl.decode("utf-8")) + _ml = clean_str(_ml.decode("utf-8")) + _tl = clean_str(_tl.decode("utf-8")) + data[(_sl, _ml,)] = _tl + + ens = "\n".encode("utf-8") + + with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(tgtf, "wb") as ft: + for sl, ml in zip(fs, fm): + _sl, _ml = sl.strip(), ml.strip() + if _sl: + _sl = clean_str(_sl.decode("utf-8")) + _ml = clean_str(_ml.decode("utf-8")) + tmp = data.get((_sl, _ml,), "") + ft.write(tmp.encode("utf-8")) + ft.write(ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6]) diff --git a/tools/ape/sort.py b/tools/ape/sort.py new file mode 100644 index 0000000..f306c7e --- /dev/null +++ b/tools/ape/sort.py @@ -0,0 +1,47 @@ +#encoding: utf-8 + +import sys +from random import seed as rpyseed + +from utils.fmt.base import clean_liststr_lentok, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list + +# remove_same: reduce same data in the corpus +# shuf: shuffle the data of same source/target length +# max_remove: if one source has several targets, only keep those with highest frequency + +def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, max_len=256, remove_same=False, shuf=True, max_remove=False): + + _max_len = max(1, max_len - 2) + + data = {} + + with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft: + for ls, lm, lt in zip(fs, fm, ft): + ls, lm, lt = ls.strip(), lm.strip(), lt.strip() + if ls and lm and lt: + ls, slen = clean_liststr_lentok(ls.decode("utf-8").split()) + lm, mlen = clean_liststr_lentok(lm.decode("utf-8").split()) + lt, tlen = clean_liststr_lentok(lt.decode("utf-8").split()) + if (slen <= _max_len) and (mlen <= _max_len) and (tlen <= _max_len): + data = dict_insert_list(data, (ls, lm, lt,), slen + tlen + mlen, tlen, mlen) + + ens = "\n".encode("utf-8") + + with open(tgtfs, "wb") as fs, open(tgtfm, "wb") as fm, open(tgtft, "wb") as ft: + for tmp in iter_dict_sort(data): + ls, lm, lt = zip(*tmp) + if len(ls) > 1: + if remove_same: + (ls, lm,), lt = maxfreq_filter((ls, lm,), lt, max_remove) + if shuf: + ls, lm, lt = shuffle_pair(ls, lm, lt) + fs.write("\n".join(ls).encode("utf-8")) + fs.write(ens) + fm.write("\n".join(lm).encode("utf-8")) + fm.write(ens) + ft.write("\n".join(lt).encode("utf-8")) + ft.write(ens) + +if __name__ == "__main__": + rpyseed(666666) + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], int(sys.argv[7])) diff --git a/tools/ape/utils b/tools/ape/utils new file mode 120000 index 0000000..7d6b64a --- /dev/null +++ b/tools/ape/utils @@ -0,0 +1 @@ +../../utils/ \ No newline at end of file diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index 55c0fb3..b431a03 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast, GradScaler from torch import optim @@ -13,7 +14,7 @@ from utils.init import init_model_params from utils.dynbatch import GradientMonitor from utils.h5serial import h5save, h5load -from utils.fmt.base import tostr, save_states, load_states +from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb from lrsch import GoogleLR @@ -45,21 +46,15 @@ def select_function(modin, select_index): return _sel_m.parameters() -grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_recoder=cnfg.num_dynb_his, num_his_gm=max_his) +grad_mon = GradientMonitor(num_layer * 2, select_function, module=None, angle_alpha=cnfg.dyn_tol_alpha, num_tol_amin=cnfg.dyn_tol_amin, num_his_record=cnfg.num_dynb_his, num_his_gm=max_his) -def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): - sum_loss = 0.0 - sum_wd = 0 - part_loss = 0.0 - part_wd = 0 - _done_tokens = done_tokens + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(tl) model.train() - cur_b = 1 - ndata = len(tl) - _cur_checkid = cur_checkid - _cur_rstep = remain_steps - _ls = {} if save_loss else None + cur_b, _ls = 1, {} if save_loss else None global grad_mon, update_angle, num_layer, log_dyn_p, log_dynb, wkdir _log_f_dynbatch = open(wkdir+"dynbatch.log", "ab") @@ -76,19 +71,19 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, oi) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() loss_add = loss.data.item() - if use_amp: - with amp.scale_loss(loss, optm) as scaled_loss: - scaled_loss.backward() - else: + if scaler is None: loss.backward() + else: + scaler.scale(loss).backward() - wd_add = ot.ne(0).int().sum().item() + wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: @@ -115,18 +110,15 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _log_f_dynbatch.write(("%d\n" % (_done_tokens,)).encode("utf-8")) if multi_gpu: model.collect_gradients() - optm.step() - optm.zero_grad() + optm_step(optm, scaler) + optm.zero_grad() + if multi_gpu: model.update_replicas() - else: - optm.step() - optm.zero_grad() lrsch.step() else: if log_dynb: _log_f_dynbatch.write(("D %d\n" % (_done_tokens,)).encode("utf-8")) if multi_gpu: - #optm.zero_grad() model.reset_grad() else: optm.zero_grad() @@ -158,7 +150,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_wd += wd_add if cur_b % nreport == 0: if report_eva: - _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() @@ -192,9 +184,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls -def eva(ed, nd, model, lossf, mv_device, multi_gpu): - r = 0 - w = 0 +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] @@ -208,15 +199,16 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu): seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, seq_o.narrow(1, 0, lo)) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() - trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) - else: - trans = output.argmax(-1) + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) sum_loss += loss.data.item() - data_mask = ot.ne(0) + data_mask = ot.ne(pad_id) correct = (trans.eq(ot) & data_mask).int() w += data_mask.int().sum().item() r += correct.sum().item() @@ -273,16 +265,6 @@ def init_fixing(module): use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) -if use_cuda and cnfg.amp_opt: - try: - from apex import amp - use_amp = True - except Exception as e: - logger.info(str(e)) - use_amp = False -else: - use_amp = False - set_random_seed(cnfg.seed, use_cuda) td = h5py.File(cnfg.train_data, "r") @@ -306,7 +288,7 @@ def init_fixing(module): logger.info("Load pre-trained model from: " + fine_tune_m) mymodel = load_model_cpu(fine_tune_m, mymodel) -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='sum', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) if cnfg.src_emb is not None: logger.info("Load source embedding from: " + cnfg.src_emb) @@ -322,8 +304,8 @@ def init_fixing(module): optimizer = optim.Adam(mymodel.parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) optimizer.zero_grad() -if use_amp: - mymodel, optimizer = amp.initialize(mymodel, optimizer, opt_level=cnfg.amp_opt) +use_amp = cnfg.use_amp and use_cuda +scaler = GradScaler() if use_amp else None if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) @@ -341,7 +323,7 @@ def init_fixing(module): tminerr = inf_default -minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: @@ -351,8 +333,8 @@ def init_fixing(module): cnt_states = cnfg.train_statesf if (cnt_states is not None) and p_check(cnt_states): logger.info("Continue last epoch") - tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) if save_optm_state: @@ -378,8 +360,8 @@ def init_fixing(module): for i in range(1, maxrun + 1): shuffle(tl) free_cache(use_cuda) - terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): @@ -409,7 +391,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) done_tokens = 0 logger.info("early stop") break @@ -430,7 +412,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) save_model(mymodel, wkdir + "last.h5", multi_gpu, logger) if save_optm_state: diff --git a/tools/check/para.py b/tools/check/para.py new file mode 100644 index 0000000..d6af351 --- /dev/null +++ b/tools/check/para.py @@ -0,0 +1,30 @@ +#encoding: utf-8 + +''' usage: + python tools/check/para.py $model_file.h5 +''' + +import sys + +import h5py + +def handle_group(srcg): + + rs = 0 + for k, v in srcg.items(): + if isinstance(v, h5py.Dataset): + rs += v[:].size + else: + rs += handle_group(v) + + return rs + +def handle(srcf): + + sfg = h5py.File(srcf, "r") + rs = handle_group(sfg) + sfg.close() + print(rs) + +if __name__ == "__main__": + handle(sys.argv[1]) diff --git a/tools/check/vsize/cnfg b/tools/check/vsize/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/check/vsize/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/check/vsize/copy.py b/tools/check/vsize/copy.py new file mode 100644 index 0000000..022b26c --- /dev/null +++ b/tools/check/vsize/copy.py @@ -0,0 +1,24 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_list + +def handle(srcfl, tgtfl): + + nsrc = ntgt = ncopy = 0 + for srcf, tgtf in zip(srcfl, tgtfl): + with open(srcf, "rb") as fsrc, open(tgtf, "rb") as ftgt: + for srcl, tgtl in zip(fsrc, ftgt): + srcl, tgtl = srcl.strip(), tgtl.strip() + if srcl or tgtl: + srcvcb, tgtvcb = clean_list(srcl.decode("utf-8").split()), clean_list(tgtl.decode("utf-8").split()) + nsrc += len(srcvcb) + ntgt += len(tgtvcb) + ncopy += len(set(srcvcb)&set(tgtvcb)) + + print("src, tgt, copy: %d, %d, %d" % (nsrc, ntgt, ncopy,)) + +if __name__ == "__main__": + sep_index = (len(sys.argv) + 1) // 2 + handle(sys.argv[1:sep_index], sys.argv[sep_index:]) diff --git a/tools/check/vsize/detail.py b/tools/check/vsize/detail.py new file mode 100644 index 0000000..f1ac174 --- /dev/null +++ b/tools/check/vsize/detail.py @@ -0,0 +1,29 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_list_iter + +def collect(fl): + + vcb = set() + for srcf in fl: + with open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + for token in clean_list_iter(tmp.decode("utf-8").split()): + if not token in vcb: + vcb.add(token) + + return vcb + +def handle(srcfl, tgtfl): + + src_vcb, tgt_vcb = collect(srcfl), collect(tgtfl) + + print("src/tgt vcb: %d, %d, shared token: %d" % (len(src_vcb), len(tgt_vcb), len(src_vcb&tgt_vcb),)) + +if __name__ == "__main__": + sep_index = (len(sys.argv) + 1) // 2 + handle(sys.argv[1:sep_index], sys.argv[sep_index:]) diff --git a/tools/check/vsize.py b/tools/check/vsize/mono.py similarity index 100% rename from tools/check/vsize.py rename to tools/check/vsize/mono.py diff --git a/tools/check/vsize/utils b/tools/check/vsize/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/check/vsize/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/clean/ape/chars.py b/tools/clean/ape/chars.py new file mode 100644 index 0000000..d99e42d --- /dev/null +++ b/tools/clean/ape/chars.py @@ -0,0 +1,78 @@ +#encoding: utf-8 + +import sys + +# cratio: number of "@@" ended tokens / number of all tokens +# bratio: number of bpe tokens / number of tokens before bpe processing +# sratio: number of tokens seperated by bpe / number of tokens before bpe processing +# pratio: max(source length, target length) / min(source length, target length) length after bpe processing +# oratio: same as pratio but before bpe processing +# num_rules_drop: choose from [1, 6], fewer data will be droped with larger value, none data would be droped if it was set to 6 + +def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, cratio=0.8, bratio=5.0, sratio=0.8, pratio=3.0, oratio=3.0, num_rules_drop=1): + + def legal_mono(strin, cratio, bratio, sratio): + ntokens = nchars = nsp = nrule = 0 + pbpe = False + for tmpu in strin.split(): + if tmpu: + if tmpu.endswith("@@"): + nchars += 1 + if not pbpe: + pbpe = True + nsp += 1 + elif pbpe: + pbpe = False + ntokens += 1 + ntokens = float(ntokens) + lorigin = float(len(strin.replace("@@ ", "").split())) + if float(nchars) / ntokens > cratio: + nrule += 1 + if ntokens / lorigin > bratio: + nrule += 1 + if float(nsp) / lorigin > sratio: + nrule += 1 + return nrule, ntokens, lorigin + + def legal(strins, strint, cratio, bratio, sratio, pratio, oratio, num_rules_drop): + + def ratio_bilingual(ls, lt): + if ls > lt: + return ls / lt + else: + return lt / ls + + ls, lens, lenso = legal_mono(strins, cratio, bratio, sratio) + lt, lent, lento = legal_mono(strint, cratio, bratio, sratio) + nrule = max(ls, lt) + if ratio_bilingual(lens, lent) > pratio: + nrule += 1 + if ratio_bilingual(lenso, lento) > oratio: + nrule += 1 + if nrule < num_rules_drop: + return True + else: + return False + + ens = "\n".encode("utf-8") + + with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtfm, "wb") as fmw, open(tgtft, "wb") as ftw: + total = keep = 0 + if num_rules_drop > 0: + for ls, lm, lt in zip(fs, fm, ft): + ls, lm, lt = ls.strip(), lm.strip(), lt.strip() + if ls and lm and lt: + ls, lm, lt = ls.decode("utf-8"), lm.decode("utf-8"), lt.decode("utf-8") + if (num_rules_drop > 5) or legal(ls, lt, cratio, bratio, sratio, pratio, oratio, num_rules_drop): + fsw.write(ls.encode("utf-8")) + fsw.write(ens) + fmw.write(lm.encode("utf-8")) + fmw.write(ens) + ftw.write(lt.encode("utf-8")) + ftw.write(ens) + keep += 1 + total += 1 + print("%d in %d data keeped with ratio %.2f" % (keep, total, float(keep) / float(total) * 100.0 if total > 0 else 0.0)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], float(sys.argv[7]), float(sys.argv[8]), float(sys.argv[9]), float(sys.argv[10]), float(sys.argv[11]), int(sys.argv[12])) diff --git a/tools/clean/ape/cnfg b/tools/clean/ape/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/clean/ape/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/clean/ape/maxkeeper.py b/tools/clean/ape/maxkeeper.py new file mode 100644 index 0000000..5c3a4a8 --- /dev/null +++ b/tools/clean/ape/maxkeeper.py @@ -0,0 +1,79 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_liststr_lentok + +def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, max_len=256): + + _max_len = max(1, max_len - 2) + + data = {} + + with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft: + for ls, lm, lt in zip(fs, fm, ft): + ls, lm, lt = ls.strip(), lm.strip(), lt.strip() + if ls and lt: + ls, slen = clean_liststr_lentok(ls.decode("utf-8").split()) + lm, mlen = clean_liststr_lentok(lm.decode("utf-8").split()) + lt, tlen = clean_liststr_lentok(lt.decode("utf-8").split()) + if (slen <= _max_len) and (mlen <= _max_len) and (tlen <= _max_len): + if ls in data: + data[ls][(lm, lt,)] = data[ls].get((lm, lt,), 0) + 1 + else: + data[ls] = {(lm, lt,): 1} + + _clean = {} + for ls, v in data.items(): + if len(v) > 1: + rlt = [] + _maxf = 0 + for key, value in v.items(): + if value > _maxf: + _maxf = value + rlt = [key] + elif value == _maxf: + rlt.append(key) + for lt in rlt: + if lt in _clean: + _clean[lt][ls] = _clean[lt].get(ls, 0) + 1 + else: + _clean[lt] = {ls: 1} + else: + lt = list(v.keys())[0] + if lt in _clean: + _clean[lt][ls] = _clean[lt].get(ls, 0) + 1 + else: + _clean[lt] = {ls: 1} + + data = _clean + + ens = "\n".encode("utf-8") + + with open(tgtfs, "wb") as fs, open(tgtfm, "wb") as fm, open(tgtft, "wb") as ft: + for (lm, lt,), v in data.items(): + if len(v) > 1: + rls = [] + _maxf = 0 + for key, value in v.items(): + if value > _maxf: + _maxf = value + rls = [key] + elif value == _maxf: + rls.append(key) + rlm = "\n".join([lm for i in range(len(rls))]) + rlt = "\n".join([lt for i in range(len(rls))]) + rls = "\n".join(rls) + else: + rlm = lm + rlt = lt + rls = list(v.keys())[0] + fs.write(rls.encode("utf-8")) + fs.write(ens) + fm.write(rlm.encode("utf-8")) + fm.write(ens) + ft.write(rlt.encode("utf-8")) + ft.write(ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], int(sys.argv[7])) diff --git a/tools/clean/ape/utils b/tools/clean/ape/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/clean/ape/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/clean/ape/vocab.py b/tools/clean/ape/vocab.py new file mode 100644 index 0000000..b1992d9 --- /dev/null +++ b/tools/clean/ape/vocab.py @@ -0,0 +1,39 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import ldvocab_list, legal_vocab + +# vratio: percentages of vocabulary size of retrieved words of least frequencies +# dratio: a datum will be dropped who contains high frequency words less than this ratio + +def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, vcbfs, vcbft, vratio, dratio=None): + + _dratio = vratio if dratio is None else dratio + + ens = "\n".encode("utf-8") + + vcbs, nvs = ldvocab_list(vcbfs) + vcbt, nvt = ldvocab_list(vcbft) + ilgs = set(vcbs[int(float(nvs) * (1.0 - vratio)):]) + ilgt = set(vcbt[int(float(nvt) * (1.0 - vratio)):]) + + with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtfm, "wb") as fmw, open(tgtft, "wb") as ftw: + total = keep = 0 + for ls, lm, lt in zip(fs, fm, ft): + ls, lm, lt = ls.strip(), lm.strip(), lt.strip() + if ls and lm and lt: + ls, lm, lt = ls.decode("utf-8"), lm.decode("utf-8"), lt.decode("utf-8") + if legal_vocab(ls, ilgs, _dratio) and legal_vocab(lt, ilgt, _dratio): + fsw.write(ls.encode("utf-8")) + fsw.write(ens) + fmw.write(lm.encode("utf-8")) + fmw.write(ens) + ftw.write(lt.encode("utf-8")) + ftw.write(ens) + keep += 1 + total += 1 + print("%d in %d data keeped with ratio %.2f" % (keep, total, float(keep) / float(total) * 100.0 if total > 0 else 0.0)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6], sys.argv[7], sys.argv[8], float(sys.argv[9])) diff --git a/train.py b/train.py index 2f2cc44..51c0608 100644 --- a/train.py +++ b/train.py @@ -3,6 +3,7 @@ import sys import torch +from torch.cuda.amp import autocast, GradScaler #from torch import nn from torch import optim @@ -13,7 +14,7 @@ from utils.base import * from utils.init import init_model_params from utils.h5serial import h5save, h5load -from utils.fmt.base import tostr, save_states, load_states +from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb from lrsch import GoogleLR @@ -33,19 +34,13 @@ from transformer.NMT import NMT -def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, use_amp=False): +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, chkpof=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): - sum_loss = 0.0 - sum_wd = 0 - part_loss = 0.0 - part_wd = 0 - _done_tokens = done_tokens + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp, ndata = done_tokens, cur_checkid, remain_steps, scaler is not None, len(tl) model.train() - cur_b = 1 - ndata = len(tl) - _cur_checkid = cur_checkid - _cur_rstep = remain_steps - _ls = {} if save_loss else None + cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl): seq_batch = torch.from_numpy(src_grp[i_d][:]).long() @@ -57,21 +52,21 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, oi) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() + with autocast(enabled=_use_amp): + output = model(seq_batch, oi) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() loss_add = loss.data.item() # scale the sum of losses down according to the number of tokens adviced by: https://mp.weixin.qq.com/s/qAHZ4L5qK3rongCIIq5hQw, I think not reasonable. #loss /= wd_add - if use_amp: - with amp.scale_loss(loss, optm) as scaled_loss: - scaled_loss.backward() - else: + if scaler is None: loss.backward() + else: + scaler.scale(loss).backward() - wd_add = ot.ne(0).int().sum().item() + wd_add = ot.ne(pad_id).int().sum().item() loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: @@ -82,12 +77,10 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if _done_tokens >= tokens_optm: if multi_gpu: model.collect_gradients() - optm.step() - optm.zero_grad() + optm_step(optm, scaler) + optm.zero_grad() + if multi_gpu: model.update_replicas() - else: - optm.step() - optm.zero_grad() _done_tokens = 0 if _cur_rstep is not None: if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): @@ -115,7 +108,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_wd += wd_add if cur_b % nreport == 0: if report_eva: - _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu) + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) free_cache(mv_device) model.train() @@ -145,9 +138,8 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd)) return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls -def eva(ed, nd, model, lossf, mv_device, multi_gpu): - r = 0 - w = 0 +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] @@ -161,15 +153,16 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu): seq_batch = seq_batch.to(mv_device) seq_o = seq_o.to(mv_device) ot = seq_o.narrow(1, 1, lo).contiguous() - output = model(seq_batch, seq_o.narrow(1, 0, lo)) - loss = lossf(output, ot) - if multi_gpu: - loss = loss.sum() - trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) - else: - trans = output.argmax(-1) + with autocast(enabled=use_amp): + output = model(seq_batch, seq_o.narrow(1, 0, lo)) + loss = lossf(output, ot) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + else: + trans = output.argmax(-1) sum_loss += loss.data.item() - data_mask = ot.ne(0) + data_mask = ot.ne(pad_id) correct = (trans.eq(ot) & data_mask).int() w += data_mask.int().sum().item() r += correct.sum().item() @@ -230,16 +223,6 @@ def init_fixing(module): use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) -if use_cuda and cnfg.amp_opt: - try: - from apex import amp - use_amp = True - except Exception as e: - logger.info(str(e)) - use_amp = False -else: - use_amp = False - set_random_seed(cnfg.seed, use_cuda) td = h5py.File(cnfg.train_data, "r") @@ -266,7 +249,7 @@ def init_fixing(module): #lw = torch.ones(nwordt).float() #lw[0] = 0.0 #lossf = nn.NLLLoss(lw, ignore_index=0, reduction='sum') -lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=0, reduction='sum', forbidden_index=cnfg.forbidden_indexes) +lossf = LabelSmoothingLoss(nwordt, cnfg.label_smoothing, ignore_index=pad_id, reduction='sum', forbidden_index=cnfg.forbidden_indexes) if cnfg.src_emb is not None: logger.info("Load source embedding from: " + cnfg.src_emb) @@ -283,8 +266,8 @@ def init_fixing(module): optimizer = optim.Adam(mymodel.parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) optimizer.zero_grad() -if use_amp: - mymodel, optimizer = amp.initialize(mymodel, optimizer, opt_level=cnfg.amp_opt) +use_amp = cnfg.use_amp and use_cuda +scaler = GradScaler() if use_amp else None if multi_gpu: #mymodel = nn.DataParallel(mymodel, device_ids=cuda_devices, output_device=cuda_device.index) @@ -304,7 +287,7 @@ def init_fixing(module): tminerr = inf_default -minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: @@ -314,8 +297,8 @@ def init_fixing(module): cnt_states = cnfg.train_statesf if (cnt_states is not None) and p_check(cnt_states): logger.info("Continue last epoch") - tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, load_states(cnt_states), vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec)) save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec), multi_gpu, logger) if save_optm_state: @@ -341,8 +324,8 @@ def init_fixing(module): for i in range(1, maxrun + 1): shuffle(tl) free_cache(use_cuda) - terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, use_amp) - vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, tokens_optm, batch_report, save_every, chkpf, chkpof, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec)) if (vprec <= minerr) or (vloss <= minloss): @@ -372,7 +355,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) #lrsch.step() done_tokens = 0 #optimizer.zero_grad() @@ -402,7 +385,7 @@ def init_fixing(module): if done_tokens > 0: if multi_gpu: mymodel.collect_gradients() - optimizer.step() + optm_step(optimizer, scaler) #lrsch.step() #done_tokens = 0 #optimizer.zero_grad() diff --git a/transformer/APE/Decoder.py b/transformer/APE/Decoder.py new file mode 100644 index 0000000..0faf0f2 --- /dev/null +++ b/transformer/APE/Decoder.py @@ -0,0 +1,297 @@ +#encoding: utf-8 + +import torch +from torch import nn +from modules.base import CrossAttn + +from transformer.Decoder import DecoderLayer as DecoderLayerBase +from transformer.Decoder import Decoder as DecoderBase + +from utils.base import all_done, repeat_bsize_for_beam_tensor, mask_tensor_type +from math import sqrt + +from utils.fmt.base import pad_id + +from cnfg.ihyp import * + +class DecoderLayer(DecoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + + self.cross_attn_mt = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) + self.layer_normer3 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, tgt_pad_mask=None, query_unit=None): + + if query_unit is None: + _inputo = self.layer_normer1(inputo) + + states_return = None + + context = self.self_attn(_inputo, mask=tgt_pad_mask) + + if self.drop is not None: + context = self.drop(context) + + context = context + (_inputo if self.norm_residual else inputo) + + else: + _query_unit = self.layer_normer1(query_unit) + + _inputo = _query_unit if inputo is None else torch.cat((inputo, _query_unit,), 1) + + states_return = _inputo + + context = self.self_attn(_query_unit, iK=_inputo) + + if self.drop is not None: + context = self.drop(context) + + context = context + (_query_unit if self.norm_residual else query_unit) + + _context = self.layer_normer2(context) + _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) + + if self.drop is not None: + _context_new = self.drop(_context_new) + + context = _context_new + (_context if self.norm_residual else context) + + _context = self.layer_normer3(context) + _context_new = self.cross_attn_mt(_context, inputm, mask=mt_pad_mask) + + if self.drop is not None: + _context_new = self.drop(_context_new) + + context = _context_new + (_context if self.norm_residual else context) + + context = self.ff(context) + + if states_return is None: + return context + else: + return context, states_return + +class Decoder(DecoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs) + + if share_layer: + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None): + + nquery = inputo.size(-1) + + out = self.wemb(inputo) + + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputo, expand=False) + + if self.drop is not None: + out = self.drop(out) + + _mask = self._get_subsequent_mask(nquery) + + for net in self.nets: + out = net(inpute, inputm, out, src_pad_mask, mt_pad_mask, _mask) + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)) + + return out + + def decode(self, inpute, inputm, src_pad_mask, mt_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + + return self.beam_decode(inpute, inputm, src_pad_mask, mt_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputm, src_pad_mask, mt_pad_mask, max_len, fill_pad=fill_pad) + + def greedy_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, max_len=512, fill_pad=False, sample=False): + + bsize = inpute.size(0) + + sos_emb = self.get_sos_emb(inpute) + + sqrt_isize = sqrt(sos_emb.size(-1)) + + out = sos_emb * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, inputm, None, src_pad_mask, mt_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans = [wds] + + done_trans = wds.eq(2) + + for i in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(i) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, inputm, states[_tmp], src_pad_mask, mt_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) + + done_trans = done_trans | wds.eq(2) + if all_done(done_trans, bsize): + break + + return torch.cat(trans, 1) + + def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=False, fill_pad=False): + + bsize, seql = inpute.size()[:2] + mtl = inputm.size(1) + + beam_size2 = beam_size * beam_size + bsizeb2 = bsize * beam_size2 + real_bsize = bsize * beam_size + + sos_emb = self.get_sos_emb(inpute) + isize = sos_emb.size(-1) + sqrt_isize = sqrt(isize) + + if length_penalty > 0.0: + lpv = sos_emb.new_ones(real_bsize, 1) + lpv_base = 6.0 ** length_penalty + + out = sos_emb * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(0) + + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, inputm, None, src_pad_mask, mt_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)) + + scores, wds = out.topk(beam_size, dim=-1) + scores = scores.squeeze(1) + sum_scores = scores + wds = wds.view(real_bsize, 1) + trans = wds + + done_trans = wds.view(bsize, beam_size).eq(2) + + inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) + inputm = inputm.repeat(1, beam_size, 1).view(real_bsize, mtl, isize) + + _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) + _mt_pad_mask = None if mt_pad_mask is None else mt_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, mtl) + + for key, value in states.items(): + states[key] = repeat_bsize_for_beam_tensor(value, beam_size) + + for step in range(1, max_len): + + out = self.wemb(wds) * sqrt_isize + if self.pemb is not None: + out = out + self.pemb.get_pos(step) + + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, inputm, states[_tmp], _src_pad_mask, _mt_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) + + _scores, _wds = out.topk(beam_size, dim=-1) + _scores = (_scores.masked_fill(done_trans.unsqueeze(2).expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).expand(bsize, beam_size, beam_size)) + + if length_penalty > 0.0: + lpv = lpv.masked_fill(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base) + + if clip_beam and (length_penalty > 0.0): + scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size) + else: + scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + torch.arange(0, bsizeb2, beam_size2, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + sum_scores = scores + + wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1) + + _inds = (_inds // beam_size + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + + trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), 0) if fill_pad else wds), 1) + + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + + _done = False + if length_penalty > 0.0: + lpv = lpv.index_select(0, _inds) + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): + _done = True + + if _done or all_done(done_trans, real_bsize): + break + + for key, value in states.items(): + states[key] = value.index_select(0, _inds) + + if (not clip_beam) and (length_penalty > 0.0): + scores = scores / lpv.view(bsize, beam_size) + scores, _inds = scores.topk(beam_size, dim=-1) + _inds = (_inds + torch.arange(0, real_bsize, beam_size, dtype=_inds.dtype, device=_inds.device).unsqueeze(1).expand_as(_inds)).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds).view(bsize, beam_size, -1) + + if return_all: + + return trans, scores + else: + + return trans.view(bsize, beam_size, -1).select(1, 0) diff --git a/transformer/APE/Encoder.py b/transformer/APE/Encoder.py new file mode 100644 index 0000000..129713c --- /dev/null +++ b/transformer/APE/Encoder.py @@ -0,0 +1,103 @@ +#encoding: utf-8 + +from torch import nn +from modules.base import Dropout, PositionalEmb + +from transformer.Encoder import Encoder as EncoderBase +from transformer.Decoder import DecoderLayer as MSEncoderLayerBase + +from utils.fmt.base import parse_double_value_tuple + +from math import sqrt + +from cnfg.ihyp import * + +class MSEncoderLayer(MSEncoderLayerBase): + + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): + + _inputo = self.layer_normer1(inputo) + + context = self.self_attn(_inputo, mask=tgt_pad_mask) + + if self.drop is not None: + context = self.drop(context) + + context = context + (_inputo if self.norm_residual else inputo) + + _context = self.layer_normer2(context) + _context_new = self.cross_attn(_context, inpute, mask=src_pad_mask) + + if self.drop is not None: + _context_new = self.drop(_context_new) + + context = _context_new + (_context if self.norm_residual else context) + + context = self.ff(context) + + return context + +class MSEncoder(nn.Module): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, emb_w=None, share_layer=False, disable_pemb=disable_std_pemb_encoder): + + super(MSEncoder, self).__init__() + + _ahsize = isize if ahsize is None else ahsize + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None + + self.wemb = nn.Embedding(nwd, isize, padding_idx=0) + if emb_w is not None: + self.wemb.weight = emb_w + + self.pemb = None if disable_pemb else PositionalEmb(isize, xseql, 0, 0) + if share_layer: + _shared_layer = MSEncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([MSEncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + + self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None + + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): + + nquery = inputo.size(-1) + + out = self.wemb(inputo) + + out = out * sqrt(out.size(-1)) + if self.pemb is not None: + out = out + self.pemb(inputo, expand=False) + + if self.drop is not None: + out = self.drop(out) + + for net in self.nets: + out = net(inpute, out, src_pad_mask=src_pad_mask, tgt_pad_mask=tgt_pad_mask) + + if self.out_normer is not None: + out = self.out_normer(out) + + return out + +class Encoder(nn.Module): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, global_emb=False, **kwargs): + + super(Encoder, self).__init__() + + nwd_src, nwd_tgt = parse_double_value_tuple(nwd) + + self.src_enc = EncoderBase(isize, nwd_src, num_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, **kwargs) + + emb_w = self.src_enc.wemb.weight if global_emb else None + + self.tgt_enc = MSEncoder(isize, nwd_tgt, num_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, emb_w, **kwargs) + + def forward(self, inpute, inputo, src_mask=None, tgt_mask=None): + + enc_src = self.src_enc(inpute, src_mask) + + return enc_src, self.tgt_enc(enc_src, inputo, src_mask, tgt_mask) diff --git a/transformer/APE/NMT.py b/transformer/APE/NMT.py new file mode 100644 index 0000000..31797d5 --- /dev/null +++ b/transformer/APE/NMT.py @@ -0,0 +1,48 @@ +#encoding: utf-8 + +from torch import nn + +from utils.relpos import share_rel_pos_cache +from utils.fmt.base import parse_double_value_tuple + +from transformer.APE.Encoder import Encoder +from transformer.APE.Decoder import Decoder + +from cnfg.ihyp import * + +class NMT(nn.Module): + + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=False, forbidden_index=None): + + super(NMT, self).__init__() + + enc_layer, dec_layer = parse_double_value_tuple(num_layer) + + self.enc = Encoder(isize, (snwd, tnwd,), enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, global_emb=global_emb) + + emb_w = self.enc.tgt_enc.wemb.weight if global_emb else None + + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index) + + if rel_pos_enabled: + share_rel_pos_cache(self) + + def forward(self, inpute, inputm, inputo, src_mask=None, mt_mask=None): + + _src_mask = inpute.eq(0).unsqueeze(1) if src_mask is None else src_mask + _mt_mask = inputm.eq(0).unsqueeze(1) if mt_mask is None else mt_mask + + enc_src, enc_mt = self.enc(inpute, inputm, _src_mask, _mt_mask) + + return self.dec(enc_src, enc_mt, inputo, _src_mask, _mt_mask) + + def decode(self, inpute, inputm, beam_size=1, max_len=None, length_penalty=0.0): + + src_mask = inpute.eq(0).unsqueeze(1) + mt_mask = inputm.eq(0).unsqueeze(1) + + _max_len = inpute.size(1) + max(64, inpute.size(1) // 4) if max_len is None else max_len + + enc_src, enc_mt = self.enc(inpute, inputm, src_mask, mt_mask) + + return self.dec.decode(enc_src, enc_mt, src_mask, mt_mask, beam_size, _max_len, length_penalty) diff --git a/transformer/APE/__init__.py b/transformer/APE/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/APE/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index e75a73e..bc5db2f 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -3,7 +3,9 @@ import torch from torch import nn from modules.base import * -from utils.base import repeat_bsize_for_beam_tensor +from utils.sampler import SampleMax +from utils.base import all_done, repeat_bsize_for_beam_tensor +from utils.aan import share_aan_cache from math import sqrt from transformer.Decoder import DecoderLayer as DecoderLayerBase @@ -106,6 +108,8 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + share_aan_cache(self) + # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: @@ -157,9 +161,9 @@ def load_base(self, base_decoder): # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql = inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -184,12 +188,9 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): out = self.out_normer(out) # out: (bsize, 1, nwd) - - out = self.lsm(self.classifier(out)) - + out = self.classifier(out) # wds: (bsize, 1) - - wds = out.argmax(dim=-1) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -214,13 +215,14 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): out = self.out_normer(out) # out: (bsize, 1, nwd) - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + # wds: (bsize, 1) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -367,12 +369,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/transformer/Decoder.py b/transformer/Decoder.py index e4aae94..8a00fca 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -3,7 +3,8 @@ import torch from torch import nn from modules.base import * -from utils.base import repeat_bsize_for_beam_tensor, mask_tensor_type +from utils.sampler import SampleMax +from utils.base import all_done, repeat_bsize_for_beam_tensor, mask_tensor_type from math import sqrt from utils.fmt.base import pad_id @@ -208,10 +209,11 @@ def decode(self, inpute, src_pad_mask, beam_size=1, max_len=512, length_penalty= # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate + # sample: for back translation - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql = inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -236,12 +238,10 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): out = self.out_normer(out) # out: (bsize, 1, nwd) - - out = self.lsm(self.classifier(out)) - + # omit self.lsm for efficiency + out = self.classifier(out) # wds: (bsize, 1) - - wds = out.argmax(dim=-1) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -265,14 +265,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): if self.out_normer is not None: out = self.out_normer(out) - # out: (bsize, 1, nwd) - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -419,12 +418,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/transformer/Doc/Para/Base/Decoder.py b/transformer/Doc/Para/Base/Decoder.py index d9bd3c6..d4c1cc0 100644 --- a/transformer/Doc/Para/Base/Decoder.py +++ b/transformer/Doc/Para/Base/Decoder.py @@ -3,8 +3,9 @@ import torch from torch import nn from modules.base import * +from utils.sampler import SampleMax from modules.paradoc import GateResidual -from utils.base import repeat_bsize_for_beam_tensor +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt from transformer.Decoder import DecoderLayer as DecoderLayerBase @@ -143,9 +144,9 @@ def decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam_size return self.beam_decode(inpute, inputc, src_pad_mask, context_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputc, src_pad_mask, context_mask, max_len, fill_pad=fill_pad) - def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql = inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -167,9 +168,8 @@ def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, ma if self.out_normer is not None: out = self.out_normer(out) - out = self.lsm(self.classifier(out)) - - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -191,13 +191,13 @@ def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, ma if self.out_normer is not None: out = self.out_normer(out) - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -300,10 +300,10 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break for key, value in states.items(): diff --git a/transformer/EnsembleAvgDecoder.py b/transformer/EnsembleAvgDecoder.py index b907c72..1f7d842 100644 --- a/transformer/EnsembleAvgDecoder.py +++ b/transformer/EnsembleAvgDecoder.py @@ -1,7 +1,8 @@ #encoding: utf-8 import torch -from utils.base import repeat_bsize_for_beam_tensor +from utils.sampler import SampleMax +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt from transformer.EnsembleDecoder import Decoder as DecoderBase @@ -47,7 +48,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): bsize, seql, isize = inpute[0].size() @@ -81,11 +82,8 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): outs.append(model.classifier(out).softmax(dim=-1)) - out = torch.stack(outs).mean(0).log() - - # wds: (bsize, 1) - - wds = out.argmax(dim=-1) + out = torch.stack(outs).mean(0) + wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -116,14 +114,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): # outs: [(bsize, 1, nwd)...] outs.append(model.classifier(out).softmax(dim=-1)) - out = torch.stack(outs).mean(0).log() - - wds = out.argmax(dim=-1) + out = torch.stack(outs).mean(0) + wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -283,12 +280,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/transformer/EnsembleDecoder.py b/transformer/EnsembleDecoder.py index f5a0c38..1d7f03c 100644 --- a/transformer/EnsembleDecoder.py +++ b/transformer/EnsembleDecoder.py @@ -2,7 +2,8 @@ import torch from torch import nn -from utils.base import repeat_bsize_for_beam_tensor +from utils.sampler import SampleMax +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt class Decoder(nn.Module): @@ -68,7 +69,7 @@ def decode(self, inpute, src_pad_mask, beam_size=1, max_len=512, length_penalty= # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): bsize, seql, isize = inpute[0].size() @@ -102,11 +103,8 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): outs.append(model.classifier(out).softmax(dim=-1)) - out = torch.stack(outs).mean(0).log() - - # wds: (bsize, 1) - - wds = out.argmax(dim=-1) + out = torch.stack(outs).mean(0) + wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -137,14 +135,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): # outs: [(bsize, 1, nwd)...] outs.append(model.classifier(out).softmax(dim=-1)) - out = torch.stack(outs).mean(0).log() - - wds = out.argmax(dim=-1) + out = torch.stack(outs).mean(0) + wds = SampleMax(out, dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -302,12 +299,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/transformer/EnsembleNMT.py b/transformer/EnsembleNMT.py index e058620..8a8b2f4 100644 --- a/transformer/EnsembleNMT.py +++ b/transformer/EnsembleNMT.py @@ -54,7 +54,7 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): ence = self.enc(inpute, mask) - bsize, _ = inpute.size() + bsize = inpute.size(0) # out: input to the decoder for the first step (bsize, 1) diff --git a/transformer/LD/Decoder.py b/transformer/LD/Decoder.py index 9c727c0..fc495aa 100644 --- a/transformer/LD/Decoder.py +++ b/transformer/LD/Decoder.py @@ -4,9 +4,10 @@ from torch import nn from modules.base import CrossAttn, ResidueCombiner +from utils.sampler import SampleMax from modules.TA import PositionwiseFF -from utils.base import repeat_bsize_for_beam_tensor +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt from transformer.Decoder import DecoderLayer as DecoderLayerBase @@ -111,9 +112,9 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None): return out - def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql= inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -134,9 +135,8 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma out, _state = net(inputu, inputhu, None, src_pad_mask, chk_pad_mask, None, out, True) states[_tmp] = _state - out = self.lsm(self.classifier(out)) - - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -157,13 +157,13 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, chk_pad_mask, None, out, True) states[_tmp] = _state - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -261,10 +261,10 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break for key, value in states.items(): diff --git a/transformer/NMT.py b/transformer/NMT.py index ee31d7c..212183c 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -90,7 +90,7 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): ence = self.enc(inpute, mask) - bsize, _ = inpute.size() + bsize = inpute.size(0) # out: input to the decoder for the first step (bsize, 1) diff --git a/transformer/README.md b/transformer/README.md index eada07d..68a0171 100644 --- a/transformer/README.md +++ b/transformer/README.md @@ -51,3 +51,7 @@ Implementation of sentential context proposed in [Exploiting Sentential Context ## `Doc/` Implementation of context-aware Transformer proposed in [Improving the Transformer Translation Model with Document-Level Context](https://www.aclweb.org/anthology/D18-1049/). + +## `APE/` + +Implementation of an APE model. diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index 4509cfc..a0c3b85 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -6,6 +6,7 @@ from torch import nn from modules.base import * +from utils.sampler import SampleMax from modules.rnncells import * from utils.fmt.base import pad_id @@ -26,6 +27,8 @@ def __init__(self, isize, osize=None): self.init_hx = nn.Parameter(torch.zeros(1, osize)) self.init_cx = nn.Parameter(torch.zeros(1, osize)) + self.drop = Dropout(dropout, inplace=False) if dropout > 0.0 else None + # inputo: embedding of decoded translation (bsize, nquery, isize) # query_unit: single query to decode, used to support decoding for given step @@ -35,15 +38,22 @@ def forward(self, inputo, state=None, first_step=False): hx, cx = prepare_initState(self.init_hx, self.init_cx, inputo.size(0)) outs = [] - for i in range(inputo.size(1)): - hx, cx = self.net(inputo.select(1, i), (hx, cx)) + for _du in inputo.unbind(1): + hx, cx = self.net(_du, (hx, cx)) outs.append(hx) - return torch.stack(outs, 1) + outs = torch.stack(outs, 1) + + if self.drop is not None: + outs = self.drop(outs) + + return outs else: hx, cx = self.net(inputo, prepare_initState(self.init_hx, self.init_cx, inputo.size(0)) if first_step else state) - return hx, (hx, cx) + out = hx if self.drop is None else self.drop(hx) + + return out, (hx, cx) class DecoderLayer(nn.Module): @@ -74,8 +84,8 @@ def forward(self, inputo, attn, state=None, first_step=False): _inputo = torch.cat((inputo, attn), -1) - for i in range(_inputo.size(1)): - hx, cx = self.net(_inputo.select(1, i), (hx, cx)) + for _du in _inputo.unbind(1): + hx, cx = self.net(_du, (hx, cx)) outs.append(hx) outs = torch.stack(outs, 1) @@ -185,9 +195,9 @@ def decode(self, inpute, src_pad_mask, beam_size=1, max_len=512, length_penalty= # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql = inpute.size()[:2] + bsize = inpute.size(0) out = self.get_sos_emb(inpute) @@ -213,12 +223,9 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): out = self.out_normer(out) # out: (bsize, nwd) - - out = self.lsm(self.classifier(torch.cat((out, attn), -1))) - + out = self.classifier(torch.cat((out, attn), -1)) # wds: (bsize) - - wds = out.argmax(dim=-1) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -244,14 +251,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): if self.out_normer is not None: out = self.out_normer(out) - # out: (bsize, nwd) - out = self.lsm(self.classifier(torch.cat((out, attn), -1))) - wds = out.argmax(dim=-1) + out = self.classifier(torch.cat((out, attn), -1)) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.stack(trans, 1) @@ -404,12 +410,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/transformer/SC/Decoder.py b/transformer/SC/Decoder.py index fcdce05..022c79b 100644 --- a/transformer/SC/Decoder.py +++ b/transformer/SC/Decoder.py @@ -4,9 +4,10 @@ from torch import nn from modules.base import ResidueCombiner +from utils.sampler import SampleMax from modules.TA import PositionwiseFF -from utils.base import repeat_bsize_for_beam_tensor +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt from transformer.Decoder import DecoderLayer as DecoderLayerBase @@ -110,9 +111,9 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None): return out - def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql= inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -133,9 +134,8 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad out, _state = net(inputu, inputhu, None, src_pad_mask, None, out, True) states[_tmp] = _state - out = self.lsm(self.classifier(out)) - - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -156,13 +156,13 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, None, out, True) states[_tmp] = _state - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -259,10 +259,10 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break for key, value in states.items(): diff --git a/transformer/TA/Decoder.py b/transformer/TA/Decoder.py index 970052c..28109e4 100644 --- a/transformer/TA/Decoder.py +++ b/transformer/TA/Decoder.py @@ -2,7 +2,8 @@ import torch from modules.base import * -from utils.base import repeat_bsize_for_beam_tensor +from utils.sampler import SampleMax +from utils.base import all_done, repeat_bsize_for_beam_tensor from math import sqrt from transformer.Decoder import Decoder as DecoderBase @@ -65,9 +66,9 @@ def forward(self, inpute, inputo, src_pad_mask=None): # src_pad_mask = input.eq(0).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): - bsize, seql= inpute.size()[:2] + bsize = inpute.size(0) sos_emb = self.get_sos_emb(inpute) @@ -91,13 +92,8 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): if self.out_normer is not None: out = self.out_normer(out) - # out: (bsize, 1, nwd) - - out = self.lsm(self.classifier(out)) - - # wds: (bsize, 1) - - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] @@ -122,13 +118,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False): out = self.out_normer(out) # out: (bsize, 1, nwd) - out = self.lsm(self.classifier(out)) - wds = out.argmax(dim=-1) + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans.append(wds.masked_fill(done_trans, 0) if fill_pad else wds) done_trans = done_trans | wds.eq(2) - if done_trans.int().sum().item() == bsize: + if all_done(done_trans, bsize): break return torch.cat(trans, 1) @@ -275,12 +271,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).int().sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.int().sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break # update the corresponding hidden states diff --git a/translator.py b/translator.py index c2c2aef..f3da494 100644 --- a/translator.py +++ b/translator.py @@ -1,6 +1,7 @@ #encoding: utf-8 import torch +from torch.cuda.amp import autocast from transformer.NMT import NMT from transformer.EnsembleNMT import NMT as Ensemble @@ -87,6 +88,7 @@ def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mul model.to(self.cuda_device) if self.multi_gpu: model = DataParallelMT(model, device_ids=cuda_devices, output_device=self.cuda_device.index, host_replicate=True, gather_output=False) + self.use_amp = cnfg.use_amp and self.use_cuda self.beam_size = cnfg.beam_size @@ -99,7 +101,8 @@ def __call__(self, sentences_iter): for seq_batch in data_loader(sentences_iter, self.vcbi, self.minbsize, self.bsize, self.maxpad, self.maxpart, self.maxtoken): if self.use_cuda: seq_batch = seq_batch.to(self.cuda_device) - output = self.net.decode(seq_batch, self.beam_size, None, self.length_penalty) + with autocast(enabled=self.use_amp): + output = self.net.decode(seq_batch, self.beam_size, None, self.length_penalty) if self.multi_gpu: tmp = [] for ou in output: diff --git a/utils/aan.py b/utils/aan.py new file mode 100644 index 0000000..1fdc968 --- /dev/null +++ b/utils/aan.py @@ -0,0 +1,19 @@ +#encoding: utf-8 + +from torch.nn import ModuleList +from modules.base import AverageAttn + +def share_aan_cache(netin): + + rel_cache_d = {} + for net in netin.modules(): + if isinstance(net, ModuleList): + _cache = None + for layer in net.modules(): + if isinstance(layer, AverageAttn): + if _cache is None: + _cache = layer.w + else: + layer.w = _cache + + return netin diff --git a/utils/base.py b/utils/base.py index 531096f..115f751 100644 --- a/utils/base.py +++ b/utils/base.py @@ -18,14 +18,24 @@ secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64} +def all_done_bool(stat, *inputs, **kwargs): + + return stat.all().item() + +def all_done_byte(stat, bsize, **kwargs): + + return stat.int().sum().item() == bsize + # handling torch.bool -if torch.__version__ < "1.2.0": - mask_tensor_type = torch.uint8 - nccl_type_map = None -else: +try: mask_tensor_type = torch.bool secure_type_map[mask_tensor_type] = torch.int64 nccl_type_map = {torch.bool:torch.uint8} + all_done = all_done_bool +except Exception as e: + mask_tensor_type = torch.uint8 + nccl_type_map = None + all_done = all_done_byte def pad_tensors(tensor_list, dim=-1): @@ -208,8 +218,9 @@ def set_random_seed(seed, set_cuda=False): torch.manual_seed(_rseed) if set_cuda: torch.cuda.manual_seed_all(_rseed) - # Make cudnn methods deterministic according to: https://github.com/OpenNMT/OpenNMT-py/blob/master/onmt/utils/misc.py#L80-L82 + # Make cudnn methods deterministic according to: https://pytorch.org/docs/stable/notes/randomness.html#cudnn torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False def repeat_bsize_for_beam_tensor(tin, beam_size): @@ -361,3 +372,11 @@ def iternext(iterin): rs = None return rs + +def optm_step(optm, scaler=None): + + if scaler is None: + optm.step() + else: + scaler.step(optm) + scaler.update() diff --git a/utils/dynbatch.py b/utils/dynbatch.py index e033bd4..25e2b69 100644 --- a/utils/dynbatch.py +++ b/utils/dynbatch.py @@ -56,7 +56,7 @@ def cos_acc_pg(old_pg, new_pg): return sim -class EffRecoder: +class EffRecorder: def __init__(self, num_choice, num_his=50, init_value=180.0): @@ -81,15 +81,15 @@ class GradientMonitor: # select_func: a function takes (model, index) as input arguments, which returns the index parameter group of model. # angle_alpha: the alpha value, if the angle change is greater than or equal to the multiplication of the minimum value in the history and the alpha, this class will return True to require performing an optimization step. # num_tol_amin: number of tolerant steps after the minimum angle change, if fails to obtain a smaller angle change after this number of steps, will return True to ask performing an optimization step. - # num_his_recoder: number of records of the angle change reduction. + # num_his_record: number of records of the angle change reduction. # num_his_gm: cache num_his_gm gradients into a history, and return this number of angle changes. # returns: (update_r, angle_r), update_r: to performing an optimization step, angle_r: the angle change in current step. - def __init__(self, num_group, select_func, module=None, angle_alpha=1.1, num_tol_amin=3, num_his_recoder=50, num_his_gm=1): + def __init__(self, num_group, select_func, module=None, angle_alpha=1.1, num_tol_amin=3, num_his_record=50, num_his_gm=1): self.scale = 180.0 / pi self.num_group = num_group - self.recorder = EffRecoder(num_group, num_his=num_his_recoder, init_value=1.0)#init_value=180.0 if use sample_gumble_norm in self.reset + self.recorder = EffRecorder(num_group, num_his=num_his_record, init_value=1.0)#init_value=180.0 if use sample_gumble_norm in self.reset self.select_func = select_func self.module = module self.alpha, self.num_tol_amin, self.num_his = angle_alpha, num_tol_amin, num_his_gm diff --git a/utils/fmt/base.py b/utils/fmt/base.py index 0512f93..10e09fc 100644 --- a/utils/fmt/base.py +++ b/utils/fmt/base.py @@ -191,7 +191,7 @@ def get_bsize(maxlen, maxtoken, maxbsize): return min(rs, maxbsize) -def no_unk_mapper(vcb, ltm, prompt=True): +def no_unk_mapper(vcb, ltm, prompt=False): if prompt: rs = [] diff --git a/utils/fmt/base4torch.py b/utils/fmt/base4torch.py index 6bd6aef..187648f 100644 --- a/utils/fmt/base4torch.py +++ b/utils/fmt/base4torch.py @@ -19,12 +19,8 @@ def parse_cuda(use_cuda_arg, gpuid): cuda_devices = None multi_gpu = False torch.cuda.set_device(cuda_device.index) - #torch.backends.cudnn.benchmark = True else: - use_cuda = False - cuda_device = False - cuda_devices = None - multi_gpu = False + use_cuda, cuda_device, cuda_devices, multi_gpu = False, False, None, False return use_cuda, cuda_device, cuda_devices, multi_gpu @@ -46,12 +42,8 @@ def parse_cuda_decode(use_cuda_arg, gpuid, multi_gpu_decoding): cuda_devices = None multi_gpu = False torch.cuda.set_device(cuda_device.index) - #torch.backends.cudnn.benchmark = True else: - use_cuda = False - cuda_device = False - cuda_devices = None - multi_gpu = False + use_cuda, cuda_device, cuda_devices, multi_gpu = False, False, None, False return use_cuda, cuda_device, cuda_devices, multi_gpu diff --git a/utils/sampler.py b/utils/sampler.py new file mode 100644 index 0000000..9202a78 --- /dev/null +++ b/utils/sampler.py @@ -0,0 +1,10 @@ +#encoding: utf-8 + +def SampleMax(input, dim=-1, keepdim=False): + + _ics = input.cumsum(dim) + isize = list(input.size()) + isize[dim] = 1 + _sv = input.new_empty(isize).uniform_(0.0, 1.0) + + return _ics.ge(_sv).int().cumsum(dim).eq(1).int().argmax(dim=dim, keepdim=keepdim)