diff --git a/adv/predict/doc/para/predict_doc_para.py b/adv/predict/doc/para/predict_doc_para.py index 9b92749..38f89d6 100644 --- a/adv/predict/doc/para/predict_doc_para.py +++ b/adv/predict/doc/para/predict_doc_para.py @@ -22,7 +22,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(cnfg.test_data, "r") diff --git a/adv/predict/predict_ape.py b/adv/predict/predict_ape.py index d296392..5c969d7 100644 --- a/adv/predict/predict_ape.py +++ b/adv/predict/predict_ape.py @@ -22,7 +22,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(cnfg.test_data, "r") diff --git a/adv/rank/doc/para/rank_loss_para.py b/adv/rank/doc/para/rank_loss_para.py index 03743f2..a3bd653 100644 --- a/adv/rank/doc/para/rank_loss_para.py +++ b/adv/rank/doc/para/rank_loss_para.py @@ -29,7 +29,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(sys.argv[2], "r") diff --git a/adv/rank/doc/rank_loss_sent.py b/adv/rank/doc/rank_loss_sent.py index f1c8494..a126bf7 100644 --- a/adv/rank/doc/rank_loss_sent.py +++ b/adv/rank/doc/rank_loss_sent.py @@ -29,7 +29,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(sys.argv[2], "r") diff --git a/adv/train/doc/para/train_doc_para.py b/adv/train/doc/para/train_doc_para.py index 38533c1..3e2b21d 100644 --- a/adv/train/doc/para/train_doc_para.py +++ b/adv/train/doc/para/train_doc_para.py @@ -14,7 +14,7 @@ from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb -from lrsch import GoogleLR +from lrsch import GoogleLR as LRScheduler from loss.base import LabelSmoothingLoss from random import shuffle @@ -176,7 +176,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): def init_fixing(module): - if "fix_init" in dir(module): + if hasattr(module, "fix_init"): module.fix_init() rid = cnfg.run_id @@ -280,7 +280,7 @@ def init_fixing(module): logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) -lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) num_checkpoint = cnfg.num_checkpoint cur_checkid = 0 diff --git a/adv/train/train_ape.py b/adv/train/train_ape.py index 3cb2e66..0d35118 100644 --- a/adv/train/train_ape.py +++ b/adv/train/train_ape.py @@ -14,7 +14,7 @@ from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb -from lrsch import GoogleLR +from lrsch import GoogleLR as LRScheduler from loss.base import LabelSmoothingLoss from random import shuffle @@ -174,7 +174,7 @@ def hook_lr_update(optm, flags=None): def init_fixing(module): - if "fix_init" in dir(module): + if hasattr(module, "fix_init"): module.fix_init() rid = cnfg.run_id @@ -270,7 +270,7 @@ def init_fixing(module): logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) -lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) num_checkpoint = cnfg.num_checkpoint cur_checkid = 0 diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index 3636986..d800e8c 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -16,7 +16,7 @@ from utils.fmt.base4torch import parse_cuda, load_emb -from lrsch import GoogleLR +from lrsch import GoogleLR as LRScheduler from loss.base import LabelSmoothingLoss from random import shuffle @@ -195,7 +195,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): def init_fixing(module): - if "fix_init" in dir(module): + if hasattr(module, "fix_init"): module.fix_init() rid = cnfg.run_id @@ -291,7 +291,7 @@ def init_fixing(module): logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) -lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) num_checkpoint = cnfg.num_checkpoint cur_checkid = 0 diff --git a/cnfg/hyp.py b/cnfg/hyp.py index d23f27d..782dec0 100644 --- a/cnfg/hyp.py +++ b/cnfg/hyp.py @@ -17,6 +17,9 @@ use_k_relative_position = 0 disable_std_pemb = False +# using fast implementation of label smoothing loss, but it cannot exclude the negative impact of special tokens, like , on training. `forbidden_indexes` in `cnfg/base.py` shall be set to None to enable. +use_fast_loss = False + # configure maximum batch size w.r.t GPU memory max_sentences_gpu = 768 max_tokens_gpu = 4608 diff --git a/loss/base.py b/loss/base.py index 31d2338..3accf3f 100644 --- a/loss/base.py +++ b/loss/base.py @@ -6,7 +6,32 @@ from torch.nn.functional import kl_div, nll_loss -from utils.base import clear_pad_mask +from utils.base import clear_pad_mask, eq_indexes + +from cnfg.ihyp import * + +# Faster implementation from fairseq: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py#L33-L50, but do not support fbil. +def fast_label_smoothing_loss(input, target, ignore_index, conf, smoothing_value, reduction): + + _target = target.unsqueeze(-1) + nll_loss = -input.gather(dim=-1, index=_target) + smooth_loss = -input.sum(dim=-1, keepdim=True) + if isinstance(ignore_index, (list, tuple)): + pad_mask = eq_indexes(_target, ignore_index) + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + elif ignore_index >= 0: + pad_mask = _target == ignore_index + nll_loss.masked_fill_(pad_mask, 0.0) + smooth_loss.masked_fill_(pad_mask, 0.0) + if reduction != "none": + nll_loss = nll_loss.sum() + smooth_loss = smooth_loss.sum() + loss = conf * nll_loss + smoothing_value * smooth_loss + if reduction == "mean": + loss = loss / float(target.numel()) + + return loss """ from: Rethinking the Inception Architecture for Computer Vision (https://arxiv.org/abs/1512.00567) With label smoothing, KL-divergence between q_{smoothed ground truth prob.}(w) and p_{prob. computed by model}(w) is minimized. @@ -14,19 +39,14 @@ class LabelSmoothingLoss(_Loss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1): + # enable fast_mode will ignore forbidden_index + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1, fast_mode=use_fast_loss): super(LabelSmoothingLoss, self).__init__() - fbil = set() - if isinstance(forbidden_index, (list, tuple)): - for fi in forbidden_index: - if (fi >= 0) and (fi not in fbil): - fbil.add(fi) - else: - if forbidden_index is not None and forbidden_index >= 0: - fbil.add(forbidden_index) + self.fast_mode, self.reduction = fast_mode, reduction + fbil = set() if isinstance(ignore_index, (list, tuple)): tmp = [] for _tmp in ignore_index: @@ -36,10 +56,7 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean fbil.add(_tmp) _nid = len(tmp) if _nid > 0: - if _nid > 1: - self.ignore_index = tuple(tmp) - else: - self.ignore_index = tmp[0] + self.ignore_index = tuple(tmp) if _nid > 1 else tmp[0] else: self.ignore_index = ignore_index[0] if len(ignore_index) > 0 else -1 else: @@ -47,13 +64,24 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean if (ignore_index >= 0) and (ignore_index not in fbil): fbil.add(ignore_index) - smoothing_value = label_smoothing / (nclass - 1 - len(fbil)) - weight = torch.full((nclass,), smoothing_value) - weight.index_fill_(0, torch.tensor(tuple(fbil), dtype=torch.long, device=weight.device), 0.0) - self.register_buffer("weight", weight.unsqueeze(0)) + if fast_mode: + self.smoothing_value = label_smoothing / (nclass - 1) + self.conf = 1.0 - label_smoothing - self.smoothing_value + else: + if isinstance(forbidden_index, (list, tuple)): + for fi in forbidden_index: + if (fi >= 0) and (fi not in fbil): + fbil.add(fi) + else: + if forbidden_index is not None and forbidden_index >= 0: + fbil.add(forbidden_index) + + smoothing_value = label_smoothing / (nclass - 1 - len(fbil)) - self.reduction = reduction - self.conf = 1.0 - label_smoothing + weight = torch.full((nclass,), smoothing_value) + weight.index_fill_(0, torch.tensor(tuple(fbil), dtype=torch.long, device=weight.device), 0.0) + self.register_buffer("weight", weight.unsqueeze(0)) + self.conf = 1.0 - label_smoothing # input: (batch size, num_classes) # target: (batch size) @@ -61,21 +89,23 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean def forward(self, input, target): - _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input - - _target = target.view(-1, 1) + if self.fast_mode: - model_prob = self.weight.repeat(_target.size(0), 1) - model_prob.scatter_(1, _target, self.conf) + return fast_label_smoothing_loss(input, target, self.ignore_index, self.conf, self.smoothing_value, self.reduction) + else: + _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input + _target = target.view(-1, 1) + model_prob = self.weight.repeat(_target.size(0), 1) + model_prob.scatter_(1, _target, self.conf) - if isinstance(self.ignore_index, (list, tuple)): - model_prob.masked_fill_(torch.stack([_target == _tmp for _tmp in self.ignore_index]).int().sum(0).gt(0), 0.0) - elif self.ignore_index >= 0: - model_prob.masked_fill_(_target == self.ignore_index, 0.0) + if isinstance(self.ignore_index, (list, tuple)): + model_prob.masked_fill_(eq_indexes(_target, self.ignore_index), 0.0) + elif self.ignore_index >= 0: + model_prob.masked_fill_(_target == self.ignore_index, 0.0) - rs = kl_div(_input, model_prob, reduction=self.reduction) + rs = kl_div(_input, model_prob, reduction=self.reduction) - return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs + return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs class NLLLoss(NLLLossBase): @@ -99,80 +129,82 @@ def forward(self, input, target): class MultiLabelSmoothingLoss(_Loss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1): + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1, fast_mode=use_fast_loss): super(MultiLabelSmoothingLoss, self).__init__() - fbil = [] - for fbilu in forbidden_index: - tmp = set() - if isinstance(fbilu, (list, tuple)): - for fi in fbilu: - if (fi >= 0) and (fi not in tmp): - tmp.add(fi) - else: - if fbilu is not None and fbilu >= 0: - tmp.add(forbidden_index) - fbil.append(tmp) + self.fast_mode, self.reduction = fast_mode, reduction + fbil_common = set() if isinstance(ignore_index, (list, tuple)): tmp = [] for _tmp in ignore_index: if (_tmp >= 0) and (_tmp not in tmp): tmp.append(_tmp) - for fbilu in fbil: - if _tmp not in fbilu: - fbilu.add(_tmp) + if _tmp not in fbil_common: + fbil_common.add(_tmp) _nid = len(tmp) if _nid > 0: - if _nid > 1: - self.ignore_index = tuple(tmp) - else: - self.ignore_index = tmp[0] + self.ignore_index = tuple(tmp) if _nid > 1 else tmp[0] else: self.ignore_index = ignore_index[0] if len(ignore_index) > 0 else -1 else: self.ignore_index = ignore_index - if (ignore_index >= 0): - for fbilu in fbil: - if ignore_index not in fbilu: - fbilu.add(ignore_index) - - _weight = [] - for fbilu in fbil: - smoothing_value = label_smoothing / (nclass - 1 - len(fbilu)) - _tmp_w = torch.full((nclass,), smoothing_value) - _tmp_w.index_fill_(0, torch.tensor(tuple(fbilu), dtype=torch.long, device=_tmp_w.device), 0.0) - _weight.append(_tmp_w) - self.register_buffer("weight", torch.stack(_weight, 0).unsqueeze(1)) + if (ignore_index >= 0) and (ignore_index not in fbil_common): + fbil_common.add(ignore_index) - self.reduction = reduction - - self.conf = 1.0 - label_smoothing + if fast_mode: + self.smoothing_value = label_smoothing / (nclass - 1) + self.conf = 1.0 - label_smoothing - self.smoothing_value + else: + fbil = [] + for fbilu in forbidden_index: + tmp = set() + if isinstance(fbilu, (list, tuple)): + for fi in fbilu: + if (fi >= 0) and (fi not in tmp): + tmp.add(fi) + else: + if fbilu is not None and fbilu >= 0: + tmp.add(forbidden_index) + tmp |= fbil_common + fbil.append(tmp) + + _weight = [] + for fbilu in fbil: + smoothing_value = label_smoothing / (nclass - 1 - len(fbilu)) + _tmp_w = torch.full((nclass,), smoothing_value) + _tmp_w.index_fill_(0, torch.tensor(tuple(fbilu), dtype=torch.long, device=_tmp_w.device), 0.0) + _weight.append(_tmp_w) + self.register_buffer("weight", torch.stack(_weight, 0).unsqueeze(1)) + self.conf = 1.0 - label_smoothing def forward(self, input, target, lang_id=0): - _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input + if self.fast_mode: - _target = target.view(-1, 1) + return fast_label_smoothing_loss(input, target, self.ignore_index, self.conf, self.smoothing_value, self.reduction) + else: + _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input + _target = target.view(-1, 1) - model_prob = self.weight[lang_id].repeat(_target.size(0), 1) - model_prob.scatter_(1, _target, self.conf) + model_prob = self.weight[lang_id].repeat(_target.size(0), 1) + model_prob.scatter_(1, _target, self.conf) - if isinstance(self.ignore_index, (list, tuple)): - model_prob.masked_fill_(torch.stack([_target == _tmp for _tmp in self.ignore_index]).int().sum(0).gt(0), 0.0) - elif self.ignore_index >= 0: - model_prob.masked_fill_(_target == self.ignore_index, 0.0) + if isinstance(self.ignore_index, (list, tuple)): + model_prob.masked_fill_(eq_indexes(_target, self.ignore_index), 0.0) + elif self.ignore_index >= 0: + model_prob.masked_fill_(_target == self.ignore_index, 0.0) - rs = kl_div(_input, model_prob, reduction=self.reduction) + rs = kl_div(_input, model_prob, reduction=self.reduction) - return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs + return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs class ReducedLabelSmoothingLoss(LabelSmoothingLoss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1, reduce_dim=None): + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction='mean', forbidden_index=-1, fast_mode=use_fast_loss, reduce_dim=None): - super(ReducedLabelSmoothingLoss, self).__init__(nclass, label_smoothing, ignore_index, reduction, forbidden_index) + super(ReducedLabelSmoothingLoss, self).__init__(nclass, label_smoothing=label_smoothing, ignore_index=ignore_index, reduction=reduction, forbidden_index=forbidden_index, fast_mode=fast_mode) self.reduce_dim = reduce_dim @@ -181,18 +213,21 @@ def forward(self, input, target): if self.reduce_dim is not None: input, target = clear_pad_mask([input, target], target.eq(0), [self.reduce_dim - 1, self.reduce_dim], mask_dim=self.reduce_dim, return_contiguous=True)[0] - _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input + if self.fast_mode: - _target = target.view(-1, 1) + return fast_label_smoothing_loss(input, target, self.ignore_index, self.conf, self.smoothing_value, self.reduction) + else: + _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input + _target = target.view(-1, 1) - model_prob = self.weight.repeat(_target.size(0), 1) - model_prob.scatter_(1, _target, self.conf) + model_prob = self.weight.repeat(_target.size(0), 1) + model_prob.scatter_(1, _target, self.conf) - if isinstance(self.ignore_index, (list, tuple)): - model_prob.masked_fill_(torch.stack([_target == _tmp for _tmp in self.ignore_index]).int().sum(0).gt(0), 0.0) - elif self.ignore_index >= 0: - model_prob.masked_fill_(_target == self.ignore_index, 0.0) + if isinstance(self.ignore_index, (list, tuple)): + model_prob.masked_fill_(eq_indexes(_target, self.ignore_index), 0.0) + elif self.ignore_index >= 0: + model_prob.masked_fill_(_target == self.ignore_index, 0.0) - rs = kl_div(_input, model_prob, reduction=self.reduction) + rs = kl_div(_input, model_prob, reduction=self.reduction) - return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs + return rs.view(input.size()) if self.reduction == 'none' and target.dim() > 1 else rs diff --git a/lrsch.py b/lrsch.py index b4b6669..7cfb788 100644 --- a/lrsch.py +++ b/lrsch.py @@ -7,22 +7,44 @@ class GoogleLR(_LRScheduler): def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1): - self.cur_step = 0 - self.k = 1.0 / sqrt(dmodel) - self.wk = 1.0 / sqrt(warm_steps) / warm_steps - self.warm_steps = warm_steps - self.scale = scale + self.cur_step, self.warm_steps = 0, warm_steps + self.k = scale / sqrt(dmodel) + self.wk = self.k / sqrt(warm_steps) / warm_steps super(GoogleLR, self).__init__(optimizer, last_epoch) def get_lr(self): self.cur_step += 1 - cur_lr = self.k * ((self.cur_step * self.wk) if self.cur_step <= self.warm_steps else (1.0 / sqrt(self.cur_step))) - if self.scale != 1.0: - cur_lr *= self.scale + cur_lr = (self.cur_step * self.wk) if self.cur_step < self.warm_steps else (self.k / sqrt(self.cur_step)) + return [cur_lr for i in range(len(self.base_lrs))] -class ReverseSqrtLR(_LRScheduler): +# inverse square root with warm up, portal from: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py, equal to GoogleLR when warm_end_lr = 1.0 / sqrt(dmodel * warm_steps) +class WarmUpInverseSqrtLR(_LRScheduler): + + def __init__(self, optimizer, warm_end_lr, warm_steps, warm_init_lr=0.0, last_epoch=-1): + + self.cur_step, self.warm_end_lr, self.warm_steps, self.warm_init_lr = 0, warm_end_lr, warm_steps, warm_init_lr + self.lr_step = (warm_end_lr - warm_init_lr) / warm_steps + self.decay_factor = warm_end_lr * sqrt(warm_steps) + + super(WarmUpInverseSqrtLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + + self.cur_step += 1 + cur_lr = (self.warm_init_lr + self.cur_step * self.lr_step) if self.cur_step < self.warm_steps else (self.decay_factor / sqrt(self.cur_step)) + + return [cur_lr for i in range(len(self.base_lrs))] + +''' +class GoogleLR(WarmUpInverseSqrtLR): + + def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1): + + super(GoogleLR, self).__init__(optimizer, scale / sqrt(dmodel * warm_steps), warm_steps, warm_init_lr=0.0, last_epoch=last_epoch)''' + +class InverseSqrtLR(_LRScheduler): def __init__(self, optimizer, lr=1e-4, scalar=1.0, min_lr=None, last_epoch=-1): @@ -30,10 +52,11 @@ def __init__(self, optimizer, lr=1e-4, scalar=1.0, min_lr=None, last_epoch=-1): self.base_lr = lr self.epoch_steps = scalar self.min_lr = (lr / 512.0) if min_lr is None else min_lr - super(ReverseSqrtLR, self).__init__(optimizer, last_epoch) + super(InverseSqrtLR, self).__init__(optimizer, last_epoch) def get_lr(self): self.cur_step += 1 cur_lr = max(min(1.0, 1.0 / sqrt(self.cur_step / self.epoch_steps)), self.min_lr) * self.base_lr + return [cur_lr for i in range(len(self.base_lrs))] diff --git a/mkcy.py b/mkcy.py index bc17a22..2159175 100644 --- a/mkcy.py +++ b/mkcy.py @@ -17,7 +17,7 @@ def legal(pname, fbl): rs = True for pyst, pyf in fbl: - if pname.startswith(pyst) or pname==pyf: + if pname.startswith(pyst) or pname == pyf: rs = False break diff --git a/modules/base.py b/modules/base.py index 9b84011..f811571 100644 --- a/modules/base.py +++ b/modules/base.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from math import sqrt, log, exp, pi +from math import sqrt, log, exp import torch from torch import nn from torch.nn import functional as nnFunc @@ -9,7 +9,7 @@ from utils.base import reduce_model_list from modules.act import Custom_Act from modules.act import reduce_model as reduce_model_act -from modules.dropout import Dropout, TokenDropout, InfDropout +from modules.dropout import Dropout from modules.dropout import reduce_model as reduce_model_drop from cnfg.ihyp import * diff --git a/predict.py b/predict.py index 8fa6117..9231e54 100644 --- a/predict.py +++ b/predict.py @@ -22,7 +22,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(cnfg.test_data, "r") diff --git a/rank_loss.py b/rank_loss.py index 8b464b6..ac4374d 100644 --- a/rank_loss.py +++ b/rank_loss.py @@ -29,7 +29,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(sys.argv[2], "r") diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index 7874b4e..8cdb7a7 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -17,7 +17,7 @@ from utils.fmt.base import tostr, save_states, load_states, pad_id, parse_double_value_tuple from utils.fmt.base4torch import parse_cuda, load_emb -from lrsch import GoogleLR +from lrsch import GoogleLR as LRScheduler from loss.base import LabelSmoothingLoss from random import shuffle, random @@ -217,7 +217,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): def init_fixing(module): - if "fix_init" in dir(module): + if hasattr(module, "fix_init"): module.fix_init() rid = cnfg.run_id @@ -315,7 +315,7 @@ def init_fixing(module): logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) -lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) num_checkpoint = cnfg.num_checkpoint cur_checkid = 0 diff --git a/tools/check/tspeed.py b/tools/check/tspeed.py index 0afa4a8..dc57da0 100644 --- a/tools/check/tspeed.py +++ b/tools/check/tspeed.py @@ -19,7 +19,7 @@ def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() td = h5py.File(cnfg.dev_data, "r") diff --git a/tools/doc/para/mkiodata.py b/tools/doc/para/mkiodata.py index f778aad..a34727c 100644 --- a/tools/doc/para/mkiodata.py +++ b/tools/doc/para/mkiodata.py @@ -11,8 +11,8 @@ from cnfg.ihyp import * def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): - vcbi, nwordi = ldvocab(fvocab_i, minfreq, vsize) - vcbt, nwordt = ldvocab(fvocab_t, minfreq, vsize) + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbt, nwordt = ldvocab(fvocab_t, minf=minfreq, omit_vsize=vsize, vanilla=False) if expand_for_mulgpu: _bsize = bsize * minbsize _maxtoken = maxtoken * minbsize @@ -24,8 +24,8 @@ def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulg tgt_grp = rsf.create_group("tgt") curd = {} for i_d, td, nsent in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype = numpy.int32) - rtd = numpy.array(td, dtype = numpy.int32) + rid = numpy.array(i_d, dtype=numpy.int32) + rtd = numpy.array(td, dtype=numpy.int32) _nsentgid = str(nsent) _curd = curd.get(nsent, 0) if _curd == 0: @@ -36,11 +36,11 @@ def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulg tgt_grp[_nsentgid].create_dataset(_curid, data=rtd, **h5datawargs) curd[nsent] = _curd + 1 sents, ndl = dict2pairs(curd) - rsf["nsent"] = numpy.array(sents, dtype = numpy.int32) - rsf["ndata"] = numpy.array(ndl, dtype = numpy.int32) - rsf["nword"] = numpy.array([nwordi, nwordt], dtype = numpy.int32) + rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) + rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi, nwordt], dtype=numpy.int32) rsf.close() - print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (sum(ndl), nwordi, nwordt)) + print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (sum(ndl), nwordi, nwordt,)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], int(sys.argv[6])) diff --git a/tools/doc/para/mktest.py b/tools/doc/para/mktest.py index 86c1fa0..c0548b2 100644 --- a/tools/doc/para/mktest.py +++ b/tools/doc/para/mktest.py @@ -11,7 +11,7 @@ from cnfg.ihyp import * def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): - vcbi, nwordi = ldvocab(fvocab_i, minfreq, vsize) + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) if expand_for_mulgpu: _bsize = bsize * minbsize _maxtoken = maxtoken * minbsize @@ -22,7 +22,7 @@ def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_ src_grp = rsf.create_group("src") curd = {} for i_d, nsent in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): - rid = numpy.array(i_d, dtype = numpy.int32) + rid = numpy.array(i_d, dtype=numpy.int32) _nsentgid = str(nsent) _curd = curd.get(nsent, 0) if _curd == 0: @@ -30,11 +30,11 @@ def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_ src_grp[_nsentgid].create_dataset(str(_curd), data=rid, **h5datawargs) curd[nsent] = _curd + 1 sents, ndl = dict2pairs(curd) - rsf["nsent"] = numpy.array(sents, dtype = numpy.int32) - rsf["ndata"] = numpy.array(ndl, dtype = numpy.int32) - rsf["nword"] = numpy.array([nwordi], dtype = numpy.int32) + rsf["nsent"] = numpy.array(sents, dtype=numpy.int32) + rsf["ndata"] = numpy.array(ndl, dtype=numpy.int32) + rsf["nword"] = numpy.array([nwordi], dtype=numpy.int32) rsf.close() - print("Number of batches: %d\nSource Vocabulary Size: %d" % (sum(ndl), nwordi)) + print("Number of batches: %d\nSource Vocabulary Size: %d" % (sum(ndl), nwordi,)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4])) diff --git a/train.py b/train.py index e56c090..7e873c8 100644 --- a/train.py +++ b/train.py @@ -15,7 +15,7 @@ from utils.fmt.base import tostr, save_states, load_states, pad_id from utils.fmt.base4torch import parse_cuda, load_emb -from lrsch import GoogleLR +from lrsch import GoogleLR as LRScheduler from loss.base import LabelSmoothingLoss from random import shuffle @@ -174,7 +174,7 @@ def hook_lr_update(optm, flags=None): def init_fixing(module): - if "fix_init" in dir(module): + if hasattr(module, "fix_init"): module.fix_init() rid = cnfg.run_id @@ -258,7 +258,7 @@ def init_fixing(module): mymodel.to(cuda_device) lossf.to(cuda_device) -# lr will be over written by GoogleLR before used +# lr will be over written by LRScheduler before used optimizer = Optimizer(mymodel.parameters(), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) optimizer.zero_grad(set_to_none=True) @@ -275,7 +275,7 @@ def init_fixing(module): logger.info("Load optimizer state from: " + fine_tune_state) optimizer.load_state_dict(h5load(fine_tune_state)) -lrsch = GoogleLR(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) #lrsch.step() num_checkpoint = cnfg.num_checkpoint diff --git a/transformer/Encoder.py b/transformer/Encoder.py index 271cbd6..ffc1162 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -128,7 +128,7 @@ def load_base(self, base_encoder): def fix_init(self): - if "fix_load" in dir(self): + if hasattr(self, "fix_load"): self.fix_load() with torch.no_grad(): self.wemb.weight[pad_id].zero_() diff --git a/transformer/EnsembleNMT.py b/transformer/EnsembleNMT.py index 8a8b2f4..7f61de2 100644 --- a/transformer/EnsembleNMT.py +++ b/transformer/EnsembleNMT.py @@ -2,6 +2,9 @@ import torch from torch import nn + +from utils.base import all_done + from transformer.EnsembleEncoder import Encoder # import Decoder from transformer.AGG.Ensemble implementation or transformer.AGG.Ensemble implementation to enable feature combination between layers @@ -73,9 +76,9 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): out = torch.cat((out, wds), -1) # done_trans: (bsize) - done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans + wds.squeeze(1).eq(2)).gt(0) + done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans | wds.squeeze(1).eq(2)) - if done_trans.sum().item() == bsize: + if all_done(done_trans, bsize): break return out.narrow(1, 1, out.size(1) - 1) @@ -142,7 +145,7 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ out = torch.cat((out.index_select(0, _inds), wds), -1) # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) + wds.view(real_bsize).eq(2)).gt(0).view(bsize, beam_size) + done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(2)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) @@ -151,12 +154,12 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break out = out.narrow(1, 1, out.size(1) - 1) diff --git a/transformer/NMT.py b/transformer/NMT.py index 212183c..d2edf8a 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -3,6 +3,7 @@ import torch from torch import nn +from utils.base import all_done from utils.relpos import share_rel_pos_cache from utils.fmt.base import parse_double_value_tuple @@ -57,11 +58,11 @@ def forward(self, inpute, inputo, mask=None): def load_base(self, base_nmt): - if "load_base" in dir(self.enc): + if hasattr(self.enc, "load_base"): self.enc.load_base(base_nmt.enc) else: self.enc = base_nmt.enc - if "load_base" in dir(self.dec): + if hasattr(self.dec, "load_base"): self.dec.load_base(base_nmt.dec) else: self.dec = base_nmt.dec @@ -109,9 +110,9 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): out = torch.cat((out, wds), -1) # done_trans: (bsize) - done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans + wds.squeeze(1).eq(2)).gt(0) + done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans | wds.squeeze(1).eq(2)) - if done_trans.sum().item() == bsize: + if all_done(done_trans, bsize): break return out.narrow(1, 1, out.size(1) - 1) @@ -178,7 +179,7 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ out = torch.cat((out.index_select(0, _inds), wds), -1) # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) + wds.view(real_bsize).eq(2)).gt(0).view(bsize, beam_size) + done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(2)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) @@ -187,12 +188,12 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ _done = False if length_penalty > 0.0: lpv = lpv.index_select(0, _inds) - elif (not return_all) and done_trans.select(1, 0).sum().item() == bsize: + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): _done = True # check beam states(done or not) - if _done or (done_trans.sum().item() == real_bsize): + if _done or all_done(done_trans, real_bsize): break out = out.narrow(1, 1, out.size(1) - 1) diff --git a/translator.py b/translator.py index c2a7d24..66aa651 100644 --- a/translator.py +++ b/translator.py @@ -20,7 +20,7 @@ def data_loader(sentences_iter, vcbi, minbsize=1, bsize=768, maxpad=16, maxpart= yield torch.tensor(i_d, dtype=torch.long) def load_fixing(module): - if "fix_load" in dir(module): + if hasattr(module, "fix_load"): module.fix_load() def sorti(lin): diff --git a/utils/base.py b/utils/base.py index 1a964d4..d5c2596 100644 --- a/utils/base.py +++ b/utils/base.py @@ -107,7 +107,7 @@ def unfreeze_module(module): def unfreeze_fixing(mod): - if "fix_unfreeze" in dir(mod): + if hasattr(mod, "fix_unfreeze"): mod.fix_unfreeze() for p in module.parameters(): @@ -115,6 +115,16 @@ def unfreeze_fixing(mod): module.apply(unfreeze_fixing) +def eq_indexes(tensor, indexes): + + rs = None + for ind in indexes: + if rs is None: + rs = tensor.eq(ind) + else: + rs |= tensor.eq(ind) + return rs + def getlr(optm): lr = [] @@ -219,7 +229,7 @@ def _worker(model, fname, sub_module=False, logger=None, para_lock=None, log_suc def get_logger(fname): logger = logging.getLogger(__name__) - logger.setLevel(level = logging.INFO) + logger.setLevel(level=logging.INFO) handler = logging.FileHandler(fname) handler.setLevel(logging.INFO) formatter = logging.Formatter('[%(asctime)s %(levelname)s] %(message)s')