From 635c75c7d927f035ae1d45ad1f14f77d9877ac26 Mon Sep 17 00:00:00 2001 From: Hongfei Xu Date: Thu, 10 Aug 2023 13:48:20 +0800 Subject: [PATCH] August 2023 update --- README.md | 2 +- adv/eva/eva_probe.py | 29 +- adv/eva/prompt/roberta/eva_single.py | 39 +- adv/examples/plm/bart.py | 12 +- adv/examples/plm/bert.py | 12 +- adv/examples/plm/mbart.py | 47 ++ adv/examples/plm/roberta.py | 12 +- adv/examples/plm/t5.py | 13 +- adv/predict/plm/bart/predict.py | 56 +- adv/predict/plm/roberta/predict.py | 40 +- adv/predict/plm/roberta/predict_reg.py | 106 ++++ adv/predict/plm/t5/predict.py | 56 +- adv/predict/predict_ape.py | 32 +- adv/predict/predict_doc_para.py | 32 +- adv/predict/predict_mulang.py | 32 +- adv/predict/predict_probe_enc.py | 27 +- adv/rank/doc/para/rank_loss_para.py | 35 +- adv/rank/doc/rank_loss_sent.py | 35 +- adv/train/mulang/train_m2o.py | 65 ++- adv/train/mulang/train_mulang.py | 65 ++- adv/train/mulang/train_mulang_robt.py | 71 ++- .../roberta/{train_single.py => train.py} | 64 ++- adv/train/prompt/roberta/train_reg.py | 381 +++++++++++++ adv/train/train_ape.py | 61 ++- adv/train/train_doc_para.py | 67 ++- adv/train/train_dynb.py | 69 +-- adv/train/train_probe.py | 69 +-- cnfg/README.md | 17 +- cnfg/base.py | 3 + cnfg/hyp.py | 23 +- cnfg/ihyp.py | 13 +- cnfg/plm/bart/base.py | 1 + cnfg/plm/bart/hyp.py | 3 + cnfg/plm/bart/ihyp.py | 3 - cnfg/plm/bert/base.py | 1 + cnfg/plm/bert/hyp.py | 3 + cnfg/plm/bert/ihyp.py | 3 - cnfg/plm/mbart/base.py | 26 + cnfg/plm/mbart/hyp.py | 20 + cnfg/plm/mbart/ihyp.py | 17 + cnfg/plm/roberta/base.py | 2 + cnfg/plm/roberta/hyp.py | 3 + cnfg/plm/roberta/ihyp.py | 3 - cnfg/plm/t5/base.py | 1 + cnfg/plm/t5/hyp.py | 12 +- cnfg/plm/t5/ihyp.py | 3 - cnfg/prompt/roberta/base.py | 6 +- cnfg/server.py | 7 + cnfg/vocab/plm/mbart.py | 11 + datautils/bpe.py | 8 +- datautils/moses.py | 20 +- datautils/pymoses.py | 14 +- datautils/zh.py | 8 +- loss/base.py | 45 +- loss/mulang.py | 8 +- loss/regression.py | 23 + lrsch.py | 30 +- mkcy.py | 10 +- modules/LD.py | 10 +- modules/TA.py | 8 +- modules/aan.py | 15 +- modules/act.py | 34 +- modules/attn/rap.py | 17 +- modules/attn/res.py | 11 +- modules/attn/retr.py | 24 +- modules/base.py | 138 ++--- modules/cpp/act/setup.py | 2 +- modules/cpp/base/attn/cross/setup.py | 2 +- modules/cpp/base/attn/self/setup.py | 2 +- modules/cpp/base/attn/setup.py | 2 +- modules/cpp/base/ffn/setup.py | 2 +- modules/cpp/base/resattn/cross/setup.py | 2 +- modules/cpp/base/resattn/self/setup.py | 2 +- modules/cpp/base/resattn/setup.py | 2 +- modules/cpp/group/setup.py | 2 +- modules/cpp/hplstm/setup.py | 3 +- modules/dropout.py | 50 +- modules/group/base.py | 12 +- modules/hplstm/base.py | 32 +- modules/hplstm/hfn.py | 34 +- modules/hplstm/wrapper.py | 56 +- modules/mulang/eff/base.py | 30 +- modules/noise.py | 22 +- modules/paradoc.py | 5 +- modules/plm/bert.py | 38 -- modules/plm/mbart.py | 41 ++ modules/plm/t5.py | 56 +- modules/rnncells.py | 20 +- modules/sampler.py | 11 +- modules/sdu.py | 91 ++++ modules/server/__init__.py | 1 + .../server/transformer.py | 52 +- optm/adabelief.py | 7 +- optm/lookahead.py | 2 +- optm/radam.py | 6 +- optm/ranger.py | 6 +- parallel/base.py | 117 +++- parallel/optm.py | 12 +- parallel/parallelMT.py | 85 ++- predict.py | 37 +- rank_loss.py | 47 +- requirements.opt.txt | 10 +- requirements.txt | 6 +- scripts/README.md | 8 + scripts/ape/bpe/mk.sh | 4 +- scripts/ape/mktest.sh | 7 +- scripts/ape/mktrain.sh | 24 +- scripts/bpe/mk.sh | 4 +- scripts/doc/para/mktest.sh | 5 +- scripts/mktest.sh | 5 +- scripts/mktrain.sh | 24 +- scripts/mulang/mktest.sh | 6 +- scripts/mulang/mktrain.sh | 24 +- scripts/plm/bart/mktest.sh | 14 +- scripts/plm/roberta/mktest.sh | 14 +- scripts/plm/roberta/mktrain.sh | 14 +- scripts/plm/t5/mktest.sh | 14 +- scripts/spm/clean.sh | 9 +- scripts/spm/mk.sh | 13 +- server.py | 13 +- tools/ape/mkiodata.py | 5 +- tools/average_model.py | 50 +- tools/char/cnfg | 1 + tools/char/mkiodata.py | 37 ++ tools/char/mktest.py | 33 ++ tools/char/utils | 1 + tools/check/avg_bsize.py | 5 +- tools/check/biratio.py | 4 +- tools/check/charatio.py | 4 +- tools/check/doc/para/epoch_steps.py | 5 +- tools/check/dynb/bsize/bsize.py | 5 +- tools/check/dynb/bsize/extbsize.py | 6 +- tools/check/dynb/bsize/freqbsize.py | 5 +- tools/check/dynb/bsize/stabsize.py | 4 +- tools/check/dynb/consis.py | 5 +- tools/check/dynb/ext.py | 4 +- tools/check/dynb/report_dynb.py | 66 ++- tools/check/dynb/slayer.py | 5 +- tools/check/epoch_steps.py | 5 +- tools/check/ext_emb.py | 6 +- tools/check/fbindexes.py | 7 +- tools/check/lsplitter.py | 5 +- tools/check/mkrandio.py | 29 + tools/check/mulang/eff/epoch_steps.py | 5 +- tools/check/mulang/fbindexes.py | 7 +- tools/check/para.py | 2 +- tools/check/probe/merge_probe.py | 13 +- tools/check/rank.py | 4 +- tools/check/topk/eva.py | 31 +- tools/check/topk/eva_stat.py | 30 +- tools/check/topk/stat.py | 8 +- tools/check/tspeed.py | 23 +- tools/check/vsize/copy.py | 4 +- tools/check/vsize/detail.py | 4 +- tools/check/vsize/mono.py | 9 +- tools/check/vsize/vocab.py | 2 +- tools/clean/ape/chars.py | 4 +- tools/clean/ape/maxkeeper.py | 6 +- tools/clean/ape/vocab.py | 10 +- tools/clean/chars.py | 4 +- tools/clean/cython.py | 2 +- tools/clean/dedup.py | 2 +- tools/clean/doc/para/maxkeeper.py | 6 +- tools/clean/doc/para/vocab.py | 9 +- tools/clean/gold.py | 4 +- tools/clean/maxkeeper.py | 6 +- tools/clean/normu8.py | 19 + tools/clean/rank.py | 4 +- tools/clean/sampler/strict_sampler.py | 6 +- tools/clean/token_repeat.py | 2 +- tools/clean/tokens.py | 4 +- tools/clean/vocab/ratio.py | 13 +- tools/clean/vocab/target.py | 5 +- tools/clean/vocab/unk.py | 4 +- tools/doc/para/mkiodata.py | 6 +- tools/doc/para/mktest.py | 6 +- tools/doc/para/restore.py | 8 +- tools/doc/sort.py | 14 +- tools/h5/compress.py | 4 +- tools/h5/convert.py | 4 +- tools/lang/zh/cnfg | 1 + tools/lang/zh/deseg.py | 19 + tools/lang/zh/utils | 1 + tools/lsort/merge.py | 34 +- tools/lsort/partsort.py | 4 +- tools/mkiodata.py | 5 +- tools/mktest.py | 5 +- tools/mulang/eff/mkiodata.py | 5 +- tools/mulang/eff/mktest.py | 5 +- tools/mulang/eff/sort.py | 14 +- tools/mulang/vocab/char/cnfg | 1 + tools/mulang/vocab/char/share.py | 41 ++ tools/mulang/vocab/char/single.py | 28 + tools/mulang/vocab/char/utils | 1 + tools/mulang/vocab/token/cnfg | 1 + .../{share_vocab.py => vocab/token/share.py} | 7 +- .../{vocab.py => vocab/token/single.py} | 5 +- tools/mulang/vocab/token/utils | 1 + tools/plm/map/bart.py | 4 +- tools/plm/map/bert.py | 4 +- tools/plm/map/mbart.py | 13 + tools/plm/map/mbart50.py | 15 + tools/plm/map/roberta.py | 4 +- tools/plm/map/t5.py | 4 +- tools/plm/mback/bart.py | 4 +- tools/plm/mback/bert.py | 4 +- tools/plm/mback/mbart.py | 13 + tools/plm/mback/mbart50.py | 15 + tools/plm/mback/roberta.py | 4 +- tools/plm/mback/t5.py | 4 +- tools/plm/mkiodata.py | 3 +- tools/plm/mkiodata_reg.py | 36 ++ tools/plm/mktest.py | 3 +- tools/plm/mtyp/bert.py | 4 +- tools/plm/mtyp/roberta.py | 4 +- tools/plm/token/bart.py | 4 +- tools/plm/token/bert.py | 4 +- tools/plm/token/cnfg | 1 + tools/plm/token/mbart.py | 13 + tools/plm/token/mbart50.py | 15 + tools/plm/token/roberta.py | 4 +- tools/plm/token/t5.py | 4 +- tools/plm/token/utils | 1 + tools/prune_model_vocab.py | 27 +- tools/restore.py | 5 +- tools/shuffle.py | 5 +- tools/sort.py | 14 +- tools/spm/encode.py | 2 +- tools/spm/train.py | 1 + tools/vocab/char/cnfg | 1 + tools/vocab/char/filter.py | 14 + tools/vocab/char/merge.py | 13 + tools/vocab/char/share.py | 23 + tools/vocab/char/single.py | 22 + tools/vocab/char/utils | 1 + tools/vocab/cnfg | 1 + tools/vocab/map.py | 19 + tools/vocab/token/cnfg | 1 + tools/vocab/token/filter.py | 14 + tools/vocab/token/merge.py | 13 + .../{share_vocab.py => vocab/token/share.py} | 5 +- tools/{vocab.py => vocab/token/single.py} | 5 +- tools/vocab/token/utils | 1 + tools/vocab/utils | 1 + train.py | 76 +-- transformer/AGG/HierDecoder.py | 23 +- transformer/AGG/HierEncoder.py | 22 +- transformer/AGG/InceptDecoder.py | 22 +- transformer/AGG/InceptEncoder.py | 21 +- transformer/APE/Decoder.py | 54 +- transformer/APE/Encoder.py | 52 +- transformer/APE/NMT.py | 28 +- transformer/AvgDecoder.py | 61 +-- transformer/ConstrainedDecoder.py | 28 +- transformer/Decoder.py | 121 +++-- transformer/Doc/Para/Base/Decoder.py | 55 +- transformer/Doc/Para/Base/Encoder.py | 57 +- transformer/Doc/Para/Base/NMT.py | 26 +- transformer/Encoder.py | 55 +- transformer/EnsembleAvgDecoder.py | 35 +- transformer/EnsembleDecoder.py | 44 +- transformer/EnsembleEncoder.py | 4 +- transformer/EnsembleNMT.py | 24 +- transformer/HPLSTM/Decoder.py | 48 +- transformer/HPLSTM/Encoder.py | 24 +- transformer/HPLSTM/FNDecoder.py | 49 +- transformer/LD/AttnEncoder.py | 10 +- transformer/LD/Decoder.py | 67 ++- transformer/LD/Encoder.py | 31 +- transformer/LD/NMT.py | 24 +- transformer/MuLang/Eff/Base/Decoder.py | 74 +-- transformer/MuLang/Eff/Base/Encoder.py | 34 +- transformer/MuLang/Eff/Base/NMT.py | 27 +- transformer/NMT.py | 51 +- transformer/PLM/BART/Decoder.py | 119 ++-- transformer/PLM/BART/Encoder.py | 97 ++-- transformer/PLM/BART/NMT.py | 28 +- transformer/PLM/BERT/Decoder.py | 43 +- transformer/PLM/BERT/Encoder.py | 78 +-- transformer/PLM/BERT/NMT.py | 34 +- transformer/PLM/MBART/Decoder.py | 288 ++++++++++ transformer/PLM/MBART/Encoder.py | 88 +++ transformer/PLM/MBART/NMT.py | 47 ++ transformer/PLM/MBART/__init__.py | 1 + transformer/PLM/NMT.py | 15 + transformer/PLM/RoBERTa/Decoder.py | 19 +- transformer/PLM/RoBERTa/Encoder.py | 52 +- transformer/PLM/RoBERTa/NMT.py | 24 +- transformer/PLM/T5/Decoder.py | 74 +-- transformer/PLM/T5/Encoder.py | 53 +- transformer/PLM/T5/NMT.py | 30 +- transformer/Probe/Decoder.py | 27 +- transformer/Probe/Encoder.py | 13 +- transformer/Probe/NMT.py | 23 +- transformer/Probe/ReDecoder.py | 23 +- transformer/Probe/ReNMT.py | 19 +- transformer/Prompt/RoBERTa/NMT.py | 10 +- transformer/README.md | 22 +- transformer/RNMTDecoder.py | 68 ++- transformer/RealFormer/Decoder.py | 51 +- transformer/RealFormer/Encoder.py | 29 +- transformer/RetrAttn/Decoder.py | 22 +- transformer/RetrAttn/Encoder.py | 21 +- transformer/SC/Decoder.py | 65 ++- transformer/SC/Encoder.py | 18 +- transformer/SC/NMT.py | 28 +- transformer/SDU/Decoder.py | 37 ++ transformer/SDU/Encoder.py | 36 ++ transformer/SDU/__init__.py | 1 + transformer/TA/Decoder.py | 34 +- transformer/TA/Encoder.py | 44 +- transformer/UniEncoder.py | 27 +- utils/README.md | 5 + utils/aan.py | 3 +- utils/angle.py | 12 +- utils/base.py | 506 ++++-------------- utils/comm.py | 3 +- utils/contpara.py | 8 +- utils/decode/__init__.py | 1 + utils/decode/base.py | 26 + utils/decode/beam.py | 29 + utils/dynbatch.py | 22 +- utils/fmt/ape/triple.py | 32 +- utils/fmt/base.py | 301 ++++++----- utils/fmt/base4torch.py | 11 +- utils/fmt/char/__init__.py | 1 + utils/fmt/char/dual.py | 8 + utils/fmt/char/single.py | 8 + utils/fmt/diff.py | 60 +++ utils/fmt/doc/base.py | 8 +- utils/fmt/doc/para/dual.py | 28 +- utils/fmt/doc/para/many.py | 72 +++ utils/fmt/doc/para/single.py | 15 +- utils/fmt/dual.py | 32 +- utils/fmt/json.py | 8 +- utils/fmt/lang/__init__.py | 1 + utils/fmt/lang/zh/__init__.py | 1 + utils/fmt/lang/zh/deseg.py | 8 + utils/fmt/lang/zh/t2s.py | 24 + utils/fmt/many.py | 67 +++ utils/fmt/manyvalue.py | 62 +++ utils/fmt/mulang/eff/char/__init__.py | 1 + utils/fmt/mulang/eff/char/dual.py | 53 ++ utils/fmt/mulang/eff/char/single.py | 46 ++ utils/fmt/mulang/eff/dual.py | 36 +- utils/fmt/mulang/eff/many.py | 78 +++ utils/fmt/mulang/eff/single.py | 22 +- utils/fmt/plm/bart/__init__.py | 1 + utils/fmt/plm/bart/dual.py | 50 ++ utils/fmt/plm/bert/base.py | 7 +- utils/fmt/plm/bert/dual.py | 7 +- utils/fmt/plm/bert/single.py | 7 +- utils/fmt/plm/dual.py | 26 +- utils/fmt/plm/dual_reg.py | 46 ++ utils/fmt/plm/mbart/__init__.py | 1 + utils/fmt/plm/mbart/dual.py | 53 ++ utils/fmt/plm/roberta/base.py | 6 +- utils/fmt/plm/roberta/dual.py | 7 +- utils/fmt/plm/roberta/dual_reg.py | 9 + utils/fmt/plm/roberta/single.py | 7 +- utils/fmt/plm/single.py | 16 +- utils/fmt/plm/t5/dual.py | 7 +- utils/fmt/plm/t5/single.py | 7 +- utils/fmt/plm/token.py | 48 +- utils/fmt/raw/__init__.py | 1 + utils/fmt/raw/cachepath.py | 28 + utils/fmt/raw/reader/__init__.py | 1 + utils/fmt/raw/reader/sort/__init__.py | 1 + utils/fmt/raw/reader/sort/many.py | 40 ++ utils/fmt/raw/reader/sort/single.py | 46 ++ utils/fmt/raw/reader/sort/tag.py | 46 ++ utils/fmt/single.py | 22 +- utils/fmt/triple.py | 35 +- utils/fmt/u8.py | 33 ++ utils/fmt/vocab/__init__.py | 1 + utils/fmt/vocab/base.py | 60 +++ utils/fmt/vocab/char.py | 23 + utils/fmt/vocab/token.py | 203 +++++++ utils/func.py | 5 + utils/h5serial.py | 30 +- utils/init/base.py | 20 +- utils/io.py | 128 +++++ utils/math.py | 13 + utils/mulang.py | 35 +- utils/plm/bart.py | 89 +++ utils/plm/base.py | 5 +- utils/plm/t5.py | 3 +- utils/process.py | 28 + utils/relpos/base.py | 15 +- utils/relpos/bucket.py | 4 +- utils/retrattn.py | 3 +- utils/sampler.py | 2 +- utils/server/__init__.py | 1 + utils/server/batcher.py | 132 +++++ utils/state/thrand.py | 2 +- utils/thread.py | 53 ++ utils/torch/__init__.py | 1 + utils/torch/c.py | 26 + utils/torch/comp.py | 208 +++++++ utils/{torch.py => torch/ext.py} | 10 +- utils/{pyctorch.py => torch/pyc.py} | 8 +- utils/train/__init__.py | 1 + utils/train/base.py | 93 ++++ utils/train/dss.py | 19 + 404 files changed, 7343 insertions(+), 3372 deletions(-) create mode 100644 adv/examples/plm/mbart.py create mode 100644 adv/predict/plm/roberta/predict_reg.py rename adv/train/prompt/roberta/{train_single.py => train.py} (88%) create mode 100644 adv/train/prompt/roberta/train_reg.py create mode 100644 cnfg/plm/mbart/base.py create mode 100644 cnfg/plm/mbart/hyp.py create mode 100644 cnfg/plm/mbart/ihyp.py create mode 100755 cnfg/server.py create mode 100644 cnfg/vocab/plm/mbart.py create mode 100644 loss/regression.py delete mode 100644 modules/plm/bert.py create mode 100644 modules/plm/mbart.py create mode 100644 modules/sdu.py create mode 100644 modules/server/__init__.py rename translator.py => modules/server/transformer.py (80%) create mode 120000 tools/char/cnfg create mode 100644 tools/char/mkiodata.py create mode 100644 tools/char/mktest.py create mode 120000 tools/char/utils create mode 100644 tools/check/mkrandio.py create mode 100644 tools/clean/normu8.py create mode 120000 tools/lang/zh/cnfg create mode 100644 tools/lang/zh/deseg.py create mode 120000 tools/lang/zh/utils create mode 120000 tools/mulang/vocab/char/cnfg create mode 100644 tools/mulang/vocab/char/share.py create mode 100644 tools/mulang/vocab/char/single.py create mode 120000 tools/mulang/vocab/char/utils create mode 120000 tools/mulang/vocab/token/cnfg rename tools/mulang/{share_vocab.py => vocab/token/share.py} (82%) rename tools/mulang/{vocab.py => vocab/token/single.py} (83%) create mode 120000 tools/mulang/vocab/token/utils create mode 100644 tools/plm/map/mbart.py create mode 100644 tools/plm/map/mbart50.py create mode 100644 tools/plm/mback/mbart.py create mode 100644 tools/plm/mback/mbart50.py create mode 100644 tools/plm/mkiodata_reg.py create mode 120000 tools/plm/token/cnfg create mode 100644 tools/plm/token/mbart.py create mode 100644 tools/plm/token/mbart50.py create mode 120000 tools/plm/token/utils create mode 120000 tools/vocab/char/cnfg create mode 100644 tools/vocab/char/filter.py create mode 100644 tools/vocab/char/merge.py create mode 100644 tools/vocab/char/share.py create mode 100644 tools/vocab/char/single.py create mode 120000 tools/vocab/char/utils create mode 120000 tools/vocab/cnfg create mode 100644 tools/vocab/map.py create mode 120000 tools/vocab/token/cnfg create mode 100644 tools/vocab/token/filter.py create mode 100644 tools/vocab/token/merge.py rename tools/{share_vocab.py => vocab/token/share.py} (74%) rename tools/{vocab.py => vocab/token/single.py} (76%) create mode 120000 tools/vocab/token/utils create mode 120000 tools/vocab/utils create mode 100644 transformer/PLM/MBART/Decoder.py create mode 100644 transformer/PLM/MBART/Encoder.py create mode 100644 transformer/PLM/MBART/NMT.py create mode 100644 transformer/PLM/MBART/__init__.py create mode 100644 transformer/PLM/NMT.py create mode 100644 transformer/SDU/Decoder.py create mode 100644 transformer/SDU/Encoder.py create mode 100644 transformer/SDU/__init__.py create mode 100644 utils/README.md create mode 100644 utils/decode/__init__.py create mode 100644 utils/decode/base.py create mode 100644 utils/decode/beam.py create mode 100644 utils/fmt/char/__init__.py create mode 100644 utils/fmt/char/dual.py create mode 100644 utils/fmt/char/single.py create mode 100644 utils/fmt/diff.py create mode 100644 utils/fmt/doc/para/many.py create mode 100644 utils/fmt/lang/__init__.py create mode 100644 utils/fmt/lang/zh/__init__.py create mode 100644 utils/fmt/lang/zh/deseg.py create mode 100644 utils/fmt/lang/zh/t2s.py create mode 100644 utils/fmt/many.py create mode 100644 utils/fmt/manyvalue.py create mode 100644 utils/fmt/mulang/eff/char/__init__.py create mode 100644 utils/fmt/mulang/eff/char/dual.py create mode 100644 utils/fmt/mulang/eff/char/single.py create mode 100644 utils/fmt/mulang/eff/many.py create mode 100644 utils/fmt/plm/bart/__init__.py create mode 100644 utils/fmt/plm/bart/dual.py create mode 100644 utils/fmt/plm/dual_reg.py create mode 100644 utils/fmt/plm/mbart/__init__.py create mode 100644 utils/fmt/plm/mbart/dual.py create mode 100644 utils/fmt/plm/roberta/dual_reg.py create mode 100644 utils/fmt/raw/__init__.py create mode 100644 utils/fmt/raw/cachepath.py create mode 100644 utils/fmt/raw/reader/__init__.py create mode 100644 utils/fmt/raw/reader/sort/__init__.py create mode 100644 utils/fmt/raw/reader/sort/many.py create mode 100644 utils/fmt/raw/reader/sort/single.py create mode 100644 utils/fmt/raw/reader/sort/tag.py create mode 100644 utils/fmt/u8.py create mode 100644 utils/fmt/vocab/__init__.py create mode 100644 utils/fmt/vocab/base.py create mode 100644 utils/fmt/vocab/char.py create mode 100644 utils/fmt/vocab/token.py create mode 100644 utils/func.py create mode 100644 utils/io.py create mode 100644 utils/plm/bart.py create mode 100644 utils/process.py create mode 100644 utils/server/__init__.py create mode 100755 utils/server/batcher.py create mode 100644 utils/thread.py create mode 100644 utils/torch/__init__.py create mode 100644 utils/torch/c.py create mode 100644 utils/torch/comp.py rename utils/{torch.py => torch/ext.py} (87%) rename utils/{pyctorch.py => torch/pyc.py} (68%) create mode 100644 utils/train/__init__.py create mode 100644 utils/train/base.py create mode 100644 utils/train/dss.py diff --git a/README.md b/README.md index b80e319..1854296 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ Tokenized case-sensitive BLEU measured with [multi-bleu.perl](https://github.com ## Acknowledgments -Hongfei Xu enjoys a doctoral grant from [China Scholarship Council](https://www.csc.edu.cn/) ([2018]3101, 201807040056) while maintaining this project. +Hongfei Xu is partially supported by the Education Department of Henan Province (Grant No. 232300421386) while maintaining this project. Details of this project can be found [here](https://arxiv.org/abs/1903.07402), and please cite it if you enjoy the implementation :) diff --git a/adv/eva/eva_probe.py b/adv/eva/eva_probe.py index a3bf2ed..c39c6d5 100644 --- a/adv/eva/eva_probe.py +++ b/adv/eva/eva_probe.py @@ -1,24 +1,22 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - -from utils.h5serial import h5File - -import cnfg.probe as cnfg -from cnfg.ihyp import * - -from transformer.Probe.NMT import NMT from loss.base import LabelSmoothingLoss from parallel.base import DataParallelCriterion from parallel.parallelMT import DataParallelMT +from transformer.Probe.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base4torch import parse_cuda +from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm -from utils.base import * +import cnfg.probe as cnfg +from cnfg.ihyp import * from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda probe_reorder = cnfg.probe_reorder @@ -37,7 +35,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -48,7 +46,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, ind_shift, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -71,7 +69,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): nword = td["nword"][()].tolist() nwordi, nwordt = nword[0], nword[-1] -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) mymodel = load_model_cpu(sys.argv[2], mymodel) mymodel.apply(load_fixing) @@ -100,6 +98,9 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + use_amp = cnfg.use_amp and use_cuda vloss, vprec = eva(td, ntest, mymodel, lossf, cuda_device, multi_gpu, use_amp) diff --git a/adv/eva/prompt/roberta/eva_single.py b/adv/eva/prompt/roberta/eva_single.py index 486ed5e..2629ef4 100644 --- a/adv/eva/prompt/roberta/eva_single.py +++ b/adv/eva/prompt/roberta/eva_single.py @@ -1,25 +1,23 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - -from utils.h5serial import h5File - -import cnfg.prompt.roberta.base as cnfg -from cnfg.prompt.roberta.ihyp import * -from cnfg.vocab.plm.roberta import vocab_size - -from transformer.Prompt.RoBERTa.NMT import NMT from loss.base import NLLLoss from parallel.base import DataParallelCriterion from parallel.parallelMT import DataParallelMT - -from utils.base import * +from transformer.Prompt.RoBERTa.NMT import NMT +from utils.base import set_random_seed from utils.fmt.base4torch import parse_cuda from utils.fmt.plm.base import fix_parameter_name +from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm + +import cnfg.prompt.roberta.base as cnfg +from cnfg.prompt.roberta.ihyp import * +from cnfg.vocab.plm.roberta import vocab_size def load_fixing(module): @@ -31,7 +29,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -40,7 +38,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_batch = seq_batch.to(mv_device, non_blocking=True) seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch) loss = lossf(output, seq_o) if multi_gpu: @@ -55,12 +53,9 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): w = float(w) return sum_loss / w, (w - r) / w * 100.0 -td = h5File(sys.argv[1], "r") - -ntest = td["ndata"][()].item() nwordi = nwordt = vocab_size -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) # important to load the pre-trained model, as the load_plm function not only load parameters, but also may introduce new parameters, which affects the parameter alignment. pre_trained_m = cnfg.pre_trained_m @@ -93,10 +88,12 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) -use_amp = cnfg.use_amp and use_cuda +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) -vloss, vprec = eva(td, ntest, mymodel, lossf, cuda_device, multi_gpu, use_amp) +use_amp = cnfg.use_amp and use_cuda -td.close() +with h5File(sys.argv[1], "r") as td: + vloss, vprec = eva(td, td["ndata"][()].item(), mymodel, lossf, cuda_device, multi_gpu, use_amp) print("loss/error: %.3f %.2f" % (vloss, vprec,)) diff --git a/adv/examples/plm/bart.py b/adv/examples/plm/bart.py index 9245d86..d059db9 100644 --- a/adv/examples/plm/bart.py +++ b/adv/examples/plm/bart.py @@ -1,18 +1,16 @@ #encoding: utf-8 import torch +from transformers import BartModel +from transformer.PLM.BART.NMT import NMT from utils.fmt.plm.base import fix_parameter_name -from utils.fmt.plm.roberta.base import ldvocab +from utils.torch.comp import torch_inference_mode import cnfg.plm.bart.base as cnfg from cnfg.plm.bart.ihyp import * from cnfg.vocab.plm.roberta import vocab_size -from transformer.PLM.BART.NMT import NMT - -from transformers import BartModel - def init_fixing(module): if hasattr(module, "fix_init"): @@ -20,7 +18,7 @@ def init_fixing(module): print("load pre-trained models") -tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) tmod.apply(init_fixing) tmod.load_plm(fix_parameter_name(torch.load("plm/bart-base/pytorch_model.bin", map_location="cpu"))) @@ -34,7 +32,7 @@ def init_fixing(module): tde = torch.as_tensor([0, 100, 50264, 15162, 4, 2], dtype=torch.long).unsqueeze(0) tdo = torch.as_tensor([2, 100, 33, 41, 15162, 4, 2], dtype=torch.long).unsqueeze(0) -with torch.no_grad(): +with torch_inference_mode(): ers = smod(input_ids=tde, decoder_input_ids=tdo).last_hidden_state print("forward for test") trs = tmod(tde, tdo) diff --git a/adv/examples/plm/bert.py b/adv/examples/plm/bert.py index ee23942..2e4d20f 100644 --- a/adv/examples/plm/bert.py +++ b/adv/examples/plm/bert.py @@ -1,18 +1,16 @@ #encoding: utf-8 import torch +from transformers import BertModel +from transformer.PLM.BERT.NMT import NMT from utils.fmt.plm.base import fix_parameter_name -from utils.fmt.plm.bert.base import ldvocab +from utils.torch.comp import torch_inference_mode import cnfg.plm.bert.base as cnfg from cnfg.plm.bert.ihyp import * from cnfg.vocab.plm.bert import vocab_size -from transformer.PLM.BERT.NMT import NMT - -from transformers import BertModel - def init_fixing(module): if hasattr(module, "fix_init"): @@ -20,7 +18,7 @@ def init_fixing(module): print("load pre-trained models") -tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) tmod.apply(init_fixing) tmod.load_plm(fix_parameter_name(torch.load("plm/bert-base-cased/pytorch_model.bin", map_location="cpu"))) @@ -33,7 +31,7 @@ def init_fixing(module): print("forward with transformers") td = torch.as_tensor([101, 146, 1138, 1126, 12075, 119, 102], dtype=torch.long).unsqueeze(0) -with torch.no_grad(): +with torch_inference_mode(): ers = smod(td).last_hidden_state print("forward for test") trs = tmod(td) diff --git a/adv/examples/plm/mbart.py b/adv/examples/plm/mbart.py new file mode 100644 index 0000000..9acec67 --- /dev/null +++ b/adv/examples/plm/mbart.py @@ -0,0 +1,47 @@ +#encoding: utf-8 + +import torch +from transformers import MBartForConditionalGeneration, MBartTokenizerFast as Tokenizer + +from transformer.PLM.MBART.NMT import NMT +from utils.fmt.plm.base import fix_parameter_name +from utils.torch.comp import torch_inference_mode + +import cnfg.plm.mbart.base as cnfg +from cnfg.plm.mbart.ihyp import * +from cnfg.vocab.plm.mbart import vocab_size + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +print("load pre-trained models") +tokenizer = Tokenizer(tokenizer_file="plm/mbart-large-cc25/tokenizer.json") + +tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +tmod.apply(init_fixing) +tmod.load_plm(fix_parameter_name(torch.load("plm/mbart-large-cc25/pytorch_model.bin", map_location="cpu"))) +tmod.eval() + +print("load models with transformers") +smod = MBartForConditionalGeneration.from_pretrained("plm/mbart-large-cc25") +smod.eval() + +print("forward with transformers") +tde = torch.as_tensor([17, 765, 142, 108787, 5, 2, 250004], dtype=torch.long).unsqueeze(0) +tdo = torch.as_tensor([250004, 17, 765, 142, 108787, 5, 2], dtype=torch.long).unsqueeze(0) + +print("forward for test") +with torch_inference_mode(): + ers = smod(input_ids=tde, decoder_input_ids=tdo, output_hidden_states=True).decoder_hidden_states[-1] + trs = tmod(tde, tdo) + +print(ers) +print(trs) + +with torch_inference_mode(): + ers = smod.generate(tde, decoder_start_token_id=250004) + trs = tmod.decode(tde, lang_id=250004) +print(tokenizer.convert_ids_to_tokens(ers.squeeze(0))) +print(tokenizer.convert_ids_to_tokens(trs.squeeze(0))) diff --git a/adv/examples/plm/roberta.py b/adv/examples/plm/roberta.py index 02d1b0a..dcd4db3 100644 --- a/adv/examples/plm/roberta.py +++ b/adv/examples/plm/roberta.py @@ -1,18 +1,16 @@ #encoding: utf-8 import torch +from transformers import RobertaModel +from transformer.PLM.RoBERTa.NMT import NMT from utils.fmt.plm.base import fix_parameter_name -from utils.fmt.plm.roberta.base import ldvocab +from utils.torch.comp import torch_inference_mode import cnfg.plm.roberta.base as cnfg from cnfg.plm.roberta.ihyp import * from cnfg.vocab.plm.roberta import vocab_size -from transformer.PLM.RoBERTa.NMT import NMT - -from transformers import RobertaModel - def init_fixing(module): if hasattr(module, "fix_init"): @@ -20,7 +18,7 @@ def init_fixing(module): print("load pre-trained models") -tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) tmod.apply(init_fixing) tmod.load_plm(fix_parameter_name(torch.load("plm/roberta-base/pytorch_model.bin", map_location="cpu"))) @@ -33,7 +31,7 @@ def init_fixing(module): print("forward with transformers") td = torch.as_tensor([0, 100, 33, 41, 15162, 4, 2], dtype=torch.long).unsqueeze(0) -with torch.no_grad(): +with torch_inference_mode(): ers = smod(td).last_hidden_state print("forward for test") trs = tmod(td) diff --git a/adv/examples/plm/t5.py b/adv/examples/plm/t5.py index 7a6fda4..66aa9e6 100644 --- a/adv/examples/plm/t5.py +++ b/adv/examples/plm/t5.py @@ -1,17 +1,16 @@ #encoding: utf-8 import torch +from transformers import T5ForConditionalGeneration, T5TokenizerFast as Tokenizer +from transformer.PLM.T5.NMT import NMT from utils.fmt.plm.base import fix_parameter_name +from utils.torch.comp import torch_inference_mode import cnfg.plm.t5.base as cnfg from cnfg.plm.t5.ihyp import * from cnfg.vocab.plm.t5 import vocab_size -from transformer.PLM.T5.NMT import NMT - -from transformers import T5ForConditionalGeneration, T5TokenizerFast as Tokenizer - def init_fixing(module): if hasattr(module, "fix_init"): @@ -20,7 +19,7 @@ def init_fixing(module): print("load pre-trained models") tokenizer = Tokenizer(tokenizer_file="plm/t5-base/tokenizer.json") -tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +tmod = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) tmod.apply(init_fixing) tmod.load_plm(fix_parameter_name(torch.load("plm/t5-base/pytorch_model.bin", map_location="cpu"))) @@ -34,7 +33,7 @@ def init_fixing(module): tde = torch.as_tensor([27, 43, 192, 16981, 5, 1], dtype=torch.long).unsqueeze(0) tdo = torch.as_tensor([0, 531, 25, 241, 80, 58], dtype=torch.long).unsqueeze(0) -with torch.no_grad(): +with torch_inference_mode(): ers = smod(input_ids=tde, decoder_input_ids=tdo, output_hidden_states=True).decoder_hidden_states[-1] print("forward for test") trs = tmod(tde, tdo) @@ -42,7 +41,7 @@ def init_fixing(module): print(trs) tde = torch.as_tensor([27, 43, 32099, 16981, 5, 32098, 241, 80, 58, 1], dtype=torch.long).unsqueeze(0) -with torch.no_grad(): +with torch_inference_mode(): ers = smod.generate(tde) trs = tmod.decode(tde) print(tokenizer.convert_ids_to_tokens(ers.squeeze(0))) diff --git a/adv/predict/plm/bart/predict.py b/adv/predict/plm/bart/predict.py index 8a0e8db..7d52fa1 100644 --- a/adv/predict/plm/bart/predict.py +++ b/adv/predict/plm/bart/predict.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch from transformers import BartTokenizerFast as Tokenizer -from utils.tqdm import tqdm +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.PLM.BART.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.plm.base import fix_parameter_name from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.plm.bart.base as cnfg from cnfg.plm.bart.ihyp import * -from cnfg.vocab.plm.roberta import sos_id, eos_id, vocab_size - -from transformer.PLM.BART.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base4torch import parse_cuda_decode -from utils.fmt.plm.base import fix_parameter_name +from cnfg.vocab.plm.roberta import eos_id, sos_id, vocab_size def init_fixing(module): @@ -30,30 +30,37 @@ def load_fixing(module): if hasattr(module, "fix_load"): module.fix_load() -td = h5File(cnfg.test_data, "r") - -ntest = td["ndata"][()].item() detoken = Tokenizer(tokenizer_file=sys.argv[2]).decode +pre_trained_m = cnfg.pre_trained_m _num_args = len(sys.argv) if _num_args < 4: - mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) mymodel.apply(init_fixing) - pre_trained_m = cnfg.pre_trained_m if pre_trained_m is not None: print("Load pre-trained model from: " + pre_trained_m) mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) elif _num_args == 4: - mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) else: models = [] + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + _ = fix_parameter_name(torch.load(pre_trained_m, map_location="cpu")) for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + tmp = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + if pre_trained_m is not None: + tmp.load_plm(_) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) models.append(tmp) + if pre_trained_m is not None: + _ = None mymodel = Ensemble(models) mymodel.eval() @@ -69,19 +76,20 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") - -src_grp = td["src"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): - for i in tqdm(range(ntest), mininterval=tqdm_mininterval): +with sys_open(sys.argv[1], "wb") as f, h5File(cnfg.test_data, "r") as td, torch_inference_mode(): + src_grp = td["src"] + for i in tqdm(range(td["ndata"][()].item()), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) if cuda_device: seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch, beam_size, None, length_penalty) if multi_gpu: tmp = [] @@ -99,5 +107,3 @@ def load_fixing(module): break f.write(detoken(tmp, skip_special_tokens=True, clean_up_tokenization_spaces=False).encode("utf-8")) f.write(ens) - -td.close() diff --git a/adv/predict/plm/roberta/predict.py b/adv/predict/plm/roberta/predict.py index 43b3418..5628b49 100644 --- a/adv/predict/plm/roberta/predict.py +++ b/adv/predict/plm/roberta/predict.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.Prompt.RoBERTa.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.plm.base import fix_parameter_name from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.prompt.roberta.base as cnfg from cnfg.prompt.roberta.ihyp import * from cnfg.vocab.plm.roberta import vocab_size -from transformer.Prompt.RoBERTa.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base4torch import parse_cuda_decode -from utils.fmt.plm.base import fix_parameter_name - def init_fixing(module): if hasattr(module, "fix_init"): @@ -29,15 +29,12 @@ def load_fixing(module): if hasattr(module, "fix_load"): module.fix_load() -td = h5File(cnfg.test_data, "r") - -ntest = td["ndata"][()].item() nwordi = nwordt = vocab_size pre_trained_m = cnfg.pre_trained_m _num_args = len(sys.argv) if _num_args == 3: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) if pre_trained_m is not None: print("Load pre-trained model from: " + pre_trained_m) mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) @@ -51,7 +48,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[2:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) if pre_trained_m is not None: print("Load pre-trained model from: " + pre_trained_m) mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) @@ -76,16 +73,17 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) -ens = "\n".encode("utf-8") +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) -src_grp = td["src"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): - for i in tqdm(range(ntest), mininterval=tqdm_mininterval): +ens = "\n".encode("utf-8") +with sys_open(sys.argv[1], "wb") as f, h5File(cnfg.test_data, "r") as td, torch_inference_mode(): + src_grp = td["src"] + for i in tqdm(range(td["ndata"][()].item()), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) if cuda_device: seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch) if multi_gpu: tmp = [] @@ -96,5 +94,3 @@ def load_fixing(module): output = output.argmax(-1).tolist() f.write("\n".join([str(_) for _ in output]).encode("utf-8")) f.write(ens) - -td.close() diff --git a/adv/predict/plm/roberta/predict_reg.py b/adv/predict/plm/roberta/predict_reg.py new file mode 100644 index 0000000..5b3b00d --- /dev/null +++ b/adv/predict/plm/roberta/predict_reg.py @@ -0,0 +1,106 @@ +#encoding: utf-8 + +import sys +import torch +from torch import nn + +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.Prompt.RoBERTa.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.plm.base import fix_parameter_name +from utils.func import identity_func +from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm + +import cnfg.prompt.roberta.base as cnfg +from cnfg.prompt.roberta.ihyp import * +from cnfg.vocab.plm.roberta import vocab_size + +reg_weight, reg_bias = cnfg.reg_weight, cnfg.reg_bias +if reg_weight is None: + scale_func = identity_func if reg_bias is None else (lambda x: x - reg_bias) +else: + scale_func = (lambda x: x / reg_weight) if reg_bias is None else (lambda x: x / reg_weight - reg_bias) + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +def load_fixing(module): + + if hasattr(module, "fix_load"): + module.fix_load() + +nwordi = nwordt = vocab_size + +pre_trained_m = cnfg.pre_trained_m +_num_args = len(sys.argv) +if _num_args == 3: + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel.dec.lsm = nn.Softmax(-1) + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) + if (cnfg.classifier_indices is not None) and hasattr(mymodel, "update_classifier"): + print("Build new classifier") + mymodel.update_classifier(torch.as_tensor(cnfg.classifier_indices, dtype=torch.long)) + fine_tune_m = sys.argv[2] + print("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + mymodel.apply(load_fixing) +else: + models = [] + for modelf in sys.argv[2:]: + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + tmp.dec.lsm = nn.Softmax(-1) + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) + if (cnfg.classifier_indices is not None) and hasattr(mymodel, "update_classifier"): + print("Build new classifier") + mymodel.update_classifier(torch.as_tensor(cnfg.classifier_indices, dtype=torch.long)) + print("Load pre-trained model from: " + modelf) + tmp = load_model_cpu(modelf, tmp) + tmp.apply(load_fixing) + models.append(tmp) + mymodel = Ensemble(models) + +mymodel.eval() + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding) +use_amp = cnfg.use_amp and use_cuda + +set_random_seed(cnfg.seed, use_cuda) + +if cuda_device: + mymodel.to(cuda_device, non_blocking=True) + if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + +ens = "\n".encode("utf-8") +with sys_open(sys.argv[1], "wb") as f, h5File(cnfg.test_data, "r") as td, torch_inference_mode(): + src_grp = td["src"] + for i in tqdm(range(td["ndata"][()].item()), mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(src_grp[str(i)][()]) + if cuda_device: + seq_batch = seq_batch.to(cuda_device, non_blocking=True) + seq_batch = seq_batch.long() + with torch_autocast(enabled=use_amp): + output = mymodel(seq_batch) + if multi_gpu: + tmp = [] + for ou in output: + tmp.extend(scale_func(ou.select(-1, -1)).tolist()) + output = tmp + else: + output = scale_func(output.select(-1, -1)).tolist() + f.write("\n".join([str(_) for _ in output]).encode("utf-8")) + f.write(ens) diff --git a/adv/predict/plm/t5/predict.py b/adv/predict/plm/t5/predict.py index d4e70ff..9b80a7c 100644 --- a/adv/predict/plm/t5/predict.py +++ b/adv/predict/plm/t5/predict.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch from transformers import T5TokenizerFast as Tokenizer -from utils.tqdm import tqdm +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.PLM.T5.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.plm.base import fix_parameter_name from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.plm.t5.base as cnfg from cnfg.plm.t5.ihyp import * -from cnfg.vocab.plm.t5 import sos_id, eos_id, vocab_size - -from transformer.PLM.T5.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.plm.base import fix_parameter_name -from utils.fmt.base4torch import parse_cuda_decode +from cnfg.vocab.plm.t5 import eos_id, vocab_size def init_fixing(module): @@ -30,30 +30,37 @@ def load_fixing(module): if hasattr(module, "fix_load"): module.fix_load() -td = h5File(cnfg.test_data, "r") - -ntest = td["ndata"][()].item() detoken = Tokenizer(tokenizer_file=sys.argv[2]).decode +pre_trained_m = cnfg.pre_trained_m _num_args = len(sys.argv) if _num_args < 4: - mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) mymodel.apply(init_fixing) - pre_trained_m = cnfg.pre_trained_m if pre_trained_m is not None: print("Load pre-trained model from: " + pre_trained_m) mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) elif _num_args == 4: - mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + mymodel = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) else: models = [] + if pre_trained_m is not None: + print("Load pre-trained model from: " + pre_trained_m) + _ = fix_parameter_name(torch.load(pre_trained_m, map_location="cpu")) for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + tmp = NMT(cnfg.isize, vocab_size, vocab_size, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + if pre_trained_m is not None: + tmp.load_plm(_) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) models.append(tmp) + if pre_trained_m is not None: + _ = None mymodel = Ensemble(models) mymodel.eval() @@ -68,19 +75,20 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") - -src_grp = td["src"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): - for i in tqdm(range(ntest), mininterval=tqdm_mininterval): +with sys_open(sys.argv[1], "wb") as f, h5File(cnfg.test_data, "r") as td, torch_inference_mode(): + src_grp = td["src"] + for i in tqdm(range(td["ndata"][()].item()), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) if cuda_device: seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch, beam_size, None, length_penalty) if multi_gpu: tmp = [] @@ -98,5 +106,3 @@ def load_fixing(module): break f.write(detoken(tmp, skip_special_tokens=False, clean_up_tokenization_spaces=False).encode("utf-8")) f.write(ens) - -td.close() diff --git a/adv/predict/predict_ape.py b/adv/predict/predict_ape.py index 64b1636..6557c6e 100644 --- a/adv/predict/predict_ape.py +++ b/adv/predict/predict_ape.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.APE.NMT import NMT +from transformer.EnsembleNMT import NMT as Ensemble +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.APE.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base import ldvocab, reverse_dict from cnfg.vocab.base import eos_id -from utils.fmt.base4torch import parse_cuda_decode def load_fixing(module): @@ -33,7 +33,7 @@ def load_fixing(module): vcbt = reverse_dict(vcbt) if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -41,7 +41,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -62,6 +62,8 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty @@ -69,7 +71,7 @@ def load_fixing(module): # using tgt instead of mt since data are processed by tools/mkiodata.py for the mt task src_grp, mt_grp = td["src"], td["tgt"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) seq_mt = torch.from_numpy(mt_grp[str(i)][()]) @@ -77,7 +79,7 @@ def load_fixing(module): seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_mt = seq_mt.to(cuda_device, non_blocking=True) seq_batch, seq_mt = seq_batch.long(), seq_mt.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch, seq_mt, beam_size, None, length_penalty) if multi_gpu: tmp = [] diff --git a/adv/predict/predict_doc_para.py b/adv/predict/predict_doc_para.py index b323a29..2558657 100644 --- a/adv/predict/predict_doc_para.py +++ b/adv/predict/predict_doc_para.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.Doc.Para.Base.NMT import NMT +from transformer.EnsembleNMT import NMT as Ensemble +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.docpara as cnfg from cnfg.ihyp import * - -from transformer.Doc.Para.Base.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base import ldvocab, reverse_dict from cnfg.vocab.base import eos_id -from utils.fmt.base4torch import parse_cuda_decode def load_fixing(module): @@ -33,7 +33,7 @@ def load_fixing(module): vcbt = reverse_dict(vcbt) if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -41,7 +41,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -62,6 +62,8 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + #num_prev_sent = cnfg.num_prev_sent beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty @@ -70,7 +72,7 @@ def load_fixing(module): ens_skip = "\n".encode("utf-8")#.join(["\n" for i in range(num_prev_sent)]) src_grp = td["src"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for nsent, i_d in tqdm(tl, mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[nsent][i_d][()]) if cuda_device: @@ -78,7 +80,7 @@ def load_fixing(module): seq_batch = seq_batch.long() bsize, _nsent, seql = seq_batch.size() _nsent_use = _nsent - 1 - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch.narrow(1, 1, _nsent_use).contiguous(), seq_batch.narrow(1, 0, _nsent_use).contiguous(), beam_size, None, length_penalty).view(bsize, _nsent_use, -1) if multi_gpu: tmp = [] diff --git a/adv/predict/predict_mulang.py b/adv/predict/predict_mulang.py index e92318a..a2d3b1b 100644 --- a/adv/predict/predict_mulang.py +++ b/adv/predict/predict_mulang.py @@ -1,24 +1,24 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.MuLang.Eff.Base.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.mulang as cnfg from cnfg.ihyp import * - -from transformer.MuLang.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base import ldvocab, reverse_dict from cnfg.vocab.base import eos_id -from utils.fmt.base4torch import parse_cuda_decode def load_fixing(module): @@ -33,7 +33,7 @@ def load_fixing(module): vcbt = reverse_dict(vcbt) if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -41,7 +41,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -61,6 +61,8 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty @@ -68,13 +70,13 @@ def load_fixing(module): ntest = [(str(i), _task,) for _nd, _task in zip(ntest, td["taskorder"][()].tolist()) for i in range(_nd)] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i_d, taskid in tqdm(ntest, mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(td[str(taskid)]["src"][i_d][()]) if cuda_device: seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch, taskid, beam_size, None, length_penalty) if multi_gpu: tmp = [] diff --git a/adv/predict/predict_probe_enc.py b/adv/predict/predict_probe_enc.py index e4c02d7..766e570 100644 --- a/adv/predict/predict_probe_enc.py +++ b/adv/predict/predict_probe_enc.py @@ -3,21 +3,22 @@ # usage: python file_name.py test.eva.h5 model.h5 tgt.vcb $rsf.txt import sys - import torch -from utils.tqdm import tqdm - +from transformer.Probe.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import init_vocab, ldvocab from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_inference_mode +from utils.tqdm import tqdm import cnfg.probe as cnfg from cnfg.ihyp import * - -from transformer.Probe.NMT import NMT - -from utils.base import * -from utils.fmt.base import ldvocab, reverse_dict, init_vocab, sos_id, eos_id -from utils.fmt.base4torch import parse_cuda_decode +from cnfg.vocab.base import eos_id, pad_id, sos_id def load_fixing(module): @@ -31,7 +32,7 @@ def load_fixing(module): vcbt, nwordt = ldvocab(sys.argv[3]) vcbt = reverse_dict(vcbt) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) mymodel = load_model_cpu(sys.argv[2], mymodel) mymodel.apply(load_fixing) @@ -56,15 +57,15 @@ def load_fixing(module): src_grp = td["src"] ens = "\n".encode("utf-8") -with open(sys.argv[4], "wb") as fwrt, torch.no_grad(): +with sys_open(sys.argv[4], "wb") as fwrt, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) if cuda_device: seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - _mask = seq_batch.eq(0) - with autocast(enabled=use_amp): + _mask = seq_batch.eq(pad_id) + with torch_autocast(enabled=use_amp): # mask pad/sos/eos_id in output output = classifier(trans(enc(seq_batch, mask=_mask.unsqueeze(1), no_std_out=True))).argmax(-1).masked_fill(_mask | seq_batch.eq(sos_id) | seq_batch.eq(eos_id), 0).tolist() for tran in output: diff --git a/adv/rank/doc/para/rank_loss_para.py b/adv/rank/doc/para/rank_loss_para.py index 06cdbe4..f792878 100644 --- a/adv/rank/doc/para/rank_loss_para.py +++ b/adv/rank/doc/para/rank_loss_para.py @@ -5,26 +5,24 @@ norm_token = True import sys - import torch -from utils.tqdm import tqdm - +from loss.base import LabelSmoothingLoss +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from transformer.Doc.Para.Base.NMT import NMT +from transformer.EnsembleNMT import NMT as Ensemble +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.docpara as cnfg from cnfg.ihyp import * - -from transformer.Doc.Para.Base.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT -from parallel.base import DataParallelCriterion - -from loss.base import LabelSmoothingLoss - -from utils.base import * from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda def load_fixing(module): @@ -38,7 +36,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -46,7 +44,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -71,12 +69,15 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + ens = "\n".encode("utf-8") num_prev_sent = cnfg.num_prev_sent src_grp, tgt_grp = td["src"]["4"], td["tgt"]["4"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): _curid = str(i) seq_batch = torch.from_numpy(src_grp[_curid][()]) @@ -91,7 +92,7 @@ def load_fixing(module): seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()).view(bsize, _nsent_use, lo, -1) loss = lossf(output, ot).view(bsize, -1).sum(-1) if norm_token: diff --git a/adv/rank/doc/rank_loss_sent.py b/adv/rank/doc/rank_loss_sent.py index 20b06f2..a21564d 100644 --- a/adv/rank/doc/rank_loss_sent.py +++ b/adv/rank/doc/rank_loss_sent.py @@ -5,26 +5,24 @@ norm_token = True import sys - import torch -from utils.tqdm import tqdm - +from loss.base import LabelSmoothingLoss +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT -from parallel.base import DataParallelCriterion - -from loss.base import LabelSmoothingLoss - -from utils.base import * from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda def load_fixing(module): @@ -38,7 +36,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -46,7 +44,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -71,10 +69,13 @@ def load_fixing(module): mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + ens = "\n".encode("utf-8") src_grp, tgt_grp = td["src"]["4"], td["tgt"]["4"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): _curid = str(i) seq_batch = torch.from_numpy(src_grp[_curid][()]) @@ -87,7 +88,7 @@ def load_fixing(module): seq_batch, seq_o = seq_batch.long(), seq_o.long() lo = seq_o.size(-1) - 1 ot = seq_o.narrow(-1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch.view(ebsize, -1), seq_o.narrow(-1, 0, lo).contiguous().view(ebsize, -1)).view(bsize, nsent, lo, -1) loss = lossf(output, ot).view(bsize, -1).sum(-1) if norm_token: diff --git a/adv/train/mulang/train_m2o.py b/adv/train/mulang/train_m2o.py index 782ee72..c3f3cd4 100644 --- a/adv/train/mulang/train_m2o.py +++ b/adv/train/mulang/train_m2o.py @@ -1,43 +1,41 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states +from utils.mulang import data_sampler from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb -from utils.mulang import data_sampler - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.mulang as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT +from cnfg.vocab.base import pad_id def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None for i_d, taskid in tqdm(tl, mininterval=tqdm_mininterval): @@ -52,7 +50,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: @@ -68,7 +66,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[(i_d, taskid,)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -98,6 +96,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -105,7 +113,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -124,7 +132,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): r = w = 0 sum_loss = 0.0 model.eval() - with torch.no_grad(): + with torch_inference_mode(): for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): task_grp = ed[str(taskid)] seq_batch = torch.from_numpy(task_grp["src"][i_d][()]) @@ -135,7 +143,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -211,6 +219,7 @@ def load_fixing(module): if task_weight_T is None or task_weight_T == 1.0: tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][()].tolist()) for i in range(_nd)] train_sampler = None + ntrain = len(tl) else: train_taskorder = td["taskorder"][()].tolist() _tnd = dict(zip(train_taskorder, ntrain)) @@ -218,10 +227,11 @@ def load_fixing(module): ntrain = [_tnd[i] for i in train_taskorder] _tnd = None train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) + ntrain = train_sampler.nsample nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][()].tolist()) for i in range(_nd)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) fine_tune_m = cnfg.fine_tune_m @@ -260,6 +270,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -268,7 +281,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/mulang/train_mulang.py b/adv/train/mulang/train_mulang.py index 18a587e..54844d6 100644 --- a/adv/train/mulang/train_mulang.py +++ b/adv/train/mulang/train_mulang.py @@ -1,43 +1,41 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.MuLang.Eff.Base.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states +from utils.mulang import data_sampler from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb -from utils.mulang import data_sampler - -from lrsch import GoogleLR as LRScheduler -from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.mulang as cnfg from cnfg.ihyp import * - -from transformer.MuLang.NMT import NMT +from cnfg.vocab.base import pad_id def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None for i_d, taskid in tqdm(tl, mininterval=tqdm_mininterval): @@ -52,7 +50,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi, taskid=taskid) loss = lossf(output, ot, lang_id=taskid) if multi_gpu: @@ -68,7 +66,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[(i_d, taskid,)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -98,6 +96,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -105,7 +113,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -124,7 +132,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): r = w = 0 sum_loss = 0.0 model.eval() - with torch.no_grad(): + with torch_inference_mode(): for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): task_grp = ed[str(taskid)] seq_batch = torch.from_numpy(task_grp["src"][i_d][()]) @@ -135,7 +143,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo), taskid=taskid) loss = lossf(output, ot, lang_id=taskid) if multi_gpu: @@ -211,6 +219,7 @@ def load_fixing(module): if task_weight_T is None or task_weight_T == 1.0: tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][()].tolist()) for i in range(_nd)] train_sampler = None + ntrain = len(tl) else: train_taskorder = td["taskorder"][()].tolist() _tnd = dict(zip(train_taskorder, ntrain)) @@ -218,10 +227,11 @@ def load_fixing(module): ntrain = [_tnd[i] for i in train_taskorder] _tnd = None train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) + ntrain = train_sampler.nsample nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][()].tolist()) for i in range(_nd)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask) fine_tune_m = cnfg.fine_tune_m @@ -260,6 +270,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -268,7 +281,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/mulang/train_mulang_robt.py b/adv/train/mulang/train_mulang_robt.py index ca02acd..33c9e0d 100644 --- a/adv/train/mulang/train_mulang_robt.py +++ b/adv/train/mulang/train_mulang_robt.py @@ -1,39 +1,36 @@ #encoding: utf-8 import torch - +from random import randint, shuffle from torch.optim import Adam as Optimizer +from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.MuLang.Eff.Base.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, pad_tensors, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states +from utils.mulang import data_sampler from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb -from utils.mulang import data_sampler - -from lrsch import GoogleLR as LRScheduler -from loss.base import MultiLabelSmoothingLoss as LabelSmoothingLoss - -from random import shuffle, randint - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.mulang as cnfg from cnfg.ihyp import * +from cnfg.vocab.base import pad_id -from transformer.MuLang.NMT import NMT - -def back_translate(model, seq_in, taskid, beam_size, multi_gpu, enable_autocast=False, step_bsize=32, step_ntok=640, pivot_bt=True): +def back_translate(model, seq_in, taskid, beam_size, multi_gpu, enable_torch_autocast=False, step_bsize=32, step_ntok=640, pivot_bt=True): rs = [] bsize, seql = seq_in.size() @@ -43,7 +40,7 @@ def back_translate(model, seq_in, taskid, beam_size, multi_gpu, enable_autocast= _g_out = model.gather_output model.gather_output = True sind = 0 - with torch.no_grad(), autocast(enabled=enable_autocast): + with torch_inference_mode(), torch_autocast(enabled=enable_torch_autocast): while sind < bsize: num_narrow = min(_step_bsize, bsize - sind) if pivot_bt and (taskid != 0): @@ -65,6 +62,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None global ntask, ro_beam_size @@ -79,10 +77,10 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok _bt_taskid = randint(0, t_sample_max_id) if _bt_taskid >= taskid: _bt_taskid += 1 - seq_batch = back_translate(model, seq_o, _bt_taskid, ro_beam_size, multi_gpu, enable_autocast=_use_amp) + seq_batch = back_translate(model, seq_o, _bt_taskid, ro_beam_size, multi_gpu, enable_torch_autocast=_use_amp) oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi, taskid=taskid) loss = lossf(output, ot, lang_id=taskid) if multi_gpu: @@ -98,7 +96,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[(i_d, taskid,)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -128,6 +126,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -135,7 +143,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -154,7 +162,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): r = w = 0 sum_loss = 0.0 model.eval() - with torch.no_grad(): + with torch_inference_mode(): for i_d, taskid in tqdm(nd, mininterval=tqdm_mininterval): task_grp = ed[str(taskid)] seq_batch = torch.from_numpy(task_grp["src"][i_d][()]) @@ -165,7 +173,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo), taskid=taskid) loss = lossf(output, ot, lang_id=taskid) if multi_gpu: @@ -242,6 +250,7 @@ def load_fixing(module): if task_weight_T is None or task_weight_T == 1.0: tl = [(str(i), _task,) for _nd, _task in zip(ntrain, td["taskorder"][()].tolist()) for i in range(_nd)] train_sampler = None + ntrain = len(tl) else: train_taskorder = td["taskorder"][()].tolist() _tnd = dict(zip(train_taskorder, ntrain)) @@ -249,10 +258,11 @@ def load_fixing(module): ntrain = [_tnd[i] for i in train_taskorder] _tnd = None train_sampler = data_sampler(ntrain if task_weight is None else task_weight, task_weight_T, ntrain, train_taskorder, nsample=sum(ntrain)) + ntrain = train_sampler.nsample nvalid = [(str(i), _task,) for _nd, _task in zip(nvalid, vd["taskorder"][()].tolist()) for i in range(_nd)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask, ngroup=cnfg.ngroup) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, ntask=ntask) fine_tune_m = cnfg.fine_tune_m @@ -291,6 +301,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -299,7 +312,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/prompt/roberta/train_single.py b/adv/train/prompt/roberta/train.py similarity index 88% rename from adv/train/prompt/roberta/train_single.py rename to adv/train/prompt/roberta/train.py index ae572d5..4d6dd57 100644 --- a/adv/train/prompt/roberta/train_single.py +++ b/adv/train/prompt/roberta/train.py @@ -1,43 +1,41 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import NLLLoss +from lrsch import CustLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.Prompt.RoBERTa.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.fmt.plm.base import fix_parameter_name +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from utils.fmt.base4torch import parse_cuda, load_emb -from utils.fmt.plm.base import fix_parameter_name - -from lrsch import CustLR as LRScheduler -from loss.base import NLLLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import freeze_module, getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.prompt.roberta.base as cnfg from cnfg.prompt.roberta.ihyp import * from cnfg.vocab.plm.roberta import vocab_size -from transformer.Prompt.RoBERTa.NMT import NMT - def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] @@ -49,7 +47,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch) loss = lossf(output, seq_o) if multi_gpu: @@ -65,7 +63,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -95,6 +93,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -102,7 +110,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -122,7 +130,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -131,7 +139,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_batch = seq_batch.to(mv_device, non_blocking=True) seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch) loss = lossf(output, seq_o) if multi_gpu: @@ -203,7 +211,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) mymodel = init_model_params(mymodel) mymodel.apply(init_fixing) @@ -229,6 +237,9 @@ def load_fixing(module): if cnfg.tgt_emb is not None: logger.info("Load target embedding from: " + cnfg.tgt_emb) load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) +if cnfg.freeze_word_embedding: + logger.info("Freeze word embedding") + freeze_module(mymodel.enc.wemb) if cuda_device: mymodel.to(cuda_device, non_blocking=True) @@ -249,6 +260,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, lr_func=lambda a, b: (init_lr, b,)) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -257,7 +271,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/prompt/roberta/train_reg.py b/adv/train/prompt/roberta/train_reg.py new file mode 100644 index 0000000..bc49fe6 --- /dev/null +++ b/adv/train/prompt/roberta/train_reg.py @@ -0,0 +1,381 @@ +#encoding: utf-8 + +import torch +from random import shuffle +from torch.optim import Adam as Optimizer + +from loss.regression import KLRegLoss +from lrsch import CustLR as LRScheduler +from parallel.base import DataParallelCriterion +from parallel.optm import MultiGPUGradScaler +from parallel.parallelMT import DataParallelMT +from transformer.Prompt.RoBERTa.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed +from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.fmt.plm.base import fix_parameter_name +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states +from utils.state.holder import Holder +from utils.state.pyrand import PyRandomState +from utils.state.thrand import THRandomState +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm +from utils.train.base import freeze_module, getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample + +import cnfg.prompt.roberta.base as cnfg +from cnfg.prompt.roberta.ihyp import * +from cnfg.vocab.plm.roberta import vocab_size + +reg_weight, reg_bias = cnfg.reg_weight, cnfg.reg_bias +_compare_threshold = 0.5 if reg_weight is None else (0.5 / reg_weight) +compare_func = (lambda tgt, ref: ref.gt(_compare_threshold).to(tgt.dtype).eq(tgt)) if reg_bias is None else (lambda tgt, ref: ref.add(reg_bias).gt(_compare_threshold).to(tgt.dtype).eq(tgt)) + +def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): + + sum_loss = part_loss = 0.0 + sum_wd = part_wd = 0 + _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin + model.train() + cur_b, _ls = 1, {} if save_loss else None + src_grp, tgt_grp = td["src"], td["tgt"] + for i_d in tqdm(tl, mininterval=tqdm_mininterval): + seq_batch = torch.from_numpy(src_grp[i_d][()]) + seq_o = torch.from_numpy(tgt_grp[i_d][()]).squeeze(-1) + if mv_device: + seq_batch = seq_batch.to(mv_device, non_blocking=True) + seq_o = seq_o.to(mv_device, non_blocking=True) + seq_batch = seq_batch.long()#, seq_o.long() + + with torch_autocast(enabled=_use_amp): + output = model(seq_batch) + loss = lossf(output, seq_o) + if multi_gpu: + loss = loss.sum() + loss_add = loss.data.item() + + if scaler is None: + loss.backward() + else: + scaler.scale(loss).backward() + + wd_add = seq_o.numel() + loss = output = seq_batch = seq_o = None + sum_loss += loss_add + if save_loss: + _ls[i_d] = loss_add / wd_add + sum_wd += wd_add + _done_tokens += wd_add + + if _done_tokens >= tokens_optm: + optm_step(optm, model=model, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer, zero_grad_none=optm_step_zero_grad_set_none) + _done_tokens = 0 + if _cur_rstep is not None: + if save_checkp_epoch and (save_every is not None) and (_cur_rstep % save_every == 0) and (chkpf is not None) and (_cur_rstep > 0): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + _cur_rstep -= 1 + if _cur_rstep <= 0: + break + lrsch.step() + + if nreport is not None: + part_loss += loss_add + part_wd += wd_add + if cur_b % nreport == 0: + if report_eva: + _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) + logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva + free_cache(mv_device) + model.train() + else: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + part_loss = 0.0 + part_wd = 0 + + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): + if num_checkpoint > 1: + _fend = "_%d.h5" % (_cur_checkid) + _chkpf = chkpf[:-3] + _fend + _cur_checkid = (_cur_checkid + 1) % num_checkpoint + else: + _chkpf = chkpf + save_model(model, _chkpf, multi_gpu, print_func=logger.info) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + cur_b += 1 + if part_wd != 0.0: + logger.info("Average loss over %d tokens: %.3f" % (part_wd, part_loss / part_wd,)) + return sum_loss / sum_wd, _done_tokens, _cur_checkid, _cur_rstep, _ls + +def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): + r = w = 0 + sum_loss = 0.0 + model.eval() + src_grp, tgt_grp = ed["src"], ed["tgt"] + with torch_inference_mode(): + for i in tqdm(range(nd), mininterval=tqdm_mininterval): + bid = str(i) + seq_batch = torch.from_numpy(src_grp[bid][()]) + seq_o = torch.from_numpy(tgt_grp[bid][()]).squeeze(-1) + if mv_device: + seq_batch = seq_batch.to(mv_device, non_blocking=True) + seq_o = seq_o.to(mv_device, non_blocking=True) + seq_batch = seq_batch.long()#, seq_o.long() + with torch_autocast(enabled=use_amp): + output = model(seq_batch) + loss = lossf(output, seq_o) + if multi_gpu: + loss = loss.sum() + trans = torch.cat([outu.argmax(-1).to(mv_device, non_blocking=True) for outu in output], 0) + else: + trans = output.argmax(-1) + sum_loss += loss.data.item() + w += seq_o.numel() + r += compare_func(trans, seq_o).int().sum().item() + trans = loss = output = seq_batch = seq_o = None + w = float(w) + return sum_loss / w, (w - r) / w * 100.0 + +def hook_lr_update(optm, flags=None): + + reset_Adam(optm, flags) + +def init_fixing(module): + + if hasattr(module, "fix_init"): + module.fix_init() + +def load_fixing(module): + + if hasattr(module, "fix_load"): + module.fix_load() + +rid = cnfg.run_id +earlystop = cnfg.earlystop +maxrun = cnfg.maxrun +tokens_optm = cnfg.tokens_optm +done_tokens = 0 +batch_report = cnfg.batch_report +report_eva = cnfg.report_eva +use_ams = cnfg.use_ams +cnt_states = cnfg.train_statesf +save_auto_clean = cnfg.save_auto_clean +overwrite_eva = cnfg.overwrite_eva +save_every = cnfg.save_every +start_chkp_save = cnfg.epoch_start_checkpoint_save +epoch_save = cnfg.epoch_save +remain_steps = cnfg.training_steps + +wkdir = "".join((cnfg.exp_dir, cnfg.data_id, "/", cnfg.group_id, "/", rid, "/")) +mkdir(wkdir) + +chkpf = None +statesf = None +if save_every is not None: + chkpf = wkdir + "checkpoint.h5" +if cnfg.save_train_state: + statesf = wkdir + "train.states.t7" + +logger = get_logger(wkdir + "train.log") + +use_cuda, cuda_device, cuda_devices, multi_gpu = parse_cuda(cnfg.use_cuda, cnfg.gpuid) +multi_gpu_optimizer = multi_gpu and cnfg.multi_gpu_optimizer + +set_random_seed(cnfg.seed, use_cuda) + +td = h5File(cnfg.train_data, "r") +vd = h5File(cnfg.dev_data, "r") + +ntrain = td["ndata"][()].item() +nvalid = vd["ndata"][()].item() +nwordi = nwordt = vocab_size + +tl = [str(i) for i in range(ntrain)] + +logger.info("Design models with seed: %d" % torch.initial_seed()) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, fhsize=cnfg.ff_hsize, dropout=cnfg.drop, attn_drop=cnfg.attn_drop, act_drop=cnfg.act_drop, global_emb=cnfg.share_emb, num_head=cnfg.nhead, xseql=cache_len_default, ahsize=cnfg.attn_hsize, norm_output=cnfg.norm_output, bindDecoderEmb=cnfg.bindDecoderEmb, forbidden_index=cnfg.forbidden_indexes, model_name=cnfg.model_name) + +mymodel = init_model_params(mymodel) +mymodel.apply(init_fixing) + +pre_trained_m = cnfg.pre_trained_m +if pre_trained_m is not None: + logger.info("Load pre-trained model from: " + pre_trained_m) + mymodel.load_plm(fix_parameter_name(torch.load(pre_trained_m, map_location="cpu"))) +if (cnfg.classifier_indices is not None) and hasattr(mymodel, "update_classifier"): + logger.info("Build new classifier") + mymodel.update_classifier(torch.as_tensor(cnfg.classifier_indices, dtype=torch.long)) +fine_tune_m = cnfg.fine_tune_m +if fine_tune_m is not None: + logger.info("Load pre-trained model from: " + fine_tune_m) + mymodel = load_model_cpu(fine_tune_m, mymodel) + mymodel.apply(load_fixing) + +lossf = KLRegLoss(weight=reg_weight, bias=reg_bias, reduction="sum")#ignore_index=pad_id + +if cnfg.src_emb is not None: + logger.info("Load source embedding from: " + cnfg.src_emb) + load_emb(cnfg.src_emb, mymodel.enc.wemb.weight, nwordi, cnfg.scale_down_emb, cnfg.freeze_srcemb) +if cnfg.tgt_emb is not None: + logger.info("Load target embedding from: " + cnfg.tgt_emb) + load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) +if cnfg.freeze_word_embedding: + logger.info("Freeze word embedding") + freeze_module(mymodel.enc.wemb) + +if cuda_device: + mymodel.to(cuda_device, non_blocking=True) + lossf.to(cuda_device, non_blocking=True) + +use_amp = cnfg.use_amp and use_cuda +scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None + +if multi_gpu: + mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) + lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) + +if multi_gpu: + optimizer = mymodel.build_optimizer(Optimizer, lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams, multi_gpu_optimizer=multi_gpu_optimizer, contiguous_parameters=contiguous_parameters) +else: + optimizer = Optimizer(get_model_parameters(mymodel, contiguous_parameters=contiguous_parameters), lr=init_lr, betas=adam_betas_default, eps=ieps_adam_default, weight_decay=cnfg.weight_decay, amsgrad=use_ams) +optimizer.zero_grad(set_to_none=optm_step_zero_grad_set_none) + +lrsch = LRScheduler(optimizer, lr_func=lambda a, b: (init_lr, b,)) + +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + +state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) + +num_checkpoint = cnfg.num_checkpoint +cur_checkid = 0 + +tminerr = inf_default + +minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) + +if fine_tune_m is None: + save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) + logger.info("Initial model saved") +else: + if cnt_states is not None: + logger.info("Loading training states") + _remain_states = state_holder.load_state_dict(torch.load(cnt_states)) + remain_steps, cur_checkid = _remain_states["remain_steps"], _remain_states["checkpoint_id"] + if "training_list" in _remain_states: + _ctl = _remain_states["training_list"] + else: + shuffle(tl) + _ctl = tl + tminerr, done_tokens, cur_checkid, remain_steps, _ = train(td, _ctl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, False, False, scaler) + _ctl = _remain_states = None + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: 0, train loss: %.3f, valid loss/error: %.3f %.2f" % (tminerr, vloss, vprec,)) + save_model(mymodel, wkdir + "train_0_%.3f_%.3f_%.2f.h5" % (tminerr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info) + logger.info("New best model saved") + +if cnfg.dss_ws is not None and cnfg.dss_ws > 0.0 and cnfg.dss_ws < 1.0: + dss_ws = int(cnfg.dss_ws * ntrain) + _Dws = {} + _prev_Dws = {} + _crit_inc = {} + if cnfg.dss_rm is not None and cnfg.dss_rm > 0.0 and cnfg.dss_rm < 1.0: + dss_rm = int(cnfg.dss_rm * ntrain * (1.0 - cnfg.dss_ws)) + else: + dss_rm = 0 +else: + dss_ws = 0 + dss_rm = 0 + _Dws = None + +namin = 0 + +for i in range(1, maxrun + 1): + shuffle(tl) + free_cache(use_cuda) + terr, done_tokens, cur_checkid, remain_steps, _Dws = train(td, tl, vd, nvalid, optimizer, lrsch, mymodel, lossf, cuda_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm, batch_report, save_every, chkpf, state_holder, statesf, num_checkpoint, cur_checkid, report_eva, remain_steps, dss_ws > 0, i >= start_chkp_save, scaler) + vloss, vprec = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) + logger.info("Epoch: %d, train loss: %.3f, valid loss/error: %.3f %.2f" % (i, terr, vloss, vprec,)) + + if (vprec <= minerr) or (vloss <= minloss): + save_model(mymodel, wkdir + "eva_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp="eva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info) + logger.info("New best model saved") + + namin = 0 + + if vprec < minerr: + minerr = vprec + if vloss < minloss: + minloss = vloss + + else: + if terr < tminerr: + tminerr = terr + save_model(mymodel, wkdir + "train_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info, mtyp=("eva" if overwrite_eva else "train") if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info) + elif epoch_save: + save_model(mymodel, wkdir + "epoch_%d_%.3f_%.3f_%.2f.h5" % (i, terr, vloss, vprec,), multi_gpu, print_func=logger.info) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info) + + namin += 1 + if namin >= earlystop: + if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + lrsch.step() + done_tokens = 0 + logger.info("early stop") + break + + if remain_steps is not None and remain_steps <= 0: + logger.info("Last training step reached") + break + + if dss_ws > 0: + if _prev_Dws: + for _key, _value in _Dws.items(): + if _key in _prev_Dws: + _ploss = _prev_Dws[_key] + _crit_inc[_key] = (_ploss - _value) / _ploss + tl = dynamic_sample(_crit_inc, dss_ws, dss_rm) + _prev_Dws = _Dws + +if done_tokens > 0: + optm_step(optimizer, model=mymodel, scaler=scaler, multi_gpu=multi_gpu, multi_gpu_optimizer=multi_gpu_optimizer) + lrsch.step() + +save_model(mymodel, wkdir + "last.h5", multi_gpu, print_func=logger.info) +if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": remain_steps, "checkpoint_id": cur_checkid}), statesf, print_func=logger.info) +logger.info("model saved") + +td.close() +vd.close() diff --git a/adv/train/train_ape.py b/adv/train/train_ape.py index 850fa2f..02eb37c 100644 --- a/adv/train/train_ape.py +++ b/adv/train/train_ape.py @@ -1,42 +1,40 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.APE.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.APE.NMT import NMT +from cnfg.vocab.base import pad_id def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, mt_grp, tgt_grp = td["src"], td["mt"], td["tgt"] @@ -53,7 +51,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, seq_mt, oi) loss = lossf(output, ot) if multi_gpu: @@ -69,7 +67,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -99,6 +97,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -106,7 +114,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -126,7 +134,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, mt_grp, tgt_grp = ed["src"], ed["mt"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -139,7 +147,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_mt, seq_o = seq_batch.long(), seq_mt.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_mt, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -214,7 +222,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) fine_tune_m = cnfg.fine_tune_m @@ -253,6 +261,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -261,7 +272,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) +logger.info("".join(("Init lr: ", ",".join(iter_to_str(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/train_doc_para.py b/adv/train/train_doc_para.py index 846c8f8..e7ab7bf 100644 --- a/adv/train/train_doc_para.py +++ b/adv/train/train_doc_para.py @@ -1,43 +1,41 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.Doc.Para.Base.NMT import NMT +from transformer.NMT import NMT as BaseNMT +from utils.base import filter_para_grad, free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import freeze_module, getlr, optm_step, optm_step_zero_grad_set_none +from utils.train.dss import dynamic_sample import cnfg.docpara as cnfg from cnfg.ihyp import * - -from transformer.Doc.Para.Base.NMT import NMT -from transformer.NMT import NMT as BaseNMT +from cnfg.vocab.base import pad_id def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None @@ -56,7 +54,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) loss = lossf(output, ot) if multi_gpu: @@ -72,7 +70,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[(nsent, i_d,)] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -102,6 +100,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -109,7 +117,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -130,7 +138,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for nsent, i_d in tqdm(nd, mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[nsent][i_d][()]) seq_o = torch.from_numpy(tgt_grp[nsent][i_d][()]) @@ -145,7 +153,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.narrow(1, 1, _nsent_use) oi = seq_o.narrow(-1, 0, lo).contiguous() ot = seq_o.narrow(-1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch.narrow(1, 1, _nsent_use).contiguous(), oi, seq_batch.narrow(1, 0, _nsent_use).contiguous()) loss = lossf(output, ot) if multi_gpu: @@ -216,7 +224,7 @@ def load_fixing(module): vl = [(str(nsent), str(_curd),) for nsent, ndata in zip(vd["nsent"][()].tolist(), vd["ndata"][()].tolist()) for _curd in range(ndata)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_prev_sent, cnfg.num_layer_context) fine_tune_m = cnfg.fine_tune_m @@ -224,9 +232,9 @@ def load_fixing(module): mymodel.apply(init_fixing) if fine_tune_m is not None: logger.info("Load pre-trained model from: " + fine_tune_m) - _tmpm = BaseNMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + _tmpm = BaseNMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) _tmpm = load_model_cpu(fine_tune_m, _tmpm) - #with torch.no_grad(): + #with torch_inference_mode(): #_tmpm.dec.classifier.bias[_tmpm.dec.classifier.bias.lt(-1e3)] = -1e3 if cnfg.freeze_load_model: freeze_module(_tmpm) @@ -265,6 +273,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -273,7 +284,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, vl, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) +logger.info("".join(("Init lr: ", ",".join(iter_to_str(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/train_dynb.py b/adv/train/train_dynb.py index 9706fc4..c89286a 100644 --- a/adv/train/train_dynb.py +++ b/adv/train/train_dynb.py @@ -1,38 +1,35 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.dynbatch import GradientMonitor +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.fmt.parser import parse_double_value_tuple +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.dynbatch import GradientMonitor -from utils.fmt.base import tostr, parse_double_value_tuple -from cnfg.vocab.base import pad_id - -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none +from utils.train.dss import dynamic_sample import cnfg.dynb as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT +from cnfg.vocab.base import pad_id update_angle = cnfg.update_angle enc_layer, dec_layer = parse_double_value_tuple(cnfg.nlayer) @@ -50,11 +47,9 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin, grad_mon, update_angle model.train() cur_b, _ls = 1, {} if save_loss else None - - global grad_mon, update_angle - src_grp, tgt_grp = td["src"], td["tgt"] for i_d in tqdm(tl, mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[i_d][()]) @@ -67,7 +62,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: @@ -83,12 +78,11 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add _perform_dyn_optm_step, _cos_sim = grad_mon.update(model.module if multi_gpu else model) - if _perform_dyn_optm_step or (_done_tokens >= tokens_optm): if not _perform_dyn_optm_step: grad_mon.reset() @@ -125,6 +119,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -132,7 +136,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -152,7 +156,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -163,7 +167,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -234,7 +238,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) fine_tune_m = cnfg.fine_tune_m @@ -273,6 +277,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -281,7 +288,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/adv/train/train_probe.py b/adv/train/train_probe.py index 8f2ca06..4abac68 100644 --- a/adv/train/train_probe.py +++ b/adv/train/train_probe.py @@ -1,48 +1,44 @@ #encoding: utf-8 import torch - +from random import shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.NMT import NMT as NMTBase +from transformer.Probe.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import freeze_module, getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.probe as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT as NMTBase -from transformer.Probe.NMT import NMT +from cnfg.vocab.base import pad_id probe_reorder = cnfg.probe_reorder def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): - global probe_reorder - ind_shift = 2 if probe_reorder else 1 - sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin, probe_reorder + ind_shift = 2 if probe_reorder else 1 model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] @@ -57,7 +53,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, ind_shift, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: @@ -73,7 +69,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -103,6 +99,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -110,7 +116,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -134,7 +140,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -145,7 +151,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, ind_shift, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -220,7 +226,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) fine_tune_m = cnfg.fine_tune_m @@ -228,7 +234,7 @@ def load_fixing(module): mymodel.apply(init_fixing) if fine_tune_m is not None: logger.info("Load pre-trained model from: " + fine_tune_m) - _tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + _tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) _tmpm = init_model_params(_tmpm) _tmpm.apply(init_fixing) _tmpm = load_model_cpu(fine_tune_m, _tmpm) @@ -269,6 +275,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -277,7 +286,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("".join(("Init lr: ", ",".join(tostr(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) +logger.info("".join(("Init lr: ", ",".join(iter_to_str(getlr(optimizer))), ", Dev Loss/Error: %.3f %.2f" % (minloss, minerr)))) if fine_tune_m is None: save_model(mymodel.module.dec if multi_gpu else mymodel.dec, wkdir + "init.h5", False, print_func=logger.info) diff --git a/cnfg/README.md b/cnfg/README.md index d971ac6..01bd12c 100644 --- a/cnfg/README.md +++ b/cnfg/README.md @@ -82,6 +82,8 @@ nlayer = 6 drop = 0.1 # dropout rate applied to multi-head attention. attn_drop = drop +# dropout rate applied to the activation of FFN. +act_drop = drop # False for Hier/Incept Models norm_output = True @@ -139,7 +141,7 @@ Configuration of following variables: # reducing the optimization difficulty of models ease_optimization = True -# using lipschitz constraint parameter initialization in [Lipschitz Constrained Parameter Initialization for Deep Transformers](https://www.aclweb.org/anthology/2020.acl-main.38/) +# using lipschitz constraint parameter initialization in [Lipschitz Constrained Parameter Initialization for Deep Transformers](https://aclanthology.org/2020.acl-main.38/) lipschitz_initialization = True # using advanced activation function, choices: None, "GeLU", "Swish", "Sigmoid", "NormSwish" @@ -153,7 +155,7 @@ computation_order = "v2" # default cached sequence length (for positional embedding, etc.) cache_len_default = 256 -# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://www.aclweb.org/anthology/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which slightly hampers the performance on WMT 14 En-De. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. +# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://aclanthology.org/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which slightly hampers the performance on WMT 14 En-De. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. use_k_relative_position = 0 relative_position_max_bucket_distance = 0 disable_std_pemb = False @@ -170,10 +172,14 @@ normal_tokens_vs_pad_tokens = 4 # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True +# learning rate, override by the GoogleLR in most cases +init_lr = 1e-4 + # enable tqdm progress bar. enable_tqdm = True -# trade CPU for IO and disk space, see [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. +# trade CPU for IO and disk space, see [gzip](https://docs.python.org/3/library/gzip.html) and [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. +raw_cache_compression_level = 9 # choices: None, "gzip", "lzf" hdf5_data_compression = "gzip" # choices: 0 to 9, default is 4. None for lzf. @@ -184,6 +190,8 @@ hdf5_model_compression_level = 0 hdf5_perf_over_camp = True # whether to track creation order. hdf5_track_order = False +# the existence of the names of model parameters. (save `named_parameters` or `parameters`) +hdf5_load_parameter_name = hdf5_save_parameter_name = False # prune with length penalty in each beam decoding step clip_beam_with_lp = True @@ -191,6 +199,9 @@ clip_beam_with_lp = True # optimize speed even if it sacrifices reproduction performance_over_reproduction = True +# use torch.inference_mode if supported +use_inference_mode = True + # enable torch checks, only support anomaly detection for the autograd engine currently. enable_torch_check = True diff --git a/cnfg/base.py b/cnfg/base.py index 332af1a..f53d644 100644 --- a/cnfg/base.py +++ b/cnfg/base.py @@ -19,7 +19,9 @@ # "":0, "":1, "":2, "":3 # add 3 to forbidden_indexes if there are tokens in data # must be None if use_fast_loss is set in cnfg/hyp.py + #from fbind import fbl + forbidden_indexes = None#[0, 1] + fbl save_auto_clean = True @@ -55,6 +57,7 @@ drop = 0.1 attn_drop = drop +act_drop = drop # False for Hier/Incept Models norm_output = True diff --git a/cnfg/hyp.py b/cnfg/hyp.py index 54dca0b..f7aee9a 100644 --- a/cnfg/hyp.py +++ b/cnfg/hyp.py @@ -3,6 +3,7 @@ ease_optimization = True lipschitz_initialization = True +lipschitz_scale = 1.0 # choices: None, "GeLU", "Swish", "Sigmoid", "SReLU", "Mish", "NormSwish" advance_activation_function = None @@ -10,12 +11,12 @@ use_glu_ffn = None # choices: "v1", "v2" -computation_order = "v2" +computation_order = "v1" # default cached sequence length (for positional embedding, etc.) cache_len_default = 256 -# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://www.aclweb.org/anthology/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which slightly hampers the performance on WMT 14 En-De. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. +# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://aclanthology.org/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which slightly hampers the performance on WMT 14 En-De. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. use_k_relative_position = 0 relative_position_max_bucket_distance = 0 disable_std_pemb = False @@ -32,10 +33,14 @@ # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True +# learning rate, override by the GoogleLR in most cases +init_lr = 1e-4 + # enable tqdm progress bar. -enable_tqdm = True +enable_tqdm = False -# trade CPU for IO and disk space, see [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. +# trade CPU for IO and disk space, see [gzip](https://docs.python.org/3/library/gzip.html) and [h5py](http://docs.h5py.org/en/stable/high/dataset.html) for details. +raw_cache_compression_level = 9 # choices: None, "gzip", "lzf" hdf5_data_compression = "gzip" # choices: 0 to 9, default is 4. None for lzf. @@ -46,6 +51,8 @@ hdf5_perf_over_camp = True # whether to track creation order. hdf5_track_order = False +# the existence of the names of model parameters. (save `named_parameters` or `parameters`) +hdf5_load_parameter_name = hdf5_save_parameter_name = True # prune with length penalty in each beam decoding step clip_beam_with_lp = True @@ -53,8 +60,14 @@ # optimize speed even if it sacrifices reproduction performance_over_reproduction = True +# use torch.inference_mode if supported +use_inference_mode = True + +# use torch.compile if supported +use_torch_compile = False + # enable torch checks, only support anomaly detection for the autograd engine currently. -enable_torch_check = True +enable_torch_check = False # accelerate optimizer by using contigous parameters and gradients. Disabling it leads to better performance. contiguous_parameters = False diff --git a/cnfg/ihyp.py b/cnfg/ihyp.py index 6925779..bbddd12 100644 --- a/cnfg/ihyp.py +++ b/cnfg/ihyp.py @@ -2,11 +2,11 @@ # this file interprets hyper-parameters assigned in cnfg/hyp.py -from cnfg.hyp import * - from math import inf -from utils.fmt.parser import parse_none, parse_double_value_tuple +from utils.fmt.parser import parse_double_value_tuple, parse_none + +from cnfg.hyp import * # C backend if use_c_backend is None: @@ -19,7 +19,7 @@ # the use of deterministic algorithms use_deterministic = not performance_over_reproduction -allow_fp16_reduction = use_deterministic +allow_tf32 = allow_fp16_reduction = performance_over_reproduction # biases enable_prev_ln_bias_default = enable_proj_bias_default = not ease_optimization @@ -42,8 +42,9 @@ disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) relpos_reduction_with_zeros = True -# learning rate, override by the GoogleLR in most case -init_lr = 1e-4 +# torch_compile args +torch_compile_args = [] +torch_compile_kwargs = {"fullgraph": False, "dynamic": False} # hyper-parameters inf_default = inf diff --git a/cnfg/plm/bart/base.py b/cnfg/plm/bart/base.py index f268016..48ab5b3 100644 --- a/cnfg/plm/bart/base.py +++ b/cnfg/plm/bart/base.py @@ -21,5 +21,6 @@ drop = 0.1 attn_drop = drop +act_drop = drop norm_output = True diff --git a/cnfg/plm/bart/hyp.py b/cnfg/plm/bart/hyp.py index 1a0f95d..ebea36d 100644 --- a/cnfg/plm/bart/hyp.py +++ b/cnfg/plm/bart/hyp.py @@ -15,3 +15,6 @@ # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True + +# learning rate +init_lr = 1e-5 diff --git a/cnfg/plm/bart/ihyp.py b/cnfg/plm/bart/ihyp.py index 22cd1e2..5d20998 100644 --- a/cnfg/plm/bart/ihyp.py +++ b/cnfg/plm/bart/ihyp.py @@ -14,7 +14,4 @@ adv_act = advance_activation_function.lower() if use_adv_act_default else None inplace_after_Custom_Act = use_adv_act_default and (adv_act not in set(["sigmoid"])) -# learning rate -init_lr = 1e-5 - ieps_ln_default = 1e-05 diff --git a/cnfg/plm/bert/base.py b/cnfg/plm/bert/base.py index 299f5f6..d781fda 100644 --- a/cnfg/plm/bert/base.py +++ b/cnfg/plm/bert/base.py @@ -20,5 +20,6 @@ drop = 0.1 attn_drop = drop +act_drop = 0.0 norm_output = True diff --git a/cnfg/plm/bert/hyp.py b/cnfg/plm/bert/hyp.py index 677b0ba..16cdc85 100644 --- a/cnfg/plm/bert/hyp.py +++ b/cnfg/plm/bert/hyp.py @@ -15,3 +15,6 @@ # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True + +# learning rate +init_lr = 1e-5 diff --git a/cnfg/plm/bert/ihyp.py b/cnfg/plm/bert/ihyp.py index 2aaad0f..3135d18 100644 --- a/cnfg/plm/bert/ihyp.py +++ b/cnfg/plm/bert/ihyp.py @@ -14,7 +14,4 @@ adv_act = advance_activation_function.lower() if use_adv_act_default else None inplace_after_Custom_Act = use_adv_act_default and (adv_act not in set(["sigmoid"])) -# learning rate -init_lr = 1e-5 - ieps_ln_default = 1e-12 diff --git a/cnfg/plm/mbart/base.py b/cnfg/plm/mbart/base.py new file mode 100644 index 0000000..f42f189 --- /dev/null +++ b/cnfg/plm/mbart/base.py @@ -0,0 +1,26 @@ +#encoding: utf-8 + +from cnfg.base import * + +# new configurations for MBART +model_name = ("model.encoder", "model.decoder",) +num_type = None +remove_classifier_bias = True +pre_trained_m = None + +# override standard configurations +bindDecoderEmb = True +share_emb = True + +isize = 1024 +ff_hsize = isize * 4 +nhead = max(1, isize // 64) +attn_hsize = isize + +nlayer = 12 + +drop = 0.1 +attn_drop = 0.0 +act_drop = 0.0 + +norm_output = True diff --git a/cnfg/plm/mbart/hyp.py b/cnfg/plm/mbart/hyp.py new file mode 100644 index 0000000..aac774c --- /dev/null +++ b/cnfg/plm/mbart/hyp.py @@ -0,0 +1,20 @@ +#encoding: utf-8 + +from cnfg.hyp import * + +ease_optimization = False + +# choices: None, "GeLU", "Swish", "Sigmoid", "SReLU", "Mish", "NormSwish" +advance_activation_function = "GeLU" + +# choices: "v1", "v2" +computation_order = "v2" + +# default cached sequence length (for positional embedding, etc.) +cache_len_default = 1026 + +# For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. +use_unk = True + +# learning rate +init_lr = 1e-5 diff --git a/cnfg/plm/mbart/ihyp.py b/cnfg/plm/mbart/ihyp.py new file mode 100644 index 0000000..13e430d --- /dev/null +++ b/cnfg/plm/mbart/ihyp.py @@ -0,0 +1,17 @@ +#encoding: utf-8 + +from cnfg.ihyp import * +from cnfg.plm.mbart.hyp import * + +# biases +enable_prev_ln_bias_default = enable_proj_bias_default = not ease_optimization + +# computation order +norm_residual_default = not (computation_order.lower() == "v2") + +# activation fucntion +use_adv_act_default = advance_activation_function is not None +adv_act = advance_activation_function.lower() if use_adv_act_default else None +inplace_after_Custom_Act = use_adv_act_default and (adv_act not in set(["sigmoid"])) + +ieps_ln_default = 1e-05 diff --git a/cnfg/plm/roberta/base.py b/cnfg/plm/roberta/base.py index 4f48616..72b76bc 100644 --- a/cnfg/plm/roberta/base.py +++ b/cnfg/plm/roberta/base.py @@ -5,6 +5,7 @@ # new configurations for RoBERTa model_name = "roberta" num_type = 1 +eliminate_type_emb = False pre_trained_m = None # override standard configurations @@ -20,5 +21,6 @@ drop = 0.1 attn_drop = drop +act_drop = 0.0 norm_output = True diff --git a/cnfg/plm/roberta/hyp.py b/cnfg/plm/roberta/hyp.py index b0aaa3b..65928c2 100644 --- a/cnfg/plm/roberta/hyp.py +++ b/cnfg/plm/roberta/hyp.py @@ -15,3 +15,6 @@ # For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. use_unk = True + +# learning rate +init_lr = 1e-5 diff --git a/cnfg/plm/roberta/ihyp.py b/cnfg/plm/roberta/ihyp.py index 36b8169..322a503 100644 --- a/cnfg/plm/roberta/ihyp.py +++ b/cnfg/plm/roberta/ihyp.py @@ -14,7 +14,4 @@ adv_act = advance_activation_function.lower() if use_adv_act_default else None inplace_after_Custom_Act = use_adv_act_default and (adv_act not in set(["sigmoid"])) -# learning rate -init_lr = 1e-5 - ieps_ln_default = 1e-05 diff --git a/cnfg/plm/t5/base.py b/cnfg/plm/t5/base.py index a03f1bf..dfe6baa 100644 --- a/cnfg/plm/t5/base.py +++ b/cnfg/plm/t5/base.py @@ -21,5 +21,6 @@ drop = 0.1 attn_drop = drop +act_drop = drop norm_output = True diff --git a/cnfg/plm/t5/hyp.py b/cnfg/plm/t5/hyp.py index 223c791..2f1ca36 100644 --- a/cnfg/plm/t5/hyp.py +++ b/cnfg/plm/t5/hyp.py @@ -13,11 +13,15 @@ # default cached sequence length (for positional embedding, etc.) cache_len_default = 512 -# For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. -use_unk = True - -# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://www.aclweb.org/anthology/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which hampers the performance. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. +# window size (one side) of relative positional embeddings, 0 to disable. 8 and 16 are used in [Self-Attention with Relative Position Representations](https://aclanthology.org/N18-2074/) for Transformer Base and Big respectively. relative_position_max_bucket_distance for the bucket relative positional encoding used by T5, [Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer](https://www.jmlr.org/papers/v21/20-074.html), which hampers the performance. disable_std_pemb to disable the standard positional embedding when use the relative position, or to disable only the decoder side with a tuple (False, True,), useful for AAN. use_k_relative_position = (15, 31,) relative_position_max_bucket_distance = 128 +# cross-attention layers have relative position encoding in the original T5 but not in v1.1. use_k_relative_position_cattn = 15 disable_std_pemb = True + +# For BPE (using full vocabulary), the special token will never appear and thus can be removed from the vocabulary. Otherwise, it should be set to True. +use_unk = True + +# learning rate +init_lr = 1e-5 diff --git a/cnfg/plm/t5/ihyp.py b/cnfg/plm/t5/ihyp.py index c231c2f..7351f58 100644 --- a/cnfg/plm/t5/ihyp.py +++ b/cnfg/plm/t5/ihyp.py @@ -22,7 +22,4 @@ disable_std_pemb_encoder, disable_std_pemb_decoder = parse_double_value_tuple(disable_std_pemb) relpos_reduction_with_zeros = True -# learning rate -init_lr = 1e-5 - ieps_ln_default = 1e-06 diff --git a/cnfg/prompt/roberta/base.py b/cnfg/prompt/roberta/base.py index b6631ce..a5f22f4 100644 --- a/cnfg/prompt/roberta/base.py +++ b/cnfg/prompt/roberta/base.py @@ -2,4 +2,8 @@ from cnfg.plm.roberta.base import * -classifier_indices = [6587, 12338] +freeze_word_embedding = True +classifier_indices = [6587, 372] + +reg_weight = 1.0 / 5.0 +reg_bias = None diff --git a/cnfg/server.py b/cnfg/server.py new file mode 100755 index 0000000..3cd4bb4 --- /dev/null +++ b/cnfg/server.py @@ -0,0 +1,7 @@ +#encoding: utf-8 + +thread_keeper_interval = 1.0 + +batcher_wait_interval = 0.015625 +batcher_maintain_interval = 3.0 +batcher_watcher_interval = batcher_wait_interval / 8.0 diff --git a/cnfg/vocab/plm/mbart.py b/cnfg/vocab/plm/mbart.py new file mode 100644 index 0000000..48fdb23 --- /dev/null +++ b/cnfg/vocab/plm/mbart.py @@ -0,0 +1,11 @@ +#encoding: utf-8 + +pad_id, sos_id, eos_id, unk_id, mask_id = 1, 0, 2, 3, 250026#250053 +vocab_size = 250027#250054 +pemb_start_ind = 2 + +shift_target_lang_id = True#False +add_sos_id = None + +lang_id = {"ar_AR": 250001, "cs_CZ": 250002, "de_DE": 250003, "en_XX": 250004, "es_XX": 250005, "et_EE": 250006, "fi_FI": 250007, "fr_XX": 250008, "gu_IN": 250009, "hi_IN": 250010, "it_IT": 250011, "ja_XX": 250012, "kk_KZ": 250013, "ko_KR": 250014, "lt_LT": 250015, "lv_LV": 250016, "my_MM": 250017, "ne_NP": 250018, "nl_XX": 250019, "ro_RO": 250020, "ru_RU": 250021, "si_LK": 250022, "tr_TR": 250023, "vi_VN": 250024, "zh_CN": 250025} +#lang_id = {"ar_AR": 250001, "cs_CZ": 250002, "de_DE": 250003, "en_XX": 250004, "es_XX": 250005, "et_EE": 250006, "fi_FI": 250007, "fr_XX": 250008, "gu_IN": 250009, "hi_IN": 250010, "it_IT": 250011, "ja_XX": 250012, "kk_KZ": 250013, "ko_KR": 250014, "lt_LT": 250015, "lv_LV": 250016, "my_MM": 250017, "ne_NP": 250018, "nl_XX": 250019, "ro_RO": 250020, "ru_RU": 250021, "si_LK": 250022, "tr_TR": 250023, "vi_VN": 250024, "zh_CN": 250025, "af_ZA": 250026, "az_AZ": 250027, "bn_IN": 250028, "fa_IR": 250029, "he_IL": 250030, "hr_HR": 250031, "id_ID": 250032, "ka_GE": 250033, "km_KH": 250034, "mk_MK": 250035, "ml_IN": 250036, "mn_MN": 250037, "mr_IN": 250038, "pl_PL": 250039, "ps_AF": 250040, "pt_XX": 250041, "sv_SE": 250042, "sw_KE": 250043, "ta_IN": 250044, "te_IN": 250045, "th_TH": 250046, "tl_XX": 250047, "uk_UA": 250048, "ur_PK": 250049, "xh_ZA": 250050, "gl_ES": 250051, "sl_SI": 250052} diff --git a/datautils/bpe.py b/datautils/bpe.py index a32ad2d..51cf5cd 100644 --- a/datautils/bpe.py +++ b/datautils/bpe.py @@ -6,7 +6,7 @@ class BPE(object): - def __init__(self, codes, merges=-1, separator="@@", vocab=None, glossaries=None): + def __init__(self, codes, merges=-1, separator="@@", vocab=None, glossaries=None, **kwargs): codes.seek(0) offset=1 @@ -222,7 +222,7 @@ def isolate_glossary(word, glossary): class BPERemover: - def __call__(self, input): + def __call__(self, input, **kwargs): if isinstance(input, (list, tuple,)): rs = [] @@ -234,7 +234,7 @@ def __call__(self, input): class BPEApplier: - def __init__(self, codesf, bpe_vcb=None, vocabulary_threshold=None, separator="@@", merges=-1, glossaries=None): + def __init__(self, codesf, bpe_vcb=None, vocabulary_threshold=None, separator="@@", merges=-1, glossaries=None, **kwargs): if bpe_vcb is not None: vocabulary = read_vocabulary(codecs.open(bpe_vcb, encoding="utf-8"), vocabulary_threshold) @@ -244,7 +244,7 @@ def __init__(self, codesf, bpe_vcb=None, vocabulary_threshold=None, separator="@ glossaries = [g.decode("utf-8") for g in glossaries] self.bpe = BPE(codecs.open(codesf, encoding="utf-8"), merges, separator, vocabulary, glossaries) - def __call__(self, input): + def __call__(self, input, **kwargs): if isinstance(input, (list, tuple,)): rs = [] diff --git a/datautils/moses.py b/datautils/moses.py index 65620ac..d94dae0 100644 --- a/datautils/moses.py +++ b/datautils/moses.py @@ -12,7 +12,7 @@ class ProcessWrapper: - def __init__(self, cmd=None): + def __init__(self, cmd=None, **kwargs): self.process = None self.cmd = [] if cmd is None else cmd @@ -30,7 +30,7 @@ def __del__(self): class LineProcessor(ProcessWrapper): - def __call__(self, input): + def __call__(self, input, **kwargs): self.process.stdin.write(("%s\n" % input.strip()).encode("utf-8", "ignore")) self.process.stdin.flush() @@ -39,7 +39,7 @@ def __call__(self, input): class BatchProcessor(ProcessWrapper): - def __call__(self, input): + def __call__(self, input, **kwargs): if isinstance(input, (list, tuple,)): rs = [] @@ -58,14 +58,14 @@ def __call__(self, input): class SentenceSplitter(ProcessWrapper): """Wrapper for standard Moses sentence splitter.""" - def __init__(self, lang): + def __init__(self, lang, **kwargs): ssplit_cmd = moses_scripts + sep.join(("ems", "support", "split-sentences.perl")) self.cmd = [perl_exec, ssplit_cmd, "-b", "-q", "-l", lang] self.process = None self.start() - def __call__(self, input): + def __call__(self, input, **kwargs): self.process.stdin.write((input.strip() + "\n

\n").encode("utf-8", "ignore")) self.process.stdin.flush() @@ -81,7 +81,7 @@ class Pretokenizer(BatchProcessor): """Pretokenizer wrapper. The pretokenizer fixes known issues with the input. """ - def __init__(self, lang): + def __init__(self, lang, **kwargs): pretok_cmd = moses_scripts + sep.join(("tokenizer", "pre-tokenizer.perl")) self.cmd = [perl_exec, pretok_cmd, "-b", "-q", "-l", lang] @@ -93,7 +93,7 @@ class Tokenizer(BatchProcessor): The pretokenizer fixes known issues with the input. """ # default args: ["-a", "-no-escape"] - def __init__(self, lang, args=["-a"]): + def __init__(self, lang, args=["-a"], **kwargs): tok_cmd = moses_scripts + sep.join(("tokenizer", "tokenizer.perl")) self.cmd = [perl_exec, tok_cmd, "-b", "-q", "-l", lang] + args @@ -102,7 +102,7 @@ def __init__(self, lang, args=["-a"]): class Normalizepunctuation(BatchProcessor): - def __init__(self, lang): + def __init__(self, lang, **kwargs): tok_cmd = moses_scripts + sep.join(("tokenizer", "normalize-punctuation.perl")) self.cmd = [perl_exec, tok_cmd, "-b", "-q", "-l", lang] @@ -111,7 +111,7 @@ def __init__(self, lang): class Truecaser(BatchProcessor): """Truecaser wrapper.""" - def __init__(self, model): + def __init__(self, model, **kwargs): truecase_cmd = moses_scripts + sep.join(("recaser", "truecase.perl")) self.cmd = [perl_exec, truecase_cmd, "-b", "--model", model] @@ -130,7 +130,7 @@ def __init__(self): class Detokenizer(BatchProcessor): # default args: ["-a", "-no-escape"] - def __init__(self, lang): + def __init__(self, lang, **kwargs): tok_cmd = moses_scripts + sep.join(("tokenizer", "detokenizer.perl")) self.cmd = [perl_exec, tok_cmd, "-q", "-b", "-l", lang] diff --git a/datautils/pymoses.py b/datautils/pymoses.py index 4033174..5b90d95 100644 --- a/datautils/pymoses.py +++ b/datautils/pymoses.py @@ -1,12 +1,12 @@ #encoding: utf-8 -from sacremoses import MosesPunctNormalizer, MosesTokenizer, MosesDetokenizer, MosesTruecaser, MosesDetruecaser +from sacremoses import MosesDetokenizer, MosesDetruecaser, MosesPunctNormalizer, MosesTokenizer, MosesTruecaser from utils.fmt.base import clean_list class BatchProcessor(): - def __call__(self, input): + def __call__(self, input, **kwargs): return [self.process(inputu) for inputu in input] if isinstance(input, (list, tuple,)) else self.process(input) @@ -17,7 +17,7 @@ def process(self, input): class Tokenizer(BatchProcessor): # default args: ["-a", "-no-escape"] - def __init__(self, lang, args=["-a"]): + def __init__(self, lang, args=["-a"], **kwargs): self.handler = MosesTokenizer(lang=lang) self.escape = not ("-no-escape" in args or "--no-escape" in args) @@ -29,7 +29,7 @@ def process(self, input): class Normalizepunctuation(BatchProcessor): - def __init__(self, lang): + def __init__(self, lang, **kwargs): self.handler = MosesPunctNormalizer(lang=lang) @@ -39,7 +39,7 @@ def process(self, input): class Truecaser(BatchProcessor): - def __init__(self, model): + def __init__(self, model, **kwargs): self.handler = MosesTruecaser(load_from=model) @@ -51,7 +51,7 @@ class Detruecaser(BatchProcessor): def __init__(self): - self.handler = MosesTruecaser() + self.handler = MosesDetruecaser() def process(self, input, is_headline=False): @@ -59,7 +59,7 @@ def process(self, input, is_headline=False): class Detokenizer(BatchProcessor): - def __init__(self, lang): + def __init__(self, lang, **kwargs): self.handler = MosesDetokenizer(lang=lang) diff --git a/datautils/zh.py b/datautils/zh.py index d2ea067..aadb409 100644 --- a/datautils/zh.py +++ b/datautils/zh.py @@ -7,11 +7,11 @@ class SentenceSplitter: """Wrapper for standard Moses sentence splitter.""" - def __init__(self, splc=splitcode): + def __init__(self, splc=splitcode, **kwargs): self.splc = splc - def __call__(self, input): + def __call__(self, input, **kwargs): rs = [] ind = lind = 0 @@ -41,7 +41,7 @@ def __del__(self): nlpir.Exit() - def __call__(self, input): + def __call__(self, input, **kwargs): def clear_tag(strin): @@ -72,7 +72,7 @@ def clear_tag(strin): class Detokenizer: - def __call__(self, input): + def __call__(self, input, **kwargs): if not isinstance(input, (list, tuple,)): input = [input] diff --git a/loss/base.py b/loss/base.py index fc16b31..4e2211a 100644 --- a/loss/base.py +++ b/loss/base.py @@ -1,13 +1,13 @@ #encoding: utf-8 import torch -from torch.nn.modules.loss import _Loss, NLLLoss as NLLLossBase - -from torch.nn.functional import kl_div, nll_loss +from torch.nn.functional import cross_entropy, kl_div, nll_loss +from torch.nn.modules.loss import CrossEntropyLoss as CrossEntropyLossBase, NLLLoss as NLLLossBase, _Loss from utils.base import clear_pad_mask, eq_indexes from cnfg.ihyp import * +from cnfg.vocab.base import pad_id # ignores forbidden_index class FastLabelSmoothingLoss(_Loss): @@ -19,7 +19,7 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean self.conf = 1.0 - label_smoothing - self.smoothing_value # Faster implementation from fairseq: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py#L33-L50, but do not support fbil. - def forward(self, input, target, mask=None): + def forward(self, input, target, mask=None, **kwargs): _tsize = list(input.size()) _tsize[-1] = 1 @@ -52,7 +52,7 @@ def forward(self, input, target, mask=None): class StdLabelSmoothingLoss(_Loss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1): + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1, **kwargs): super(StdLabelSmoothingLoss, self).__init__() @@ -62,10 +62,9 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean if isinstance(ignore_index, (list, tuple,)): tmp = [] for _tmp in ignore_index: - if (_tmp >= 0) and (_tmp not in tmp): + if (_tmp >= 0) and (_tmp not in fbil): tmp.append(_tmp) - if _tmp not in fbil: - fbil.add(_tmp) + fbil.add(_tmp) _nid = len(tmp) if _nid > 0: self.ignore_index = tuple(tmp) if _nid > 1 else tmp[0] @@ -89,14 +88,14 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean weight = torch.full((nclass,), smoothing_value) if fbil: weight.index_fill_(0, torch.as_tensor(tuple(fbil), dtype=torch.long, device=weight.device), 0.0) - self.register_buffer("weight", weight.unsqueeze(0)) + self.register_buffer("weight", weight.unsqueeze(0), persistent=False) self.conf = 1.0 - label_smoothing # input: (batch size, num_classes) # target: (batch size) # they will be flattened automatically if the dimension of input is larger than 2. - def forward(self, input, target, mask=None): + def forward(self, input, target, mask=None, **kwargs): _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input _target = target.view(-1, 1) @@ -122,17 +121,25 @@ def forward(self, input, target, mask=None): class NLLLoss(NLLLossBase): - def forward(self, input, target): + def forward(self, input, target, **kwargs): rs = nll_loss(input.view(-1, input.size(-1)), target.view(-1), weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction) return rs.view(input.size()) if self.reduction == "none" and target.dim() > 1 else rs +class CrossEntropyLoss(CrossEntropyLossBase): + + def forward(self, input, target, **kwargs): + + rs = cross_entropy(input.view(-1, input.size(-1)), target.view(-1), weight=self.weight, ignore_index=self.ignore_index, reduction=self.reduction)#, label_smoothing=self.label_smoothing + + return rs.view(input.size()) if self.reduction == "none" and target.dim() > 1 else rs + class RankingLoss(_Loss): # input: (batch size) # target: (batch size) - def forward(self, input, target): + def forward(self, input, target, **kwargs): loss = input * target if self.reduction == "mean": @@ -142,7 +149,7 @@ def forward(self, input, target): class MultiLabelSmoothingLoss(_Loss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1): + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1, **kwargs): super(MultiLabelSmoothingLoss, self).__init__() @@ -186,10 +193,10 @@ def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean if fbilu: _tmp_w.index_fill_(0, torch.as_tensor(tuple(fbilu), dtype=torch.long, device=_tmp_w.device), 0.0) _weight.append(_tmp_w) - self.register_buffer("weight", torch.stack(_weight, 0).unsqueeze(1)) + self.register_buffer("weight", torch.stack(_weight, 0).unsqueeze(1), persistent=False) self.conf = 1.0 - label_smoothing - def forward(self, input, target, lang_id=0, mask=None): + def forward(self, input, target, lang_id=0, mask=None, **kwargs): _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input _target = target.view(-1, 1) @@ -214,16 +221,16 @@ def forward(self, input, target, lang_id=0, mask=None): class ReducedLabelSmoothingLoss(StdLabelSmoothingLoss): - def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1, reduce_dim=None): + def __init__(self, nclass, label_smoothing=0.1, ignore_index=-1, reduction="mean", forbidden_index=-1, reduce_dim=None, pad_id=pad_id, **kwargs): super(ReducedLabelSmoothingLoss, self).__init__(nclass, label_smoothing=label_smoothing, ignore_index=ignore_index, reduction=reduction, forbidden_index=forbidden_index) - self.reduce_dim = reduce_dim + self.reduce_dim, self.pad_id = reduce_dim, pad_id - def forward(self, input, target, mask=None): + def forward(self, input, target, mask=None, pad_id=None, **kwargs): if self.reduce_dim is not None: - input, target = clear_pad_mask([input, target], target.eq(0), [self.reduce_dim - 1, self.reduce_dim], mask_dim=self.reduce_dim, return_contiguous=True)[0] + input, target = clear_pad_mask([input, target], target.eq(self.pad_id if pad_id is None else pad_id), [self.reduce_dim - 1, self.reduce_dim], mask_dim=self.reduce_dim, return_contiguous=True)[0] _input = input.view(-1, input.size(-1)) if input.dim() > 2 else input _target = target.view(-1, 1) diff --git a/loss/mulang.py b/loss/mulang.py index aaff26f..a84b098 100644 --- a/loss/mulang.py +++ b/loss/mulang.py @@ -1,20 +1,18 @@ #encoding: utf-8 -import torch from torch.nn.functional import kl_div -from utils.base import eq_indexes - from loss.base import MultiLabelSmoothingLoss as MultiLabelSmoothingLossBase +from utils.base import eq_indexes class MultiLabelSmoothingLoss(MultiLabelSmoothingLossBase): def __init__(self, *inputs, **kwargs): super(MultiLabelSmoothingLoss, self).__init__(*inputs, **kwargs) - self.register_buffer("weight", self.weight.squeeze(1)) + self.register_buffer("weight", self.weight.squeeze(1), persistent=False) - def forward(self, input, target, tinput, mask=None): + def forward(self, input, target, tinput, mask=None, **kwargs): _rsize = list(input.size()) _nclass = _rsize[-1] diff --git a/loss/regression.py b/loss/regression.py new file mode 100644 index 0000000..c6a9ba8 --- /dev/null +++ b/loss/regression.py @@ -0,0 +1,23 @@ +#encoding: utf-8 + +import torch +from torch.nn.functional import kl_div +from torch.nn.modules.loss import _Loss + +class KLRegLoss(_Loss): + + def __init__(self, reg_right=True, weight=None, bias=None, reduction="mean", **kwargs): + + super(KLRegLoss, self).__init__() + self.weight, self.bias, self.reduction = weight, bias, reduction + self.p_func = (lambda x: torch.cat((1.0 - x, x,), dim=-1)) if reg_right else (lambda x: torch.cat((x, 1.0 - x,), dim=-1)) + + def forward(self, input, target, **kwargs): + + _target = target + if self.bias is not None: + _target = _target + self.bias + if self.weight is not None: + _target = _target * self.weight + + return kl_div(input, self.p_func(_target.unsqueeze(-1)), reduction=self.reduction) diff --git a/lrsch.py b/lrsch.py index 6c28e31..15dd558 100644 --- a/lrsch.py +++ b/lrsch.py @@ -1,15 +1,17 @@ #encoding: utf-8 -from torch.optim.lr_scheduler import _LRScheduler from math import sqrt +from torch.optim.lr_scheduler import _LRScheduler + +from cnfg.ihyp import init_lr class GoogleLR(_LRScheduler): - def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1): + def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1, **kwargs): self.cur_step, self.warm_steps = 0, warm_steps self.k = scale / sqrt(dmodel) - self.wk = self.k / sqrt(warm_steps) / warm_steps + self.wk = self.k / (warm_steps ** 1.5) super(GoogleLR, self).__init__(optimizer, last_epoch=last_epoch) def get_lr(self): @@ -22,7 +24,7 @@ def get_lr(self): # inverse square root with warm up, portal from: https://github.com/pytorch/fairseq/blob/master/fairseq/optim/lr_scheduler/inverse_square_root_schedule.py, equal to GoogleLR when warm_end_lr = 1.0 / sqrt(dmodel * warm_steps) class WarmUpInverseSqrtLR(_LRScheduler): - def __init__(self, optimizer, warm_end_lr, warm_steps, warm_init_lr=0.0, last_epoch=-1): + def __init__(self, optimizer, warm_end_lr, warm_steps, warm_init_lr=0.0, last_epoch=-1, **kwargs): self.cur_step, self.warm_end_lr, self.warm_steps, self.warm_init_lr = 0, warm_end_lr, warm_steps, warm_init_lr self.lr_step = (warm_end_lr - warm_init_lr) / warm_steps @@ -40,13 +42,13 @@ def get_lr(self): """ class GoogleLR(WarmUpInverseSqrtLR): - def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1): + def __init__(self, optimizer, dmodel, warm_steps, scale=1.0, last_epoch=-1, **kwargs): super(GoogleLR, self).__init__(optimizer, scale / sqrt(dmodel * warm_steps), warm_steps, warm_init_lr=0.0, last_epoch=last_epoch)""" class InverseSqrtLR(_LRScheduler): - def __init__(self, optimizer, lr=1e-4, scalar=1.0, min_lr=None, last_epoch=-1): + def __init__(self, optimizer, lr=1e-4, scalar=1.0, min_lr=None, last_epoch=-1, **kwargs): self.cur_step = 0 self.base_lr = lr @@ -60,3 +62,19 @@ def get_lr(self): cur_lr = max(min(1.0, 1.0 / sqrt(self.cur_step / self.epoch_steps)), self.min_lr) * self.base_lr return [cur_lr for i in range(len(self.base_lrs))] + +class CustLR(_LRScheduler): + + def __init__(self, optimizer, lr_func=lambda a, b: (init_lr, b,), ctx=None, last_epoch=-1, **kwargs): + + self.cur_step = 0 + self.lr_func = lr_func + self.ctx = ctx + super(CustLR, self).__init__(optimizer, last_epoch=last_epoch) + + def get_lr(self): + + self.cur_step += 1 + cur_lr, self.ctx = self.lr_func(self.cur_step, self.ctx) + + return [cur_lr for i in range(len(self.base_lrs))] diff --git a/mkcy.py b/mkcy.py index d3fff0a..4f251a8 100644 --- a/mkcy.py +++ b/mkcy.py @@ -1,12 +1,10 @@ #encoding: utf-8 -from distutils.core import setup -from distutils.extension import Extension - from Cython.Build import cythonize -from Cython.Distutils import build_ext from Cython.Compiler import Options - +from Cython.Distutils import build_ext +from distutils.core import setup +from distutils.extension import Extension from os import walk from os.path import join as pjoin @@ -55,7 +53,7 @@ def walk_path(ptw, eccargs): eccargs = ["-Ofast", "-march=native", "-pipe", "-fomit-frame-pointer"] - baselist = ["lrsch.py", "translator.py"] + baselist = ["lrsch.py"] extlist = [Extension(get_name(pyf), [pyf], extra_compile_args=eccargs) for pyf in baselist] for _mp in ("parallel/", "loss/", "optm/", "modules/", "transformer/", "utils/", "datautils/",): _tmp = walk_path(_mp, eccargs) diff --git a/modules/LD.py b/modules/LD.py index eea9501..b7cf452 100644 --- a/modules/LD.py +++ b/modules/LD.py @@ -3,14 +3,14 @@ import torch from torch import nn -from modules.base import Scorer, Linear, Dropout from modules.act import Custom_Act +from modules.base import Dropout, Linear, Scorer from cnfg.ihyp import * class ATTNCombiner(nn.Module): - def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default): + def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default, **kwargs): super(ATTNCombiner, self).__init__() @@ -18,7 +18,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_defaul self.net = nn.Sequential(Linear(isize * 2, _hsize), Dropout(dropout, inplace=True), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) if dropout > 0.0 else nn.Sequential(Linear(isize * 2, _hsize), Custom_Act() if custom_act else nn.Sigmoid(), Scorer(_hsize), nn.Sigmoid()) - def forward(self, input1, input2, mask=None): + def forward(self, input1, input2, mask=None, **kwargs): scores = self.net(torch.cat((input1.expand_as(input2), input2,), dim=-1)) @@ -37,7 +37,7 @@ def forward(self, input1, input2, mask=None): class DATTNCombiner(nn.Module): - def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default): + def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_default, **kwargs): super(DATTNCombiner, self).__init__() @@ -48,7 +48,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, custom_act=use_adv_act_defaul # input1: (bsize, 1, isize) # input2: (bsize, seql, isize) # mask: (bsize, seql, 1) - def forward(self, input1, input2, mask=None): + def forward(self, input1, input2, mask=None, **kwargs): # scores: (bsize, seql, 1) scores = self.net(torch.cat((input1.expand_as(input2), input2,), dim=-1)) diff --git a/modules/TA.py b/modules/TA.py index 7541227..a51231a 100644 --- a/modules/TA.py +++ b/modules/TA.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase +from modules.base import PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase from cnfg.ihyp import * @@ -49,10 +49,10 @@ class PositionwiseFF(PositionwiseFFBase): # isize: input dimension # hsize: hidden dimension - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=True, **kwargs): + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=True, **kwargs): - super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, norm_residual=True, **kwargs) + super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, act_drop=act_drop, norm_residual=True, **kwargs) - def forward(self, x): + def forward(self, x, **kwargs): return self.normer(self.net(x) + x) diff --git a/modules/aan.py b/modules/aan.py index 4c351d0..1a7152b 100644 --- a/modules/aan.py +++ b/modules/aan.py @@ -2,11 +2,14 @@ import torch from torch import nn -from modules.base import Linear, Dropout, Custom_Act + +from modules.base import Custom_Act, Dropout, Linear +from utils.fmt.parser import parse_none from cnfg.ihyp import * -# Average Attention is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) +# Average Attention is proposed in Accelerating Neural Transformer via an Average Attention Network (https://aclanthology.org/P18-1166/) + class AverageAttn(nn.Module): # isize: input size of Feed-forward NN @@ -15,14 +18,14 @@ class AverageAttn(nn.Module): # enable_ffn: using FFN to process the average bag-of-words representation # num_pos: maximum length of sentence cached, extended length will be generated while needed and droped immediately after that - def __init__(self, isize, hsize=None, dropout=0.0, enable_ffn=False, num_pos=cache_len_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default): + def __init__(self, isize, hsize=None, dropout=0.0, enable_ffn=False, num_pos=cache_len_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, **kwargs): super(AverageAttn, self).__init__() - _hsize = isize if hsize is None else hsize + _hsize = parse_none(hsize, isize) self.num_pos = num_pos - self.register_buffer("w", torch.Tensor(num_pos, 1)) + self.register_buffer("w", torch.Tensor(num_pos, 1), persistent=False) if enable_ffn: self.ffn = nn.Sequential(Linear(isize, _hsize, bias=enable_bias), nn.LayerNorm(_hsize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Dropout(dropout, inplace=inplace_after_Custom_Act), Linear(_hsize, isize, bias=enable_proj_bias), Dropout(dropout, inplace=True)) if dropout > 0.0 else nn.Sequential(Linear(isize, _hsize, bias=enable_bias), nn.LayerNorm(_hsize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_proj_bias)) @@ -38,7 +41,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, enable_ffn=False, num_pos=cac # iV: values (bsize, seql, vsize) # decoding: training state or decoding state - def forward(self, iQ, iV, decoding=False): + def forward(self, iQ, iV, decoding=False, **kwargs): if decoding: avg = iV diff --git a/modules/act.py b/modules/act.py index 7b3196a..344a619 100644 --- a/modules/act.py +++ b/modules/act.py @@ -1,13 +1,13 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn from torch.autograd import Function from torch.nn import functional as nnFunc from utils.base import reduce_model_list - -from math import sqrt +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * @@ -21,7 +21,7 @@ def __init__(self): self.k = sqrt(2.0 / pi) - def forward(self, x): + def forward(self, x, **kwargs): return 0.5 * x * (1.0 + (self.k * (x + 0.044715 * x.pow(3.0))).tanh()) @@ -33,7 +33,7 @@ def __init__(self): self.k = sqrt(2.0) - def forward(self, x): + def forward(self, x, **kwargs): return 0.5 * x * (1.0 + (x / self.k).erf()) @@ -46,7 +46,7 @@ def forward(self, x): # GELU is nonmonotonic function that has a shape similar to Swish with beta = 1.4 (https://arxiv.org/abs/1710.05941). class CustSwish(nn.Module): - def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if adv_act == "normswish" else None, eps=ieps_default): + def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if adv_act == "normswish" else None, eps=ieps_default, **kwargs): super(CustSwish, self).__init__() @@ -58,7 +58,7 @@ def __init__(self, beta=1.0, freeze_beta=True, isize=None, dim=-1 if adv_act == self.beta = nn.Parameter(torch.as_tensor([beta])) if isize is None else nn.Parameter(torch.as_tensor([beta]).repeat(isize)) self.dim, self.eps = dim, eps - def forward(self, x): + def forward(self, x, **kwargs): if self.dim is None: _norm_x = x @@ -70,7 +70,7 @@ def forward(self, x): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): if self.reset_beta is not None: self.beta.fill_(self.reset_beta) @@ -81,19 +81,19 @@ def fix_init(self): class SReLU(nn.Module): - def __init__(self, inplace=False, k=2.0): + def __init__(self, inplace=False, k=2.0, **kwargs): super(SReLU, self).__init__() self.inplace, self.k = inplace, k - def forward(self, x): + def forward(self, x, **kwargs): return nnFunc.relu(x, inplace=self.inplace).pow(self.k) class CustMish(nn.Module): - def forward(self, x): + def forward(self, x, **kwargs): return x * nnFunc.softplus(x).tanh() @@ -104,13 +104,13 @@ def forward(self, x): class LGLU(nn.Module): - def __init__(self, dim=-1): + def __init__(self, dim=-1, **kwargs): super(LGLU, self).__init__() self.dim = dim - def forward(self, x): + def forward(self, x, **kwargs): _h, _t = x.tensor_split(2, self.dim) @@ -118,14 +118,14 @@ def forward(self, x): class GLU_Act(LGLU): - def __init__(self, act=None, dim=-1): + def __init__(self, act=None, dim=-1, **kwargs): super(GLU_Act, self).__init__() self.dim = dim self.act = nn.Sigmoid() if act is None else act - def forward(self, x): + def forward(self, x, **kwargs): _h, _t = x.tensor_split(2, self.dim) @@ -133,7 +133,7 @@ def forward(self, x): class GEGLU(GLU_Act): - def __init__(self, dim=-1): + def __init__(self, dim=-1, **kwargs): _act = GELU() super(GEGLU, self).__init__(act=_act, dim=dim) @@ -199,12 +199,12 @@ def backward(ctx, grad_output): class Sparsemax(nn.Module): - def __init__(self, dim=-1): + def __init__(self, dim=-1, **kwargs): super(Sparsemax, self).__init__() self.dim = dim - def forward(self, input): + def forward(self, input, **kwargs): return SparsemaxFunction.apply(input, self.dim) diff --git a/modules/attn/rap.py b/modules/attn/rap.py index fa20e2b..9d3943d 100644 --- a/modules/attn/rap.py +++ b/modules/attn/rap.py @@ -1,18 +1,17 @@ #encoding: utf-8 -from math import sqrt - import torch +from math import sqrt from torch import nn from torch.autograd import Function -from modules.base import CrossAttn as CrossAttnBase, SelfAttn as SelfAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase +from modules.base import CrossAttn as CrossAttnBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase, SelfAttn as SelfAttnBase from cnfg.ihyp import * class SelfAttn(SelfAttnBase): - def forward(self, iQ, mask=None, states=None): + def forward(self, iQ, mask=None, states=None, **kwargs): bsize, nquery = iQ.size()[:2] nheads = self.num_head @@ -72,12 +71,12 @@ def load_base(self, base_module): if self.rel_pemb is not None: self.k_rel_pos, self.xseql = base_module.k_rel_pos, base_module.xseql self.ref_rel_posm = base_module.ref_rel_posm - self.register_buffer("rel_pos", base_module.rel_pos) - self.register_buffer("rel_pos_cache", base_module.rel_pos_cache) + self.register_buffer("rel_pos", base_module.rel_pos, persistent=False) + self.register_buffer("rel_pos_cache", base_module.rel_pos_cache, persistent=False) class CrossAttn(CrossAttnBase): - def forward(self, iQ, iK, mask=None): + def forward(self, iQ, iK, mask=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -85,12 +84,12 @@ def forward(self, iQ, iK, mask=None): adim = self.attn_dim real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK, real_iV = self.real_iK, self.real_iV else: real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) - if not self.training: + if self.is_decoding: self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) / sqrt(adim) diff --git a/modules/attn/res.py b/modules/attn/res.py index c2be122..c5bf856 100644 --- a/modules/attn/res.py +++ b/modules/attn/res.py @@ -1,16 +1,15 @@ #encoding: utf-8 import torch -from torch.nn import functional as nnFunc from math import sqrt -from modules.base import SelfAttn as SelfAttnBase, CrossAttn as CrossAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase +from modules.base import CrossAttn as CrossAttnBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase, SelfAttn as SelfAttnBase from cnfg.ihyp import * class SelfAttn(SelfAttnBase): - def forward(self, iQ, mask=None, states=None, resin=None): + def forward(self, iQ, mask=None, states=None, resin=None, **kwargs): bsize, nquery = iQ.size()[:2] nheads = self.num_head @@ -61,7 +60,7 @@ def forward(self, iQ, mask=None, states=None, resin=None): class CrossAttn(CrossAttnBase): - def forward(self, iQ, iK, mask=None, resin=None): + def forward(self, iQ, iK, mask=None, resin=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -69,12 +68,12 @@ def forward(self, iQ, iK, mask=None, resin=None): adim = self.attn_dim real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK, real_iV = self.real_iK, self.real_iV else: real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) - if not self.training: + if self.is_decoding: self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) / sqrt(adim) diff --git a/modules/attn/retr.py b/modules/attn/retr.py index 1da5843..df67f2f 100644 --- a/modules/attn/retr.py +++ b/modules/attn/retr.py @@ -1,15 +1,11 @@ #encoding: utf-8 import torch -from torch import nn -from torch.nn import functional as nnFunc -from utils.base import exist_any +from math import sqrt +from modules.base import CrossAttn as CrossAttnBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase, SelfAttn as SelfAttnBase from modules.sampler import Retriever - -from modules.base import SelfAttn as SelfAttnBase, CrossAttn as CrossAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase - -from math import sqrt +from utils.torch.comp import exist_any, torch_no_grad from cnfg.ihyp import * @@ -23,12 +19,12 @@ def __init__(self, *args, xseql=cache_len_default, smoothing=None, use_cumsum=Fa self.smoothing, self.use_cumsum = smoothing if (smoothing is not None) and (smoothing > 0.0) and (smoothing < 1.0) else None, use_cumsum if self.use_cumsum and (self.smoothing is not None): self.num_pos = xseql - self.register_buffer("csum", torch.Tensor(xseql, 1)) + self.register_buffer("csum", torch.Tensor(xseql, 1), persistent=False) self.reset_parameters() else: - self.register_buffer("csum", None) + self.register_buffer("csum", None, persistent=False) - def forward(self, iQ, mask=None, states=None): + def forward(self, iQ, mask=None, states=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = nquery @@ -116,7 +112,7 @@ def reset_parameters(self): def get_ext(self, npos): _rs = torch.arange(npos, dtype=self.csum.dtype, device=self.csum.device) - with torch.no_grad(): + with torch_no_grad(): _rs[0] = 1.0 return _rs.unsqueeze(-1) @@ -130,7 +126,7 @@ def __init__(self, *args, smoothing=None, **kwargs): self.retriever = Retriever() self.smoothing = smoothing if (smoothing is not None) and (smoothing > 0.0) and (smoothing < 1.0) else None - def forward(self, iQ, iK, mask=None): + def forward(self, iQ, iK, mask=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -138,12 +134,12 @@ def forward(self, iQ, iK, mask=None): adim = self.attn_dim real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK, real_iV = self.real_iK, self.real_iV else: real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2).contiguous() - if not self.training: + if self.is_decoding: self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) diff --git a/modules/base.py b/modules/base.py index 0dd2ccd..e7bf976 100644 --- a/modules/base.py +++ b/modules/base.py @@ -1,18 +1,19 @@ #encoding: utf-8 -from math import sqrt, log, exp import torch +from math import exp, log, sqrt from torch import nn -from torch.nn import functional as nnFunc from torch.autograd import Function from torch.utils.cpp_extension import load -from utils.base import reduce_model_list, repeat_bsize_for_beam_tensor -from utils.relpos.bucket import build_rel_pos_bucket_map, build_rel_pos_bucket from modules.act import Custom_Act, LGLU, get_act, reduce_model as reduce_model_act from modules.dropout import Dropout, reduce_model as reduce_model_drop - -from utils.pyctorch import transfer_CNone_tuple +from utils.base import reduce_model_list +from utils.decode.beam import repeat_bsize_for_beam_tensor +from utils.fmt.parser import parse_none +from utils.relpos.bucket import build_rel_pos_bucket, build_rel_pos_bucket_map +from utils.torch.comp import torch_no_grad +from utils.torch.pyc import transfer_CNone_tuple from cnfg.ihyp import * @@ -23,11 +24,12 @@ class PositionwiseFF(nn.Module): # isize: input dimension # hsize: hidden dimension - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn): + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn, **kwargs): super(PositionwiseFF, self).__init__() _hsize = isize * 4 if hsize is None else hsize + _act_drop = parse_none(act_drop, dropout) if (use_glu is not None) and (_hsize % 2 == 1): _hsize += 1 @@ -49,7 +51,8 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_d _.append(Linear(_hsize // 2, isize, bias=enable_bias)) if dropout > 0.0: _.append(Dropout(dropout, inplace=True)) - _.insert(_drop_ind, Dropout(dropout, inplace=inplace_after_Custom_Act)) + if _act_drop > 0.0: + _.insert(_drop_ind, Dropout(_act_drop, inplace=inplace_after_Custom_Act)) self.net = nn.Sequential(*_) self.normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) @@ -59,7 +62,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_d if self.c_available() and (use_glu is None): self.c_init() - def forward(self, x): + def forward(self, x, **kwargs): _out = self.normer(x) @@ -126,7 +129,7 @@ class PositionalEmb(nn.Module): # pos_offset: initial offset for position # dim_offset: initial offset for dimension - def __init__(self, num_dim, num_pos=cache_len_default, pos_offset=0, dim_offset=0, alpha=1.0): + def __init__(self, num_dim, num_pos=cache_len_default, pos_offset=0, dim_offset=0, alpha=1.0, **kwargs): super(PositionalEmb, self).__init__() @@ -135,12 +138,12 @@ def __init__(self, num_dim, num_pos=cache_len_default, pos_offset=0, dim_offset= self.poff = pos_offset self.doff = dim_offset self.alpha = alpha - self.register_buffer("w", torch.Tensor(num_pos, num_dim)) + self.register_buffer("w", torch.Tensor(num_pos, num_dim), persistent=False) self.reset_parameters() # x: input (bsize, seql) - def forward(self, x, expand=True): + def forward(self, x, expand=True, **kwargs): bsize, seql = x.size() @@ -197,7 +200,7 @@ class MultiHeadAttn(nn.Module): # sparsenorm: using sparse normer or standard softmax # bind_qk: query and key can share a same linear transformation for the Reformer: The Efficient Transformer (https://arxiv.org/abs/2001.04451) paper. - def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=0, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, max_bucket_distance=0, sparsenorm=False, bind_qk=False, xseql=cache_len_default): + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=0, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, max_bucket_distance=0, sparsenorm=False, bind_qk=False, xseql=cache_len_default, is_decoding=False, **kwargs): super(MultiHeadAttn, self).__init__() @@ -206,7 +209,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v self.num_head = num_head self.query_adaptor = Linear(isize, self.hsize, bias=enable_proj_bias) - _k_isize = isize if k_isize is None else k_isize + _k_isize = parse_none(k_isize, isize) self.key_adaptor = self.query_adaptor if bind_qk and isize == _k_isize else Linear(_k_isize, self.hsize, bias=enable_proj_bias) self.value_adaptor = Linear(_k_isize if v_isize is None else v_isize, self.hsize, bias=enable_proj_bias) @@ -220,8 +223,8 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v if k_rel_pos > 0: self.rel_shift = k_rel_pos if max_bucket_distance > 0: - self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction)) - self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction, dis_map=self.rel_pos_map)) + self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction), persistent=False) + self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction, dis_map=self.rel_pos_map), persistent=False) self.rel_pemb = nn.Embedding((k_rel_pos + 1) if uni_direction else (k_rel_pos + k_rel_pos + 1), self.num_head) self.clamp_max, self.clamp_min = max_bucket_distance, uni_direction_reduction else: @@ -246,19 +249,20 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v self.clamp_min, self.clamp_max = -k_rel_pos, k_rel_pos self.rel_pemb = nn.Embedding(_n_pemb, self.attn_dim, padding_idx=padding_idx) _rpm = torch.arange(0, xseql, dtype=torch.long) - self.register_buffer("rel_pos", (_rpm.unsqueeze(0) - _rpm.unsqueeze(1)).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) - self.register_buffer("rel_pos_map", None) + self.register_buffer("rel_pos", (_rpm.unsqueeze(0) - _rpm.unsqueeze(1)).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift, persistent=False) + self.register_buffer("rel_pos_map", None, persistent=False) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos self.ref_rel_posm = None - self.register_buffer("rel_pos_cache", None) + self.register_buffer("rel_pos_cache", None, persistent=False) else: self.rel_pemb = None - self.register_buffer("real_iK", None) - self.register_buffer("real_iV", None) - self.register_buffer("iK", None) - self.register_buffer("iV", None) + self.register_buffer("real_iK", None, persistent=False) + self.register_buffer("real_iV", None, persistent=False) + self.register_buffer("iK", None, persistent=False) + self.register_buffer("iV", None, persistent=False) + self.is_decoding = is_decoding if self.c_available(): self.c_init() @@ -268,7 +272,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, v # iV: values (bsize, seql, vsize) # mask (bsize, num_query, seql) - def forward(self, iQ, iK, iV, mask=None, states=None): + def forward(self, iQ, iK, iV, mask=None, states=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -281,11 +285,11 @@ def forward(self, iQ, iK, iV, mask=None, states=None): real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK = self.real_iK else: real_iK = self.key_adaptor(iK).view(bsize, seql, nheads, adim).permute(0, 2, 3, 1) - if not self.training: + if self.is_decoding: self.iK, self.real_iK = iK, real_iK if (self.real_iV is not None) and self.iV.is_set_to(iV) and (not self.training): real_iV = self.real_iV @@ -428,7 +432,7 @@ def c_process_output(self, rs, iK, iV, states=None): # Accelerated MultiHeadAttn for self attention, use when Q == K == V class SelfAttn(nn.Module): - def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, max_bucket_distance=0, sparsenorm=False, xseql=cache_len_default): + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, max_bucket_distance=0, sparsenorm=False, xseql=cache_len_default, **kwargs): super(SelfAttn, self).__init__() @@ -448,8 +452,8 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena if k_rel_pos > 0: self.rel_shift = k_rel_pos if max_bucket_distance > 0: - self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction)) - self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction, dis_map=self.rel_pos_map)) + self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction), persistent=False) + self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=uni_direction_reduction, dis_map=self.rel_pos_map), persistent=False) self.rel_pemb = nn.Embedding((k_rel_pos + 1) if uni_direction_reduction else (k_rel_pos + k_rel_pos + 1), self.num_head) self.clamp_max, self.clamp_min = max_bucket_distance, uni_direction_reduction else: @@ -474,19 +478,19 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena self.clamp_min, self.clamp_max = -k_rel_pos, k_rel_pos self.rel_pemb = nn.Embedding(_n_pemb, self.attn_dim, padding_idx=padding_idx) _rpm = torch.arange(0, xseql, dtype=torch.long) - self.register_buffer("rel_pos", (_rpm.unsqueeze(0) - _rpm.unsqueeze(1)).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift) - self.register_buffer("rel_pos_map", None) + self.register_buffer("rel_pos", (_rpm.unsqueeze(0) - _rpm.unsqueeze(1)).clamp(min=self.clamp_min, max=self.clamp_max) + self.rel_shift, persistent=False) + self.register_buffer("rel_pos_map", None, persistent=False) self.xseql = xseql # the buffer can be shared inside the encoder or the decoder across layers for saving memory, by setting self.ref_rel_posm of self attns in deep layers to SelfAttn in layer 0, and sharing corresponding self.rel_pos self.ref_rel_posm = None - self.register_buffer("rel_pos_cache", None) + self.register_buffer("rel_pos_cache", None, persistent=False) else: self.rel_pemb = None if self.c_available(): self.c_init() - def forward(self, iQ, mask=None, states=None): + def forward(self, iQ, mask=None, states=None, **kwargs): bsize, nquery = iQ.size()[:2] nheads = self.num_head @@ -598,7 +602,7 @@ def c_process_output(self, rs, states=None): # Accelerated MultiHeadAttn for cross attention, use when K == V class CrossAttn(nn.Module): - def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, sparsenorm=False): + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, sparsenorm=False, is_decoding=False, **kwargs): super(CrossAttn, self).__init__() @@ -617,14 +621,15 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, e self.drop = Dropout(dropout, inplace=sparsenorm) if dropout > 0.0 else None - self.register_buffer("real_iK", None) - self.register_buffer("real_iV", None) - self.register_buffer("iK", None) + self.register_buffer("real_iK", None, persistent=False) + self.register_buffer("real_iV", None, persistent=False) + self.register_buffer("iK", None, persistent=False) + self.is_decoding = is_decoding if self.c_available(): self.c_init() - def forward(self, iQ, iK, mask=None): + def forward(self, iQ, iK, mask=None, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -632,12 +637,12 @@ def forward(self, iQ, iK, mask=None): adim = self.attn_dim real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK, real_iV = self.real_iK, self.real_iV else: real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) - if not self.training: + if self.is_decoding: self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) / sqrt(adim) @@ -658,6 +663,7 @@ def train(self, mode=True): if mode: self.reset_buffer() + self.is_decoding = not mode return self @@ -1011,7 +1017,7 @@ class ResidueCombiner(nn.Module): # isize: input size of Feed-forward NN - def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, **kwargs): super(ResidueCombiner, self).__init__() @@ -1022,7 +1028,7 @@ def __init__(self, isize, ncomb=2, hsize=None, dropout=0.0, custom_act=use_adv_a self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - def forward(self, *xl): + def forward(self, *xl, **kwargs): # faster only when len(xl) is very large #out = torch.stack([self.net(torch.cat(xl, -1))] + list(xl), -2).sum(-2) @@ -1034,14 +1040,14 @@ def forward(self, *xl): class Scorer(nn.Module): - def __init__(self, isize, bias=True): + def __init__(self, isize, bias=True, **kwargs): super(Scorer, self).__init__() self.w = nn.Parameter(torch.Tensor(isize).uniform_(- sqrt(1.0 / isize), sqrt(1.0 / isize))) self.bias = nn.Parameter(torch.zeros(1)) if bias else None - def forward(self, x): + def forward(self, x, **kwargs): xsize = x.size() @@ -1054,7 +1060,7 @@ def forward(self, x): class NDWrapper(nn.Module): - def __init__(self, module, num_dim): + def __init__(self, module, num_dim, **kwargs): super(NDWrapper, self).__init__() @@ -1081,7 +1087,7 @@ def forward(ctx, inputs, adv_weight=1.0): @staticmethod def backward(ctx, grad_outputs): - if grad_outputs is not None and ctx.needs_input_grad[0]: + if (grad_outputs is not None) and ctx.needs_input_grad[0]: _adv_weight = ctx.adv_weight return -grad_outputs if _adv_weight == 1.0 else (grad_outputs * -_adv_weight), None else: @@ -1091,13 +1097,13 @@ def backward(ctx, grad_outputs): class GradientReversalLayer(nn.Module): - def __init__(self, adv_weight=1.0): + def __init__(self, adv_weight=1.0, **kwargs): super(GradientReversalLayer, self).__init__() self.adv_weight = adv_weight - def forward(self, *inputs): + def forward(self, *inputs, **kwargs): return tuple(GradientReversalFunc(inputu, self.adv_weight) for inputu in inputs) if len(inputs) > 1 else GradientReversalFunc(inputs[0], self.adv_weight) @@ -1128,18 +1134,18 @@ def __init__(self): super(ACT_Loss, self).__init__() - def forward(self, weight, weight_loss, remain_value): + def forward(self, weight, weight_loss, remain_value, **kwargs): return ACTLossFunction.apply(weight, weight_loss, remain_value) class ApproximateEmb(nn.Module): - def __init__(self, weight): + def __init__(self, weight, **kwargs): super(ApproximateEmb, self).__init__() self.weight = weight - def forward(self, inpute): + def forward(self, inpute, **kwargs): isize = list(inpute.size()) out = inpute.view(-1, isize[-1]) @@ -1152,7 +1158,7 @@ class SparseNormer(nn.Module): # dim: dimension to normalize - def __init__(self, dim=-1, eps=ieps_default): + def __init__(self, dim=-1, eps=ieps_default, **kwargs): super(SparseNormer, self).__init__() @@ -1161,7 +1167,7 @@ def __init__(self, dim=-1, eps=ieps_default): self.act = nn.ReLU(inplace=True) self.eps = eps - def forward(self, x): + def forward(self, x, **kwargs): _tmp = self.act(x + self.bias) _tmp = _tmp * _tmp @@ -1174,7 +1180,7 @@ class MHSparseNormer(nn.Module): # nheads: number of heads # dim: dimension to normalize - def __init__(self, nheads, dim=-1, eps=ieps_default): + def __init__(self, nheads, dim=-1, eps=ieps_default, **kwargs): super(MHSparseNormer, self).__init__() @@ -1184,7 +1190,7 @@ def __init__(self, nheads, dim=-1, eps=ieps_default): self.eps = eps # input should be: (bsize, nheads, nquery, seql) - def forward(self, x): + def forward(self, x, **kwargs): _tmp = self.act(x + self.bias) _tmp = _tmp * _tmp @@ -1194,12 +1200,12 @@ def forward(self, x): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.bias.data.zero_() class MHAttnSummer(nn.Module): - def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0): + def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0, **kwargs): super(MHAttnSummer, self).__init__() @@ -1207,13 +1213,13 @@ def __init__(self, isize, ahsize=None, num_head=8, attn_drop=0.0): self.attn = CrossAttn(isize, isize if ahsize is None else ahsize, isize, num_head, dropout=attn_drop) # x: (bsize, seql, isize) - def forward(self, x): + def forward(self, x, **kwargs): return self.attn(self.w, x).squeeze(1) class FertSummer(nn.Module): - def __init__(self, isize): + def __init__(self, isize, **kwargs): super(FertSummer, self).__init__() @@ -1221,7 +1227,7 @@ def __init__(self, isize): self.normer = nn.Softmax(dim=1) # x: (bsize, seql, isize) - def forward(self, x, mask=None): + def forward(self, x, mask=None, **kwargs): _weight = self.net(x) if mask is not None: @@ -1238,7 +1244,7 @@ class CoordinateEmb(nn.Module): # pos_offset: initial offset for position # dim_offset: initial offset for dimension - def __init__(self, num_dim, num_pos=cache_len_default, num_steps=8, pos_offset=0, dim_offset=0, alpha=1.0): + def __init__(self, num_dim, num_pos=cache_len_default, num_steps=8, pos_offset=0, dim_offset=0, alpha=1.0, **kwargs): super(CoordinateEmb, self).__init__() @@ -1248,12 +1254,12 @@ def __init__(self, num_dim, num_pos=cache_len_default, num_steps=8, pos_offset=0 self.poff = pos_offset self.doff = dim_offset self.alpha = alpha - self.register_buffer("w", torch.Tensor(num_steps, num_pos, num_dim)) + self.register_buffer("w", torch.Tensor(num_steps, num_pos, num_dim), persistent=False) self.reset_parameters() # x: input (bsize, seql) - def forward(self, x, step, expand=True): + def forward(self, x, step, expand=True, **kwargs): bsize, seql = x.size()[:2] @@ -1305,7 +1311,7 @@ def get_pos(self, step, layer): class Temperature(nn.Module): - def __init__(self, isize, minv=0.125): + def __init__(self, isize, minv=0.125, **kwargs): super(Temperature, self).__init__() @@ -1315,7 +1321,7 @@ def __init__(self, isize, minv=0.125): self.k = nn.Parameter(torch.ones(1)) self.minv = minv - def forward(self, x): + def forward(self, x, **kwargs): xsize = x.size() @@ -1328,7 +1334,7 @@ def forward(self, x): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.k.data.fill_(1.0) self.bias.data.zero_() diff --git a/modules/cpp/act/setup.py b/modules/cpp/act/setup.py index 386f16d..a0a5bfb 100644 --- a/modules/cpp/act/setup.py +++ b/modules/cpp/act/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="act_cpp", ext_modules=[cpp_extension.CppExtension("act_cpp", ["modules/cpp/act/act.cpp", "modules/cpp/act/act_func.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/cross/setup.py b/modules/cpp/base/attn/cross/setup.py index a98fa27..ff9c607 100644 --- a/modules/cpp/base/attn/cross/setup.py +++ b/modules/cpp/base/attn/cross/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="cross_attn_cpp", ext_modules=[cpp_extension.CppExtension("cross_attn_cpp", ["modules/cpp/base/attn/cross/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/self/setup.py b/modules/cpp/base/attn/self/setup.py index e900b39..2db210b 100644 --- a/modules/cpp/base/attn/self/setup.py +++ b/modules/cpp/base/attn/self/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="self_attn_cpp", ext_modules=[cpp_extension.CppExtension("self_attn_cpp", ["modules/cpp/base/attn/self/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/attn/setup.py b/modules/cpp/base/attn/setup.py index 6bc0d03..a71735c 100644 --- a/modules/cpp/base/attn/setup.py +++ b/modules/cpp/base/attn/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="attn_cpp", ext_modules=[cpp_extension.CppExtension("attn_cpp", ["modules/cpp/base/attn/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/ffn/setup.py b/modules/cpp/base/ffn/setup.py index 3e7dc72..c3feb5e 100644 --- a/modules/cpp/base/ffn/setup.py +++ b/modules/cpp/base/ffn/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="pff_cpp", ext_modules=[cpp_extension.CppExtension("pff_cpp", ["modules/cpp/base/ffn/pff.cpp", "modules/cpp/base/ffn/pff_func.cpp", "modules/cpp/act/act_func.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/cross/setup.py b/modules/cpp/base/resattn/cross/setup.py index 9bdff98..756511f 100644 --- a/modules/cpp/base/resattn/cross/setup.py +++ b/modules/cpp/base/resattn/cross/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="res_cross_attn_cpp", ext_modules=[cpp_extension.CppExtension("res_cross_attn_cpp", ["modules/cpp/base/resattn/cross/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/self/setup.py b/modules/cpp/base/resattn/self/setup.py index 5f1d403..19abbff 100644 --- a/modules/cpp/base/resattn/self/setup.py +++ b/modules/cpp/base/resattn/self/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="res_self_attn_cpp", ext_modules=[cpp_extension.CppExtension("res_self_attn_cpp", ["modules/cpp/base/resattn/self/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/base/resattn/setup.py b/modules/cpp/base/resattn/setup.py index c3af675..1910754 100644 --- a/modules/cpp/base/resattn/setup.py +++ b/modules/cpp/base/resattn/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="res_attn_cpp", ext_modules=[cpp_extension.CppExtension("res_attn_cpp", ["modules/cpp/base/resattn/attn.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/group/setup.py b/modules/cpp/group/setup.py index 97380c0..644163a 100644 --- a/modules/cpp/group/setup.py +++ b/modules/cpp/group/setup.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="group_cpp", ext_modules=[cpp_extension.CppExtension("group_cpp", ["modules/cpp/group/group.cpp", "modules/cpp/group/group_func.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/cpp/hplstm/setup.py b/modules/cpp/hplstm/setup.py index 1dc4fd6..75399ed 100644 --- a/modules/cpp/hplstm/setup.py +++ b/modules/cpp/hplstm/setup.py @@ -1,6 +1,7 @@ #encoding: utf-8 -from setuptools import setup, Extension +from setuptools import setup from torch.utils import cpp_extension setup(name="lgate_cpp", ext_modules=[cpp_extension.CppExtension("lgate_cpp", ["modules/cpp/hplstm/lgate.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) +setup(name="lgate_nocx_cpp", ext_modules=[cpp_extension.CppExtension("lgate_nocx_cpp", ["modules/cpp/hplstm/lgate_nocx.cpp"])], cmdclass={"build_ext": cpp_extension.BuildExtension}) diff --git a/modules/dropout.py b/modules/dropout.py index bb57650..cf9d46f 100644 --- a/modules/dropout.py +++ b/modules/dropout.py @@ -1,27 +1,37 @@ #encoding: utf-8 -from torch import nn -from random import random +import torch from math import ceil +from random import random +from torch import nn -from utils.base import mask_tensor_type, reduce_model_list +from utils.base import reduce_model_list +from utils.torch.comp import mask_tensor_type Dropout = nn.Dropout class TokenDropout(Dropout): - def __init__(self, p=0.5, inplace=False, keep_magnitude=True): + def __init__(self, p=0.5, inplace=False, keep_magnitude=True, **kwargs): super(TokenDropout, self).__init__(p=p, inplace=inplace) self.keep_magnitude = (1.0 / (1.0 - self.p)) if keep_magnitude else False + self.register_buffer("pcache", torch.full((1,), self.p), persistent=False) - def forward(self, inpute): + def forward(self, inpute, **kwargs): if self.training: - mask = inpute.new_full(inpute.size()[:-1], self.p, requires_grad=False).bernoulli().to(mask_tensor_type, non_blocking=True).unsqueeze(-1) - out = inpute.masked_fill_(mask, 0.0) if self.inplace else inpute.masked_fill(mask, 0.0) - if self.keep_magnitude: - out = out * self.keep_magnitude + _ = inpute.dim() - 1 + _p = self.pcache.view([1 for i in range(_)]) if _ > 1 else self.pcache + mask = _p.expand(inpute.size()[:-1]).bernoulli().to(mask_tensor_type, non_blocking=True).unsqueeze(-1) + if self.inplace: + out = inpute.masked_fill_(mask, 0.0) + if self.keep_magnitude: + out.mul_(self.keep_magnitude) + else: + out = inpute.masked_fill(mask, 0.0) + if self.keep_magnitude: + out = out * self.keep_magnitude return out else: @@ -46,32 +56,40 @@ def sample(lin): class NGramDropout(Dropout): - def __init__(self, p=0.5, inplace=False, seqdim=1, sample_p=[1.0 / tmpu for tmpu in range(1, 3 + 1)], keep_magnitude=True): + def __init__(self, p=0.5, inplace=False, seqdim=1, sample_p=[1.0 / tmpu for tmpu in range(1, 3 + 1)], keep_magnitude=True, **kwargs): super(NGramDropout, self).__init__(p=p, inplace=inplace) self.seqdim = seqdim self.keep_magnitude = (1.0 / (1.0 - self.p)) if keep_magnitude else False self.sample_p = norm([float(pu) for pu in sample_p]) self.max_n = len(sample_p) + self.register_buffer("pcache", torch.full((1,), self.p), persistent=False) - def forward(self, inpute): + def forward(self, inpute, **kwargs): if self.training: seql = inpute.size(self.seqdim) ngram = sample(self.sample_p if seql > self.max_n else norm(self.sample_p[:seql - 1])) + 1 _msize = list(inpute.size())[:-1] + _ = len(_msize) + _p = self.pcache.view([1 for i in range(_)]) if _ > 1 else self.pcache if ngram > 1: nblock = ceil(float(seql) / float(ngram)) _msize[self.seqdim] = nblock - mask = inpute.new_full(_msize, self.p, requires_grad=False).bernoulli().to(mask_tensor_type, non_blocking=True).repeat([ngram if i == self.seqdim else 1 for i in range(len(_msize))]) + mask = _p.expand(_msize).bernoulli().to(mask_tensor_type, non_blocking=True).repeat([ngram if i == self.seqdim else 1 for i in range(len(_msize))]) if ngram * nblock != seql: mask = mask.narrow(self.seqdim, 0, seql) mask = mask.unsqueeze(-1) else: - mask = inpute.new_full(_msize, self.p, requires_grad=False).bernoulli().to(mask_tensor_type, non_blocking=True).unsqueeze(-1) - out = inpute.masked_fill_(mask, 0.0) if self.inplace else inpute.masked_fill(mask, 0.0) - if self.keep_magnitude: - out = out * self.keep_magnitude + mask = _p.expand(_msize).bernoulli().to(mask_tensor_type, non_blocking=True).unsqueeze(-1) + if self.inplace: + out = inpute.masked_fill_(mask, 0.0) + if self.keep_magnitude: + out.mul_(self.keep_magnitude) + else: + out = inpute.masked_fill(mask, 0.0) + if self.keep_magnitude: + out = out * self.keep_magnitude return out else: diff --git a/modules/group/base.py b/modules/group/base.py index 167c84f..f473bd1 100644 --- a/modules/group/base.py +++ b/modules/group/base.py @@ -1,10 +1,12 @@ #encoding: utf-8 -from math import sqrt import torch +from math import sqrt from torch import nn -from cnfg.ihyp import use_c_backend_group, bind_c_forward +from utils.torch.comp import torch_no_grad + +from cnfg.ihyp import bind_c_forward, use_c_backend_group class GroupLinear(nn.Module): @@ -16,7 +18,7 @@ class GroupLinear(nn.Module): # shuffle: shuffle across groups for output # flatten_output: concatenate outputs of groups - def __init__(self, isize, osize, ngroup, bias=True, trans_input=True, shuffle=False, flatten_output=True): + def __init__(self, isize, osize, ngroup, bias=True, trans_input=True, shuffle=False, flatten_output=True, **kwargs): super(GroupLinear, self).__init__() @@ -37,7 +39,7 @@ def __init__(self, isize, osize, ngroup, bias=True, trans_input=True, shuffle=Fa # inputu: (..., isize) - def forward(self, inputu, weight=None, bias=None): + def forward(self, inputu, weight=None, bias=None, **kwargs): _size = list(inputu.size()) _id = inputu.view(-1, self.ngroup, self.isize if self.trans_input else _size[-1]) if (inputu.dim() != 3) or self.trans_input else inputu @@ -64,7 +66,7 @@ def extra_repr(self): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.weight.uniform_(- sqrt(1.0 / self.isize), sqrt(1.0 / self.isize)) if self.bias is not None: self.bias.zero_() diff --git a/modules/hplstm/base.py b/modules/hplstm/base.py index b0dd739..f6578d6 100644 --- a/modules/hplstm/base.py +++ b/modules/hplstm/base.py @@ -2,20 +2,24 @@ import torch from torch import nn -from modules.base import Linear, Dropout -from modules.group.base import GroupLinear + from modules.act import Custom_Act +from modules.base import Dropout, Linear +from modules.group.base import GroupLinear from modules.hplstm.LGate import LGateFunc -from utils.base import float2odd, flip_mask +from utils.base import float2odd +from utils.fmt.parser import parse_none +from utils.torch.comp import flip_mask, torch_no_grad + from cnfg.ihyp import * class MHPLSTMCore(nn.Module): - def __init__(self, isize, num_head=8, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, num_head=8, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, **kwargs): super(MHPLSTMCore, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) i_head_dim = float2odd(float(isize) / num_head) i_hsize = i_head_dim * num_head @@ -36,7 +40,7 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, custom_act=use_ad # states: ((bsize, 1, num_head, head_dim), (bsize, 1, num_head, head_dim),) # head_mask: (bsize, seql, 1, 1) - def forward(self, heads_input, states=None, head_mask=None): + def forward(self, heads_input, states=None, head_mask=None, **kwargs): bsize, seql, nheads, adim = heads_input.size() if states is None: @@ -71,16 +75,16 @@ def forward(self, heads_input, states=None, head_mask=None): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.init_cx.zero_() class HPLSTM(nn.Module): - def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias=enable_proj_bias_default): + def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias=enable_proj_bias_default, **kwargs): super(HPLSTM, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) o_hsize = float2odd(float(_osize) / num_head) * num_head self.head_dim = float2odd(float(isize) / num_head) @@ -91,7 +95,7 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias= self.net = MHPLSTMCore(i_hsize, num_head=self.num_head, osize=o_hsize, dropout=dropout) self.trans_output = Linear(o_hsize, _osize, bias=enable_proj_bias) - def forward(self, inpute, states=None, head_mask=None): + def forward(self, inpute, states=None, head_mask=None, **kwargs): bsize, seql = inpute.size()[:2] heads_input = self.trans_input(inpute).view(bsize, seql, self.num_head, self.head_dim) @@ -110,11 +114,11 @@ def forward(self, inpute, states=None, head_mask=None): class BiHPLSTM(nn.Module): - def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias=enable_proj_bias_default): + def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias=enable_proj_bias_default, **kwargs): super(BiHPLSTM, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) o_hsize = float2odd(float(_osize) / num_head) * num_head self.head_dim = float2odd(float(isize) / num_head) @@ -126,10 +130,10 @@ def __init__(self, isize, num_head=8, osize=None, dropout=0.0, enable_proj_bias= self.trans_output = Linear(o_hsize + o_hsize, _osize, bias=enable_proj_bias) # inpute: (bsize, seql, isize) - # mask: (bsize, seql, 1, 1), generated by input.eq(0).view(bsize, seql, 1, 1) + # mask: (bsize, seql, 1, 1), generated by input.eq(pad_id).view(bsize, seql, 1, 1) # pad_reversed_mask: (bsize, seql, nheads * 2, 1), generated by torch.cat((mask.new_zeros(1, 1, 1, 1).expand(bsize, seql, nheads, 1), mask.flip(1).expand(bsize, seql, nheads, 1),), dim=2) - def forward(self, inpute, mask=None, pad_reversed_mask=None): + def forward(self, inpute, mask=None, pad_reversed_mask=None, **kwargs): bsize, seql = inpute.size()[:2] nheads = self.num_head diff --git a/modules/hplstm/hfn.py b/modules/hplstm/hfn.py index 1d8726b..c0bc4eb 100644 --- a/modules/hplstm/hfn.py +++ b/modules/hplstm/hfn.py @@ -2,24 +2,27 @@ import torch from torch import nn -from modules.base import Linear, Dropout -from modules.group.base import GroupLinear + from modules.act import Custom_Act, LGLU, get_act +from modules.base import Dropout +from modules.group.base import GroupLinear from modules.hplstm.LGate import LGateFunc +from modules.hplstm.base import BiHPLSTM as BiHPLSTMBase, HPLSTM as HPLSTMBase from utils.base import float2odd - -from modules.hplstm.base import HPLSTM as HPLSTMBase, BiHPLSTM as BiHPLSTMBase +from utils.fmt.parser import parse_none +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * class MHPLSTMCore(nn.Module): # use_glu leads to performance drop with MHPLSTM, disable by default - def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, use_glu=None): + def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, act_drop=None, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, use_glu=None, **kwargs): super(MHPLSTMCore, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) + _act_drop = parse_none(act_drop, dropout) i_head_dim = float2odd(float(isize) / num_head) i_hsize = i_head_dim * num_head @@ -47,7 +50,8 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, cust _.append(GroupLinear(_fhsize // 2, o_hsize, num_head, bias=enable_proj_bias, shuffle=False, trans_input=False, flatten_output=False)) if dropout > 0.0: _.append(Dropout(dropout, inplace=True)) - _.insert(_drop_ind, Dropout(dropout, inplace=inplace_after_Custom_Act)) + if _act_drop > 0.0: + _.insert(_drop_ind, Dropout(_act_drop, inplace=inplace_after_Custom_Act)) self.trans_hid = nn.Sequential(*_) self.trans_ifg = GroupLinear(i_hsize + i_hsize, o_hsize + o_hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False) self.trans_og = nn.Sequential(GroupLinear(i_hsize + o_hsize, o_hsize, num_head, bias=enable_bias, shuffle=False, trans_input=False, flatten_output=False), nn.LayerNorm((num_head, o_head_dim), eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) @@ -57,7 +61,7 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, cust self.init_cx = nn.Parameter(torch.zeros(num_head, o_head_dim)) - def forward(self, heads_input, states=None, head_mask=None): + def forward(self, heads_input, states=None, head_mask=None, **kwargs): bsize, seql, nheads, adim = heads_input.size() if states is None: @@ -88,14 +92,14 @@ def forward(self, heads_input, states=None, head_mask=None): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.init_cx.zero_() class HPLSTM(HPLSTMBase): - def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kwargs): + def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, act_drop=None, **kwargs): - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) super(HPLSTM, self).__init__(isize, num_head=num_head, osize=_osize, dropout=dropout, **kwargs) @@ -103,13 +107,13 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kw o_hsize = float2odd(float(_osize) / num_head) * num_head _fhsize = float2odd(float(o_hsize * 4 if fhsize is None else fhsize) / num_head) * num_head - self.net = MHPLSTMCore(i_hsize, num_head=self.num_head, osize=o_hsize, fhsize=_fhsize, dropout=dropout) + self.net = MHPLSTMCore(i_hsize, num_head=self.num_head, osize=o_hsize, fhsize=_fhsize, dropout=dropout, act_drop=act_drop) class BiHPLSTM(BiHPLSTMBase): - def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kwargs): + def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, act_drop=None, **kwargs): - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) super(BiHPLSTM, self).__init__(isize, num_head=num_head, osize=_osize, dropout=dropout, **kwargs) @@ -117,4 +121,4 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kw o_hsize = float2odd(float(_osize) / num_head) * num_head _fhsize = float2odd(float(o_hsize * 4 if fhsize is None else fhsize) / num_head) * num_head - self.net = MHPLSTMCore(i_hsize + i_hsize, num_head=self.num_head + self.num_head, osize=o_hsize + o_hsize, fhsize=_fhsize + _fhsize, dropout=dropout) + self.net = MHPLSTMCore(i_hsize + i_hsize, num_head=self.num_head + self.num_head, osize=o_hsize + o_hsize, fhsize=_fhsize + _fhsize, dropout=dropout, act_drop=act_drop) diff --git a/modules/hplstm/wrapper.py b/modules/hplstm/wrapper.py index 586e6e5..8bd08a5 100644 --- a/modules/hplstm/wrapper.py +++ b/modules/hplstm/wrapper.py @@ -3,7 +3,13 @@ import torch from torch import nn -from modules.rnncells import LSTMCell4RNMT, ATRCell +from modules.act import Custom_Act +from modules.base import Dropout, Linear +from modules.rnncells import ATRCell, LSTMCell4RNMT +from utils.fmt.parser import parse_none +from utils.torch.comp import torch_no_grad + +from cnfg.ihyp import * class LSTM4RNMT(nn.Module): @@ -11,13 +17,13 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kw super(LSTM4RNMT, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) self.net = LSTMCell4RNMT(isize, osize=_osize, dropout=dropout) self.init_cx = nn.Parameter(torch.zeros(1, _osize)) self.init_hx = nn.Parameter(torch.zeros(1, _osize)) - def forward(self, inpute, states=None, head_mask=None): + def forward(self, inpute, states=None, head_mask=None, **kwargs): if states is None: bsize = inpute.size(0) @@ -43,7 +49,7 @@ def forward(self, inpute, states=None, head_mask=None): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.init_cx.zero_() self.init_hx.zero_() @@ -57,7 +63,7 @@ def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, **kw self.init_hx = nn.Parameter(torch.zeros(1, isize)) - def forward(self, inpute, states=None, head_mask=None): + def forward(self, inpute, states=None, head_mask=None, **kwargs): if states is None: bsize = inpute.size(0) @@ -84,5 +90,43 @@ def forward(self, inpute, states=None, head_mask=None): def fix_init(self): - with torch.no_grad(): + with torch_no_grad(): self.init_hx.zero_() + +class RNN(ATR): + + def __init__(self, isize, num_head=8, osize=None, fhsize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, **kwargs): + + _osize = parse_none(osize, isize) + _hsize = _osize * 4 if hsize is None else hsize + + super(RNN, self).__init__(isize, num_head=num_head, osize=_osize, fhsize=_hsize, dropout=dropout) + + self.net = nn.Sequential(Linear(isize + _osize, _hsize, bias=enable_bias), nn.LayerNorm(_hsize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters), Custom_Act() if custom_act else nn.ReLU(inplace=True), Linear(_hsize, isize, bias=enable_bias)) + if dropout > 0.0: + self.net.insert(3, Dropout(dropout, inplace=inplace_after_Custom_Act)) + + def forward(self, inpute, states=None, head_mask=None, **kwargs): + + if states is None: + bsize = inpute.size(0) + _state = self.init_hx.expand(bsize, -1) + out = [] + for tmp in inpute.unbind(1): + _state = self.net(torch.cat((tmp, _state,), dim=-1)) + out.append(_state) + out = torch.stack(out, dim=1) + else: + if states == "init": + bsize = inpute.size(0) + _state = self.init_hx.expand(bsize, -1) + else: + _state = states + _out = self.net(torch.cat((inpute.select(1, -1), _state,), dim=-1)) + states_return = _out + out = _out.unsqueeze(1) + + if states is None: + return out + else: + return out, states_return diff --git a/modules/mulang/eff/base.py b/modules/mulang/eff/base.py index 9874e79..db44bda 100644 --- a/modules/mulang/eff/base.py +++ b/modules/mulang/eff/base.py @@ -1,51 +1,51 @@ #encoding: utf-8 import torch +from math import sqrt +from numbers import Integral from torch import nn from torch.nn import functional as nnFunc -from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase - -from math import sqrt -from numbers import Integral +from modules.base import PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * class MBLinear(nn.Linear): - def __init__(self, in_features, out_features, nbias, bias=True): + def __init__(self, in_features, out_features, nbias, bias=True, **kwargs): super(MBLinear, self).__init__(in_features, out_features, bias=False) if bias: self.bias = nn.Parameter(torch.zeros(nbias, out_features)) - def forward(self, x, taskid): + def forward(self, x, taskid, **kwargs): return nnFunc.linear(x, self.weight, None if self.bias is None else self.bias[taskid]) def fix_init(self): if self.bias is not None: - with torch.no_grad(): + with torch_no_grad(): self.bias.zero_() class MWLinear(MBLinear): - def __init__(self, in_features, out_features, nbias, bias=True): + def __init__(self, in_features, out_features, nbias, bias=True, **kwargs): super(MWLinear, self).__init__(in_features, out_features, nbias, bias=False) self.weight = nn.Parameter(torch.Tensor(nbias, out_features, in_features).uniform_(- sqrt(1.0 / in_features), sqrt(1.0 / in_features))) - def forward(self, x, taskid): + def forward(self, x, taskid, **kwargs): return nnFunc.linear(x, self.weight[taskid], None if self.bias is None else self.bias[taskid]) def fix_init(self): _isize = self.weight.size(-1) - with torch.no_grad(): + with torch_no_grad(): self.weight.data.uniform_(- sqrt(1.0 / _isize), sqrt(1.0 / _isize)) super(MWLinear, self).fix_init() @@ -62,7 +62,7 @@ def __init__(self, normalized_shape, ntask=None, eps=1e-5, elementwise_affine=Tr self.normalized_shape = self.normalized_shape[1:] - def forward(self, input, taskid=None): + def forward(self, input, taskid=None, **kwargs): return nnFunc.layer_norm(input, self.normalized_shape, None if self.weight is None else self.weight[taskid], None if self.bias is None else self.bias[taskid], self.eps) @@ -124,15 +124,13 @@ def forward(self, iQ, iK, *inputs, taskid=None, **kwargs): class PositionwiseFF(PositionwiseFFBase): - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, ntask=None, **kwargs): - - _hsize = isize * 4 if hsize is None else hsize + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=norm_residual_default, ntask=None, **kwargs): - super(PositionwiseFF, self).__init__(isize, hsize=_hsize, dropout=dropout, norm_residual=norm_residual, **kwargs) + super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, **kwargs) self.normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) - def forward(self, x, taskid=None): + def forward(self, x, taskid=None, **kwargs): _out = self.normer(x, taskid=taskid) diff --git a/modules/noise.py b/modules/noise.py index 7618618..5efabc9 100644 --- a/modules/noise.py +++ b/modules/noise.py @@ -3,19 +3,21 @@ import torch from torch import nn -from modules.base import ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase +from modules.base import PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase +from utils.fmt.parser import parse_none +from utils.torch.ext import randint_t_core from cnfg.ihyp import * class GausNoiser(nn.Module): - def __init__(self, power, inplace=False): + def __init__(self, power, inplace=False, **kwargs): super(GausNoiser, self).__init__() self.power, self.inplace = power, inplace # mask: (bsize, seql, 1), otherwise cannot multiply with inpute.size(-1) - def forward(self, inpute, mask=None): + def forward(self, inpute, mask=None, **kwargs): if self.training: _noise = self.get_noise(inpute.detach(), mask=mask) @@ -46,7 +48,7 @@ def get_noise(self, inpute, mask=None): class GausNoiserVec(GausNoiser): - def __init__(self, power, dim=-1, inplace=False, eps=ieps_noise_default): + def __init__(self, power, dim=-1, inplace=False, eps=ieps_noise_default, **kwargs): super(GausNoiserVec, self).__init__(power, inplace=inplace) self.dim, self.eps = dim, eps @@ -75,7 +77,7 @@ def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_res super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) - _noiser = Noiser if custom_noiser is None else custom_noiser + _noiser = parse_none(custom_noiser, Noiser) self.noiser = None if power is None else _noiser(power, inplace=True) def forward(self, iQ, *inputs, noise_mask=None, **kwargs): @@ -107,7 +109,7 @@ def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_res super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) - _noiser = Noiser if custom_noiser is None else custom_noiser + _noiser = parse_none(custom_noiser, Noiser) self.noiser = None if power is None else _noiser(power, inplace=True) def forward(self, iQ, iK, *inputs, noise_mask=None, **kwargs): @@ -135,14 +137,14 @@ def forward(self, iQ, iK, *inputs, noise_mask=None, **kwargs): class PositionwiseFF(PositionwiseFFBase): - def __init__(self, isize, power=None, custom_noiser=None, **kwargs): + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, power=None, custom_noiser=None, **kwargs): - super(PositionwiseFF, self).__init__(isize, **kwargs) + super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, act_drop=act_drop, **kwargs) - _noiser = Noiser if custom_noiser is None else custom_noiser + _noiser = parse_none(custom_noiser, Noiser) self.noiser = None if power is None else _noiser(power, inplace=True) - def forward(self, x, mask=None): + def forward(self, x, mask=None, **kwargs): _out = self.normer(x) if self.noiser is not None: diff --git a/modules/paradoc.py b/modules/paradoc.py index 967d645..450df9c 100644 --- a/modules/paradoc.py +++ b/modules/paradoc.py @@ -2,19 +2,20 @@ import torch from torch import nn + from modules.base import Linear class GateResidual(nn.Module): # isize: input dimension - def __init__(self, isize): + def __init__(self, isize, **kwargs): super(GateResidual, self).__init__() self.net = nn.Sequential(Linear(isize * 2, isize), nn.Sigmoid()) - def forward(self, x1, x2): + def forward(self, x1, x2, **kwargs): gate = self.net(torch.cat((x1, x2,), dim=-1)) diff --git a/modules/plm/bert.py b/modules/plm/bert.py deleted file mode 100644 index 16e1a19..0000000 --- a/modules/plm/bert.py +++ /dev/null @@ -1,38 +0,0 @@ -#encoding: utf-8 - -from torch import nn -from modules.base import Linear -from modules.act import Custom_Act, LGLU, get_act, GELU -from modules.dropout import Dropout - -from modules.TA import PositionwiseFF as PositionwiseFFBase - -from cnfg.plm.bert.ihyp import * - -class PositionwiseFF(PositionwiseFFBase): - - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn, **kwargs): - - _hsize = isize * 4 if hsize is None else hsize - - super(PositionwiseFF, self).__init__(isize, hsize=_hsize, dropout=dropout, norm_residual=norm_residual, custom_act=custom_act, enable_bias=enable_bias, use_glu=use_glu, **kwargs) - - if (use_glu is not None) and (_hsize % 2 == 1): - _hsize += 1 - - _ = [Linear(isize, _hsize)] - if use_glu is None: - _.extend([Custom_Act() if custom_act else GELU(), Linear(_hsize, isize, bias=enable_bias)]) - else: - use_glu = use_glu.lower() - if use_glu == "glu": - _.append(nn.GLU()) - else: - _act = get_act(use_glu, None) - if _act is not None: - _.append(_act()) - _.append(LGLU()) - _.append(Linear(_hsize // 2, isize, bias=enable_bias)) - if dropout > 0.0: - _.append(Dropout(dropout, inplace=True)) - self.net = nn.Sequential(*_) diff --git a/modules/plm/mbart.py b/modules/plm/mbart.py new file mode 100644 index 0000000..fe8fa7c --- /dev/null +++ b/modules/plm/mbart.py @@ -0,0 +1,41 @@ +#encoding: utf-8 + +import torch + +from modules.base import CrossAttn as CrossAttnBase, PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase, SelfAttn as SelfAttnBase + +from cnfg.plm.mbart.ihyp import * + +class SelfAttn(SelfAttnBase): + + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, uni_direction_reduction=False, is_left_to_right_reduction=True, zero_reduction=relpos_reduction_with_zeros, max_bucket_distance=0, sparsenorm=False, xseql=cache_len_default, **kwargs): + + super(SelfAttn, self).__init__(isize, hsize, osize, num_head=num_head, dropout=dropout, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias, k_rel_pos=k_rel_pos, uni_direction_reduction=uni_direction_reduction, is_left_to_right_reduction=is_left_to_right_reduction, zero_reduction=zero_reduction, max_bucket_distance=max_bucket_distance, sparsenorm=sparsenorm, xseql=xseql, **kwargs) + +class CrossAttn(CrossAttnBase): + + def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, sparsenorm=False, is_decoding=False, **kwargs): + + super(CrossAttn, self).__init__(isize, hsize, osize, num_head=num_head, dropout=dropout, k_isize=k_isize, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias, sparsenorm=sparsenorm, **kwargs) + +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) + +class PositionwiseFF(PositionwiseFFBase): + + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn, **kwargs): + + super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, custom_act=custom_act, enable_bias=enable_bias, use_glu=use_glu, **kwargs) diff --git a/modules/plm/t5.py b/modules/plm/t5.py index 7fcd5b5..7339767 100644 --- a/modules/plm/t5.py +++ b/modules/plm/t5.py @@ -1,15 +1,15 @@ #encoding: utf-8 import torch -from torch import nn #from math import sqrt +from torch import nn -from modules.base import Linear, SelfAttn as SelfAttnBase, CrossAttn as CrossAttnBase, ResSelfAttn as ResSelfAttnBase, ResCrossAttn as ResCrossAttnBase, PositionwiseFF as PositionwiseFFBase -from modules.act import Custom_Act, LGLU, get_act, GELU, GEGLU +from modules.act import Custom_Act, GEGLU, LGLU, get_act +from modules.base import CrossAttn as CrossAttnBase, Linear, PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase, SelfAttn as SelfAttnBase from modules.dropout import Dropout from modules.norm import RMSNorm as Norm - -from utils.relpos.bucket import build_rel_pos_bucket_map, build_rel_pos_bucket +from utils.fmt.parser import parse_none +from utils.relpos.bucket import build_rel_pos_bucket, build_rel_pos_bucket_map from cnfg.plm.t5.ihyp import * @@ -22,7 +22,7 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, enable_bias=ena self.ref_rel_emb = None self.rel_emb_cache = None - def forward(self, iQ, mask=None, states=None): + def forward(self, iQ, mask=None, states=None, **kwargs): bsize, nquery = iQ.size()[:2] nheads = self.num_head @@ -41,7 +41,9 @@ def forward(self, iQ, mask=None, states=None): scores = real_iQ.matmul(real_iK) - if self.rel_pemb is not None: + if self.ref_rel_emb is not None: + scores += self.ref_rel_emb.rel_emb_cache + elif self.rel_pemb is not None: if states is None: self.rel_pos_cache = self.get_rel_pos(nquery).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache self.rel_emb_cache = (real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.rel_pos_cache).transpose(1, 2)).view(nquery, bsize, nheads, nquery).permute(1, 2, 0, 3) if self.rel_pos_map is None else self.rel_pemb(self.rel_pos_cache).permute(2, 0, 1)).contiguous() @@ -49,8 +51,6 @@ def forward(self, iQ, mask=None, states=None): self.rel_pos_cache = self.get_rel_pos(seql).narrow(0, seql - nquery, nquery).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache self.rel_emb_cache = (real_iQ.permute(2, 0, 1, 3).contiguous().view(nquery, bsize * nheads, adim).bmm(self.rel_pemb(self.rel_pos_cache).transpose(1, 2)).view(nquery, bsize, nheads, seql).permute(1, 2, 0, 3) if self.rel_pos_map is None else self.rel_pemb(self.rel_pos_cache).permute(2, 0, 1)).contiguous() scores += self.rel_emb_cache - elif self.ref_rel_emb is not None: - scores += self.ref_rel_emb.rel_emb_cache ## t5 does not scale attention scores #scores = scores / sqrt(adim) @@ -78,19 +78,19 @@ def __init__(self, isize, hsize, osize, num_head=8, dropout=0.0, k_isize=None, e if (k_rel_pos > 0) and (max_bucket_distance > 0): self.rel_shift = k_rel_pos - self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=False)) - self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=False, dis_map=self.rel_pos_map)) + self.register_buffer("rel_pos_map", build_rel_pos_bucket_map(k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=False), persistent=False) + self.register_buffer("rel_pos", build_rel_pos_bucket(xseql, k_rel_pos=k_rel_pos, max_len=max_bucket_distance, uni_direction=False, dis_map=self.rel_pos_map), persistent=False) self.rel_pemb = nn.Embedding(k_rel_pos + k_rel_pos + 1, self.num_head) self.clamp_max, self.clamp_min = max_bucket_distance, False self.xseql = xseql self.ref_rel_posm = None - self.register_buffer("rel_pos_cache", None) + self.register_buffer("rel_pos_cache", None, persistent=False) else: self.rel_pemb = None self.ref_rel_emb = None self.rel_emb_cache = None - def forward(self, iQ, iK, mask=None, step=0): + def forward(self, iQ, iK, mask=None, step=0, **kwargs): bsize, nquery = iQ.size()[:2] seql = iK.size(1) @@ -98,22 +98,22 @@ def forward(self, iQ, iK, mask=None, step=0): adim = self.attn_dim real_iQ = self.query_adaptor(iQ).view(bsize, nquery, nheads, adim).transpose(1, 2) - if (self.real_iK is not None) and self.iK.is_set_to(iK) and (not self.training): + if (self.real_iK is not None) and self.iK.is_set_to(iK) and self.is_decoding: real_iK, real_iV = self.real_iK, self.real_iV else: real_iK, real_iV = self.kv_adaptor(iK).view(bsize, seql, 2, nheads, adim).unbind(2) real_iK, real_iV = real_iK.permute(0, 2, 3, 1), real_iV.transpose(1, 2) - if not self.training: + if self.is_decoding: self.iK, self.real_iK, self.real_iV = iK, real_iK, real_iV scores = real_iQ.matmul(real_iK) - if self.rel_pemb is not None: + if self.ref_rel_emb is not None: + scores += self.ref_rel_emb.rel_emb_cache + elif self.rel_pemb is not None: self.rel_pos_cache = (self.get_rel_pos(step, seql).narrow(0, step - nquery, nquery) if step > 0 else self.get_rel_pos(nquery, seql)).contiguous() if self.ref_rel_posm is None else self.ref_rel_posm.rel_pos_cache self.rel_emb_cache = self.rel_pemb(self.rel_pos_cache).permute(2, 0, 1).contiguous() scores += self.rel_emb_cache - elif self.ref_rel_emb is not None: - scores += self.ref_rel_emb.rel_emb_cache # t5 does not scale attention scores #scores = scores / sqrt(adim) @@ -143,29 +143,30 @@ def get_rel_pos(self, length, seql): class ResSelfAttn(ResSelfAttnBase): - def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=use_k_relative_position, zero_reduction=relpos_reduction_with_zeros, xseql=cache_len_default, **kwargs): + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): - super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias, k_rel_pos=k_rel_pos, zero_reduction=zero_reduction, xseql=xseql) + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) - self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias, k_rel_pos=k_rel_pos, zero_reduction=zero_reduction, xseql=xseql, **kwargs) + self.net = SelfAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) self.normer = Norm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) class ResCrossAttn(ResCrossAttnBase): - def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, **kwargs): + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): - super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias) + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) - self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, enable_bias=enable_bias, enable_proj_bias=enable_proj_bias, **kwargs) + self.net = CrossAttn(isize, hsize, isize, num_head=num_head, dropout=dropout, **kwargs) self.normer = Norm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) class PositionwiseFF(PositionwiseFFBase): - def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn, **kwargs): + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=norm_residual_default, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn, **kwargs): _hsize = isize * 4 if hsize is None else hsize + _act_drop = parse_none(act_drop, dropout) - super(PositionwiseFF, self).__init__(isize, hsize=_hsize, dropout=dropout, norm_residual=norm_residual, custom_act=custom_act, enable_bias=enable_bias, use_glu=None, **kwargs) + super(PositionwiseFF, self).__init__(isize, hsize=_hsize, dropout=dropout, act_drop=_act_drop, norm_residual=norm_residual, custom_act=custom_act, enable_bias=enable_bias, use_glu=None, **kwargs) if (use_glu is not None) and (_hsize % 2 == 1): _hsize += 1 @@ -189,6 +190,7 @@ def __init__(self, isize, hsize=None, dropout=0.0, norm_residual=norm_residual_d _.append(Linear(_hsize // 2, isize, bias=enable_bias)) if dropout > 0.0: _.append(Dropout(dropout, inplace=True)) - _.insert(_drop_ind, Dropout(dropout, inplace=inplace_after_Custom_Act)) + if _act_drop > 0.0: + _.insert(_drop_ind, Dropout(_act_drop, inplace=inplace_after_Custom_Act)) self.net = nn.Sequential(*_) self.normer = Norm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) diff --git a/modules/rnncells.py b/modules/rnncells.py index e1a34f3..73e7ea1 100644 --- a/modules/rnncells.py +++ b/modules/rnncells.py @@ -2,8 +2,10 @@ import torch from torch import nn -from modules.base import * + from modules.act import Custom_Act +from modules.base import Dropout, Linear +from utils.fmt.parser import parse_none from cnfg.ihyp import * @@ -19,11 +21,11 @@ class LSTMCell4RNMT(nn.Module): # isize: input size of Feed-forward NN # dropout: dropout over hidden units, disabling it and applying dropout to outputs (_out) in most cases - def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, **kwargs): super(LSTMCell4RNMT, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) # layer normalization is also applied for the computation of hidden for efficiency. bias might be disabled in case provided by LayerNorm self.trans = Linear(isize + _osize, _osize * 4, bias=enable_bias) @@ -34,7 +36,7 @@ def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_defaul self.osize = _osize - def forward(self, inpute, state): + def forward(self, inpute, state, **kwargs): _out, _cell = state @@ -57,11 +59,11 @@ class GRUCell4RNMT(nn.Module): # isize: input size of Feed-forward NN - def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default): + def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, **kwargs): super(GRUCell4RNMT, self).__init__() - _osize = isize if osize is None else osize + _osize = parse_none(osize, isize) self.trans = Linear(isize + _osize, _osize * 2, bias=enable_bias) self.transi = Linear(isize, _osize, bias=enable_bias) @@ -75,7 +77,7 @@ def __init__(self, isize, osize=None, dropout=0.0, custom_act=use_adv_act_defaul self.osize = _osize - def forward(self, inpute, state): + def forward(self, inpute, state, **kwargs): osize = list(state.size()) osize.insert(-1, 2) @@ -99,7 +101,7 @@ class ATRCell(nn.Module): # isize: input size of Feed-forward NN - def __init__(self, isize): + def __init__(self, isize, **kwargs): super(ATRCell, self).__init__() @@ -109,7 +111,7 @@ def __init__(self, isize): # x: input to the cell # cell: cell to update - def forward(self, x, cell): + def forward(self, x, cell, **kwargs): p, q = self.t1(x), self.t2(cell) diff --git a/modules/sampler.py b/modules/sampler.py index 3c70ab3..9e8c7b8 100644 --- a/modules/sampler.py +++ b/modules/sampler.py @@ -3,7 +3,8 @@ import torch from torch import nn from torch.autograd import Function -from utils.torch import multinomial + +from utils.torch.ext import multinomial class SampleMaxFunction(Function): @@ -64,7 +65,7 @@ def backward(ctx, grad_outputs): class Retriever(nn.Module): - def forward(self, attnmat, vmat): + def forward(self, attnmat, vmat, **kwargs): _idim = attnmat.dim() if _idim > 3: @@ -169,18 +170,18 @@ def backward(ctx, grad_outputs): class Sampler(nn.Module): - def __init__(self, dim=-1): + def __init__(self, dim=-1, **kwargs): super(Sampler, self).__init__() self.dim = dim - def forward(self, inputs, bsize=None): + def forward(self, inputs, bsize=None, **kwargs): return SamplerFunc(inputs, self.dim, bsize) class EffSampler(Sampler): - def forward(self, inputs, weight, add_bdim=False): + def forward(self, inputs, weight, add_bdim=False, **kwargs): return EffSamplerFunc(inputs, weight, self.dim, add_bdim) diff --git a/modules/sdu.py b/modules/sdu.py new file mode 100644 index 0000000..9e1fcb5 --- /dev/null +++ b/modules/sdu.py @@ -0,0 +1,91 @@ +#encoding: utf-8 + +from torch import nn + +from modules.base import PositionwiseFF as PositionwiseFFBase, ResCrossAttn as ResCrossAttnBase, ResSelfAttn as ResSelfAttnBase +from modules.dropout import Dropout + +from cnfg.ihyp import * + +class SDU(nn.Sequential): + + def __init__(self, isize, dropout=0.0, **kwargs): + + super(SDU, self).__init__(nn.Linear(isize, isize + isize), nn.GLU()) + + if dropout > 0.0: + self.append(Dropout(dropout, inplace=True)) + +class PositionwiseFF(PositionwiseFFBase): + + def __init__(self, isize, hsize=None, dropout=0.0, act_drop=None, norm_residual=norm_residual_default, **kwargs): + + super(PositionwiseFF, self).__init__(isize, hsize=hsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, **kwargs) + + self.sdu = SDU(isize, dropout=dropout) + + def forward(self, x, **kwargs): + + _out = self.normer(x) + + out = self.net(_out) + self.sdu(_out) + + out = out + (_out if self.norm_residual else x) + + return out + +class ResSelfAttn(ResSelfAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResSelfAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.sdu = SDU(isize, dropout=dropout) + + def forward(self, iQ, *inputs, **kwargs): + + _iQ = self.normer(iQ) + + outs = self.net(_iQ, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + self.sdu(_iQ) + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + self.sdu(_iQ) + (_iQ if self.norm_residual else iQ) + +class ResCrossAttn(ResCrossAttnBase): + + def __init__(self, isize, hsize, num_head=8, dropout=0.0, norm_residual=norm_residual_default, **kwargs): + + super(ResCrossAttn, self).__init__(isize, hsize, num_head=num_head, dropout=dropout, norm_residual=norm_residual, **kwargs) + + self.sdu = SDU(isize, dropout=dropout) + + def forward(self, iQ, iK, *inputs, **kwargs): + + _iQ = self.normer(iQ) + + outs = self.net(_iQ, iK, *inputs, **kwargs) + + if isinstance(outs, tuple): + _out = outs[0] + + if self.drop is not None: + _out = self.drop(_out) + + return _out + self.sdu(_iQ) + (_iQ if self.norm_residual else iQ), *outs[1:] + + else: + if self.drop is not None: + outs = self.drop(outs) + + return outs + self.sdu(_iQ) + (_iQ if self.norm_residual else iQ) diff --git a/modules/server/__init__.py b/modules/server/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/modules/server/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/translator.py b/modules/server/transformer.py similarity index 80% rename from translator.py rename to modules/server/transformer.py index 21daebd..9dc0b2a 100644 --- a/translator.py +++ b/modules/server/transformer.py @@ -2,21 +2,23 @@ import torch -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base import ldvocab, clean_str, reverse_dict, eos_id, clean_liststr_lentok, dict_insert_set, iter_dict_sort +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.fmt.base import clean_str, dict_insert_set, iter_dict_sort from utils.fmt.base4torch import parse_cuda_decode - from utils.fmt.single import batch_padder +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from cnfg.ihyp import * +from cnfg.vocab.base import eos_id -def data_loader(sentences_iter, vcbi, minbsize=1, bsize=768, maxpad=16, maxpart=4, maxtoken=3920): +def data_loader(sentences_iter, vcbi, minbsize=1, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu): for i_d in batch_padder(sentences_iter, vcbi, bsize, maxpad, maxpart, maxtoken, minbsize): - yield torch.tensor(i_d, dtype=torch.long) + yield torch.as_tensor(i_d, dtype=torch.long) def load_fixing(module): if hasattr(module, "fix_load"): @@ -31,7 +33,8 @@ def sorti(lin): if ls: data = dict_insert_set(data, ls, len(ls.split())) - return list(iter_dict_sort(data, free=True)) + for _ in iter_dict_sort(data, free=True): + yield from _ def restore(src, tsrc, trs): @@ -46,7 +49,7 @@ def restore(src, tsrc, trs): class TranslatorCore: - def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mulgpu=True, bsize=64, maxpad=16, maxpart=4, maxtoken=1536, minfreq = False, vsize = False): + def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False, **kwargs): vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) vcbt, nwordt = ldvocab(fvocab_t, minf=minfreq, omit_vsize=vsize, vanilla=False) @@ -65,7 +68,7 @@ def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mul if isinstance(modelfs, (list, tuple,)): models = [] for modelf in modelfs: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -74,7 +77,7 @@ def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mul model = Ensemble(models) else: - model = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + model = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) model = load_model_cpu(modelfs, model) model.apply(load_fixing) @@ -84,21 +87,22 @@ def __init__(self, modelfs, fvocab_i, fvocab_t, cnfg, minbsize=1, expand_for_mul self.use_cuda, self.cuda_device, cuda_devices, self.multi_gpu = parse_cuda_decode(cnfg.use_cuda, cnfg.gpuid, cnfg.multi_gpu_decoding) if self.use_cuda: - model.to(self.cuda_device) + model.to(self.cuda_device, non_blocking=True) if self.multi_gpu: model = DataParallelMT(model, device_ids=cuda_devices, output_device=self.cuda_device.index, host_replicate=True, gather_output=False) + model = torch_compile(model, *torch_compile_args, **torch_compile_kwargs) self.use_amp = cnfg.use_amp and self.use_cuda self.beam_size = cnfg.beam_size self.length_penalty = cnfg.length_penalty self.net = model - def __call__(self, sentences_iter): + def __call__(self, sentences_iter, **kwargs): rs = [] - with torch.no_grad(): + with torch_inference_mode(): for seq_batch in data_loader(sentences_iter, self.vcbi, self.minbsize, self.bsize, self.maxpad, self.maxpart, self.maxtoken): if self.use_cuda: - seq_batch = seq_batch.to(self.cuda_device) - with autocast(enabled=self.use_amp): + seq_batch = seq_batch.to(self.cuda_device, non_blocking=True) + with torch_autocast(enabled=self.use_amp): output = self.net.decode(seq_batch, self.beam_size, None, self.length_penalty) if self.multi_gpu: tmp = [] @@ -120,7 +124,7 @@ def __call__(self, sentences_iter): class Translator: - def __init__(self, trans=None, sent_split=None, tok=None, detok=None, bpe=None, debpe=None, punc_norm=None, truecaser=None, detruecaser=None): + def __init__(self, trans=None, sent_split=None, tok=None, detok=None, bpe=None, debpe=None, punc_norm=None, truecaser=None, detruecaser=None, **kwargs): self.sent_split = sent_split @@ -142,7 +146,7 @@ def __init__(self, trans=None, sent_split=None, tok=None, detok=None, bpe=None, if detok is not None: self.flow.append(detok) - def __call__(self, paragraphs): + def __call__(self, paragraphs, **kwargs): _paras = [clean_str(tmpu.strip()) for tmpu in paragraphs.strip().split("\n") if tmpu] @@ -155,7 +159,7 @@ def __call__(self, paragraphs): for _tmpu in _paras: _tmp.extend(clean_list([clean_str(_tmps) for _tmps in self.sent_split(_tmpu)])) _tmp.append("\n") - _tmp_o = _tmpi = sorti(_tmp) + _tmp_o = _tmpi = list(sorti(_tmp)) for pu in self.flow: _tmp_o = pu(_tmp_o) @@ -164,10 +168,12 @@ def __call__(self, paragraphs): return " ".join(_tmp).replace(" \n", "\n").replace("\n ", "\n") -#import cnfg -#from datautils.moses import SentenceSplitter -#from datautils.pymoses import Tokenizer, Detokenizer, Normalizepunctuation, Truecaser, Detruecaser #from datautils.bpe import BPEApplier, BPERemover +#from datautils.moses import SentenceSplitter +#from datautils.pymoses import Detokenizer, Detruecaser, Normalizepunctuation, Tokenizer, Truecaser + +#import cnfg + #if __name__ == "__main__": #tl = ["28 @-@ jähriger Koch in San Francisco M@@ all tot a@@ u@@ f@@ gefunden", "ein 28 @-@ jähriger Koch , der vor kurzem nach San Francisco gezogen ist , wurde im T@@ r@@ e@@ p@@ p@@ e@@ n@@ haus eines örtlichen E@@ i@@ n@@ k@@ a@@ u@@ f@@ z@@ e@@ n@@ t@@ r@@ u@@ ms tot a@@ u@@ f@@ gefunden .", "der Bruder des O@@ p@@ f@@ e@@ r@@ s sagte aus , dass er sich niemanden vorstellen kann , der ihm schaden wollen würde , " E@@ n@@ d@@ lich ging es bei ihm wieder b@@ e@@ r@@ g@@ auf " .", "der am Mittwoch morgen in der W@@ e@@ s@@ t@@ field M@@ all g@@ e@@ f@@ u@@ n@@ d@@ e@@ n@@ e L@@ e@@ i@@ c@@ h@@ n@@ a@@ m wurde als der 28 Jahre alte Frank G@@ a@@ l@@ i@@ c@@ i@@ a aus San Francisco identifiziert , teilte die g@@ e@@ r@@ i@@ c@@ h@@ t@@ s@@ medizinische Abteilung in San Francisco mit .", "das San Francisco P@@ o@@ l@@ i@@ ce D@@ e@@ p@@ a@@ r@@ t@@ ment sagte , dass der Tod als Mord eingestuft wurde und die Ermittlungen am L@@ a@@ u@@ f@@ en sind .", "der Bruder des O@@ p@@ f@@ e@@ r@@ s , Louis G@@ a@@ l@@ i@@ c@@ i@@ a , teilte dem A@@ B@@ S Sender K@@ GO in San Francisco mit , dass Frank , der früher als Koch in B@@ o@@ s@@ t@@ on gearbeitet hat , vor sechs Monaten seinen T@@ r@@ a@@ u@@ m@@ j@@ ob als Koch im S@@ o@@ n@@ s & D@@ a@@ u@@ g@@ h@@ t@@ e@@ r@@ s Restaurant in San Francisco e@@ r@@ g@@ a@@ t@@ t@@ e@@ r@@ t hatte .", "ein Sprecher des S@@ o@@ n@@ s & D@@ a@@ u@@ g@@ h@@ t@@ e@@ r@@ s sagte , dass sie über seinen Tod " s@@ c@@ h@@ o@@ c@@ k@@ i@@ e@@ r@@ t und am Boden zerstört seien " .", "" wir sind ein kleines Team , das wie eine enge Familie arbeitet und wir werden ihn s@@ c@@ h@@ m@@ e@@ r@@ z@@ lich vermissen " , sagte der Sprecher weiter .", "unsere Gedanken und unser B@@ e@@ i@@ leid sind in dieser schweren Zeit bei F@@ r@@ a@@ n@@ k@@ s Familie und Freunden .", "Louis G@@ a@@ l@@ i@@ c@@ i@@ a gab an , dass Frank zunächst in Hostels lebte , aber dass , " die Dinge für ihn endlich b@@ e@@ r@@ g@@ auf gingen " ."] #spl = SentenceSplitter("de") diff --git a/optm/adabelief.py b/optm/adabelief.py index 654c450..e9def8d 100644 --- a/optm/adabelief.py +++ b/optm/adabelief.py @@ -3,20 +3,21 @@ # Portal from: https://github.com/juntang-zhuang/Adabelief-Optimizer import torch +from math import sqrt from torch.optim.optimizer import Optimizer -from math import sqrt +from utils.torch.comp import torch_no_grad class AdaBelief(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, weight_decouple=False, fixed_decay=False, rectify=False): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, weight_decouple=False, fixed_decay=False, rectify=False, **kwargs): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad) super(AdaBelief, self).__init__(params, defaults) self.weight_decouple, self.rectify, self.fixed_decay = weight_decouple, rectify, fixed_decay - @torch.no_grad() + @torch_no_grad() def step(self, closure=None): if closure is None: diff --git a/optm/lookahead.py b/optm/lookahead.py index b2db500..8e513c8 100644 --- a/optm/lookahead.py +++ b/optm/lookahead.py @@ -5,7 +5,7 @@ class Lookahead(Optimizer): - def __init__(self, params, optimizer, steps=5, alpha=0.8, pullback_momentum=None): + def __init__(self, params, optimizer, steps=5, alpha=0.8, pullback_momentum=None, **kwargs): super(Lookahead, self).__init__(params, {}) diff --git a/optm/radam.py b/optm/radam.py index e12f3bb..45d67d0 100644 --- a/optm/radam.py +++ b/optm/radam.py @@ -6,9 +6,11 @@ from math import sqrt from torch.optim.optimizer import Optimizer +from utils.torch.comp import torch_no_grad + class RAdam(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, N_sma_threshhold=5, degenerated_to_sgd=True): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, N_sma_threshhold=5, degenerated_to_sgd=True, **kwargs): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) super(RAdam, self).__init__(params, defaults) @@ -16,7 +18,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 self.N_sma_threshhold = N_sma_threshhold self.degenerated_to_sgd = degenerated_to_sgd - @torch.no_grad() + @torch_no_grad() def step(self, closure=None): if closure is None: diff --git a/optm/ranger.py b/optm/ranger.py index 20ea31d..db88b22 100644 --- a/optm/ranger.py +++ b/optm/ranger.py @@ -4,9 +4,11 @@ from math import sqrt from torch.optim.optimizer import Optimizer +from utils.torch.comp import torch_no_grad + class Ranger(Optimizer): - def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, N_sma_threshhold=5, steps=5, alpha=0.8, degenerated_to_sgd=True): + def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.0, N_sma_threshhold=5, steps=5, alpha=0.8, degenerated_to_sgd=True, **kwargs): defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, buffer=[[None, None, None] for _ in range(10)]) super(Ranger, self).__init__(params, defaults) @@ -18,7 +20,7 @@ def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0 self.alpha = alpha self.steps = steps - @torch.no_grad() + @torch_no_grad() def step(self, closure=None): if closure is None: diff --git a/parallel/base.py b/parallel/base.py index 2b00330..6d3c43b 100644 --- a/parallel/base.py +++ b/parallel/base.py @@ -1,23 +1,20 @@ #encoding: utf-8 import torch -from torch import nn import torch.cuda.comm as comm -from utils.comm import secure_broadcast_coalesced -from utils.contpara import get_contiguous_parameters_m, get_all_contiguous_parameters_m, get_contiguous_parameters_p - -from torch.jit import ScriptModule -from torch._C import ScriptMethod from collections import OrderedDict - -from torch.nn import DataParallel - from threading import Lock, Thread - -from utils.base import autocast, is_autocast_enabled, filter_para_grad, divide_para_ind, reorder_by_sort, range_parameter_iter, filter_para_grad_iter -from utils.fmt.base import clean_list +from torch import nn +#from torch._C import ScriptMethod +from torch.jit import ScriptModule +from torch.nn import DataParallel from parallel.optm import MultiGPUOptimizer +from utils.base import divide_para_ind, filter_para_grad, filter_para_grad_iter, range_parameter_iter, reorder_by_sort +from utils.comm import secure_broadcast_coalesced +from utils.contpara import get_all_contiguous_parameters_m, get_contiguous_parameters_m, get_contiguous_parameters_p +from utils.fmt.base import clean_list +from utils.torch.comp import torch_autocast, torch_inference_mode, torch_is_autocast_enabled, torch_is_grad_enabled, torch_is_inference_mode_enabled, torch_no_grad, torch_set_grad_enabled, using_inference_mode """ Example: @@ -35,7 +32,7 @@ def replicate_fixing(module): class DataParallelModel(DataParallel): # host replicates should improve a little bit performance if there are additional calls to update_replicas and collect_gradients in the training scripts. - def __init__(self, module, device_ids=None, output_device=None, dim=0, host_replicate=False, gather_output=True): + def __init__(self, module, device_ids=None, output_device=None, dim=0, host_replicate=False, gather_output=True, **kwargs): super(DataParallelModel, self).__init__(module, device_ids=device_ids, output_device=output_device, dim=dim) @@ -76,7 +73,7 @@ def forward(self, *inputs, **kwargs): def zero_grad(self, set_to_none=True): if self.is_contiguous_parameters: - with torch.no_grad(): + with torch_no_grad(): for para in get_all_contiguous_parameters_m(self.module): para.grad.zero_() if self.nets is not None and self.ngradev > 1: @@ -147,7 +144,7 @@ def update_replicas(self): if self.is_contiguous_parameters: params = [para.data for para in get_all_contiguous_parameters_m(self.module)] param_copies = comm.broadcast_coalesced(params, self.device_ids) - with torch.no_grad(): + with torch_no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) @@ -176,7 +173,7 @@ def update_replicas(self): else: param_copies = _dev_param_copies if self.is_contiguous_parameters: - with torch.no_grad(): + with torch_no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) @@ -194,7 +191,7 @@ def update_replicas_para(self): if self.is_contiguous_parameters: params = [para.data for para in get_all_contiguous_parameters_m(self.module)] param_copies = comm.broadcast_coalesced(params, self.device_ids) - with torch.no_grad(): + with torch_no_grad(): for module, param_copy in zip(self.nets[1:], param_copies[1:]): for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) @@ -220,7 +217,7 @@ def update_replicas_para(self): else: param_copies = _dev_param_copies if self.is_contiguous_parameters: - with torch.no_grad(): + with torch_no_grad(): for module, param_copy in zip(self.nets, param_copies): for mp, para in zip(get_all_contiguous_parameters_m(module), param_copy): mp.data.copy_(para) @@ -292,7 +289,7 @@ def build_optimizer(self, optm_func, *optm_args, multi_gpu_optimizer=False, cont class DataParallelCriterion(DataParallel): # if there is no parameter update in criterion, turn on replicate_once should improve a little bit performance. - def __init__(self, module, device_ids=None, output_device=None, dim=0, replicate_once=False): + def __init__(self, module, device_ids=None, output_device=None, dim=0, replicate_once=False, **kwargs): super(DataParallelCriterion, self).__init__(module, device_ids=device_ids, output_device=output_device, dim=dim) @@ -421,23 +418,23 @@ def clear_gradient(para): return [network] + [module_copies[j][0] for j in range(num_replicas)] -# update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) +# update below functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) -def parallel_apply(modules, inputs, devices, kwargs_tup=None, lock=None): +def parallel_apply_inference(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) lock = Lock() if lock is None else lock results = {} - grad_enabled, autocast_enabled = torch.is_grad_enabled(), is_autocast_enabled() + grad_enabled, autocast_enabled, inference_mode_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled(), torch_is_inference_mode_enabled() def _worker(i, module, input, kwargs, device=None): # this also avoids accidental slicing of `input` if it is a Tensor if not isinstance(input, (list, tuple,)): input = (input,) - with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + with torch_set_grad_enabled(grad_enabled), torch_inference_mode(inference_mode_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): output = module(*input, **kwargs) with lock: results[i] = output @@ -456,14 +453,14 @@ def _worker(i, module, input, kwargs, device=None): return outputs -def criterion_parallel_apply(modules, inputs, targets, devices, kwargs_tup=None, lock=None): +def criterion_parallel_apply_inference(modules, inputs, targets, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) lock = Lock() if lock is None else lock results = {} - grad_enabled, autocast_enabled = torch.is_grad_enabled(), is_autocast_enabled() + grad_enabled, autocast_enabled, inference_mode_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled(), torch_is_inference_mode_enabled() def _worker(i, module, input, target, kwargs, device): @@ -471,7 +468,7 @@ def _worker(i, module, input, target, kwargs, device): input = (input,) if not isinstance(target, (list, tuple,)): target = (target,) - with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + with torch_set_grad_enabled(grad_enabled), torch_inference_mode(inference_mode_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): output = module(*(input + target), **kwargs) with lock: results[i] = output @@ -489,3 +486,71 @@ def _worker(i, module, input, target, kwargs, device): outputs.append(output) return outputs + +def parallel_apply_grad(modules, inputs, devices, kwargs_tup=None, lock=None): + + if kwargs_tup is None: + kwargs_tup = ({},) * len(modules) + + lock = Lock() if lock is None else lock + results = {} + grad_enabled, autocast_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled() + + def _worker(i, module, input, kwargs, device=None): + + if not isinstance(input, (list, tuple,)): + input = (input,) + with torch_set_grad_enabled(grad_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): + output = module(*input, **kwargs) + with lock: + results[i] = output + + threads = [Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + outputs = [] + for i in range(len(inputs)): + output = results[i] + outputs.append(output) + + return outputs + +def criterion_parallel_apply_grad(modules, inputs, targets, devices, kwargs_tup=None, lock=None): + + if kwargs_tup is None: + kwargs_tup = ({},) * len(modules) + + lock = Lock() if lock is None else lock + results = {} + grad_enabled, autocast_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled() + + def _worker(i, module, input, target, kwargs, device): + + if not isinstance(input, (list, tuple,)): + input = (input,) + if not isinstance(target, (list, tuple,)): + target = (target,) + with torch_set_grad_enabled(grad_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): + output = module(*(input + target), **kwargs) + with lock: + results[i] = output + + threads = [Thread(target=_worker, args=(i, module, input, target, kwargs, device)) for i, (module, input, target, kwargs, device) in enumerate(zip(modules, inputs, targets, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + outputs = [] + for i in range(len(inputs)): + output = results[i] + outputs.append(output) + + return outputs + +parallel_apply, criterion_parallel_apply = (parallel_apply_inference, criterion_parallel_apply_inference,) if using_inference_mode else (parallel_apply_grad, criterion_parallel_apply_grad,) diff --git a/parallel/optm.py b/parallel/optm.py index 674b88e..9cff19d 100644 --- a/parallel/optm.py +++ b/parallel/optm.py @@ -1,15 +1,15 @@ #encoding: utf-8 import torch -from torch.optim.optimizer import Optimizer -from utils.base import GradScaler#, autocast, is_autocast_enabled - from collections import defaultdict from threading import Thread +from torch.optim.optimizer import Optimizer + +from utils.torch.comp import GradScaler#, torch_autocast, torch_inference_mode, torch_is_autocast_enabled, torch_is_grad_enabled, torch_is_inference_mode_enabled, torch_set_grad_enabled, using_inference_mode class MultiGPUOptimizer(Optimizer): - def __init__(self, optms, device_ids=None): + def __init__(self, optms, device_ids=None, **kwargs): torch._C._log_api_usage_once("python.optimizer") self.defaults = optms[0].defaults @@ -51,11 +51,11 @@ def step(self, optimizer, *args, **kwargs): def parallel_apply(optms, closure=None, devices=None): - #grad_enabled, autocast_enabled = torch.is_grad_enabled(), is_autocast_enabled() + #grad_enabled, autocast_enabled, inference_mode_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled(), torch_is_inference_mode_enabled() def _worker(optm, closure=None, device=None): - with torch.cuda.device(device):#, torch.set_grad_enabled(grad_enabled), autocast(enabled=autocast_enabled) + with torch.cuda.device(device):#, torch_set_grad_enabled(grad_enabled), torch_inference_mode(), torch_autocast(enabled=autocast_enabled) optm.step(closure=closure) threads = [Thread(target=_worker, args=(optm, closure, device)) for optm, device in zip(optms, devices)] diff --git a/parallel/parallelMT.py b/parallel/parallelMT.py index a2f5477..e685dba 100644 --- a/parallel/parallelMT.py +++ b/parallel/parallelMT.py @@ -1,13 +1,12 @@ #encoding: utf-8 import torch +from threading import Lock, Thread from parallel.base import DataParallelModel - -from utils.base import autocast, is_autocast_enabled, pad_tensors +from utils.base import pad_tensors from utils.fmt.base import clean_list - -from threading import Lock, Thread +from utils.torch.comp import torch_autocast, torch_inference_mode, torch_is_autocast_enabled, torch_is_grad_enabled, torch_is_inference_mode_enabled, torch_set_grad_enabled, using_inference_mode class DataParallelMT(DataParallelModel): @@ -45,22 +44,84 @@ def train_decode(self, *inputs, **kwargs): outputs = parallel_apply_train_decode(replicas, inputs, devices, kwargs, lock=self.lock) return self.gather(pad_tensors(outputs), self.output_device) if self.gather_output else outputs -# update these two functions with the update of parallel_apply(https://github.com/pytorch/pytorch/blob/master/torch/nn/parallel/parallel_apply.py) +def parallel_apply_decode_inference(modules, inputs, devices, kwargs_tup=None, lock=None): + + if kwargs_tup is None: + kwargs_tup = ({},) * len(modules) + + lock = Lock() if lock is None else lock + results = {} + grad_enabled, autocast_enabled, inference_mode_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled(), torch_is_inference_mode_enabled() + + def _worker(i, module, input, kwargs, device=None): + + if not isinstance(input, (list, tuple,)): + input = (input,) + with torch_set_grad_enabled(grad_enabled), torch_inference_mode(inference_mode_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): + output = module.decode(*input, **kwargs) + with lock: + results[i] = output + + threads = [Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + outputs = [] + for i in range(len(inputs)): + output = results[i] + outputs.append(output) + + return outputs -def parallel_apply_decode(modules, inputs, devices, kwargs_tup=None, lock=None): +def parallel_apply_train_decode_inference(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) lock = Lock() if lock is None else lock results = {} - grad_enabled, autocast_enabled = torch.is_grad_enabled(), is_autocast_enabled() + grad_enabled, autocast_enabled, inference_mode_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled(), torch_is_inference_mode_enabled() def _worker(i, module, input, kwargs, device=None): if not isinstance(input, (list, tuple,)): input = (input,) - with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + with torch_set_grad_enabled(grad_enabled), torch_inference_mode(inference_mode_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): + output = module.train_decode(*input, **kwargs) + with lock: + results[i] = output + + threads = [Thread(target=_worker, args=(i, module, input, kwargs, device)) for i, (module, input, kwargs, device) in enumerate(zip(modules, inputs, kwargs_tup, devices))] + + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + outputs = [] + for i in range(len(inputs)): + output = results[i] + outputs.append(output) + + return outputs + +def parallel_apply_decode_grad(modules, inputs, devices, kwargs_tup=None, lock=None): + + if kwargs_tup is None: + kwargs_tup = ({},) * len(modules) + + lock = Lock() if lock is None else lock + results = {} + grad_enabled, autocast_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled() + + def _worker(i, module, input, kwargs, device=None): + + if not isinstance(input, (list, tuple,)): + input = (input,) + with torch_set_grad_enabled(grad_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): output = module.decode(*input, **kwargs) with lock: results[i] = output @@ -79,20 +140,20 @@ def _worker(i, module, input, kwargs, device=None): return outputs -def parallel_apply_train_decode(modules, inputs, devices, kwargs_tup=None, lock=None): +def parallel_apply_train_decode_grad(modules, inputs, devices, kwargs_tup=None, lock=None): if kwargs_tup is None: kwargs_tup = ({},) * len(modules) lock = Lock() if lock is None else lock results = {} - grad_enabled, autocast_enabled = torch.is_grad_enabled(), is_autocast_enabled() + grad_enabled, autocast_enabled = torch_is_grad_enabled(), torch_is_autocast_enabled() def _worker(i, module, input, kwargs, device=None): if not isinstance(input, (list, tuple,)): input = (input,) - with torch.set_grad_enabled(grad_enabled), torch.cuda.device(device), autocast(enabled=autocast_enabled): + with torch_set_grad_enabled(grad_enabled), torch.cuda.device(device), torch_autocast(enabled=autocast_enabled): output = module.train_decode(*input, **kwargs) with lock: results[i] = output @@ -110,3 +171,5 @@ def _worker(i, module, input, kwargs, device=None): outputs.append(output) return outputs + +parallel_apply_decode, parallel_apply_train_decode = (parallel_apply_decode_inference, parallel_apply_train_decode_inference,) if using_inference_mode else (parallel_apply_decode_grad, parallel_apply_train_decode_grad,) diff --git a/predict.py b/predict.py index 9095bec..39d3540 100644 --- a/predict.py +++ b/predict.py @@ -1,23 +1,24 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda_decode +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * -from utils.fmt.base import ldvocab, reverse_dict, eos_id -from utils.fmt.base4torch import parse_cuda_decode +from cnfg.vocab.base import eos_id def load_fixing(module): @@ -32,7 +33,7 @@ def load_fixing(module): vcbt = reverse_dict(vcbt) if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -40,7 +41,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -57,23 +58,25 @@ def load_fixing(module): set_random_seed(cnfg.seed, use_cuda) if cuda_device: - mymodel.to(cuda_device) + mymodel.to(cuda_device, non_blocking=True) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty ens = "\n".encode("utf-8") src_grp = td["src"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) if cuda_device: - seq_batch = seq_batch.to(cuda_device) + seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_batch = seq_batch.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel.decode(seq_batch, beam_size, None, length_penalty) #output = mymodel.train_decode(seq_batch, beam_size, None, length_penalty) if multi_gpu: diff --git a/rank_loss.py b/rank_loss.py index 2edde71..1937819 100644 --- a/rank_loss.py +++ b/rank_loss.py @@ -5,26 +5,24 @@ norm_token = True import sys - import torch -from utils.tqdm import tqdm - +from loss.base import LabelSmoothingLoss +from parallel.base import DataParallelCriterion +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base import sys_open +from utils.fmt.base4torch import parse_cuda from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT -from parallel.base import DataParallelCriterion - -from loss.base import LabelSmoothingLoss - -from utils.base import * -from utils.fmt.base import pad_id -from utils.fmt.base4torch import parse_cuda +from cnfg.vocab.base import pad_id def load_fixing(module): @@ -38,7 +36,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -46,7 +44,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -65,31 +63,34 @@ def load_fixing(module): set_random_seed(cnfg.seed, use_cuda) if cuda_device: - mymodel.to(cuda_device) - lossf.to(cuda_device) + mymodel.to(cuda_device, non_blocking=True) + lossf.to(cuda_device, non_blocking=True) if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) lossf = DataParallelCriterion(lossf, device_ids=cuda_devices, output_device=cuda_device.index, replicate_once=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + ens = "\n".encode("utf-8") src_grp, tgt_grp = td["src"], td["tgt"] -with open(sys.argv[1], "wb") as f, torch.no_grad(): +with sys_open(sys.argv[1], "wb") as f, torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): _curid = str(i) seq_batch = torch.from_numpy(src_grp[_curid][()]) seq_o = torch.from_numpy(tgt_grp[_curid][()]) if cuda_device: - seq_batch = seq_batch.to(cuda_device) - seq_o = seq_o.to(cuda_device) + seq_batch = seq_batch.to(cuda_device, non_blocking=True) + seq_o = seq_o.to(cuda_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() lo = seq_o.size(1) - 1 ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot).view(ot.size(0), -1).sum(-1) if norm_token: - lenv = ot.ne(pad_id).int().sum(-1).to(loss) + lenv = ot.ne(pad_id).int().sum(-1).to(loss, non_blocking=True) loss = loss / lenv f.write("\n".join([str(rsu) for rsu in loss.tolist()]).encode("utf-8")) f.write(ens) diff --git a/requirements.opt.txt b/requirements.opt.txt index c02272f..4071d4a 100644 --- a/requirements.opt.txt +++ b/requirements.opt.txt @@ -1,8 +1,10 @@ -sentencepiece>=0.1.96 +sentencepiece>=0.1.99 subword-nmt>=0.3.8 sacremoses>=0.0.53 -transformers>=4.22.2 -Cython>=0.29.32 -Flask>=2.2.2 +transformers>=4.31.0 +protobuf>=4.23.3 +Cython>=3.0.0 +Flask>=2.3.2 jieba>=0.42.1 PyNLPIR>=0.6.0 +OpenCC>=1.1.6 diff --git a/requirements.txt b/requirements.txt index 39bbbba..729f73a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,3 @@ -tqdm>=4.64.1 -torch>=1.13.0 -h5py>=3.7.0 +tqdm>=4.65.0 +torch>=2.0.1 +h5py>=3.9.0 diff --git a/scripts/README.md b/scripts/README.md index a8c584f..2e88803 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -39,6 +39,9 @@ export ngpu=1 # sorting dataset and building vocabularies. true for the first time generation, false when only update the .h5 files. export do_sort=true export build_vocab=true + +# additional temporary file extension, support .xz/gz/bz2 compression to trade between I/O and CPU. +export faext=".xz" ``` ## `mktest.sh` @@ -72,6 +75,11 @@ export ngpu=1 export sort_decode=true # merge sub-words export debpe=true +# use spm for bpe (instead of subword-nmt). +export spm_bpe=false + +# additional temporary file extension, support .xz/gz/bz2 compression to trade between I/O and CPU. +export faext=".xz" ``` ## `bpe/` diff --git a/scripts/ape/bpe/mk.sh b/scripts/ape/bpe/mk.sh index 2ed1157..d0e9899 100644 --- a/scripts/ape/bpe/mk.sh +++ b/scripts/ape/bpe/mk.sh @@ -33,8 +33,8 @@ mv $tgtd/src.clean.rtmp $tgtd/src.clean.tmp mv $tgtd/mt.clean.rtmp $tgtd/mt.clean.tmp mv $tgtd/tgt.clean.rtmp $tgtd/tgt.clean.tmp -python tools/vocab.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & -python tools/vocab.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & wait python tools/clean/ape/vocab.py $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.train.tok.clean $tgtd/mt.train.tok.clean $tgtd/tgt.train.tok.clean $tgtd/src.full.vcb $tgtd/tgt.full.vcb $vratio rm -fr $tgtd/src.full.vcb $tgtd/tgt.full.vcb $tgtd/src.clean.tmp $tgtd/mt.clean.tmp $tgtd/tgt.clean.tmp diff --git a/scripts/ape/mktest.sh b/scripts/ape/mktest.sh index 62ec523..f31b45e 100644 --- a/scripts/ape/mktest.sh +++ b/scripts/ape/mktest.sh @@ -20,6 +20,8 @@ export sort_decode=true export debpe=true export spm_bpe=false +export faext=".xz" + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -35,8 +37,8 @@ fi mkdir -p $rsd if $sort_decode; then - export srt_input_f=$tgtd/$srctf.srt - export srt_input_fm=$tgtd/$srcmf.srt + export srt_input_f=$tgtd/$srctf.srt$faext + export srt_input_fm=$tgtd/$srcmf.srt$faext python tools/sort.py $srcd/$srctf $srcd/$srcmf $srt_input_f $srt_input_fm 1048576 else export srt_input_f=$srcd/$srctf @@ -64,3 +66,4 @@ if $debpe; then else mv $tgtd/$bpef $rsf fi +rm $tgtd/test.h5 diff --git a/scripts/ape/mktrain.sh b/scripts/ape/mktrain.sh index 31e1b85..2dcd8fa 100644 --- a/scripts/ape/mktrain.sh +++ b/scripts/ape/mktrain.sh @@ -28,13 +28,21 @@ export ngpu=1 export do_sort=true export build_vocab=true +export faext=".xz" + export wkd=$cachedir/$dataid mkdir -p $wkd +export stsf=$wkd/src.train.srt$faext +export mtsf=$wkd/mt.train.srt$faext +export ttsf=$wkd/tgt.train.srt$faext +export sdsf=$wkd/src.dev.srt$faext +export mdsf=$wkd/mt.dev.srt$faext +export tdsf=$wkd/tgt.dev.srt$faext if $do_sort; then - python tools/sort.py $srcd/$srctf $srcd/$mttf $srcd/$tgttf $wkd/src.train.srt $wkd/mt.train.srt $wkd/tgt.train.srt $maxtokens & - python tools/sort.py $srcd/$srcvf $srcd/$mtvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/mt.dev.srt $wkd/tgt.dev.srt 1048576 & + python tools/sort.py $srcd/$srctf $srcd/$mttf $srcd/$tgttf $stsf $mtsf $ttsf $maxtokens & + python tools/sort.py $srcd/$srcvf $srcd/$mtvf $srcd/$tgtvf $sdsf $mdsf $tdsf 1048576 & wait fi @@ -42,19 +50,19 @@ if $share_vcb; then export src_vcb=$wkd/common.vcb export tgt_vcb=$src_vcb if $build_vocab; then - python tools/share_vocab.py $wkd/src.train.srt $wkd/tgt.train.srt $wkd/mt.train.srt $src_vcb $vsize - python tools/check/fbindexes.py $tgt_vcb $wkd/tgt.train.srt $wkd/tgt.dev.srt $wkd/fbind.py & + python tools/vocab/token/share.py $stsf $ttsf $mtsf $src_vcb $vsize + python tools/check/fbindexes.py $tgt_vcb $ttsf $tdsf $wkd/fbind.py & fi else export src_vcb=$wkd/src.vcb export tgt_vcb=$wkd/tgt.vcb if $build_vocab; then - python tools/vocab.py $wkd/src.train.srt $src_vcb $vsize & - python tools/share_vocab.py $wkd/tgt.train.srt $wkd/mt.train.srt $tgt_vcb $vsize & + python tools/vocab/token/single.py $stsf $src_vcb $vsize & + python tools/vocab/token/share.py $ttsf $mtsf $tgt_vcb $vsize & wait fi fi -python tools/ape/mkiodata.py $wkd/src.train.srt $wkd/mt.train.srt $wkd/tgt.train.srt $src_vcb $tgt_vcb $wkd/$rsf_train $ngpu & -python tools/ape/mkiodata.py $wkd/src.dev.srt $wkd/mt.dev.srt $wkd/tgt.dev.srt $src_vcb $tgt_vcb $wkd/$rsf_dev $ngpu & +python tools/ape/mkiodata.py $stsf $mtsf $ttsf $src_vcb $tgt_vcb $wkd/$rsf_train $ngpu & +python tools/ape/mkiodata.py $sdsf $mdsf $tdsf $src_vcb $tgt_vcb $wkd/$rsf_dev $ngpu & wait diff --git a/scripts/bpe/mk.sh b/scripts/bpe/mk.sh index b2e79d2..b6efeff 100644 --- a/scripts/bpe/mk.sh +++ b/scripts/bpe/mk.sh @@ -30,8 +30,8 @@ python tools/clean/token_repeat.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd mv $tgtd/src.clean.rtmp $tgtd/src.clean.tmp mv $tgtd/tgt.clean.rtmp $tgtd/tgt.clean.tmp -python tools/vocab.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & -python tools/vocab.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & wait python tools/clean/vocab/ratio.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean $tgtd/src.full.vcb $tgtd/tgt.full.vcb $vratio rm -fr $tgtd/src.full.vcb $tgtd/tgt.full.vcb $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp diff --git a/scripts/doc/para/mktest.sh b/scripts/doc/para/mktest.sh index 8b4a49a..48788ef 100644 --- a/scripts/doc/para/mktest.sh +++ b/scripts/doc/para/mktest.sh @@ -19,6 +19,8 @@ export sort_decode=true export debpe=true export spm_bpe=false +export faext=".xz" + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -34,7 +36,7 @@ fi mkdir -p $rsd if $sort_decode; then - export srt_input_f=$tgtd/$srctf.srt + export srt_input_f=$tgtd/$srctf.srt$faext python tools/doc/sort.py $srcd/$srctf $srt_input_f 1048576 else export srt_input_f=$srcd/$srctf @@ -61,3 +63,4 @@ if $debpe; then else mv $tgtd/$bpef $rsf fi +rm $tgtd/test.h5 diff --git a/scripts/mktest.sh b/scripts/mktest.sh index 8143a94..0dbb487 100644 --- a/scripts/mktest.sh +++ b/scripts/mktest.sh @@ -19,6 +19,8 @@ export sort_decode=true export debpe=true export spm_bpe=false +export faext=".xz" + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -34,7 +36,7 @@ fi mkdir -p $rsd if $sort_decode; then - export srt_input_f=$tgtd/$srctf.srt + export srt_input_f=$tgtd/$srctf.srt$faext python tools/sort.py $srcd/$srctf $srt_input_f 1048576 else export srt_input_f=$srcd/$srctf @@ -60,3 +62,4 @@ if $debpe; then else mv $tgtd/$bpef $rsf fi +rm $tgtd/test.h5 diff --git a/scripts/mktrain.sh b/scripts/mktrain.sh index 4c453cd..f419bd1 100644 --- a/scripts/mktrain.sh +++ b/scripts/mktrain.sh @@ -26,15 +26,21 @@ export ngpu=1 export do_sort=true export build_vocab=true +export faext=".xz" + export wkd=$cachedir/$dataid mkdir -p $wkd +export stsf=$wkd/src.train.srt$faext +export ttsf=$wkd/tgt.train.srt$faext +export sdsf=$wkd/src.dev.srt$faext +export tdsf=$wkd/tgt.dev.srt$faext if $do_sort; then - python tools/sort.py $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens & + python tools/sort.py $srcd/$srctf $srcd/$tgttf $stsf $ttsf $maxtokens & # use the following command to sort a very large dataset with limited memory - #bash tools/lsort/sort.sh $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens & - python tools/sort.py $srcd/$srcvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/tgt.dev.srt 1048576 & + #bash tools/lsort/sort.sh $srcd/$srctf $srcd/$tgttf $stsf $ttsf $maxtokens & + python tools/sort.py $srcd/$srcvf $srcd/$tgtvf $sdsf $tdsf 1048576 & wait fi @@ -42,19 +48,19 @@ if $share_vcb; then export src_vcb=$wkd/common.vcb export tgt_vcb=$src_vcb if $build_vocab; then - python tools/share_vocab.py $wkd/src.train.srt $wkd/tgt.train.srt $src_vcb $vsize - python tools/check/fbindexes.py $tgt_vcb $wkd/tgt.train.srt $wkd/tgt.dev.srt $wkd/fbind.py & + python tools/vocab/token/share.py $stsf $ttsf $src_vcb $vsize + python tools/check/fbindexes.py $tgt_vcb $ttsf $tdsf $wkd/fbind.py & fi else export src_vcb=$wkd/src.vcb export tgt_vcb=$wkd/tgt.vcb if $build_vocab; then - python tools/vocab.py $wkd/src.train.srt $src_vcb $vsize & - python tools/vocab.py $wkd/tgt.train.srt $tgt_vcb $vsize & + python tools/vocab/token/single.py $stsf $src_vcb $vsize & + python tools/vocab/token/single.py $ttsf $tgt_vcb $vsize & wait fi fi -python tools/mkiodata.py $wkd/src.train.srt $wkd/tgt.train.srt $src_vcb $tgt_vcb $wkd/$rsf_train $ngpu & -python tools/mkiodata.py $wkd/src.dev.srt $wkd/tgt.dev.srt $src_vcb $tgt_vcb $wkd/$rsf_dev $ngpu & +python tools/mkiodata.py $stsf $ttsf $src_vcb $tgt_vcb $wkd/$rsf_train $ngpu & +python tools/mkiodata.py $sdsf $tdsf $src_vcb $tgt_vcb $wkd/$rsf_dev $ngpu & wait diff --git a/scripts/mulang/mktest.sh b/scripts/mulang/mktest.sh index fa2f6f1..b7ed71d 100644 --- a/scripts/mulang/mktest.sh +++ b/scripts/mulang/mktest.sh @@ -19,6 +19,8 @@ export sort_decode=true export debpe=true export spm_bpe=false +export faext=".xz" + export tgtd=$cachedir/$dataid export bpef=out.bpe @@ -34,7 +36,7 @@ fi mkdir -p $rsd if $sort_decode; then - export srt_input_f=$tgtd/$srctf.srt + export srt_input_f=$tgtd/$srctf.srt$faext python tools/mulang/eff/sort.py $srcd/$srctf $srt_input_f 1048576 else export srt_input_f=$srcd/$srctf @@ -57,6 +59,8 @@ if $debpe; then else sed -r 's/(@@ )|(@@ ?$)//g' < $tgtd/$bpef > $rsf fi + rm $tgtd/$bpef else mv $tgtd/$bpef $rsf fi +rm $tgtd/test.h5 diff --git a/scripts/mulang/mktrain.sh b/scripts/mulang/mktrain.sh index ee1a62d..f15c8c5 100644 --- a/scripts/mulang/mktrain.sh +++ b/scripts/mulang/mktrain.sh @@ -26,13 +26,19 @@ export ngpu=1 export do_sort=true export build_vocab=true +export faext=".xz" + export wkd=$cachedir/$dataid mkdir -p $wkd +export stsf=$wkd/src.train.srt$faext +export ttsf=$wkd/tgt.train.srt$faext +export sdsf=$wkd/src.dev.srt$faext +export tdsf=$wkd/tgt.dev.srt$faext if $do_sort; then - python tools/mulang/eff/sort.py $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens & - python tools/mulang/eff/sort.py $srcd/$srcvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/tgt.dev.srt 1048576 & + python tools/mulang/eff/sort.py $srcd/$srctf $srcd/$tgttf $stsf $ttsf $maxtokens & + python tools/mulang/eff/sort.py $srcd/$srcvf $srcd/$tgtvf $sdsf $tdsf 1048576 & wait fi @@ -40,20 +46,20 @@ if $share_vcb; then export src_vcb=$wkd/common.vcb export tgt_vcb=$src_vcb if $build_vocab; then - python tools/mulang/share_vocab.py $wkd/src.train.srt --target $wkd/tgt.train.srt $src_vcb $wkd/lang.vcb $vsize - python tools/check/mulang/fbindexes.py $tgt_vcb $wkd/src.train.srt $wkd/tgt.train.srt $wkd/src.dev.srt $wkd/tgt.dev.srt $wkd/lang.vcb $wkd/fbind.py & + python tools/mulang/vocab/token/share.py $stsf --target $ttsf $src_vcb $wkd/lang.vcb $vsize + python tools/check/mulang/fbindexes.py $tgt_vcb $stsf $ttsf $sdsf $tdsf $wkd/lang.vcb $wkd/fbind.py & fi else export src_vcb=$wkd/src.vcb export tgt_vcb=$wkd/tgt.vcb if $build_vocab; then - python tools/mulang/vocab.py $wkd/src.train.srt $src_vcb $wkd/lang.vcb $vsize & - python tools/vocab.py $wkd/tgt.train.srt $tgt_vcb $vsize & + python tools/mulang/vocab/token/single.py $stsf $src_vcb $wkd/lang.vcb $vsize & + python tools/vocab/token/single.py $ttsf $tgt_vcb $vsize & wait - python tools/check/mulang/fbindexes.py $tgt_vcb $wkd/src.train.srt $wkd/tgt.train.srt $wkd/src.dev.srt $wkd/tgt.dev.srt $wkd/lang.vcb $wkd/fbind.py & + python tools/check/mulang/fbindexes.py $tgt_vcb $stsf $ttsf $sdsf $tdsf $wkd/lang.vcb $wkd/fbind.py & fi fi -python tools/mulang/eff/mkiodata.py $wkd/src.train.srt $wkd/tgt.train.srt $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_train $ngpu & -python tools/mulang/eff/mkiodata.py $wkd/src.dev.srt $wkd/tgt.dev.srt $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_dev $ngpu & +python tools/mulang/eff/mkiodata.py $stsf $ttsf $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_train $ngpu & +python tools/mulang/eff/mkiodata.py $sdsf $tdsf $src_vcb $tgt_vcb $wkd/lang.vcb $wkd/$rsf_dev $ngpu & wait diff --git a/scripts/plm/bart/mktest.sh b/scripts/plm/bart/mktest.sh index a28ea7c..a86cc47 100644 --- a/scripts/plm/bart/mktest.sh +++ b/scripts/plm/bart/mktest.sh @@ -16,6 +16,8 @@ export ngpu=1 export sort_decode=true +export faext=".xz" + export tgtd=$cachedir/$dataid export tgt_vcb=$src_vcb @@ -23,20 +25,22 @@ export bpef=out.bpe mkdir -p $rsd -python tools/plm/map/bart.py $srcd/$srctf $src_vcb $tgtd/$srctf.ids +export stif=$tgtd/$srctf.ids$faext +python tools/plm/map/bart.py $srcd/$srctf $src_vcb $stif if $sort_decode; then - export srt_input_f=$tgtd/$srctf.ids.srt - python tools/sort.py $tgtd/$srctf.ids $srt_input_f 1048576 + export srt_input_f=$tgtd/$srctf.ids.srt$faext + python tools/sort.py $stif $srt_input_f 1048576 else - export srt_input_f=$tgtd/$srctf.ids + export srt_input_f=$stif fi python tools/plm/mktest.py $srt_input_f $tgtd/test.h5 $ngpu python predict_bart.py $tgtd/$bpef $tgt_vcb $modelf if $sort_decode; then - python tools/restore.py $tgtd/$srctf.ids $srt_input_f $tgtd/$bpef $rsf + python tools/restore.py $stif $srt_input_f $tgtd/$bpef $rsf rm $srt_input_f $tgtd/$bpef else mv $tgtd/$bpef $rsf fi +rm $stif $tgtd/test.h5 diff --git a/scripts/plm/roberta/mktest.sh b/scripts/plm/roberta/mktest.sh index 81df98d..96de546 100644 --- a/scripts/plm/roberta/mktest.sh +++ b/scripts/plm/roberta/mktest.sh @@ -16,6 +16,8 @@ export ngpu=1 export sort_decode=true +export faext=".xz" + export tgtd=$cachedir/$dataid #export tgt_vcb=$src_vcb @@ -23,20 +25,22 @@ export bpef=out.bpe mkdir -p $rsd -python tools/plm/map/roberta.py $srcd/$srctf $src_vcb $tgtd/$srctf.ids +export stif=$tgtd/$srctf.ids$faext +python tools/plm/map/roberta.py $srcd/$srctf $src_vcb $stif if $sort_decode; then - export srt_input_f=$tgtd/$srctf.ids.srt - python tools/sort.py $tgtd/$srctf.ids $srt_input_f 1048576 + export srt_input_f=$tgtd/$srctf.ids.srt$faext + python tools/sort.py $stif $srt_input_f 1048576 else - export srt_input_f=$tgtd/$srctf.ids + export srt_input_f=$stif fi python tools/plm/mktest.py $srt_input_f $tgtd/test.h5 $ngpu python predict_roberta.py $tgtd/$bpef $modelf if $sort_decode; then - python tools/restore.py $tgtd/$srctf.ids $srt_input_f $tgtd/$bpef $rsf + python tools/restore.py $stif $srt_input_f $tgtd/$bpef $rsf rm $srt_input_f $tgtd/$bpef else mv $tgtd/$bpef $rsf fi +rm $stif $tgtd/test.h5 diff --git a/scripts/plm/roberta/mktrain.sh b/scripts/plm/roberta/mktrain.sh index d6f2ec6..0af1d50 100644 --- a/scripts/plm/roberta/mktrain.sh +++ b/scripts/plm/roberta/mktrain.sh @@ -26,14 +26,18 @@ export wkd=$cachedir/$dataid mkdir -p $wkd +export stsf=$wkd/src.train.srt$faext +export ttsf=$wkd/tgt.train.srt$faext +export sdsf=$wkd/src.dev.srt$faext +export tdsf=$wkd/tgt.dev.srt$faext if $do_sort; then - python tools/sort.py $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens & + python tools/sort.py $srcd/$srctf $srcd/$tgttf $stsf $ttsf $maxtokens & # use the following command to sort a very large dataset with limited memory - #bash tools/lsort/sort.sh $srcd/$srctf $srcd/$tgttf $wkd/src.train.srt $wkd/tgt.train.srt $maxtokens & - python tools/sort.py $srcd/$srcvf $srcd/$tgtvf $wkd/src.dev.srt $wkd/tgt.dev.srt 1048576 & + #bash tools/lsort/sort.sh $srcd/$srctf $srcd/$tgttf $stsf $ttsf $maxtokens & + python tools/sort.py $srcd/$srcvf $srcd/$tgtvf $sdsf $tdsf 1048576 & wait fi -python tools/plm/mkiodata.py $wkd/src.train.srt $wkd/tgt.train.srt $wkd/$rsf_train $ngpu & -python tools/plm/mkiodata.py $wkd/src.dev.srt $wkd/tgt.dev.srt $wkd/$rsf_dev $ngpu & +python tools/plm/mkiodata.py $stsf $ttsf $wkd/$rsf_train $ngpu & +python tools/plm/mkiodata.py $sdsf $tdsf $wkd/$rsf_dev $ngpu & wait diff --git a/scripts/plm/t5/mktest.sh b/scripts/plm/t5/mktest.sh index 5f7f750..e634955 100644 --- a/scripts/plm/t5/mktest.sh +++ b/scripts/plm/t5/mktest.sh @@ -16,6 +16,8 @@ export ngpu=1 export sort_decode=true +export faext=".xz" + export tgtd=$cachedir/$dataid export tgt_vcb=$src_vcb @@ -23,20 +25,22 @@ export bpef=out.bpe mkdir -p $rsd -python tools/plm/map/t5.py $srcd/$srctf $src_vcb $tgtd/$srctf.ids +export stif=$tgtd/$srctf.ids$faext +python tools/plm/map/t5.py $srcd/$srctf $src_vcb $stif if $sort_decode; then - export srt_input_f=$tgtd/$srctf.ids.srt - python tools/sort.py $tgtd/$srctf.ids $srt_input_f 1048576 + export srt_input_f=$tgtd/$srctf.ids.srt$faext + python tools/sort.py $stif $srt_input_f 1048576 else - export srt_input_f=$tgtd/$srctf.ids + export srt_input_f=$stif fi python tools/plm/mktest.py $srt_input_f $tgtd/test.h5 $ngpu python predict_t5.py $tgtd/$bpef $tgt_vcb $modelf if $sort_decode; then - python tools/restore.py $tgtd/$srctf.ids $srt_input_f $tgtd/$bpef $rsf + python tools/restore.py $stif $srt_input_f $tgtd/$bpef $rsf rm $srt_input_f $tgtd/$bpef else mv $tgtd/$bpef $rsf fi +rm $stif $tgtd/test.h5 diff --git a/scripts/spm/clean.sh b/scripts/spm/clean.sh index fb732bc..4a18787 100644 --- a/scripts/spm/clean.sh +++ b/scripts/spm/clean.sh @@ -54,16 +54,15 @@ rm -fr $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp if $share_bpe; then # to learn joint bpe - cat $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean | shuf > $tgtd/bpe.train.txt - spm_train --input=$tgtd/bpe.train.txt --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + # --max_sentence_length=4096 --input_sentence_size=5000000 --shuffle_input_sentence=true --num_threads=32 --train_extremely_large_corpus=true + spm_train --input=$tgtd/src.train.tok.clean,$tgtd/tgt.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe & spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe & wait - rm $tgtd/bpe.train.txt else # to learn independent bpe: - spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 & - spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 & + spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 & + spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 & wait spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe & spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe & diff --git a/scripts/spm/mk.sh b/scripts/spm/mk.sh index cf35e75..fd004d4 100644 --- a/scripts/spm/mk.sh +++ b/scripts/spm/mk.sh @@ -34,8 +34,8 @@ python tools/clean/token_repeat.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd mv $tgtd/src.clean.rtmp $tgtd/src.clean.tmp mv $tgtd/tgt.clean.rtmp $tgtd/tgt.clean.tmp -python tools/vocab.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & -python tools/vocab.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/src.clean.tmp $tgtd/src.full.vcb 1048576 & +python tools/vocab/token/single.py $tgtd/tgt.clean.tmp $tgtd/tgt.full.vcb 1048576 & wait python tools/clean/vocab/ratio.py $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean $tgtd/src.full.vcb $tgtd/tgt.full.vcb $vratio rm -fr $tgtd/src.full.vcb $tgtd/tgt.full.vcb $tgtd/src.clean.tmp $tgtd/tgt.clean.tmp @@ -44,18 +44,17 @@ if $share_bpe; then # to learn joint bpe export src_cdsf=$tgtd/bpe export tgt_cdsf=$tgtd/bpe - cat $tgtd/src.train.tok.clean $tgtd/tgt.train.tok.clean | shuf > $tgtd/bpe.train.txt - spm_train --input=$tgtd/bpe.train.txt --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 + # --max_sentence_length=4096 --input_sentence_size=5000000 --shuffle_input_sentence=true --num_threads=32 --train_extremely_large_corpus=true + spm_train --input=$tgtd/src.train.tok.clean,$tgtd/tgt.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe & spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe & wait - rm $tgtd/bpe.train.txt else # to learn independent bpe: export src_cdsf=$tgtd/src export tgt_cdsf=$tgtd/tgt - spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 & - spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --minloglevel=1 & + spm_train --input=$tgtd/src.train.tok.clean --model_prefix=$src_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 & + spm_train --input=$tgtd/tgt.train.tok.clean --model_prefix=$tgt_cdsf --vocab_size=$bpeops --character_coverage=$charcov --model_type=$mtype --unk_id=3 --bos_id=1 --eos_id=2 --pad_id=0 --unk_piece="" --bos_piece="" --eos_piece="" --unk_surface="" --minloglevel=1 --random_seed=666666 & wait spm_encode --model=$src_cdsf.model --generate_vocabulary < $tgtd/src.train.tok.clean > $tgtd/src.vcb.bpe & spm_encode --model=$tgt_cdsf.model --generate_vocabulary < $tgtd/tgt.train.tok.clean > $tgtd/tgt.vcb.bpe & diff --git a/server.py b/server.py index fe0ef7e..ecabd32 100644 --- a/server.py +++ b/server.py @@ -1,14 +1,15 @@ #encoding: utf-8 -from flask import Flask, request, render_template, send_from_directory import json +from flask import Flask, render_template, request, send_from_directory -import cnfg.base as cnfg -# import Tokenizer/Detokenizer/SentenceSplitter from datautils.zh for Chinese -from datautils.pymoses import Tokenizer, Detokenizer, Normalizepunctuation, Truecaser, Detruecaser -from datautils.moses import SentenceSplitter from datautils.bpe import BPEApplier, BPERemover -from translator import TranslatorCore, Translator +from datautils.moses import SentenceSplitter +# import Tokenizer/Detokenizer/SentenceSplitter from datautils.zh for Chinese +from datautils.pymoses import Detokenizer, Detruecaser, Normalizepunctuation, Tokenizer, Truecaser +from modules.server.transformer import Translator, TranslatorCore + +import cnfg.base as cnfg """ slang = "de"# source language diff --git a/tools/ape/mkiodata.py b/tools/ape/mkiodata.py index 7736dcc..f412998 100644 --- a/tools/ape/mkiodata.py +++ b/tools/ape/mkiodata.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab from utils.fmt.ape.triple import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/average_model.py b/tools/average_model.py index e0d6f7a..283e466 100644 --- a/tools/average_model.py +++ b/tools/average_model.py @@ -6,33 +6,37 @@ import sys -import torch +from utils.h5serial import h5load, h5save +from utils.torch.comp import secure_type_map -from utils.base import secure_type_map -from utils.h5serial import h5save, h5load - -from cnfg.ihyp import * +from cnfg.ihyp import h5zipargs def handle(srcfl, rsf): - rsm = h5load(srcfl[0]) - - src_type = [para.dtype for para in rsm] - map_type = [secure_type_map[para.dtype] if para.dtype in secure_type_map else None for para in rsm] - sec_rsm = [para if typ is None else para.to(typ) for para, typ in zip(rsm, map_type)] - - nmodel = 1 - for modelf in srcfl[1:]: - for basep, mpload, typ in zip(sec_rsm, h5load(modelf), map_type): - basep.add_(mpload if typ is None else mpload.to(typ)) - nmodel += 1 - nmodel = float(nmodel) - for basep in sec_rsm: - basep.div_(nmodel) - - rsm = [para if mtyp is None else para.to(styp) for para, mtyp, styp in zip(sec_rsm, map_type, src_type)] - - h5save(rsm, rsf, h5args=h5zipargs) + src_type = {} + map_type = {} + sec_rsm = {} + nmodel = {} + for modelf in srcfl: + _nmp = h5load(modelf, restore_list=False) + for _n, _p in sec_rsm.items(): + if _n in _nmp: + _ = _nmp[_n] + _m_type = map_type[_n] + _p.add_(_ if _m_type is None else _.to(_m_type, non_blocking=True)) + nmodel[_n] += 1 + for _n, _p in _nmp.items(): + if _n not in sec_rsm: + src_type[_n] = _p_dtype = _p.dtype + map_type[_n] = _m_type = secure_type_map.get(_p_dtype, None) + sec_rsm[_n] = _p if _m_type is None else _p.to(_m_type, non_blocking=True) + nmodel[_n] = 1 + _nmp = None + + for _n, _p in sec_rsm.items(): + _p.div_(float(nmodel[_n])) + + h5save({_n: _p if map_type[_n] is None else _p.to(src_type[_n], non_blocking=True) for _n, _p in sec_rsm.items()}, rsf, h5args=h5zipargs) if __name__ == "__main__": handle(sys.argv[2:], sys.argv[1]) diff --git a/tools/char/cnfg b/tools/char/cnfg new file mode 120000 index 0000000..bcd9a88 --- /dev/null +++ b/tools/char/cnfg @@ -0,0 +1 @@ +../../cnfg/ \ No newline at end of file diff --git a/tools/char/mkiodata.py b/tools/char/mkiodata.py new file mode 100644 index 0000000..64bb5eb --- /dev/null +++ b/tools/char/mkiodata.py @@ -0,0 +1,37 @@ +#encoding: utf-8 + +import sys +from numpy import array as np_array, int32 as np_int32 + +from utils.fmt.char.dual import batch_padder +from utils.fmt.vocab.char import ldvocab +from utils.h5serial import h5File + +from cnfg.ihyp import * + +def handle(finput, ftarget, fvocab_i, fvocab_t, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbt, nwordt = ldvocab(fvocab_t, minf=minfreq, omit_vsize=vsize, vanilla=False) + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + with h5File(frs, "w", libver=h5_libver) as rsf: + src_grp = rsf.create_group("src") + tgt_grp = rsf.create_group("tgt") + curd = 0 + for i_d, td in batch_padder(finput, ftarget, vcbi, vcbt, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = np_array(i_d, dtype=np_int32) + rtd = np_array(td, dtype=np_int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + curd += 1 + rsf["ndata"] = np_array([curd], dtype=np_int32) + rsf["nword"] = np_array([nwordi, nwordt], dtype=np_int32) + print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (curd, nwordi, nwordt,)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], int(sys.argv[6])) diff --git a/tools/char/mktest.py b/tools/char/mktest.py new file mode 100644 index 0000000..02ce276 --- /dev/null +++ b/tools/char/mktest.py @@ -0,0 +1,33 @@ +#encoding: utf-8 + +import sys +from numpy import array as np_array, int32 as np_int32 + +from utils.fmt.char.single import batch_padder +from utils.fmt.vocab.char import ldvocab +from utils.h5serial import h5File + +from cnfg.ihyp import * + +def handle(finput, fvocab_i, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + vcbi, nwordi = ldvocab(fvocab_i, minf=minfreq, omit_vsize=vsize, vanilla=False) + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + with h5File(frs, "w", libver=h5_libver) as rsf: + src_grp = rsf.create_group("src") + curd = 0 + for i_d in batch_padder(finput, vcbi, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = np_array(i_d, dtype=np_int32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + curd += 1 + rsf["ndata"] = np_array([curd], dtype=np_int32) + rsf["nword"] = np_array([nwordi], dtype=np_int32) + print("Number of batches: %d\nSource Vocabulary Size: %d" % (curd, nwordi,)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4])) diff --git a/tools/char/utils b/tools/char/utils new file mode 120000 index 0000000..7d6b64a --- /dev/null +++ b/tools/char/utils @@ -0,0 +1 @@ +../../utils/ \ No newline at end of file diff --git a/tools/check/avg_bsize.py b/tools/check/avg_bsize.py index 7ff6356..ab767a0 100644 --- a/tools/check/avg_bsize.py +++ b/tools/check/avg_bsize.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - import torch -from utils.h5serial import h5File +from random import seed as rpyseed, shuffle +from utils.h5serial import h5File from utils.tqdm import tqdm -from random import shuffle, seed as rpyseed from cnfg.ihyp import tqdm_mininterval diff --git a/tools/check/biratio.py b/tools/check/biratio.py index 7844c8d..c824794 100644 --- a/tools/check/biratio.py +++ b/tools/check/biratio.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import get_bi_ratio +from utils.fmt.base import get_bi_ratio, sys_open def handle(srcfs, srcft): @@ -11,7 +11,7 @@ def handle(srcfs, srcft): bmeanratio = 0.0 omeanratio = 0.0 ndata = 0 - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft: for sline, tline in zip(fs, ft): sline, tline = sline.strip(), tline.strip() if sline and tline: diff --git a/tools/check/charatio.py b/tools/check/charatio.py index 9061b51..2f66706 100644 --- a/tools/check/charatio.py +++ b/tools/check/charatio.py @@ -2,13 +2,13 @@ import sys -from utils.fmt.base import get_char_ratio +from utils.fmt.base import get_char_ratio, sys_open def handle(srcfs, srcft): def getfratio(fname): - with open(fname, "rb") as fs: + with sys_open(fname, "rb") as fs: maxratioc = 0.0 maxratiob = 0.0 diff --git a/tools/check/doc/para/epoch_steps.py b/tools/check/doc/para/epoch_steps.py index 20710e1..7c39e6d 100644 --- a/tools/check/doc/para/epoch_steps.py +++ b/tools/check/doc/para/epoch_steps.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - import torch -from utils.h5serial import h5File +from random import seed as rpyseed, shuffle +from utils.h5serial import h5File from utils.tqdm import tqdm -from random import shuffle, seed as rpyseed from cnfg.ihyp import tqdm_mininterval diff --git a/tools/check/dynb/bsize/bsize.py b/tools/check/dynb/bsize/bsize.py index 6df9abc..726c48c 100644 --- a/tools/check/dynb/bsize/bsize.py +++ b/tools/check/dynb/bsize/bsize.py @@ -1,9 +1,10 @@ #encoding: utf-8 import sys - from math import floor +from utils.fmt.base import sys_open + def load_log(fname): def legal(clin): @@ -14,7 +15,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/dynb/bsize/extbsize.py b/tools/check/dynb/bsize/extbsize.py index de3eaf6..5bd2a64 100644 --- a/tools/check/dynb/bsize/extbsize.py +++ b/tools/check/dynb/bsize/extbsize.py @@ -2,7 +2,7 @@ import sys -from math import floor +from utils.fmt.base import sys_open def load_log(fname): @@ -14,7 +14,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: @@ -38,7 +38,7 @@ def legal(clin): def handle(srcf, rsf): ens = "\n".encode("utf-8") - with open(rsf, "wb") as f: + with sys_open(rsf, "wb") as f: for data in load_log(srcf): f.write(data[-1].encode("utf-8")) f.write(ens) diff --git a/tools/check/dynb/bsize/freqbsize.py b/tools/check/dynb/bsize/freqbsize.py index 3ebf33f..10734b5 100644 --- a/tools/check/dynb/bsize/freqbsize.py +++ b/tools/check/dynb/bsize/freqbsize.py @@ -1,9 +1,10 @@ #encoding: utf-8 import sys - from math import floor +from utils.fmt.base import sys_open + def load_log(fname): def legal(clin): @@ -14,7 +15,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/dynb/bsize/stabsize.py b/tools/check/dynb/bsize/stabsize.py index a721a01..b78c6a8 100644 --- a/tools/check/dynb/bsize/stabsize.py +++ b/tools/check/dynb/bsize/stabsize.py @@ -2,7 +2,7 @@ import sys -from math import floor +from utils.fmt.base import sys_open def load_log(fname): @@ -14,7 +14,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/dynb/consis.py b/tools/check/dynb/consis.py index 9ccdb75..ad40c45 100644 --- a/tools/check/dynb/consis.py +++ b/tools/check/dynb/consis.py @@ -1,9 +1,10 @@ #encoding: utf-8 import sys - from math import floor +from utils.fmt.base import sys_open + def load_log(fname): def legal(clin): @@ -14,7 +15,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/dynb/ext.py b/tools/check/dynb/ext.py index c42adb5..e6f114b 100644 --- a/tools/check/dynb/ext.py +++ b/tools/check/dynb/ext.py @@ -2,13 +2,15 @@ import sys +from utils.fmt.base import sys_open + def handle(srcf, rsf, el1=10, el2=9, el3=7): acc_bsize = 0 odd_line = True l1, l2, l3 = [], [], [] ens = "\n".encode("utf-8") - with open(srcf, "rb") as frd, open(rsf, "wb") as fwrt: + with sys_open(srcf, "rb") as frd, sys_open(rsf, "wb") as fwrt: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/dynb/report_dynb.py b/tools/check/dynb/report_dynb.py index f38ec04..8b6a249 100644 --- a/tools/check/dynb/report_dynb.py +++ b/tools/check/dynb/report_dynb.py @@ -1,39 +1,35 @@ #encoding: utf-8 import sys - import torch - +from random import random, shuffle from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.dynbatch import GradientMonitor +from utils.fmt.base import iter_to_str, parse_double_value_tuple, sys_open +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.dynbatch import GradientMonitor -from utils.fmt.base import tostr, parse_double_value_tuple -from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle, random - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none +from utils.train.dss import dynamic_sample import cnfg.dynb as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT +from cnfg.vocab.base import pad_id log_dyn_p, max_his, log_dynb = 1.0, 9, True @@ -53,11 +49,12 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None global grad_mon, update_angle, enc_layer, log_dyn_p, log_dynb, wkdir - _log_f_dynbatch = open(wkdir+"dynbatch.log", "ab") + _log_f_dynbatch = sys_open(wkdir+"dynbatch.log", "ab") _log_f_dynbatch.write("ES\n".encode("utf-8")) src_grp, tgt_grp = td["src"], td["tgt"] @@ -72,7 +69,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: @@ -88,7 +85,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -143,6 +140,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -151,7 +158,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -175,7 +182,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -186,7 +193,7 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: @@ -260,7 +267,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) fine_tune_m = cnfg.fine_tune_m @@ -299,6 +306,9 @@ def load_fixing(module): lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -307,7 +317,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) diff --git a/tools/check/dynb/slayer.py b/tools/check/dynb/slayer.py index 997ee4f..32fc244 100644 --- a/tools/check/dynb/slayer.py +++ b/tools/check/dynb/slayer.py @@ -1,9 +1,10 @@ #encoding: utf-8 import sys - from math import floor +from utils.fmt.base import sys_open + def load_log(fname): def legal(clin): @@ -14,7 +15,7 @@ def legal(clin): return False cache = [] - with open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: diff --git a/tools/check/epoch_steps.py b/tools/check/epoch_steps.py index 66393e8..ca5118e 100644 --- a/tools/check/epoch_steps.py +++ b/tools/check/epoch_steps.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - import torch -from utils.h5serial import h5File +from random import seed as rpyseed, shuffle +from utils.h5serial import h5File from utils.tqdm import tqdm -from random import shuffle, seed as rpyseed from cnfg.ihyp import tqdm_mininterval diff --git a/tools/check/ext_emb.py b/tools/check/ext_emb.py index dbf9db1..4ed1c5e 100644 --- a/tools/check/ext_emb.py +++ b/tools/check/ext_emb.py @@ -5,12 +5,12 @@ """ import sys - import torch -from utils.fmt.base import ldvocab, reverse_dict from utils.fmt.base4torch import load_emb_txt -from utils.h5serial import h5save, h5load +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5save def handle(vcbf, embf, rsf): diff --git a/tools/check/fbindexes.py b/tools/check/fbindexes.py index de9c560..c4f4a74 100644 --- a/tools/check/fbindexes.py +++ b/tools/check/fbindexes.py @@ -2,7 +2,8 @@ import sys -from utils.fmt.base import ldvocab, init_vocab +from utils.fmt.base import sys_open +from utils.fmt.vocab.token import init_vocab, ldvocab def handle(vcbf, srcfl, rsf, minfreq=False, vsize=False): @@ -11,7 +12,7 @@ def handle(vcbf, srcfl, rsf, minfreq=False, vsize=False): fvcb = set(init_vocab.keys()) for srcf in srcfl: - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: @@ -24,7 +25,7 @@ def handle(vcbf, srcfl, rsf, minfreq=False, vsize=False): if wd not in fvcb: rsl.append(ind) - with open(rsf, "wb") as f: + with sys_open(rsf, "wb") as f: f.write("#encoding: utf-8\n\nfbl = ".encode("utf-8")) f.write(repr(rsl).encode("utf-8")) f.write("\n".encode("utf-8")) diff --git a/tools/check/lsplitter.py b/tools/check/lsplitter.py index b25ee18..80de803 100644 --- a/tools/check/lsplitter.py +++ b/tools/check/lsplitter.py @@ -1,7 +1,8 @@ #encoding: utf-8 import sys -from utils.fmt.base import FileList + +from utils.fmt.base import FileList, sys_open interval = 15 @@ -21,7 +22,7 @@ def handle(srcf, osfl): wid = lts // interval if wid not in fwd: sind = str(wid) + "_" - fwrtl = [open(sind + srcf, "wb") for srcf in osfl] + fwrtl = [sys_open(sind + srcf, "wb") for srcf in osfl] fwd[wid] = fwrtl else: fwrtl = fwd[wid] diff --git a/tools/check/mkrandio.py b/tools/check/mkrandio.py new file mode 100644 index 0000000..11c0f54 --- /dev/null +++ b/tools/check/mkrandio.py @@ -0,0 +1,29 @@ +#encoding: utf-8 + +import sys +from numpy import array as np_array, int32 as np_int32 +from numpy.random import randint as np_randint + +from utils.h5serial import h5File + +from cnfg.ihyp import h5_libver, h5datawargs + +def handle(bsize, seql, nword, frs, ndata=1): + + _bsize = bsize + with h5File(frs, "w", libver=h5_libver) as rsf: + src_grp = rsf.create_group("src") + tgt_grp = rsf.create_group("tgt") + for curd in range(ndata): + wid = str(curd) + _size = (_bsize, seql,) + src_grp.create_dataset(wid, data=np_randint(0, high=nword, size=_size, dtype=np_int32), **h5datawargs) + tgt_grp.create_dataset(wid, data=np_randint(0, high=nword, size=_size, dtype=np_int32), **h5datawargs) + _bsize += 1 + curd += 1 + rsf["ndata"] = np_array([ndata], dtype=np_int32) + rsf["nword"] = np_array([nword, nword], dtype=np_int32) + print("Number of batches: %d\nSource Vocabulary Size: %d\nTarget Vocabulary Size: %d" % (curd, nword, nword,)) + +if __name__ == "__main__": + handle(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), sys.argv[4], int(sys.argv[-1])) if len(sys.argv) > 5 else handle(int(sys.argv[1]), int(sys.argv[2]), int(sys.argv[3]), sys.argv[4]) diff --git a/tools/check/mulang/eff/epoch_steps.py b/tools/check/mulang/eff/epoch_steps.py index 848dba5..a3e7a4c 100644 --- a/tools/check/mulang/eff/epoch_steps.py +++ b/tools/check/mulang/eff/epoch_steps.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - import torch -from utils.h5serial import h5File +from random import seed as rpyseed, shuffle +from utils.h5serial import h5File from utils.tqdm import tqdm -from random import shuffle, seed as rpyseed from cnfg.ihyp import tqdm_mininterval diff --git a/tools/check/mulang/fbindexes.py b/tools/check/mulang/fbindexes.py index c1ed9c9..9e120ea 100644 --- a/tools/check/mulang/fbindexes.py +++ b/tools/check/mulang/fbindexes.py @@ -2,7 +2,8 @@ import sys -from utils.fmt.base import ldvocab, init_vocab +from utils.fmt.base import sys_open +from utils.fmt.vocab.token import init_vocab, ldvocab def handle(vcbf, srcfl, fvocab_task, rsf, minfreq=False, vsize=False): @@ -12,7 +13,7 @@ def handle(vcbf, srcfl, fvocab_task, rsf, minfreq=False, vsize=False): fvcb = {} for srcf, tgtf in zip(srcfl[0::2], srcfl[1::2]): - with open(srcf, "rb") as fsrc, open(tgtf, "rb") as ftgt: + with sys_open(srcf, "rb") as fsrc, sys_open(tgtf, "rb") as ftgt: for lsrc, ltgt in zip(fsrc, ftgt): tsrc, ttgt = lsrc.strip(), ltgt.strip() if tsrc and ttgt: @@ -33,7 +34,7 @@ def handle(vcbf, srcfl, fvocab_task, rsf, minfreq=False, vsize=False): tmp.append(ind) rsl.append(tmp) - with open(rsf, "wb") as f: + with sys_open(rsf, "wb") as f: f.write("#encoding: utf-8\n\nfbl = ".encode("utf-8")) f.write(repr(rsl).encode("utf-8")) f.write("\n".encode("utf-8")) diff --git a/tools/check/para.py b/tools/check/para.py index b83f54a..84af84b 100644 --- a/tools/check/para.py +++ b/tools/check/para.py @@ -5,8 +5,8 @@ """ import sys - from h5py import Dataset + from utils.h5serial import h5File def handle_group(srcg): diff --git a/tools/check/probe/merge_probe.py b/tools/check/probe/merge_probe.py index e2334e4..93f3c0a 100644 --- a/tools/check/probe/merge_probe.py +++ b/tools/check/probe/merge_probe.py @@ -2,24 +2,23 @@ import sys -from utils.base import * -from utils.init.base import init_model_params +from transformer.NMT import NMT as NMTBase +from transformer.Probe.NMT import NMT from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model import cnfg.probe as cnfg from cnfg.ihyp import * -from transformer.NMT import NMT as NMTBase -from transformer.Probe.NMT import NMT - def handle(cnfg, srcmtf, decf, rsf): with h5File(cnfg.dev_data, "r") as tdf: nwordi, nwordt = tdf["nword"][()].tolist() - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes, cnfg.num_layer_fwd) init_model_params(mymodel) - _tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + _tmpm = NMTBase(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) _tmpm = init_model_params(_tmpm) _tmpm = load_model_cpu(srcmtf, _tmpm) mymodel.load_base(_tmpm) diff --git a/tools/check/rank.py b/tools/check/rank.py index 203bd7d..663a3f1 100644 --- a/tools/check/rank.py +++ b/tools/check/rank.py @@ -4,11 +4,13 @@ import sys +from utils.fmt.base import sys_open + def handle(rankf, dkeep): scores = [] - with open(rankf, "rb") as f: + with sys_open(rankf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/check/topk/eva.py b/tools/check/topk/eva.py index 4665ee8..0e3108d 100644 --- a/tools/check/topk/eva.py +++ b/tools/check/topk/eva.py @@ -5,26 +5,23 @@ """ import sys - -from numpy import array as np_array, int32 as np_int32 - import torch +from numpy import array as np_array, int32 as np_int32 from torch import nn -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base4torch import parse_cuda_decode from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda_decode def load_fixing(module): @@ -40,7 +37,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 5: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[4], mymodel) mymodel.apply(load_fixing) @@ -49,7 +46,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[4:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -71,11 +68,13 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty src_grp, tgt_grp = td["src"], td["tgt"] -with h5File(sys.argv[3], "w", libver=h5_libver) as rsf, torch.no_grad(): +with h5File(sys.argv[3], "w", libver=h5_libver) as rsf, torch_inference_mode(): p_grp = rsf.create_group("p") ind_grp = rsf.create_group("ind") for i in tqdm(range(ntest), mininterval=tqdm_mininterval): @@ -87,7 +86,7 @@ def load_fixing(module): seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_o = seq_o.to(cuda_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch, seq_o.narrow(1, 0, lo)) p, ind = output.masked_fill_(seq_o.narrow(1, 1, lo).eq(pad_id).unsqueeze(-1), 0.0).topk(k, dim=-1) ind = ind.int() diff --git a/tools/check/topk/eva_stat.py b/tools/check/topk/eva_stat.py index 61d97c6..afe0dbc 100644 --- a/tools/check/topk/eva_stat.py +++ b/tools/check/topk/eva_stat.py @@ -5,26 +5,22 @@ """ import sys - -from numpy import array as np_array, int32 as np_int32 - import torch from torch import nn -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT +from utils.base import set_random_seed +from utils.fmt.base4torch import parse_cuda_decode from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import * from cnfg.vocab.base import pad_id -from utils.fmt.base4torch import parse_cuda_decode def load_fixing(module): @@ -40,7 +36,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 4: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[3], mymodel) mymodel.apply(load_fixing) @@ -49,7 +45,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[3:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -71,6 +67,8 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=True) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty @@ -82,7 +80,7 @@ def load_fixing(module): else: cum_p = torch.zeros(k, dtype=torch.double) m_ind = torch.zeros(k, dtype=torch.long) -with torch.no_grad(): +with torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) @@ -92,7 +90,7 @@ def load_fixing(module): seq_batch = seq_batch.to(cuda_device, non_blocking=True) seq_o = seq_o.to(cuda_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = mymodel(seq_batch, seq_o.narrow(1, 0, lo)) tgt = seq_o.narrow(1, 1, lo) mask = tgt.eq(pad_id).unsqueeze(-1) diff --git a/tools/check/topk/stat.py b/tools/check/topk/stat.py index 9521c32..d149948 100644 --- a/tools/check/topk/stat.py +++ b/tools/check/topk/stat.py @@ -5,18 +5,18 @@ """ import sys - import torch -from utils.tqdm import tqdm from utils.h5serial import h5File -from cnfg.vocab.base import pad_id +from utils.torch.comp import torch_inference_mode +from utils.tqdm import tqdm from cnfg.ihyp import * +from cnfg.vocab.base import pad_id def handle(srcf, ref): - with h5File(srcf, "r") as fs, h5File(ref, "r") as fr, torch.no_grad(): + with h5File(srcf, "r") as fs, h5File(ref, "r") as fr, torch_inference_mode(): p_grp, ind_grp, tgt_grp = fs["p"], fs["ind"], fr["tgt"] ndata = fs["ndata"][()].item() nword = 0 diff --git a/tools/check/tspeed.py b/tools/check/tspeed.py index 117b888..9efca3c 100644 --- a/tools/check/tspeed.py +++ b/tools/check/tspeed.py @@ -1,22 +1,19 @@ #encoding: utf-8 import sys - import torch -from utils.tqdm import tqdm - +from parallel.parallelMT import DataParallelMT +from transformer.EnsembleNMT import NMT as Ensemble +from transformer.NMT import NMT from utils.h5serial import h5File +from utils.io import load_model_cpu +from utils.torch.comp import torch_compile, torch_inference_mode +from utils.tqdm import tqdm import cnfg.base as cnfg from cnfg.ihyp import * -from transformer.NMT import NMT -from transformer.EnsembleNMT import NMT as Ensemble -from parallel.parallelMT import DataParallelMT - -from utils.base import load_model_cpu - def load_fixing(module): if hasattr(module, "fix_load"): @@ -29,7 +26,7 @@ def load_fixing(module): nwordi, nwordt = nword[0], nword[-1] if len(sys.argv) == 2: - mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(sys.argv[1], mymodel) mymodel.apply(load_fixing) @@ -37,7 +34,7 @@ def load_fixing(module): else: models = [] for modelf in sys.argv[1:]: - tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + tmp = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) tmp = load_model_cpu(modelf, tmp) tmp.apply(load_fixing) @@ -76,11 +73,13 @@ def load_fixing(module): if multi_gpu: mymodel = DataParallelMT(mymodel, device_ids=cuda_devices, output_device=cuda_device.index, host_replicate=True, gather_output=False) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) + beam_size = cnfg.beam_size length_penalty = cnfg.length_penalty src_grp = td["src"] -with torch.no_grad(): +with torch_inference_mode(): for i in tqdm(range(ntest), mininterval=tqdm_mininterval): seq_batch = torch.from_numpy(src_grp[str(i)][()]) if cuda_device: diff --git a/tools/check/vsize/copy.py b/tools/check/vsize/copy.py index 022b26c..4658aca 100644 --- a/tools/check/vsize/copy.py +++ b/tools/check/vsize/copy.py @@ -2,13 +2,13 @@ import sys -from utils.fmt.base import clean_list +from utils.fmt.base import clean_list, sys_open def handle(srcfl, tgtfl): nsrc = ntgt = ncopy = 0 for srcf, tgtf in zip(srcfl, tgtfl): - with open(srcf, "rb") as fsrc, open(tgtf, "rb") as ftgt: + with sys_open(srcf, "rb") as fsrc, sys_open(tgtf, "rb") as ftgt: for srcl, tgtl in zip(fsrc, ftgt): srcl, tgtl = srcl.strip(), tgtl.strip() if srcl or tgtl: diff --git a/tools/check/vsize/detail.py b/tools/check/vsize/detail.py index f1ac174..8f84e4c 100644 --- a/tools/check/vsize/detail.py +++ b/tools/check/vsize/detail.py @@ -2,13 +2,13 @@ import sys -from utils.fmt.base import clean_list_iter +from utils.fmt.base import clean_list_iter, sys_open def collect(fl): vcb = set() for srcf in fl: - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/check/vsize/mono.py b/tools/check/vsize/mono.py index 04abd02..3dd0340 100644 --- a/tools/check/vsize/mono.py +++ b/tools/check/vsize/mono.py @@ -2,16 +2,14 @@ import sys -from utils.fmt.base import init_normal_token_id, clean_list_iter +from utils.fmt.base import clean_list_iter, sys_open def handle(srcfl): - global init_normal_token_id - vocab = set() for srcf in srcfl: - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: @@ -20,9 +18,8 @@ def handle(srcfl): vocab.add(token) nvcb = len(vocab) - nvcb += init_normal_token_id - print("The size of the vocabulary is: %d (with special tokens counted)" % (nvcb)) + print("The size of the vocabulary is: %d (special tokens discounted)" % (nvcb)) if __name__ == "__main__": handle(sys.argv[1:]) diff --git a/tools/check/vsize/vocab.py b/tools/check/vsize/vocab.py index f065b20..ec3c3cc 100644 --- a/tools/check/vsize/vocab.py +++ b/tools/check/vsize/vocab.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import ldvocab +from utils.fmt.vocab.token import ldvocab def handle(srcfl): diff --git a/tools/clean/ape/chars.py b/tools/clean/ape/chars.py index 164e356..6522079 100644 --- a/tools/clean/ape/chars.py +++ b/tools/clean/ape/chars.py @@ -9,6 +9,8 @@ # oratio: same as pratio but before bpe processing # num_rules_drop: choose from [1, 6], fewer data will be droped with larger value, none data would be droped if it was set to 6 +from utils.fmt.base import sys_open + def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, cratio=0.8, bratio=5.0, sratio=0.8, pratio=3.0, oratio=3.0, num_rules_drop=1): def legal_mono(strin, cratio, bratio, sratio): @@ -50,7 +52,7 @@ def ratio_bilingual(ls, lt): ens = "\n".encode("utf-8") - with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtfm, "wb") as fmw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcfm, "rb") as fm, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtfm, "wb") as fmw, sys_open(tgtft, "wb") as ftw: total = keep = 0 if num_rules_drop > 0: for ls, lm, lt in zip(fs, fm, ft): diff --git a/tools/clean/ape/maxkeeper.py b/tools/clean/ape/maxkeeper.py index 14520b1..efaf74d 100644 --- a/tools/clean/ape/maxkeeper.py +++ b/tools/clean/ape/maxkeeper.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_liststr_lentok +from utils.fmt.base import clean_liststr_lentok, sys_open def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, max_len=256): @@ -10,7 +10,7 @@ def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, max_len=256): data = {} - with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft: + with sys_open(srcfs, "rb") as fs, sys_open(srcfm, "rb") as fm, sys_open(srcft, "rb") as ft: for ls, lm, lt in zip(fs, fm, ft): ls, lm, lt = ls.strip(), lm.strip(), lt.strip() if ls and lt: @@ -49,7 +49,7 @@ def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, max_len=256): data = _clean ens = "\n".encode("utf-8") - with open(tgtfs, "wb") as fs, open(tgtfm, "wb") as fm, open(tgtft, "wb") as ft: + with sys_open(tgtfs, "wb") as fs, sys_open(tgtfm, "wb") as fm, sys_open(tgtft, "wb") as ft: for (lm, lt,), v in data.items(): if len(v) > 1: rls = [] diff --git a/tools/clean/ape/vocab.py b/tools/clean/ape/vocab.py index b1992d9..8cde8b3 100644 --- a/tools/clean/ape/vocab.py +++ b/tools/clean/ape/vocab.py @@ -2,14 +2,18 @@ import sys -from utils.fmt.base import ldvocab_list, legal_vocab +from utils.fmt.parser import parse_none +from utils.fmt.vocab.base import legal_vocab +from utils.fmt.vocab.token import ldvocab_list # vratio: percentages of vocabulary size of retrieved words of least frequencies # dratio: a datum will be dropped who contains high frequency words less than this ratio +from utils.fmt.base import sys_open + def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, vcbfs, vcbft, vratio, dratio=None): - _dratio = vratio if dratio is None else dratio + _dratio = parse_none(dratio, vratio) ens = "\n".encode("utf-8") @@ -18,7 +22,7 @@ def handle(srcfs, srcfm, srcft, tgtfs, tgtfm, tgtft, vcbfs, vcbft, vratio, drati ilgs = set(vcbs[int(float(nvs) * (1.0 - vratio)):]) ilgt = set(vcbt[int(float(nvt) * (1.0 - vratio)):]) - with open(srcfs, "rb") as fs, open(srcfm, "rb") as fm, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtfm, "wb") as fmw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcfm, "rb") as fm, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtfm, "wb") as fmw, sys_open(tgtft, "wb") as ftw: total = keep = 0 for ls, lm, lt in zip(fs, fm, ft): ls, lm, lt = ls.strip(), lm.strip(), lt.strip() diff --git a/tools/clean/chars.py b/tools/clean/chars.py index e724a66..1fa54ac 100644 --- a/tools/clean/chars.py +++ b/tools/clean/chars.py @@ -9,6 +9,8 @@ # oratio: same as pratio but before bpe processing # num_rules_drop: choose from [1, 6], fewer data will be droped with larger value, none data would be droped if it was set to 6 +from utils.fmt.base import sys_open + def handle(srcfs, srcft, tgtfs, tgtft, cratio=0.8, bratio=5.0, sratio=0.8, pratio=3.0, oratio=3.0, num_rules_drop=1): def legal_mono(strin, cratio, bratio, sratio): @@ -50,7 +52,7 @@ def ratio_bilingual(ls, lt): ens = "\n".encode("utf-8") - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw: total = keep = 0 if num_rules_drop > 0: for ls, lt in zip(fs, ft): diff --git a/tools/clean/cython.py b/tools/clean/cython.py index a219577..e78d166 100644 --- a/tools/clean/cython.py +++ b/tools/clean/cython.py @@ -1,7 +1,7 @@ #encoding: utf-8 import sys -from os import walk, remove +from os import remove, walk from os.path import join as pjoin def walk_path(ptws): diff --git a/tools/clean/dedup.py b/tools/clean/dedup.py index 99fc037..8acfb2d 100644 --- a/tools/clean/dedup.py +++ b/tools/clean/dedup.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_liststr_lentok, all_le, FileList +from utils.fmt.base import FileList, all_le, clean_liststr_lentok def handle(srcfl, tgtfl, max_len=256, drop_tail=False): diff --git a/tools/clean/doc/para/maxkeeper.py b/tools/clean/doc/para/maxkeeper.py index 9ca415d..d14a769 100644 --- a/tools/clean/doc/para/maxkeeper.py +++ b/tools/clean/doc/para/maxkeeper.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_liststr_lentok +from utils.fmt.base import clean_liststr_lentok, sys_open def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): @@ -11,7 +11,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): data = {} cache_s, cache_t = [], [] - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft: for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() if ls and lt: @@ -70,7 +70,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): data = _clean ens = "\n\n".encode("utf-8") - with open(tgtfs, "wb") as fs, open(tgtft, "wb") as ft: + with sys_open(tgtfs, "wb") as fs, sys_open(tgtft, "wb") as ft: for lt, v in data.items(): if len(v) > 1: rls = [] diff --git a/tools/clean/doc/para/vocab.py b/tools/clean/doc/para/vocab.py index 192cc13..ce94ec5 100644 --- a/tools/clean/doc/para/vocab.py +++ b/tools/clean/doc/para/vocab.py @@ -2,15 +2,18 @@ import sys -from utils.fmt.base import ldvocab_list from utils.fmt.doc.base import legal_vocab +from utils.fmt.parser import parse_none +from utils.fmt.vocab.token import ldvocab_list # vratio: percentages of vocabulary size of retrieved words of least frequencies # dratio: a datum will be dropped who contains high frequency words less than this ratio +from utils.fmt.base import sys_open + def handle(srcfs, srcft, tgtfs, tgtft, vcbfs, vcbft, vratio, dratio=None): - _dratio = vratio if dratio is None else dratio + _dratio = parse_none(dratio, vratio) ens = "\n\n".encode("utf-8") @@ -20,7 +23,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, vcbfs, vcbft, vratio, dratio=None): ilgt = set(vcbt[int(float(nvt) * (1.0 - vratio)):]) cache_s, cache_t = [], [] - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw: total = keep = 0 for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() diff --git a/tools/clean/gold.py b/tools/clean/gold.py index 038c5d4..59f61eb 100644 --- a/tools/clean/gold.py +++ b/tools/clean/gold.py @@ -2,13 +2,13 @@ import sys -from utils.fmt.base import clean_str +from utils.fmt.base import clean_str, sys_open def handle(srcfs, srcft, srcfg, tgtfs, tgtft, tgtfg): ens = "\n".encode("utf-8") - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(srcfg, "rb") as fg, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw, open(tgtfg, "wb") as fgw: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(srcfg, "rb") as fg, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw, sys_open(tgtfg, "wb") as fgw: total = keep = 0 for ls, lt, lg in zip(fs, ft, fg): ls, lt, lg = ls.strip(), lt.strip(), lg.strip() diff --git a/tools/clean/maxkeeper.py b/tools/clean/maxkeeper.py index 88df76e..0979aa3 100644 --- a/tools/clean/maxkeeper.py +++ b/tools/clean/maxkeeper.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_liststr_lentok +from utils.fmt.base import clean_liststr_lentok, sys_open def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): @@ -10,7 +10,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): data = {} - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft: for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() if ls and lt: @@ -48,7 +48,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, max_len=256): data = _clean ens = "\n".encode("utf-8") - with open(tgtfs, "wb") as fs, open(tgtft, "wb") as ft: + with sys_open(tgtfs, "wb") as fs, sys_open(tgtft, "wb") as ft: for lt, v in data.items(): if len(v) > 1: rls = [] diff --git a/tools/clean/normu8.py b/tools/clean/normu8.py new file mode 100644 index 0000000..30e8a7c --- /dev/null +++ b/tools/clean/normu8.py @@ -0,0 +1,19 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.u8 import norm_u8_byte, uni_normer + +def handle(srcf, rsf, uni_normer=uni_normer): + + ens="\n".encode("utf-8") + with sys_open(srcf, "rb") as frd, sys_open(rsf, "wb") as fwrt: + for line in frd: + tmp = line.strip() + if tmp: + fwrt.write(norm_u8_byte(tmp)) + fwrt.write(ens) + +if __name__ == "__main__": + handle(*sys.argv[1:]) diff --git a/tools/clean/rank.py b/tools/clean/rank.py index e02c2a4..26a845d 100644 --- a/tools/clean/rank.py +++ b/tools/clean/rank.py @@ -4,11 +4,11 @@ import sys -from utils.fmt.base import clean_str +from utils.fmt.base import clean_str, sys_open def handle(srcf, tgtf, rankf, rssf, rstf, threshold): - with open(srcf, "rb") as frs, open(tgtf, "rb") as frt, open(rankf, "rb") as fs, open(rssf, "wb") as fws, open(rstf, "wb") as fwt: + with sys_open(srcf, "rb") as frs, sys_open(tgtf, "rb") as frt, sys_open(rankf, "rb") as fs, sys_open(rssf, "wb") as fws, sys_open(rstf, "wb") as fwt: ndata = nkeep = 0 diff --git a/tools/clean/sampler/strict_sampler.py b/tools/clean/sampler/strict_sampler.py index 52d97ee..11e3486 100644 --- a/tools/clean/sampler/strict_sampler.py +++ b/tools/clean/sampler/strict_sampler.py @@ -4,9 +4,9 @@ # python tools/clean/sampler/strict_sampler.py srcf1 ... srcfn tgtf1 ... tgtfn keep_ratio import sys -from random import shuffle, seed as rpyseed +from random import seed as rpyseed, shuffle -from utils.fmt.base import FileList +from utils.fmt.base import FileList, sys_open def handle(srcfl, tgtfl, ratio): @@ -23,7 +23,7 @@ def handle(srcfl, tgtfl, ratio): ens = "\n".encode("utf-8") for data, tgtf in zip(rs, tgtfl): - with open(tgtf, "wb") as f: + with sys_open(tgtf, "wb") as f: # following 3 lines for memory #for line in data: #f.write(line) diff --git a/tools/clean/token_repeat.py b/tools/clean/token_repeat.py index 37f3d08..bcb63cd 100644 --- a/tools/clean/token_repeat.py +++ b/tools/clean/token_repeat.py @@ -2,7 +2,7 @@ import sys -from utils.fmt.base import clean_list, all_gt, FileList +from utils.fmt.base import FileList, all_gt, clean_list def handle(srcfl, tgtfl, r=0.4): diff --git a/tools/clean/tokens.py b/tools/clean/tokens.py index b663db7..22f04b2 100644 --- a/tools/clean/tokens.py +++ b/tools/clean/tokens.py @@ -2,13 +2,13 @@ import sys -from utils.fmt.base import clean_liststr_lentok +from utils.fmt.base import clean_liststr_lentok, sys_open def handle(srcfs, srcft, tgtfs, tgtft, maxlen=256): ens = "\n".encode("utf-8") - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw: total = keep = 0 for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() diff --git a/tools/clean/vocab/ratio.py b/tools/clean/vocab/ratio.py index 79826b7..d3e4d34 100644 --- a/tools/clean/vocab/ratio.py +++ b/tools/clean/vocab/ratio.py @@ -2,23 +2,26 @@ import sys -from utils.fmt.base import ldvocab_list, legal_vocab +from utils.fmt.parser import parse_none +from utils.fmt.vocab.base import legal_vocab +from utils.fmt.vocab.token import ldvocab_list # vratio: percentages of vocabulary size of retrieved words of least frequencies # dratio: a datum will be dropped who contains high frequency words less than this ratio -def handle(srcfs, srcft, tgtfs, tgtft, vcbfs, vcbft, vratio, dratio=None): +from utils.fmt.base import sys_open - _dratio = vratio if dratio is None else dratio +def handle(srcfs, srcft, tgtfs, tgtft, vcbfs, vcbft, vratio, dratio=None): - ens = "\n".encode("utf-8") + _dratio = parse_none(dratio, vratio) vcbs, nvs = ldvocab_list(vcbfs) vcbt, nvt = ldvocab_list(vcbft) ilgs = set(vcbs[int(float(nvs) * (1.0 - vratio)):]) ilgt = set(vcbt[int(float(nvt) * (1.0 - vratio)):]) - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw: + ens = "\n".encode("utf-8") + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw: total = keep = 0 for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() diff --git a/tools/clean/vocab/target.py b/tools/clean/vocab/target.py index 4468c3a..956508c 100644 --- a/tools/clean/vocab/target.py +++ b/tools/clean/vocab/target.py @@ -2,7 +2,8 @@ import sys -from utils.fmt.base import ldvocab_list, all_in +from utils.fmt.base import all_in, sys_open +from utils.fmt.vocab.token import ldvocab_list def handle(srcfs, srcft, tgtfs, tgtft, vcbft): @@ -11,7 +12,7 @@ def handle(srcfs, srcft, tgtfs, tgtft, vcbft): vcbt, nvt = ldvocab_list(vcbft) vcbt = set(vcbt) - with open(srcfs, "rb") as fs, open(srcft, "rb") as ft, open(tgtfs, "wb") as fsw, open(tgtft, "wb") as ftw: + with sys_open(srcfs, "rb") as fs, sys_open(srcft, "rb") as ft, sys_open(tgtfs, "wb") as fsw, sys_open(tgtft, "wb") as ftw: total = keep = 0 for ls, lt in zip(fs, ft): ls, lt = ls.strip(), lt.strip() diff --git a/tools/clean/vocab/unk.py b/tools/clean/vocab/unk.py index 82c5f75..bf587cc 100644 --- a/tools/clean/vocab/unk.py +++ b/tools/clean/vocab/unk.py @@ -2,11 +2,13 @@ import sys +from utils.fmt.base import sys_open + def handle(srcf, rsf): ens = "\n".encode("utf-8") - with open(srcf, "rb") as frd, open(rsf, "wb") as fwt: + with sys_open(srcf, "rb") as frd, sys_open(rsf, "wb") as fwt: for ls in frd: ls = ls.strip() if ls: diff --git a/tools/doc/para/mkiodata.py b/tools/doc/para/mkiodata.py index 5631281..7234f81 100644 --- a/tools/doc/para/mkiodata.py +++ b/tools/doc/para/mkiodata.py @@ -1,12 +1,12 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab, dict2pairs +from utils.fmt.base import dict2pairs from utils.fmt.doc.para.dual import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/doc/para/mktest.py b/tools/doc/para/mktest.py index 03b6d29..1a21317 100644 --- a/tools/doc/para/mktest.py +++ b/tools/doc/para/mktest.py @@ -1,12 +1,12 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab, dict2pairs +from utils.fmt.base import dict2pairs from utils.fmt.doc.para.single import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/doc/para/restore.py b/tools/doc/para/restore.py index 8537c7e..3fbbacc 100644 --- a/tools/doc/para/restore.py +++ b/tools/doc/para/restore.py @@ -2,20 +2,20 @@ import sys -from utils.fmt.base import clean_str +from utils.fmt.base import clean_str, sys_open def handle(srcfs, srtsf_base, srttf_base, srtsf, srttf, tgtf): data = {} - with open(srtsf_base, "rb") as fs, open(srttf_base, "rb") as ft: + with sys_open(srtsf_base, "rb") as fs, sys_open(srttf_base, "rb") as ft: for sl, tl in zip(fs, ft): _sl, _tl = sl.strip(), tl.strip() if _sl and _tl: _sl = clean_str(_sl.decode("utf-8")) _tl = clean_str(_tl.decode("utf-8")) data[_sl] = _tl - with open(srtsf, "rb") as fs, open(srttf, "rb") as ft: + with sys_open(srtsf, "rb") as fs, sys_open(srttf, "rb") as ft: for sl, tl in zip(fs, ft): _sl, _tl = sl.strip(), tl.strip() if _sl and _tl: @@ -25,7 +25,7 @@ def handle(srcfs, srtsf_base, srttf_base, srtsf, srttf, tgtf): ens = "\n".encode("utf-8") - with open(srcfs, "rb") as fs, open(tgtf, "wb") as ft: + with sys_open(srcfs, "rb") as fs, sys_open(tgtf, "wb") as ft: for line in fs: tmp = line.strip() if tmp: diff --git a/tools/doc/sort.py b/tools/doc/sort.py index 9992da6..4a5a78f 100644 --- a/tools/doc/sort.py +++ b/tools/doc/sort.py @@ -1,15 +1,15 @@ #encoding: utf-8 import sys -from random import seed as rpyseed +from random import seed as rpyseed, shuffle -from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import FileList, all_le, clean_liststr_lentok, dict_insert_list, dict_insert_set, iter_dict_sort, maxfreq_filter # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length # max_remove: if one source has several targets, only keep those with highest frequency -def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=False): +def handle(srcfl, tgtfl, max_len=256, remove_same=True, shuf=True, max_remove=False): _max_len = max(1, max_len - 2) @@ -51,13 +51,13 @@ def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=F ens = "\n\n".encode("utf-8") with FileList(tgtfl, "wb") as fl: for tmp in iter_dict_sort(data, free=True): - lines = zip(*tmp) + tmp = list(tmp) if len(tmp) > 1: if max_remove: - lines = maxfreq_filter(*lines) + tmp = maxfreq_filter(tmp) if shuf: - lines = shuffle_pair(*lines) - for du, f in zip(lines, fl): + shuffle(tmp) + for du, f in zip(zip(*tmp), fl): f.write(ens.join(du)) f.write(ens) diff --git a/tools/h5/compress.py b/tools/h5/compress.py index 1fe84af..06bbbd1 100644 --- a/tools/h5/compress.py +++ b/tools/h5/compress.py @@ -1,9 +1,9 @@ #encoding: utf-8 import sys - from h5py import Dataset -from utils.h5serial import h5save, h5load, h5File + +from utils.h5serial import h5File, h5load, h5save from cnfg.ihyp import * diff --git a/tools/h5/convert.py b/tools/h5/convert.py index 1270aad..69dc422 100644 --- a/tools/h5/convert.py +++ b/tools/h5/convert.py @@ -1,9 +1,9 @@ #encoding: utf-8 import sys - import torch -from utils.h5serial import h5save, h5load + +from utils.h5serial import h5load, h5save from cnfg.ihyp import * diff --git a/tools/lang/zh/cnfg b/tools/lang/zh/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/lang/zh/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/lang/zh/deseg.py b/tools/lang/zh/deseg.py new file mode 100644 index 0000000..d711a47 --- /dev/null +++ b/tools/lang/zh/deseg.py @@ -0,0 +1,19 @@ +#encoding: utf-8 + +# portal from: https://data.statmt.org/wmt18/translation-task/preprocessed/zh-en/deseg.py + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.lang.zh.deseg import deseg as map_func + +def handle(srcf, rsf): + + ens = "\n".encode("utf-8") + with sys_open(srcf, "rb") as frd, sys_open(rsf, "wb") as fwrt: + for _ in frd: + fwrt.write(map_func(_.decode("utf-8").rstrip("\r\n")).encode("utf-8")) + fwrt.write(ens) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2]) diff --git a/tools/lang/zh/utils b/tools/lang/zh/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/lang/zh/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/lsort/merge.py b/tools/lsort/merge.py index b84fd96..ca5e7e8 100644 --- a/tools/lsort/merge.py +++ b/tools/lsort/merge.py @@ -1,18 +1,17 @@ #encoding: utf-8 import sys - -from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList - -from random import seed as rpyseed from os import walk from os.path import join as pjoin +from random import seed as rpyseed, shuffle + +from utils.fmt.base import FileList, clean_liststr_lentok, maxfreq_filter # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length # max_remove: if one source has several targets, only keep those with highest frequency -def handle(cached, tgtfl, remove_same=False, shuf=True, max_remove=True): +def handle(cached, tgtfl, remove_same=True, shuf=True, max_remove=False): def paral_reader(srcfl): @@ -35,11 +34,8 @@ def open_files(cache_dir, num_files): if curfid not in opened: pg = paral_reader([pjoin(cache_dir, "%d.%s.txt" % (i, curfid,)) for i in range(num_files)]) opened.add(curfid) - try: - prd = next(pg) - except StopIteration: - prd = None - if prd: + prd = next(pg, None) + if prd is not None: rs.append(pg) query.append((prd[0], prd[1:],)) @@ -55,24 +51,24 @@ def update_query(fl, query): min_len = lens rs = (du, lens,) rid = ind - try: - _next_v = next(fl[rid]) - query[rid] = (_next_v[0], _next_v[1:],) - except StopIteration: + _next_v = next(fl[rid], None) + if _next_v is None: del query[rid] del fl[rid] + else: + query[rid] = (_next_v[0], _next_v[1:],) return rs, query, fl def write_data(data, wfl, ens, shuf=True, max_remove=False): - lines = zip(*data) - if len(data) > 1: + lines = list(data) + if len(lines) > 1: if max_remove: - lines = maxfreq_filter(*lines) + lines = maxfreq_filter(lines) if shuf: - lines = shuffle_pair(*lines) - for du, f in zip(lines, wfl): + shuffle(lines) + for du, f in zip(zip(*lines), wfl): f.write(ens.join(du)) f.write(ens) diff --git a/tools/lsort/partsort.py b/tools/lsort/partsort.py index 13a62e7..6dc12b7 100644 --- a/tools/lsort/partsort.py +++ b/tools/lsort/partsort.py @@ -3,9 +3,9 @@ import sys from os.path import join as pjoin -from utils.fmt.base import clean_liststr_lentok, all_le, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import FileList, all_le, clean_liststr_lentok, dict_insert_list, dict_insert_set, iter_dict_sort -def handle(srcfl, tgtd, max_len=256, remove_same=False, cache_token=500000000): +def handle(srcfl, tgtd, max_len=256, remove_same=True, cache_token=500000000): def save_cache(cache, tgtfl): diff --git a/tools/mkiodata.py b/tools/mkiodata.py index a24a424..40c281a 100644 --- a/tools/mkiodata.py +++ b/tools/mkiodata.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab from utils.fmt.dual import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/mktest.py b/tools/mktest.py index bde3540..376097b 100644 --- a/tools/mktest.py +++ b/tools/mktest.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.fmt.base import ldvocab -from utils.h5serial import h5File from utils.fmt.single import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/mulang/eff/mkiodata.py b/tools/mulang/eff/mkiodata.py index dc1afaa..df07ddd 100644 --- a/tools/mulang/eff/mkiodata.py +++ b/tools/mulang/eff/mkiodata.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab from utils.fmt.mulang.eff.dual import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/mulang/eff/mktest.py b/tools/mulang/eff/mktest.py index e08b6a7..95d32f3 100644 --- a/tools/mulang/eff/mktest.py +++ b/tools/mulang/eff/mktest.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File -from utils.fmt.base import ldvocab from utils.fmt.mulang.eff.single import batch_padder +from utils.fmt.vocab.token import ldvocab +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/mulang/eff/sort.py b/tools/mulang/eff/sort.py index 06c7033..b19c8e1 100644 --- a/tools/mulang/eff/sort.py +++ b/tools/mulang/eff/sort.py @@ -1,15 +1,15 @@ #encoding: utf-8 import sys -from random import seed as rpyseed +from random import seed as rpyseed, shuffle -from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import FileList, all_le, clean_liststr_lentok, dict_insert_list, dict_insert_set, iter_dict_sort, maxfreq_filter # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length # max_remove: if one source has several targets, only keep those with highest frequency -def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=False): +def handle(srcfl, tgtfl, max_len=256, remove_same=True, shuf=True, max_remove=False): _max_len = max(1, max_len - 2) @@ -30,13 +30,13 @@ def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=F with FileList(tgtfl, "wb") as fl: for tmp in iter_dict_sort(data, free=True): - lines = zip(*tmp) + tmp = list(tmp) if len(tmp) > 1: if max_remove: - lines = maxfreq_filter(*lines) + tmp = maxfreq_filter(tmp) if shuf: - lines = shuffle_pair(*lines) - for du, f in zip(lines, fl): + shuffle(tmp) + for du, f in zip(zip(*tmp), fl): f.write(ens.join(du)) f.write(ens) diff --git a/tools/mulang/vocab/char/cnfg b/tools/mulang/vocab/char/cnfg new file mode 120000 index 0000000..2f54778 --- /dev/null +++ b/tools/mulang/vocab/char/cnfg @@ -0,0 +1 @@ +../../../../cnfg/ \ No newline at end of file diff --git a/tools/mulang/vocab/char/share.py b/tools/mulang/vocab/char/share.py new file mode 100644 index 0000000..0d1434e --- /dev/null +++ b/tools/mulang/vocab/char/share.py @@ -0,0 +1,41 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.vocab.char import save_vocab + +def handle(srcfl, rsf, rslangf, vsize=65532): + + vocab = {} + lang_vocab = {} + + curid = 0 + for srcf in srcfl: + if srcf == "--target": + break + with sys_open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + tokens = tmp.decode("utf-8") + _ = tokens.find(" ") + for token in tokens[_ + 1:]: + vocab[token] = vocab.get(token, 0) + 1 + token = tokens[:_] + lang_vocab[token] = lang_vocab.get(token, 0) + 1 + curid += 1 + + for srcf in srcfl[curid+1:]: + with sys_open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + for token in tmp.decode("utf-8"): + vocab[token] = vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + save_vocab(lang_vocab, rslangf, omit_vsize=False) + +if __name__ == "__main__": + handle(sys.argv[1:-3], sys.argv[-3], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/mulang/vocab/char/single.py b/tools/mulang/vocab/char/single.py new file mode 100644 index 0000000..e4527ec --- /dev/null +++ b/tools/mulang/vocab/char/single.py @@ -0,0 +1,28 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.vocab.char import save_vocab + +def handle(srcf, rsf, rslangf, vsize=65532): + + vocab = {} + lang_vocab = {} + + with sys_open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + tokens = tmp.decode("utf-8") + _ = tokens.find(" ") + for token in tokens[_ + 1:]: + vocab[token] = vocab.get(token, 0) + 1 + token = tokens[:_] + lang_vocab[token] = lang_vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + save_vocab(lang_vocab, rslangf, omit_vsize=False) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3]) if len(sys.argv) == 4 else handle(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[-1])) diff --git a/tools/mulang/vocab/char/utils b/tools/mulang/vocab/char/utils new file mode 120000 index 0000000..c2519a9 --- /dev/null +++ b/tools/mulang/vocab/char/utils @@ -0,0 +1 @@ +../../../../utils/ \ No newline at end of file diff --git a/tools/mulang/vocab/token/cnfg b/tools/mulang/vocab/token/cnfg new file mode 120000 index 0000000..2f54778 --- /dev/null +++ b/tools/mulang/vocab/token/cnfg @@ -0,0 +1 @@ +../../../../cnfg/ \ No newline at end of file diff --git a/tools/mulang/share_vocab.py b/tools/mulang/vocab/token/share.py similarity index 82% rename from tools/mulang/share_vocab.py rename to tools/mulang/vocab/token/share.py index 8499754..f31c93e 100644 --- a/tools/mulang/share_vocab.py +++ b/tools/mulang/vocab/token/share.py @@ -2,7 +2,8 @@ import sys -from utils.fmt.base import clean_list, clean_list_iter, save_vocab +from utils.fmt.base import clean_list, clean_list_iter, sys_open +from utils.fmt.vocab.token import save_vocab def handle(srcfl, rsf, rslangf, vsize=65532): @@ -13,7 +14,7 @@ def handle(srcfl, rsf, rslangf, vsize=65532): for srcf in srcfl: if srcf == "--target": break - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: @@ -25,7 +26,7 @@ def handle(srcfl, rsf, rslangf, vsize=65532): curid += 1 for srcf in srcfl[curid+1:]: - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/mulang/vocab.py b/tools/mulang/vocab/token/single.py similarity index 83% rename from tools/mulang/vocab.py rename to tools/mulang/vocab/token/single.py index 2d1425c..5cf145a 100644 --- a/tools/mulang/vocab.py +++ b/tools/mulang/vocab/token/single.py @@ -2,14 +2,15 @@ import sys -from utils.fmt.base import clean_list, save_vocab +from utils.fmt.base import clean_list, sys_open +from utils.fmt.vocab.token import save_vocab def handle(srcf, rsf, rslangf, vsize=65532): vocab = {} lang_vocab = {} - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/mulang/vocab/token/utils b/tools/mulang/vocab/token/utils new file mode 120000 index 0000000..c2519a9 --- /dev/null +++ b/tools/mulang/vocab/token/utils @@ -0,0 +1 @@ +../../../../utils/ \ No newline at end of file diff --git a/tools/plm/map/bart.py b/tools/plm/map/bart.py index ab74225..e5fa953 100644 --- a/tools/plm/map/bart.py +++ b/tools/plm/map/bart.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/map/bert.py b/tools/plm/map/bert.py index 9debca6..07bbfd9 100644 --- a/tools/plm/map/bert.py +++ b/tools/plm/map/bert.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/map/mbart.py b/tools/plm/map/mbart.py new file mode 100644 index 0000000..722007f --- /dev/null +++ b/tools/plm/map/mbart.py @@ -0,0 +1,13 @@ +#encoding: utf-8 + +import sys +from transformers import MBartTokenizerFast as Tokenizer + +from utils.fmt.plm.token import map_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb, src_lang=lang)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/map/mbart50.py b/tools/plm/map/mbart50.py new file mode 100644 index 0000000..e66cf4a --- /dev/null +++ b/tools/plm/map/mbart50.py @@ -0,0 +1,15 @@ +#encoding: utf-8 + +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python is required to load spm files with protobuf when tokenizer.json is not available. + +import sys +from transformers import MBart50TokenizerFast as Tokenizer + +from utils.fmt.plm.token import map_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer.from_pretrained(vcb, src_lang=lang)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/map/roberta.py b/tools/plm/map/roberta.py index 0f09dc6..39cc6f3 100644 --- a/tools/plm/map/roberta.py +++ b/tools/plm/map/roberta.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/map/t5.py b/tools/plm/map/t5.py index 2217463..07b8401 100644 --- a/tools/plm/map/t5.py +++ b/tools/plm/map/t5.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/mback/bart.py b/tools/plm/mback/bart.py index 8bb58ab..ba6f930 100644 --- a/tools/plm/mback/bart.py +++ b/tools/plm/mback/bart.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_back_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb).decode) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/mback/bert.py b/tools/plm/mback/bert.py index 6048eb0..eda5b95 100644 --- a/tools/plm/mback/bert.py +++ b/tools/plm/mback/bert.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_back_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb).decode) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/mback/mbart.py b/tools/plm/mback/mbart.py new file mode 100644 index 0000000..148f743 --- /dev/null +++ b/tools/plm/mback/mbart.py @@ -0,0 +1,13 @@ +#encoding: utf-8 + +import sys +from transformers import MBartTokenizerFast as Tokenizer + +from utils.fmt.plm.token import map_back_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb, src_lang=lang).decode) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/mback/mbart50.py b/tools/plm/mback/mbart50.py new file mode 100644 index 0000000..a2e427a --- /dev/null +++ b/tools/plm/mback/mbart50.py @@ -0,0 +1,15 @@ +#encoding: utf-8 + +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python is required to load spm files with protobuf when tokenizer.json is not available. + +import sys +from transformers import MBart50TokenizerFast as Tokenizer + +from utils.fmt.plm.token import map_back_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer.from_pretrained(vcb, src_lang=lang).decode) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/mback/roberta.py b/tools/plm/mback/roberta.py index d9a8dd0..49eba8d 100644 --- a/tools/plm/mback/roberta.py +++ b/tools/plm/mback/roberta.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_back_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb).decode) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/mback/t5.py b/tools/plm/mback/t5.py index eb4a3c5..ee4ac13 100644 --- a/tools/plm/mback/t5.py +++ b/tools/plm/mback/t5.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_back_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb).decode) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/mkiodata.py b/tools/plm/mkiodata.py index 78376b6..8c05d88 100644 --- a/tools/plm/mkiodata.py +++ b/tools/plm/mkiodata.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File # import batch_padder of the corresponding model for different padding indices. from utils.fmt.plm.roberta.dual import batch_padder +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/plm/mkiodata_reg.py b/tools/plm/mkiodata_reg.py new file mode 100644 index 0000000..5aabaf9 --- /dev/null +++ b/tools/plm/mkiodata_reg.py @@ -0,0 +1,36 @@ +#encoding: utf-8 + +import sys +from numpy import array as np_array, float32 as np_float32, int32 as np_int32 + +# import batch_padder of the corresponding model for different padding indices. + +from utils.fmt.plm.roberta.dual_reg import batch_padder +from utils.h5serial import h5File + +from cnfg.ihyp import * + +def handle(finput, ftarget, frs, minbsize=1, expand_for_mulgpu=True, bsize=max_sentences_gpu, maxpad=max_pad_tokens_sentence, maxpart=normal_tokens_vs_pad_tokens, maxtoken=max_tokens_gpu, minfreq=False, vsize=False): + + if expand_for_mulgpu: + _bsize = bsize * minbsize + _maxtoken = maxtoken * minbsize + else: + _bsize = bsize + _maxtoken = maxtoken + with h5File(frs, "w", libver=h5_libver) as rsf: + src_grp = rsf.create_group("src") + tgt_grp = rsf.create_group("tgt") + curd = 0 + for i_d, td in batch_padder(finput, ftarget, _bsize, maxpad, maxpart, _maxtoken, minbsize): + rid = np_array(i_d, dtype=np_int32) + rtd = np_array(td, dtype=np_float32) + wid = str(curd) + src_grp.create_dataset(wid, data=rid, **h5datawargs) + tgt_grp.create_dataset(wid, data=rtd, **h5datawargs) + curd += 1 + rsf["ndata"] = np_array([curd], dtype=np_int32) + print("Number of batches: %d" % curd) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], int(sys.argv[4])) diff --git a/tools/plm/mktest.py b/tools/plm/mktest.py index ed7da75..e2a1e73 100644 --- a/tools/plm/mktest.py +++ b/tools/plm/mktest.py @@ -1,12 +1,11 @@ #encoding: utf-8 import sys - from numpy import array as np_array, int32 as np_int32 -from utils.h5serial import h5File # import batch_padder of the corresponding model for different padding indices. from utils.fmt.plm.roberta.single import batch_padder +from utils.h5serial import h5File from cnfg.ihyp import * diff --git a/tools/plm/mtyp/bert.py b/tools/plm/mtyp/bert.py index 02ea652..41bc563 100644 --- a/tools/plm/mtyp/bert.py +++ b/tools/plm/mtyp/bert.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file_with_token_type as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frsi, frst): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frsi, frst, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/mtyp/roberta.py b/tools/plm/mtyp/roberta.py index e9e074b..00b2738 100644 --- a/tools/plm/mtyp/roberta.py +++ b/tools/plm/mtyp/roberta.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import map_file_with_token_type as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frsi, frst): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frsi, frst, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/token/bart.py b/tools/plm/token/bart.py index dad07b0..1b75ada 100644 --- a/tools/plm/token/bart.py +++ b/tools/plm/token/bart.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import tokenize_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/token/bert.py b/tools/plm/token/bert.py index 5c21a28..e44006f 100644 --- a/tools/plm/token/bert.py +++ b/tools/plm/token/bert.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import tokenize_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/token/cnfg b/tools/plm/token/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/plm/token/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/plm/token/mbart.py b/tools/plm/token/mbart.py new file mode 100644 index 0000000..8f4ac4b --- /dev/null +++ b/tools/plm/token/mbart.py @@ -0,0 +1,13 @@ +#encoding: utf-8 + +import sys +from transformers import MBartTokenizerFast as Tokenizer + +from utils.fmt.plm.token import tokenize_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb, src_lang=lang)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/token/mbart50.py b/tools/plm/token/mbart50.py new file mode 100644 index 0000000..01ee99b --- /dev/null +++ b/tools/plm/token/mbart50.py @@ -0,0 +1,15 @@ +#encoding: utf-8 + +# PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python is required to load spm files with protobuf when tokenizer.json is not available. + +import sys +from transformers import MBart50TokenizerFast as Tokenizer + +from utils.fmt.plm.token import tokenize_file as map_func + +def handle(fsrc, vcb, frs, lang): + + return map_func(fsrc, frs, processor=Tokenizer.from_pretrained(vcb, src_lang=lang)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4]) diff --git a/tools/plm/token/roberta.py b/tools/plm/token/roberta.py index a8d005e..e9de7cb 100644 --- a/tools/plm/token/roberta.py +++ b/tools/plm/token/roberta.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import tokenize_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/token/t5.py b/tools/plm/token/t5.py index 29e762b..087f459 100644 --- a/tools/plm/token/t5.py +++ b/tools/plm/token/t5.py @@ -5,9 +5,9 @@ from utils.fmt.plm.token import tokenize_file as map_func -def handle(*inputs, **kwargs): +def handle(fsrc, vcb, frs): - return map_func(*inputs, **kwargs, Tokenizer=Tokenizer) + return map_func(fsrc, frs, processor=Tokenizer(tokenizer_file=vcb)) if __name__ == "__main__": handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/plm/token/utils b/tools/plm/token/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/plm/token/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/prune_model_vocab.py b/tools/prune_model_vocab.py index b9cdd97..15487b5 100644 --- a/tools/prune_model_vocab.py +++ b/tools/prune_model_vocab.py @@ -1,40 +1,43 @@ #encoding: utf-8 """ this file aims at pruning source/target vocabulary of the trained model using a shared vocabulary. It depends on the model implementation, and has to be executed at the root path of the project. Usage: - python prune_model_vocab.py path/to/common.vcb path/to/src.vcb path/to/tgt.vcb path/to/model.h5 path/to/pruned_model.h5 + python prune_model_vocab.py path/to/common.vcb path/to/common.vcb path/to/src.vcb path/to/tgt.vcb path/to/model.h5 path/to/pruned_model.h5 """ import sys - import torch -from utils.base import load_model_cpu, save_model -from utils.fmt.base import ldvocab, reverse_dict + from transformer.NMT import NMT +from utils.fmt.vocab.base import reverse_dict +from utils.fmt.vocab.token import ldvocab +from utils.io import load_model_cpu, save_model import cnfg.base as cnfg from cnfg.ihyp import * +from cnfg.vocab.base import pad_id -def handle(common, src, tgt, srcm, rsm, minfreq=False, vsize=False): +def handle(vsrc, vtgt, src, tgt, srcm, rsm, minfreq=False, vsize=False): - vcbc, nwordf = ldvocab(common, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbi, nwordi = ldvocab(vsrc, minf=minfreq, omit_vsize=vsize, vanilla=False) + vcbt, nwordt = ldvocab(vtgt, minf=minfreq, omit_vsize=vsize, vanilla=False) - if src == common: + if src == vsrc: src_indices = None else: vcbw, nword = ldvocab(src, minf=minfreq, omit_vsize=vsize, vanilla=False) vcbw = reverse_dict(vcbw) - src_indices = torch.tensor([vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long) - if tgt == common: + src_indices = torch.as_tensor([vcbi.get(vcbw[i], pad_id) for i in range(nword)], dtype=torch.long) + if tgt == vtgt: tgt_indices = None else: vcbw, nword = ldvocab(tgt, minf=minfreq, omit_vsize=vsize, vanilla=False) vcbw = reverse_dict(vcbw) - tgt_indices = torch.tensor([vcbc.get(vcbw[i], 0) for i in range(nword)], dtype=torch.long) + tgt_indices = torch.as_tensor([vcbt.get(vcbw[i], pad_id) for i in range(nword)], dtype=torch.long) - mymodel = NMT(cnfg.isize, nwordf, nwordf, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) + mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) mymodel = load_model_cpu(srcm, mymodel) mymodel.update_vocab(src_indices=src_indices, tgt_indices=tgt_indices) save_model(mymodel, rsm, sub_module=False, h5args=h5zipargs) if __name__ == "__main__": - handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5]) + handle(sys.argv[1], sys.argv[2], sys.argv[3], sys.argv[4], sys.argv[5], sys.argv[6]) diff --git a/tools/restore.py b/tools/restore.py index 3152fc8..4056e0a 100644 --- a/tools/restore.py +++ b/tools/restore.py @@ -3,7 +3,8 @@ import sys # WARNING: all() might be too strict in some cases which may use any() -from utils.fmt.base import clean_str, FileList + +from utils.fmt.base import FileList, clean_str, sys_open # srtfl: (k - 1) source + 1 target def handle(srcfl, srtfl, tgtf): @@ -18,7 +19,7 @@ def handle(srcfl, srtfl, tgtf): data[lines[:-1]] = lines[-1].encode("utf-8") ens = "\n".encode("utf-8") - with FileList(srcfl, "rb") as fs, open(tgtf, "wb") as ft: + with FileList(srcfl, "rb") as fs, sys_open(tgtf, "wb") as ft: for lines in zip(*fs): lines = tuple(line.strip() for line in lines) if all(lines): diff --git a/tools/shuffle.py b/tools/shuffle.py index c6b78ad..c140399 100644 --- a/tools/shuffle.py +++ b/tools/shuffle.py @@ -1,10 +1,9 @@ #encoding: utf-8 import sys - from random import seed as rpyseed, shuffle -from utils.fmt.base import clean_str, FileList +from utils.fmt.base import FileList, clean_str, sys_open def handle(srcfl, rsfl): @@ -17,7 +16,7 @@ def handle(srcfl, rsfl): ens = "\n".encode("utf-8") for du, rsf in zip(zip(*data), rsfl): - with open(rsf, "wb") as fwrt: + with sys_open(rsf, "wb") as fwrt: fwrt.write("\n".join(du).encode("utf-8")) fwrt.write(ens) diff --git a/tools/sort.py b/tools/sort.py index 7e70909..68e4a4b 100644 --- a/tools/sort.py +++ b/tools/sort.py @@ -1,15 +1,15 @@ #encoding: utf-8 import sys -from random import seed as rpyseed +from random import seed as rpyseed, shuffle -from utils.fmt.base import clean_liststr_lentok, all_le, maxfreq_filter, shuffle_pair, iter_dict_sort, dict_insert_list, dict_insert_set, FileList +from utils.fmt.base import FileList, all_le, clean_liststr_lentok, dict_insert_list, dict_insert_set, iter_dict_sort, maxfreq_filter # remove_same: reduce same data in the corpus # shuf: shuffle the data of same source/target length # max_remove: if one source has several targets, only keep those with highest frequency -def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=False): +def handle(srcfl, tgtfl, max_len=256, remove_same=True, shuf=True, max_remove=False): _max_len = max(1, max_len - 2) @@ -29,13 +29,13 @@ def handle(srcfl, tgtfl, max_len=256, remove_same=False, shuf=True, max_remove=F with FileList(tgtfl, "wb") as fl: for tmp in iter_dict_sort(data, free=True): - lines = zip(*tmp) + tmp = list(tmp) if len(tmp) > 1: if max_remove: - lines = maxfreq_filter(*lines) + tmp = maxfreq_filter(tmp) if shuf: - lines = shuffle_pair(*lines) - for du, f in zip(lines, fl): + shuffle(tmp) + for du, f in zip(zip(*tmp), fl): f.write(ens.join(du)) f.write(ens) diff --git a/tools/spm/encode.py b/tools/spm/encode.py index d2c1a2c..f739c75 100644 --- a/tools/spm/encode.py +++ b/tools/spm/encode.py @@ -3,8 +3,8 @@ # portal from fairseq: https://github.com/pytorch/fairseq/blob/master/scripts/spm_encode.py import sys -from contextlib import ExitStack from argparse import ArgumentParser +from contextlib import ExitStack from sentencepiece import SentencePieceProcessor def main(): diff --git a/tools/spm/train.py b/tools/spm/train.py index 0b6fb43..b352073 100644 --- a/tools/spm/train.py +++ b/tools/spm/train.py @@ -3,6 +3,7 @@ # portal from fairseq: https://github.com/pytorch/fairseq/blob/master/scripts/spm_train.py import sys + from sentencepiece SentencePieceTrainer if __name__ == "__main__": diff --git a/tools/vocab/char/cnfg b/tools/vocab/char/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/vocab/char/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/vocab/char/filter.py b/tools/vocab/char/filter.py new file mode 100644 index 0000000..2ba6d9c --- /dev/null +++ b/tools/vocab/char/filter.py @@ -0,0 +1,14 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.lang.zh.t2s import vcb_filter_func as filter_func +from utils.fmt.parser import parse_none +from utils.fmt.vocab.char import ldvocab_freq, save_vocab + +def handle(srcf, rsf, vsize=65532, omit_vsize=None): + + save_vocab(filter_func(ldvocab_freq(srcf, omit_vsize=vsize)[0]), rsf, omit_vsize=parse_none(omit_vsize, vsize)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], int(sys.argv[3]), int(sys.argv[4])) diff --git a/tools/vocab/char/merge.py b/tools/vocab/char/merge.py new file mode 100644 index 0000000..1e457f5 --- /dev/null +++ b/tools/vocab/char/merge.py @@ -0,0 +1,13 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.vocab.base import merge_vocab +from utils.fmt.vocab.char import ldvocab_freq, save_vocab + +def handle(srcfl, rsf, vsize=65532): + + save_vocab(merge_vocab(*[ldvocab_freq(_)[0] for _ in srcfl]), rsf, omit_vsize=vsize) + +if __name__ == "__main__": + handle(sys.argv[1:-2], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/vocab/char/share.py b/tools/vocab/char/share.py new file mode 100644 index 0000000..b1d541a --- /dev/null +++ b/tools/vocab/char/share.py @@ -0,0 +1,23 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.vocab.char import save_vocab + +def handle(srcfl, rsf, vsize=65532): + + vocab = {} + + for srcf in srcfl: + with sys_open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + for token in tmp.decode("utf-8"): + vocab[token] = vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + +if __name__ == "__main__": + handle(sys.argv[1:-2], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/vocab/char/single.py b/tools/vocab/char/single.py new file mode 100644 index 0000000..4077783 --- /dev/null +++ b/tools/vocab/char/single.py @@ -0,0 +1,22 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import sys_open +from utils.fmt.vocab.char import save_vocab + +def handle(srcf, rsf, vsize=65532): + + vocab = {} + + with sys_open(srcf, "rb") as f: + for line in f: + tmp = line.strip() + if tmp: + for token in tmp.decode("utf-8"): + vocab[token] = vocab.get(token, 0) + 1 + + save_vocab(vocab, rsf, omit_vsize=vsize) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2]) if len(sys.argv) == 3 else handle(sys.argv[1], sys.argv[2], int(sys.argv[-1])) diff --git a/tools/vocab/char/utils b/tools/vocab/char/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/vocab/char/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/vocab/cnfg b/tools/vocab/cnfg new file mode 120000 index 0000000..bcd9a88 --- /dev/null +++ b/tools/vocab/cnfg @@ -0,0 +1 @@ +../../cnfg/ \ No newline at end of file diff --git a/tools/vocab/map.py b/tools/vocab/map.py new file mode 100644 index 0000000..c149f3e --- /dev/null +++ b/tools/vocab/map.py @@ -0,0 +1,19 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import clean_list, iter_to_str, loop_file_so +from utils.fmt.vocab.base import map_instance, no_unk_mapper +from utils.fmt.vocab.token import ldvocab + +from cnfg.vocab.base import eos_id, init_normal_token_id, init_vocab, sos_id, unk_id, use_unk + +def handle(srcf, vcbf, rsf, add_sp_tokens=True, minfreq=False, vsize=False): + + _vcb = ldvocab(vcbf, minf=minfreq, omit_vsize=vsize, vanilla=False, init_vocab=init_vocab, init_normal_token_id=init_normal_token_id)[0] + map_line = (lambda lin, vcb: " ".join(iter_to_str(map_instance(clean_list(lin.split()), vcb, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id)))) if add_sp_tokens else ((lambda lin, vcb: " ".join(iter_to_str(vcb.get(wd, unk_id) for wd in clean_list(lin.split())))) if use_unk else (lambda lin, vcb: " ".join(iter_to_str(no_unk_mapper(vcb, clean_list(lin.split())))))) + + return loop_file_so(srcf, rsf, process_func=map_line, processor=_vcb) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], sys.argv[3]) diff --git a/tools/vocab/token/cnfg b/tools/vocab/token/cnfg new file mode 120000 index 0000000..a958f86 --- /dev/null +++ b/tools/vocab/token/cnfg @@ -0,0 +1 @@ +../../../cnfg/ \ No newline at end of file diff --git a/tools/vocab/token/filter.py b/tools/vocab/token/filter.py new file mode 100644 index 0000000..b4c2ad9 --- /dev/null +++ b/tools/vocab/token/filter.py @@ -0,0 +1,14 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.lang.zh.t2s import vcb_filter_func as filter_func +from utils.fmt.parser import parse_none +from utils.fmt.vocab.token import ldvocab_freq, save_vocab + +def handle(srcf, rsf, vsize=65532, omit_vsize=None): + + save_vocab(filter_func(ldvocab_freq(srcf, omit_vsize=vsize)[0]), rsf, omit_vsize=parse_none(omit_vsize, vsize)) + +if __name__ == "__main__": + handle(sys.argv[1], sys.argv[2], int(sys.argv[3]), int(sys.argv[4])) diff --git a/tools/vocab/token/merge.py b/tools/vocab/token/merge.py new file mode 100644 index 0000000..240ea61 --- /dev/null +++ b/tools/vocab/token/merge.py @@ -0,0 +1,13 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.vocab.base import merge_vocab +from utils.fmt.vocab.token import ldvocab_freq, save_vocab + +def handle(srcfl, rsf, vsize=65532): + + save_vocab(merge_vocab(*[ldvocab_freq(_)[0] for _ in srcfl]), rsf, omit_vsize=vsize) + +if __name__ == "__main__": + handle(sys.argv[1:-2], sys.argv[-2], int(sys.argv[-1])) diff --git a/tools/share_vocab.py b/tools/vocab/token/share.py similarity index 74% rename from tools/share_vocab.py rename to tools/vocab/token/share.py index 047e649..52ecdc1 100644 --- a/tools/share_vocab.py +++ b/tools/vocab/token/share.py @@ -2,14 +2,15 @@ import sys -from utils.fmt.base import clean_list_iter, save_vocab +from utils.fmt.base import clean_list_iter, sys_open +from utils.fmt.vocab.token import save_vocab def handle(srcfl, rsf, vsize=65532): vocab = {} for srcf in srcfl: - with open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/vocab.py b/tools/vocab/token/single.py similarity index 76% rename from tools/vocab.py rename to tools/vocab/token/single.py index 179fbf0..8c8bf0b 100644 --- a/tools/vocab.py +++ b/tools/vocab/token/single.py @@ -2,13 +2,14 @@ import sys -from utils.fmt.base import clean_list_iter, save_vocab +from utils.fmt.base import clean_list_iter, sys_open +from utils.fmt.vocab.token import save_vocab def handle(srcf, rsf, vsize=65532): vocab = {} - with sys.stdin.buffer if srcf == "-" else open(srcf, "rb") as f: + with sys_open(srcf, "rb") as f: for line in f: tmp = line.strip() if tmp: diff --git a/tools/vocab/token/utils b/tools/vocab/token/utils new file mode 120000 index 0000000..256f914 --- /dev/null +++ b/tools/vocab/token/utils @@ -0,0 +1 @@ +../../../utils/ \ No newline at end of file diff --git a/tools/vocab/utils b/tools/vocab/utils new file mode 120000 index 0000000..7d6b64a --- /dev/null +++ b/tools/vocab/utils @@ -0,0 +1 @@ +../../utils/ \ No newline at end of file diff --git a/train.py b/train.py index b98ce51..407ee78 100644 --- a/train.py +++ b/train.py @@ -1,42 +1,41 @@ #encoding: utf-8 import torch +from random import shuffle #from torch import nn - from torch.optim import Adam as Optimizer +from loss.base import LabelSmoothingLoss +from lrsch import GoogleLR as LRScheduler from parallel.base import DataParallelCriterion -from parallel.parallelMT import DataParallelMT from parallel.optm import MultiGPUGradScaler - -from utils.base import * -from utils.init.base import init_model_params +from parallel.parallelMT import DataParallelMT +from transformer.NMT import NMT +from utils.base import free_cache, get_logger, mkdir, set_random_seed from utils.contpara import get_model_parameters +from utils.fmt.base import iter_to_str +from utils.fmt.base4torch import load_emb, parse_cuda +from utils.h5serial import h5File +from utils.init.base import init_model_params +from utils.io import load_model_cpu, save_model, save_states from utils.state.holder import Holder from utils.state.pyrand import PyRandomState from utils.state.thrand import THRandomState -from utils.fmt.base import tostr, pad_id -from utils.fmt.base4torch import parse_cuda, load_emb - -from lrsch import GoogleLR as LRScheduler -from loss.base import LabelSmoothingLoss - -from random import shuffle - +from utils.torch.comp import torch_autocast, torch_compile, torch_inference_mode from utils.tqdm import tqdm - -from utils.h5serial import h5File +from utils.train.base import getlr, optm_step, optm_step_zero_grad_set_none, reset_Adam +from utils.train.dss import dynamic_sample import cnfg.base as cnfg from cnfg.ihyp import * - -from transformer.NMT import NMT +from cnfg.vocab.base import pad_id def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tokens, multi_gpu, multi_gpu_optimizer, tokens_optm=32768, nreport=None, save_every=None, chkpf=None, state_holder=None, statesf=None, num_checkpoint=1, cur_checkid=0, report_eva=True, remain_steps=None, save_loss=False, save_checkp_epoch=False, scaler=None): sum_loss = part_loss = 0.0 sum_wd = part_wd = 0 _done_tokens, _cur_checkid, _cur_rstep, _use_amp = done_tokens, cur_checkid, remain_steps, scaler is not None + global minerr, minloss, wkdir, save_auto_clean, namin model.train() cur_b, _ls = 1, {} if save_loss else None src_grp, tgt_grp = td["src"], td["tgt"] @@ -45,13 +44,13 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok seq_o = torch.from_numpy(tgt_grp[i_d][()]) lo = seq_o.size(1) - 1 if mv_device: - seq_batch = seq_batch.to(mv_device) - seq_o = seq_o.to(mv_device) + seq_batch = seq_batch.to(mv_device, non_blocking=True) + seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() oi = seq_o.narrow(1, 0, lo) ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=_use_amp): + with torch_autocast(enabled=_use_amp): output = model(seq_batch, oi) loss = lossf(output, ot) if multi_gpu: @@ -69,7 +68,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok loss = output = oi = ot = seq_batch = seq_o = None sum_loss += loss_add if save_loss: - _ls[(i_d, t_d)] = loss_add / wd_add + _ls[i_d] = loss_add / wd_add sum_wd += wd_add _done_tokens += wd_add @@ -99,6 +98,16 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok if report_eva: _leva, _eeva = eva(ed, nd, model, lossf, mv_device, multi_gpu, _use_amp) logger.info("Average loss over %d tokens: %.3f, valid loss/error: %.3f %.2f" % (part_wd, part_loss / part_wd, _leva, _eeva,)) + if (_eeva < minerr) or (_leva < minloss): + save_model(model, wkdir + "eva_%.3f_%.2f.h5" % (_leva, _eeva,), multi_gpu, print_func=logger.info, mtyp="ieva" if save_auto_clean else None) + if statesf is not None: + save_states(state_holder.state_dict(update=False, **{"remain_steps": _cur_rstep, "checkpoint_id": _cur_checkid, "training_list": tl[cur_b - 1:]}), statesf, print_func=logger.info) + logger.info("New best model saved") + namin = 0 + if _eeva < minerr: + minerr = _eeva + if _leva < minloss: + minloss = _leva free_cache(mv_device) model.train() else: @@ -106,7 +115,7 @@ def train(td, tl, ed, nd, optm, lrsch, model, lossf, mv_device, logger, done_tok part_loss = 0.0 part_wd = 0 - if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ndata): + if save_checkp_epoch and (_cur_rstep is None) and (save_every is not None) and (cur_b % save_every == 0) and (chkpf is not None) and (cur_b < ntrain): if num_checkpoint > 1: _fend = "_%d.h5" % (_cur_checkid) _chkpf = chkpf[:-3] + _fend @@ -127,23 +136,23 @@ def eva(ed, nd, model, lossf, mv_device, multi_gpu, use_amp=False): sum_loss = 0.0 model.eval() src_grp, tgt_grp = ed["src"], ed["tgt"] - with torch.no_grad(): + with torch_inference_mode(): for i in tqdm(range(nd), mininterval=tqdm_mininterval): bid = str(i) seq_batch = torch.from_numpy(src_grp[bid][()]) seq_o = torch.from_numpy(tgt_grp[bid][()]) lo = seq_o.size(1) - 1 if mv_device: - seq_batch = seq_batch.to(mv_device) - seq_o = seq_o.to(mv_device) + seq_batch = seq_batch.to(mv_device, non_blocking=True) + seq_o = seq_o.to(mv_device, non_blocking=True) seq_batch, seq_o = seq_batch.long(), seq_o.long() ot = seq_o.narrow(1, 1, lo).contiguous() - with autocast(enabled=use_amp): + with torch_autocast(enabled=use_amp): output = model(seq_batch, seq_o.narrow(1, 0, lo)) loss = lossf(output, ot) if multi_gpu: loss = loss.sum() - trans = torch.cat([outu.argmax(-1).to(mv_device) for outu in output], 0) + trans = torch.cat([outu.argmax(-1).to(mv_device, non_blocking=True) for outu in output], 0) else: trans = output.argmax(-1) sum_loss += loss.data.item() @@ -213,7 +222,7 @@ def load_fixing(module): tl = [str(i) for i in range(ntrain)] logger.info("Design models with seed: %d" % torch.initial_seed()) -mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) +mymodel = NMT(cnfg.isize, nwordi, nwordt, cnfg.nlayer, cnfg.ff_hsize, cnfg.drop, cnfg.attn_drop, cnfg.act_drop, cnfg.share_emb, cnfg.nhead, cache_len_default, cnfg.attn_hsize, cnfg.norm_output, cnfg.bindDecoderEmb, cnfg.forbidden_indexes) fine_tune_m = cnfg.fine_tune_m @@ -235,8 +244,8 @@ def load_fixing(module): load_emb(cnfg.tgt_emb, mymodel.dec.wemb.weight, nwordt, cnfg.scale_down_emb, cnfg.freeze_tgtemb) if cuda_device: - mymodel.to(cuda_device) - lossf.to(cuda_device) + mymodel.to(cuda_device, non_blocking=True) + lossf.to(cuda_device, non_blocking=True) use_amp = cnfg.use_amp and use_cuda scaler = (MultiGPUGradScaler() if multi_gpu_optimizer else GradScaler()) if use_amp else None @@ -256,6 +265,9 @@ def load_fixing(module): # lrsch.step() will be automatically called with the constructor lrsch = LRScheduler(optimizer, cnfg.isize, cnfg.warm_step, scale=cnfg.lr_scale) +mymodel = torch_compile(mymodel, *torch_compile_args, **torch_compile_kwargs) +lossf = torch_compile(lossf, *torch_compile_args, **torch_compile_kwargs) + state_holder = None if statesf is None and cnt_states is None else Holder(**{"optm": optimizer, "lrsch": lrsch, "pyrand": PyRandomState(), "thrand": THRandomState(use_cuda=use_cuda)}) num_checkpoint = cnfg.num_checkpoint @@ -264,7 +276,7 @@ def load_fixing(module): tminerr = inf_default minloss, minerr = eva(vd, nvalid, mymodel, lossf, cuda_device, multi_gpu, use_amp) -logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(tostr(getlr(optimizer))), minloss, minerr,)) +logger.info("Init lr: %s, Dev Loss/Error: %.3f %.2f" % (" ".join(iter_to_str(getlr(optimizer))), minloss, minerr,)) if fine_tune_m is None: save_model(mymodel, wkdir + "init.h5", multi_gpu, print_func=logger.info) @@ -361,7 +373,7 @@ def load_fixing(module): #lrsch.step(terr) #newlr = getlr(optimizer) #if updated_lr(oldlr, newlr): - #logger.info("".join(("lr update from: ", ",".join(tostr(oldlr)), ", to: ", ",".join(tostr(newlr))))) + #logger.info("".join(("lr update from: ", ",".join(iter_to_str(oldlr)), ", to: ", ",".join(iter_to_str(newlr))))) #hook_lr_update(optimizer, use_ams) if done_tokens > 0: diff --git a/transformer/AGG/HierDecoder.py b/transformer/AGG/HierDecoder.py index e5a13a1..a685f8d 100644 --- a/transformer/AGG/HierDecoder.py +++ b/transformer/AGG/HierDecoder.py @@ -1,32 +1,31 @@ #encoding: utf-8 -import torch from torch import nn -from modules.base import * - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.base import ResidueCombiner +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase from utils.base import align_modules_by_type +from utils.fmt.parser import parse_none from cnfg.ihyp import * class DecoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, num_sub=1, comb_input=True): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, num_sub=1, comb_input=True, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize super(DecoderLayer, self).__init__() - self.nets = nn.ModuleList([DecoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_sub)]) + self.nets = nn.ModuleList([DecoderLayerBase(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_sub)]) self.combiner = ResidueCombiner(isize, num_sub + 1 if comb_input else num_sub, _fhsize) self.comb_input = comb_input - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): outs = [] if query_unit is None: @@ -56,15 +55,15 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, bindemb=False, forbidden_index=None, num_sub=1, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, bindemb=False, forbidden_index=None, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, num_sub, i != 0) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, num_sub, i != 0) for i in range(num_layer)]) def load_base(self, base_decoder): diff --git a/transformer/AGG/HierEncoder.py b/transformer/AGG/HierEncoder.py index 1d3b665..2abd739 100644 --- a/transformer/AGG/HierEncoder.py +++ b/transformer/AGG/HierEncoder.py @@ -1,31 +1,31 @@ #encoding: utf-8 from torch import nn -from modules.base import * - -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.base import ResidueCombiner +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase from utils.base import align_modules_by_type +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, num_sub=1, comb_input=True): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, num_sub=1, comb_input=True, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize super(EncoderLayer, self).__init__() - self.nets = nn.ModuleList([EncoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_sub)]) + self.nets = nn.ModuleList([EncoderLayerBase(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_sub)]) self.combiner = ResidueCombiner(isize, num_sub + 1 if comb_input else num_sub, _fhsize) self.comb_input = comb_input - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): out = inputs outs = [out] if self.comb_input else [] @@ -37,15 +37,15 @@ def forward(self, inputs, mask=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, num_sub=1, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, **kwargs) - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, num_sub, i != 0) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, num_sub, i != 0) for i in range(num_layer)]) def load_base(self, base_encoder): diff --git a/transformer/AGG/InceptDecoder.py b/transformer/AGG/InceptDecoder.py index 32d1130..9dd510a 100644 --- a/transformer/AGG/InceptDecoder.py +++ b/transformer/AGG/InceptDecoder.py @@ -1,29 +1,29 @@ #encoding: utf-8 -import torch from torch import nn -from modules.base import * -from transformer.Decoder import DecoderLayer as DecoderLayerBase +from modules.base import ResidueCombiner from transformer.AGG.HierDecoder import Decoder as DecoderBase +from transformer.Decoder import DecoderLayer as DecoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class DecoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, num_sub=1): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize super(DecoderLayer, self).__init__() - self.nets = nn.ModuleList([DecoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_sub)]) + self.nets = nn.ModuleList([DecoderLayerBase(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_sub)]) self.combiner = ResidueCombiner(isize, num_sub, _fhsize) - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): outs = [] if query_unit is None: @@ -49,12 +49,12 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, bindemb=False, forbidden_index=None, num_sub=1, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, bindemb=False, forbidden_index=None, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, num_sub) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, num_sub) for i in range(num_layer)]) diff --git a/transformer/AGG/InceptEncoder.py b/transformer/AGG/InceptEncoder.py index 5c7ed72..581f1e2 100644 --- a/transformer/AGG/InceptEncoder.py +++ b/transformer/AGG/InceptEncoder.py @@ -1,28 +1,29 @@ #encoding: utf-8 from torch import nn -from modules.base import * -from transformer.Encoder import EncoderLayer as EncoderLayerBase +from modules.base import ResidueCombiner from transformer.AGG.HierEncoder import Encoder as EncoderBase +from transformer.Encoder import EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, num_sub=1): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize super(EncoderLayer, self).__init__() - self.nets = nn.ModuleList([EncoderLayerBase(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_sub)]) + self.nets = nn.ModuleList([EncoderLayerBase(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_sub)]) self.combiner = ResidueCombiner(isize, num_sub, _fhsize) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): out = inputs outs = [] @@ -34,12 +35,12 @@ def forward(self, inputs, mask=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, num_sub=1, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=False, num_sub=1, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, **kwargs) - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, num_sub) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, num_sub) for i in range(num_layer)]) diff --git a/transformer/APE/Decoder.py b/transformer/APE/Decoder.py index c687ce0..970a999 100644 --- a/transformer/APE/Decoder.py +++ b/transformer/APE/Decoder.py @@ -1,29 +1,30 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import ResCrossAttn - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase - -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, mask_tensor_type from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id +from modules.base import ResCrossAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.cross_attn_mt = ResCrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) - def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -43,20 +44,20 @@ def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, t class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None): + def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None, **kwargs): nquery = inputo.size(-1) @@ -79,11 +80,11 @@ def forward(self, inpute, inputm, inputo, src_pad_mask=None, mt_pad_mask=None): return out - def decode(self, inpute, inputm, src_pad_mask, mt_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, inputm, src_pad_mask, mt_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, inputm, src_pad_mask, mt_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputm, src_pad_mask, mt_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, inputm, src_pad_mask, mt_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, inputm, src_pad_mask, mt_pad_mask, max_len, fill_pad=fill_pad, **kwargs) - def greedy_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -109,7 +110,7 @@ def greedy_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, max trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -131,13 +132,13 @@ def greedy_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, max trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] mtl = inputm.size(1) @@ -147,8 +148,7 @@ def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_ real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) @@ -178,7 +178,7 @@ def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_ _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) #inputm = inputm.repeat(1, beam_size, 1).view(real_bsize, mtl, isize) self.repeat_cross_attn_buffer(beam_size) @@ -227,7 +227,7 @@ def beam_decode(self, inpute, inputm, src_pad_mask=None, mt_pad_mask=None, beam_ trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/APE/Encoder.py b/transformer/APE/Encoder.py index 2ddeb05..cbf5af7 100644 --- a/transformer/APE/Encoder.py +++ b/transformer/APE/Encoder.py @@ -1,20 +1,21 @@ #encoding: utf-8 +from math import sqrt from torch import nn -from modules.base import Dropout, PositionalEmb -from utils.fmt.base import parse_double_value_tuple -from cnfg.vocab.base import pad_id -from transformer.Encoder import Encoder as EncoderBase +from modules.base import Dropout, PositionalEmb from transformer.Decoder import DecoderLayer as MSEncoderLayerBase - -from math import sqrt +from transformer.Encoder import Encoder as EncoderBase +from utils.fmt.base import parse_double_value_tuple +from utils.fmt.parser import parse_none +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class MSEncoderLayer(MSEncoderLayerBase): - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, **kwargs): context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -26,11 +27,11 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): class MSEncoder(nn.Module): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, emb_w=None, share_layer=False, disable_pemb=disable_std_pemb_encoder): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, emb_w=None, share_layer=False, disable_pemb=disable_std_pemb_encoder, **kwargs): super(MSEncoder, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None @@ -41,20 +42,18 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.pemb = None if disable_pemb else PositionalEmb(isize, xseql, 0, 0) if share_layer: - _shared_layer = MSEncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = MSEncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([MSEncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([MSEncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, **kwargs): nquery = inputo.size(-1) out = self.wemb(inputo) - - out = out * sqrt(out.size(-1)) if self.pemb is not None: out = self.pemb(inputo, expand=False).add(out, alpha=sqrt(out.size(-1))) @@ -71,34 +70,33 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None): class Encoder(nn.Module): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, global_emb=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, global_emb=False, **kwargs): super(Encoder, self).__init__() nwd_src, nwd_tgt = parse_double_value_tuple(nwd) - self.src_enc = EncoderBase(isize, nwd_src, num_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, **kwargs) + self.src_enc = EncoderBase(isize, nwd_src, num_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, **kwargs) emb_w = self.src_enc.wemb.weight if global_emb else None - self.tgt_enc = MSEncoder(isize, nwd_tgt, num_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, emb_w, **kwargs) + self.tgt_enc = MSEncoder(isize, nwd_tgt, num_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, emb_w, **kwargs) - def forward(self, inpute, inputo, src_mask=None, tgt_mask=None): + def forward(self, inpute, inputo, src_mask=None, tgt_mask=None, **kwargs): enc_src = self.src_enc(inpute, src_mask) return enc_src, self.tgt_enc(enc_src, inputo, src_mask, tgt_mask) + def get_embedding_weight(self): + + return self.enc_src.get_embedding_weight() + def update_vocab(self, indices): _bind_emb = self.src_enc.wemb.weight.is_set_to(self.tgt_enc.wemb.weight) - _swemb = nn.Embedding(len(indices), self.src_enc.wemb.weight.size(-1), padding_idx=pad_id) - _twemb = nn.Embedding(len(indices), self.tgt_enc.wemb.weight.size(-1), padding_idx=pad_id) - with torch.no_grad(): - _swemb.weight.copy_(self.src_enc.wemb.weight.index_select(0, indices)) + _ = self.src_enc.update_vocab(indices) if _bind_emb: - _twemb.weight = _swemb.weight - else: - with torch.no_grad(): - _twemb.weight.copy_(self.tgt_enc.wemb.weight.index_select(0, indices)) - self.src_enc.wemb, self.tgt_enc.wemb = _swemb, _twemb + self.tgt_enc.wemb.weight = _ + + return _ diff --git a/transformer/APE/NMT.py b/transformer/APE/NMT.py index 1894e78..0eeb35e 100644 --- a/transformer/APE/NMT.py +++ b/transformer/APE/NMT.py @@ -1,44 +1,44 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.APE.Encoder import Encoder from transformer.APE.Decoder import Decoder +from transformer.APE.Encoder import Encoder from transformer.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) - self.enc = Encoder(isize, (snwd, tnwd,), enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, global_emb=global_emb) + self.enc = Encoder(isize, (snwd, tnwd,), enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, global_emb=global_emb) emb_w = self.enc.tgt_enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputm, inputo, src_mask=None, mt_mask=None): + def forward(self, inpute, inputm, inputo, src_mask=None, mt_mask=None, **kwargs): - _src_mask = inpute.eq(0).unsqueeze(1) if src_mask is None else src_mask - _mt_mask = inputm.eq(0).unsqueeze(1) if mt_mask is None else mt_mask + _src_mask = inpute.eq(pad_id).unsqueeze(1) if src_mask is None else src_mask + _mt_mask = inputm.eq(pad_id).unsqueeze(1) if mt_mask is None else mt_mask enc_src, enc_mt = self.enc(inpute, inputm, _src_mask, _mt_mask) return self.dec(enc_src, enc_mt, inputo, _src_mask, _mt_mask) - def decode(self, inpute, inputm, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, inputm, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - src_mask = inpute.eq(0).unsqueeze(1) - mt_mask = inputm.eq(0).unsqueeze(1) + src_mask = inpute.eq(pad_id).unsqueeze(1) + mt_mask = inputm.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len diff --git a/transformer/AvgDecoder.py b/transformer/AvgDecoder.py index e8a8407..14a1133 100644 --- a/transformer/AvgDecoder.py +++ b/transformer/AvgDecoder.py @@ -1,20 +1,22 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn + from modules.aan import AverageAttn -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase from utils.aan import share_aan_cache -from math import sqrt - -from cnfg.vocab.base import pad_id - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id -# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) +# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://aclanthology.org/P18-1166/) class DecoderLayer(DecoderLayerBase): @@ -24,12 +26,12 @@ class DecoderLayer(DecoderLayerBase): # num_head: number of heads in MultiHeadAttention # ahsize: hidden size of MultiHeadAttention - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.drop, self.layer_normer1, self.norm_residual = self.self_attn.drop, self.self_attn.normer, self.self_attn.norm_residual self.self_attn = AverageAttn(isize, _fhsize, dropout) @@ -37,11 +39,11 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: embedding of decoded translation (bsize, nquery, isize) during training, layer normed summed previous states for decoding # src_pad_mask: mask for given encoding source sentence (bsize, nquery, seql), see Encoder, expanded after generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # query_unit: single query to decode, used to support decoding for given step # step: current decoding step, used to average over the sum. - def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None, step=1): + def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None, step=1, **kwargs): if query_unit is None: _inputo = self.layer_normer1(inputo) @@ -93,18 +95,18 @@ class Decoder(DecoderBase): # ahsize: number of hidden units for MultiHeadAttention # bindemb: bind embedding and classifier weight - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.mask = None @@ -113,9 +115,9 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -157,10 +159,10 @@ def load_base(self, base_decoder): # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -191,7 +193,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # done_trans: (bsize, 1) - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(2, max_len + 1): @@ -216,7 +218,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break @@ -224,11 +226,11 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -237,7 +239,6 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -245,7 +246,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -277,7 +278,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) @@ -354,7 +355,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/ConstrainedDecoder.py b/transformer/ConstrainedDecoder.py index 89a47cf..4e7aaa2 100644 --- a/transformer/ConstrainedDecoder.py +++ b/transformer/ConstrainedDecoder.py @@ -1,28 +1,29 @@ #encoding: utf-8 import torch -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt -from cnfg.vocab.base import pad_id - from transformer.Decoder import Decoder as DecoderBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class Decoder(DecoderBase): - def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, cons_inds=None): + def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, cons_inds=None, **kwargs): if cons_inds is None: _cons_inds = None else: _cons_inds = cons_inds.view(1, 1, -1).expand(inpute.size(0), 1, -1) if cons_inds.dim() == 1 else cons_inds.unsqueeze(1) - return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, cons_inds=_cons_inds) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad, cons_inds=_cons_inds) + return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, cons_inds=_cons_inds, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad, cons_inds=_cons_inds, **kwargs) - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, cons_inds=None): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, cons_inds=None, **kwargs): bsize = inpute.size(0) @@ -50,7 +51,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -74,13 +75,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, cons_inds=None): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, cons_inds=None, **kwargs): bsize, seql = inpute.size()[:2] @@ -89,14 +90,13 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -123,7 +123,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -177,7 +177,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/Decoder.py b/transformer/Decoder.py index 4294e79..3689c83 100644 --- a/transformer/Decoder.py +++ b/transformer/Decoder.py @@ -1,15 +1,18 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import * -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, mask_tensor_type from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id +from modules.base import CrossAttn, Dropout, Linear, MultiHeadAttn, PositionalEmb, PositionwiseFF, ResCrossAttn, ResSelfAttn +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done, mask_tensor_type, torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(nn.Module): @@ -21,26 +24,25 @@ class DecoderLayer(nn.Module): # norm_residual: residue with layer normalized representation # k_rel_pos: window size (one side) of relative positional embeddings in self attention - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): super(DecoderLayer, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) - - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: embedding of decoded translation (bsize, nquery, isize) # src_pad_mask: mask for given encoding source sentence (bsize, nquery, seql), see Encoder, expanded after generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # tgt_pad_mask: mask to hide the future input # query_unit: single query to decode, used to support decoding for given step - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -59,12 +61,12 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un # Not used, keep this class to remind the DecoderLayer implementation before v0.3.5. class NAWDecoderLayer(DecoderLayer): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NAWDecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) + super(NAWDecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) self.layer_normer1, self.drop, self.norm_residual = self.self_attn.normer, self.self_attn.drop, self.self_attn.norm_residual self.self_attn = self.self_attn.net @@ -72,13 +74,13 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.cross_attn = self.cross_attn.net #self.self_attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) #self.cross_attn = CrossAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop) - #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) #self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) #self.layer_normer2 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) #self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None #self.norm_residual = norm_residual - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: _inputo = self.layer_normer1(inputo) @@ -130,17 +132,17 @@ class Decoder(nn.Module): # share_layer: using one shared decoder layer # disable_pemb: disable the standard positional embedding, can be enabled when use relative postional embeddings in self attention or AAN - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): super(Decoder, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None self.xseql = xseql - self.register_buffer("mask", torch.ones(xseql, xseql, dtype=mask_tensor_type).triu(1).unsqueeze(0)) + self.register_buffer("mask", torch.ones(xseql, xseql, dtype=mask_tensor_type).triu(1).unsqueeze(0), persistent=False) self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) if emb_w is not None: @@ -148,10 +150,10 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.pemb = None if disable_pemb else PositionalEmb(isize, xseql, 0, 0) if share_layer: - _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize) + _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) self.classifier = Linear(isize, nwd) # be careful since this line of code is trying to share the weight of the wemb and the classifier, which may cause problems if torch.nn updates @@ -167,9 +169,9 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): nquery = inputo.size(-1) @@ -184,7 +186,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): # the following line of code is to mask for the decoder, # which I think is useless, since only may pay attention to previous tokens, whos loss will be omitted by the loss function. - #_mask = torch.gt(_mask + inputo.eq(0).unsqueeze(1), 0) + #_mask = torch.gt(_mask + inputo.eq(pad_id).unsqueeze(1), 0) for net in self.nets: out = net(inpute, out, src_pad_mask, _mask) @@ -228,21 +230,21 @@ def repeat_cross_attn_buffer(self, beam_size): # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad, **kwargs) # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate # sample: for back translation - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -275,7 +277,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # done_trans: (bsize, 1) - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -297,7 +299,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break @@ -305,11 +307,11 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -318,7 +320,6 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -326,7 +327,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: @@ -359,7 +360,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # instead of update inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) with the following line, we only update cross-attention buffers. #inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) @@ -439,7 +440,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) @@ -487,34 +488,49 @@ def get_sos_emb(self, inpute, bsize=None): def fix_init(self): self.fix_load() - with torch.no_grad(): + with torch_no_grad(): #self.wemb.weight[pad_id].zero_() self.classifier.weight[pad_id].zero_() def fix_load(self): if (self.fbl is not None) and (self.classifier.bias is not None): - with torch.no_grad(): + with torch_no_grad(): self.classifier.bias.index_fill_(0, torch.as_tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default) + def unbind_emb(self): + + _bind_classifier_weight = self.classifier.weight.is_set_to(self.wemb.weight) + with torch_no_grad(): + self.wemb.weight = nn.Parameter(self.wemb.weight.clone()) + if _bind_classifier_weight: + self.classifier.weight = self.wemb.weight + def unbind_classifier_weight(self): if self.classifier.weight.is_set_to(self.wemb.weight): _tmp = self.classifier.weight _new_w = nn.Parameter(torch.Tensor(_tmp.size())) - with torch.no_grad(): + with torch_no_grad(): _new_w.data.copy_(_tmp.data) self.classifier.weight = _new_w - # this function will untie the decoder embedding from the encoder + def get_embedding_weight(self): + + return self.wemb.weight - def update_vocab(self, indices): + # this function will untie the decoder embedding from the encoder if wemb_weight is None + + def update_vocab(self, indices, wemb_weight=None): _nwd = indices.numel() _wemb = nn.Embedding(_nwd, self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) _classifier = Linear(self.classifier.weight.size(-1), _nwd, bias=self.classifier.bias is not None) - with torch.no_grad(): - _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + with torch_no_grad(): + if wemb_weight is None: + _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + else: + _wemb.weight = wemb_weight if self.classifier.weight.is_set_to(self.wemb.weight): _classifier.weight = _wemb.weight else: @@ -529,7 +545,7 @@ def update_classifier(self, indices): _nwd = indices.numel() _classifier = Linear(self.classifier.weight.size(-1), _nwd, bias=self.classifier.bias is not None) - with torch.no_grad(): + with torch_no_grad(): _classifier.weight.copy_(self.classifier.weight.index_select(0, indices)) if self.classifier.bias is not None: _classifier.bias.copy_(self.classifier.bias.index_select(0, indices)) @@ -550,7 +566,7 @@ def set_emb(self, emb_weight): # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: the beam size for beam search # max_len: maximum length to generate @@ -560,7 +576,7 @@ def decode_clip(self, inpute, src_pad_mask, beam_size=1, max_len=512, length_pen # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate def greedy_decode_clip(self, inpute, src_pad_mask=None, max_len=512, return_mat=True): @@ -621,7 +637,7 @@ def greedy_decode_clip(self, inpute, src_pad_mask=None, max_len=512, return_mat= trans.append(wds) # done_trans: (bsize) - done_trans = wds.squeeze(1).eq(2) + done_trans = wds.squeeze(1).eq(eos_id) _ndone = done_trans.int().sum().item() if _ndone == bsize: @@ -653,7 +669,7 @@ def greedy_decode_clip(self, inpute, src_pad_mask=None, max_len=512, return_mat= # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate @@ -666,7 +682,6 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -674,7 +689,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -704,7 +719,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) #inpute = inpute.repeat(1, beam_size, 1).view(real_bsize, seql, isize) @@ -788,7 +803,7 @@ def beam_decode_clip(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, trans = torch.cat((trans.index_select(0, _inds), wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/Doc/Para/Base/Decoder.py b/transformer/Doc/Para/Base/Decoder.py index 33a2875..4541d12 100644 --- a/transformer/Doc/Para/Base/Decoder.py +++ b/transformer/Doc/Para/Base/Decoder.py @@ -1,26 +1,28 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import * -from utils.sampler import SampleMax -from modules.paradoc import GateResidual -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.base import CrossAttn, Dropout +from modules.paradoc import GateResidual +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ncross=2, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, ncross=2, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.cattns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(ncross)]) self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) for i in range(ncross)]) @@ -28,7 +30,7 @@ def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, a self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None self.norm_residual = self.cross_attn.norm_residual - def forward(self, inpute, inputo, inputc, src_pad_mask=None, tgt_pad_mask=None, context_mask=None, query_unit=None): + def forward(self, inpute, inputo, inputc, src_pad_mask=None, tgt_pad_mask=None, context_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -62,17 +64,17 @@ def load_base(self, base_decoder_layer): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, nprev_context=2, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, nprev_context=2, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, nprev_context) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, nprev_context) for i in range(num_layer)]) - def forward(self, inpute, inputo, inputc, src_pad_mask=None, context_mask=None): + def forward(self, inpute, inputo, inputc, src_pad_mask=None, context_mask=None, **kwargs): bsize, nsent, nquery = inputo.size() _inputo = inputo.view(-1, nquery) @@ -113,11 +115,11 @@ def load_base(self, base_decoder): self.out_normer = None if self.out_normer is None else base_decoder.out_normer - def decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, inputc, src_pad_mask, context_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputc, src_pad_mask, context_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, inputc, src_pad_mask, context_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, inputc, src_pad_mask, context_mask, max_len, fill_pad=fill_pad, **kwargs) - def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -143,7 +145,7 @@ def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, ma trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -165,13 +167,13 @@ def greedy_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, ma trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -180,14 +182,13 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -211,7 +212,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -263,7 +264,7 @@ def beam_decode(self, inpute, inputc, src_pad_mask=None, context_mask=None, beam trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/Doc/Para/Base/Encoder.py b/transformer/Doc/Para/Base/Encoder.py index 22f43b4..9a070e8 100644 --- a/transformer/Doc/Para/Base/Encoder.py +++ b/transformer/Doc/Para/Base/Encoder.py @@ -1,30 +1,30 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import * -from modules.paradoc import GateResidual from math import sqrt +from torch import nn -from utils.base import mask_tensor_type - -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.base import CrossAttn +from modules.paradoc import GateResidual +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none +from utils.torch.comp import mask_tensor_type, torch_no_grad from cnfg.ihyp import * class CrossEncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ncross=2, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, ncross=2, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(CrossEncoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize) + super(CrossEncoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) self.cattns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(ncross)]) self.cattn_ln = nn.ModuleList([nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) for i in range(ncross)]) self.grs = nn.ModuleList([GateResidual(isize) for i in range(ncross)]) - def forward(self, inputs, inputc, mask=None, context_mask=None): + def forward(self, inputs, inputc, mask=None, context_mask=None, **kwargs): _inputs = self.layer_normer(inputs) context = self.attn(_inputs, mask=mask) @@ -54,22 +54,21 @@ def load_base(self, base_encoder_layer): class CrossEncoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(CrossEncoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output) + super(CrossEncoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output) - self.nets = nn.ModuleList([CrossEncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([CrossEncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inputs, inputc, mask=None, context_mask=None): + def forward(self, inputs, inputc, mask=None, context_mask=None, **kwargs): out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) @@ -94,22 +93,22 @@ def load_base(self, base_encoder): class Encoder(nn.Module): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, num_layer_context=1): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, nprev_context=2, num_layer_context=1, **kwargs): super(Encoder, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.context_enc = EncoderBase(isize, nwd, num_layer if num_layer_context is None else num_layer_context, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output) - self.enc = CrossEncoder(isize, nwd, num_layer, _fhsize, dropout, attn_drop, num_head, xseql, _ahsize, norm_output, nprev_context) + self.context_enc = EncoderBase(isize, nwd, num_layer if num_layer_context is None else num_layer_context, _fhsize, dropout, attn_drop, act_drop, num_head, xseql, _ahsize, norm_output) + self.enc = CrossEncoder(isize, nwd, num_layer, _fhsize, dropout, attn_drop, act_drop, num_head, xseql, _ahsize, norm_output, nprev_context) _tmp_pad = torch.zeros(xseql, dtype=torch.long) _tmp_pad[0] = 1 _tmp_pad = _tmp_pad.view(1, 1, xseql).repeat(1, nprev_context - 1, 1) - self.register_buffer("pad", _tmp_pad) - self.register_buffer("pad_mask", (1 - _tmp_pad).to(mask_tensor_type, non_blocking=True).unsqueeze(1)) + self.register_buffer("pad", _tmp_pad, persistent=False) + self.register_buffer("pad_mask", (1 - _tmp_pad).to(mask_tensor_type, non_blocking=True).unsqueeze(1), persistent=False) self.xseql = xseql self.nprev_context = nprev_context @@ -117,9 +116,9 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. # inputs: (bsize, _nsent, seql), nprev_context, ... , nsent - 1 # inputc: (bsize, _nsentc, seql), 0, 1, ... , nsent - 2 # mask: (bsize, 1, _nsent, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) # where _nsent = nsent - self.nprev_context, _nsentc = nsent - 1 - def forward(self, inputs, inputc, mask=None, context_mask=None): + def forward(self, inputs, inputc, mask=None, context_mask=None, **kwargs): bsize, nsentc, seql = inputc.size() _inputc = torch.cat((self.get_pad(seql).expand(bsize, -1, seql), inputc,), dim=1) @@ -139,7 +138,7 @@ def forward(self, inputs, inputc, mask=None, context_mask=None): def load_base(self, base_encoder): self.enc.load_base(base_encoder) - with torch.no_grad(): + with torch_no_grad(): self.context_enc.wemb.weight.copy_(base_encoder.wemb.weight) def get_pad(self, seql): @@ -150,6 +149,10 @@ def get_padmask(self, seql): return self.pad_mask.narrow(-1, 0, seql) if seql <= self.xseql else torch.cat((self.pad_mask, self.pad_mask.new_ones(1, 1, self.nprev_context - 1, seql - self.xseql),), dim=-1) + def get_embedding_weight(self): + + return self.enc.get_embedding_weight() + def update_vocab(self, indices): - self.context_enc.update_vocab(indices) + return self.enc.update_vocab(indices) diff --git a/transformer/Doc/Para/Base/NMT.py b/transformer/Doc/Para/Base/NMT.py index b699809..fba2c97 100644 --- a/transformer/Doc/Para/Base/NMT.py +++ b/transformer/Doc/Para/Base/NMT.py @@ -2,43 +2,43 @@ from torch import nn -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.Doc.Para.Base.Encoder import Encoder from transformer.Doc.Para.Base.Decoder import Decoder +from transformer.Doc.Para.Base.Encoder import Encoder +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(nn.Module): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, nprev_context=2, num_layer_context=1): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, nprev_context=2, num_layer_context=1, **kwargs): super(NMT, self).__init__() enc_layer, dec_layer = parse_double_value_tuple(num_layer) - self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, nprev_context, num_layer_context) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, nprev_context, num_layer_context) emb_w = self.enc.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index, nprev_context) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, act_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index, nprev_context) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, inputc, mask=None, context_mask=None): + def forward(self, inpute, inputo, inputc, mask=None, context_mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask - _context_mask = inputc.eq(0).unsqueeze(1) if context_mask is None else context_mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask + _context_mask = inputc.eq(pad_id).unsqueeze(1) if context_mask is None else context_mask ence, contexts, context_masks = self.enc(inpute, inputc, _mask, _context_mask) return self.dec(ence, inputo, contexts, _mask, context_masks) - def decode(self, inpute, inputc, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, inputc, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) - context_mask = inputc.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) + context_mask = inputc.eq(pad_id).unsqueeze(1) bsize, nsent, seql = inpute.size() _max_len = (seql + max(64, seql // 4)) if max_len is None else max_len diff --git a/transformer/Encoder.py b/transformer/Encoder.py index cd5dec4..c0870b9 100644 --- a/transformer/Encoder.py +++ b/transformer/Encoder.py @@ -1,13 +1,14 @@ #encoding: utf-8 -import torch -from torch import nn -from modules.base import * from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id +from modules.base import Dropout, PositionalEmb, PositionwiseFF, ResSelfAttn +from utils.fmt.parser import parse_none +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import pad_id # vocabulary: # :0 @@ -27,20 +28,19 @@ class EncoderLayer(nn.Module): # norm_residual: residue with layer normalized representation # k_rel_pos: window size (one side) of relative positional embeddings in self attention - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): super(EncoderLayer, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) - - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) # inputs: input of this layer (bsize, seql, isize) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): context = self.attn(inputs, mask=mask) @@ -51,22 +51,22 @@ def forward(self, inputs, mask=None): # Not used, keep this class to remind the EncoderLayer implementation before v0.3.5. class NAWEncoderLayer(EncoderLayer): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NAWEncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) + super(NAWEncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) #self.attn = SelfAttn(isize, _ahsize, isize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) - #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + #self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) self.layer_normer, self.drop, self.norm_residual = self.attn.normer, self.attn.drop, self.attn.norm_residual self.attn = self.attn.net #self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) #self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None #self.norm_residual = norm_residual - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): _inputs = self.layer_normer(inputs) context = self.attn(_inputs, mask=mask) @@ -93,11 +93,11 @@ class Encoder(nn.Module): # share_layer: using one shared encoder layer # disable_pemb: disable the standard positional embedding, enable when use relative postional embeddings in self attention - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, **kwargs): super(Encoder, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None @@ -106,23 +106,22 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.pemb = None if disable_pemb else PositionalEmb(isize, xseql, 0, 0) if share_layer: - _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize) + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) @@ -146,16 +145,22 @@ def load_base(self, base_encoder): self.out_normer = None if self.out_normer is None else base_encoder.out_normer + def get_embedding_weight(self): + + return self.wemb.weight + def update_vocab(self, indices): - _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) - with torch.no_grad(): + _wemb = nn.Embedding(indices.numel(), self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) + with torch_no_grad(): _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) self.wemb = _wemb + return self.wemb.weight + def fix_init(self): if hasattr(self, "fix_load"): self.fix_load() - #with torch.no_grad(): + #with torch_no_grad(): # self.wemb.weight[pad_id].zero_() diff --git a/transformer/EnsembleAvgDecoder.py b/transformer/EnsembleAvgDecoder.py index 4957258..484d23f 100644 --- a/transformer/EnsembleAvgDecoder.py +++ b/transformer/EnsembleAvgDecoder.py @@ -1,26 +1,27 @@ #encoding: utf-8 import torch -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt -from cnfg.vocab.base import pad_id - from transformer.EnsembleDecoder import Decoder as DecoderBase +from utils.base import index_tensors +from utils.decode.beam import expand_bsize_for_beam +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id -# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://www.aclweb.org/anthology/P18-1166/) +# Average Decoder is proposed in Accelerating Neural Transformer via an Average Attention Network (https://aclanthology.org/P18-1166/) class Decoder(DecoderBase): # inpute: encoded representation from encoders [(bsize, seql, isize)...] # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -47,14 +48,14 @@ def forward(self, inpute, inputo, src_pad_mask=None): # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize, seql, isize = inpute[0].size() - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(self.nets[0].wemb.weight.size(-1)) outs = [] @@ -89,7 +90,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # done_trans: (bsize, 1) - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for step in range(2, max_len + 1): @@ -118,7 +119,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break @@ -126,11 +127,11 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # inpute: encoded representation from encoders [(bsize, seql, isize)...] # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql, isize = inpute[0].size() @@ -138,7 +139,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt bsizeb2 = bsize * beam_size2 real_bsize = bsize * beam_size - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(self.nets[0].wemb.weight.size(-1)) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -187,7 +188,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) @@ -269,7 +270,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/EnsembleDecoder.py b/transformer/EnsembleDecoder.py index b066961..a59c800 100644 --- a/transformer/EnsembleDecoder.py +++ b/transformer/EnsembleDecoder.py @@ -1,20 +1,22 @@ #encoding: utf-8 import torch -from torch import nn -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id +from utils.base import index_tensors +from utils.decode.beam import expand_bsize_for_beam +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class Decoder(nn.Module): # models: list of decoders - def __init__(self, models): + def __init__(self, models, **kwargs): super(Decoder, self).__init__() @@ -23,9 +25,9 @@ def __init__(self, models): # inpute: encoded representation from encoders [(bsize, seql, isize)...] # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -35,7 +37,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): # the following line of code is to mask for the decoder, # which I think is useless, since only may pay attention to previous tokens, whos loss will be omitted by the loss function. - #_mask = torch.gt(_mask + inputo.eq(0).unsqueeze(1), 0) + #_mask = torch.gt(_mask + inputo.eq(pad_id).unsqueeze(1), 0) for model, inputu in zip(self.nets, inpute): @@ -58,24 +60,24 @@ def forward(self, inpute, inputo, src_pad_mask=None): # inpute: encoded representation from encoders [(bsize, seql, isize)...] # src_pad_mask: mask for given encoding source sentence (bsize, seql), see Encoder, get by: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, src_pad_mask, max_len, fill_pad=fill_pad, **kwargs) # inpute: encoded representation from encoders [(bsize, seql, isize)...] # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize, seql, isize = inpute[0].size() - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(self.nets[0].wemb.weight.size(-1)) outs = [] @@ -110,7 +112,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # done_trans: (bsize, 1) - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -139,7 +141,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break @@ -147,11 +149,11 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # inpute: encoded representation from encoders [(bsize, seql, isize)...] # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql, isize = inpute[0].size() @@ -159,7 +161,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt bsizeb2 = bsize * beam_size2 real_bsize = bsize * beam_size - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(self.nets[0].wemb.weight.size(-1)) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -207,7 +209,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) @@ -289,7 +291,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/EnsembleEncoder.py b/transformer/EnsembleEncoder.py index 60b6bd5..b85aa21 100644 --- a/transformer/EnsembleEncoder.py +++ b/transformer/EnsembleEncoder.py @@ -4,14 +4,14 @@ class Encoder(nn.Module): - def __init__(self, models): + def __init__(self, models, **kwargs): super(Encoder, self).__init__() self.nets = nn.ModuleList(models) # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) def forward(self, *inputs, **kwargs): diff --git a/transformer/EnsembleNMT.py b/transformer/EnsembleNMT.py index bc6af1a..bc4354a 100644 --- a/transformer/EnsembleNMT.py +++ b/transformer/EnsembleNMT.py @@ -3,19 +3,19 @@ import torch from torch import nn -from utils.base import all_done, select_zero_ - from transformer.EnsembleEncoder import Encoder - # switch the comment between the following two lines to choose standard decoder or average decoder from transformer.EnsembleDecoder import Decoder #from transformer.EnsembleAvgDecoder import Decoder +from utils.base import select_zero_ +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class NMT(nn.Module): - def __init__(self, models): + def __init__(self, models, **kwargs): super(NMT, self).__init__() @@ -25,11 +25,11 @@ def __init__(self, models): # inpute: source sentences from encoder (bsize, seql) # inputo: decoded translation (bsize, nquery) # mask: user specified mask, otherwise it will be: - # inpute.eq(0).unsqueeze(1) + # inpute.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, _mask), inputo, _mask) @@ -37,9 +37,9 @@ def forward(self, inpute, inputo, mask=None): # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len @@ -47,7 +47,7 @@ def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): def train_decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, mask=None): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len @@ -76,7 +76,7 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): out = torch.cat((out, wds), -1) # done_trans: (bsize) - done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans | wds.squeeze(1).eq(2)) + done_trans = wds.squeeze(1).eq(eos_id) if done_trans is None else (done_trans | wds.squeeze(1).eq(eos_id)) if all_done(done_trans, bsize): break @@ -146,7 +146,7 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ out = torch.cat((out.index_select(0, _inds), wds), -1) # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(2)).view(bsize, beam_size) + done_trans = wds.view(bsize, beam_size).eq(eos_id) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(eos_id)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/HPLSTM/Decoder.py b/transformer/HPLSTM/Decoder.py index f1e1a74..d4b89fb 100644 --- a/transformer/HPLSTM/Decoder.py +++ b/transformer/HPLSTM/Decoder.py @@ -2,37 +2,40 @@ import torch from torch import nn -from modules.base import ResCrossAttn, Dropout -from modules.hplstm.hfn import HPLSTM -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ # shall we keep scaling embeddings? -#from math import sqrt -from cnfg.vocab.base import pad_id +#from math import sqrt +from modules.base import Dropout, ResCrossAttn +from modules.hplstm.hfn import HPLSTM from transformer.Decoder import Decoder as DecoderBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, **kwargs): super(DecoderLayer, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout) + self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout, act_drop=act_drop) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) self.layer_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: @@ -60,22 +63,22 @@ def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=True, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=True, **kwargs) self.mask = None if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): out = self.wemb(inputo) @@ -92,7 +95,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -115,7 +118,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -136,13 +139,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -151,7 +154,6 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) @@ -179,7 +181,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -225,7 +227,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/HPLSTM/Encoder.py b/transformer/HPLSTM/Encoder.py index ad37c3d..3b3914b 100644 --- a/transformer/HPLSTM/Encoder.py +++ b/transformer/HPLSTM/Encoder.py @@ -2,28 +2,28 @@ import torch from torch import nn + from modules.base import Dropout from modules.hplstm.hfn import BiHPLSTM - from transformer.Encoder import Encoder as EncoderBase - -from utils.base import flip_mask +from utils.fmt.parser import parse_none +from utils.torch.comp import flip_mask from cnfg.ihyp import * class EncoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, num_head=8): + def __init__(self, isize, fhsize=None, dropout=0.0, act_drop=None, num_head=8, **kwargs): super(EncoderLayer, self).__init__() _fhsize = isize * 4 if fhsize is None else fhsize - self.net = BiHPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout) + self.net = BiHPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout, act_drop=act_drop) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - def forward(self, inputs, reversed_mask=None): + def forward(self, inputs, reversed_mask=None, **kwargs): context = self.net(inputs, reversed_mask=reversed_mask) @@ -34,20 +34,20 @@ def forward(self, inputs, reversed_mask=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, disable_pemb=True, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, disable_pemb=True, **kwargs) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, num_head) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, act_drop, num_head) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, num_head) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, act_drop, num_head) for i in range(num_layer)]) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): if mask is None: _rmask = None diff --git a/transformer/HPLSTM/FNDecoder.py b/transformer/HPLSTM/FNDecoder.py index 497a7af..548e988 100644 --- a/transformer/HPLSTM/FNDecoder.py +++ b/transformer/HPLSTM/FNDecoder.py @@ -2,33 +2,35 @@ import torch from torch import nn -from modules.base import ResCrossAttn, Dropout, PositionwiseFF -from modules.hplstm.hfn import HPLSTM -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ - -from cnfg.vocab.base import pad_id +from modules.base import Dropout, PositionwiseFF, ResCrossAttn +from modules.hplstm.hfn import HPLSTM from transformer.Decoder import Decoder as DecoderBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(nn.Module): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, **kwargs): super(DecoderLayer, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout) + self.net = HPLSTM(isize, num_head=num_head, osize=isize, fhsize=_fhsize, dropout=dropout, act_drop=act_drop) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: @@ -58,22 +60,22 @@ def forward(self, inpute, inputo, src_pad_mask=None, query_unit=None): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=True, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=True, **kwargs) self.mask = None if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): out = self.wemb(inputo) @@ -90,7 +92,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -113,7 +115,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -134,13 +136,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -149,7 +151,6 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) @@ -177,7 +178,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -223,7 +224,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/LD/AttnEncoder.py b/transformer/LD/AttnEncoder.py index cc8f9de..be55a2a 100644 --- a/transformer/LD/AttnEncoder.py +++ b/transformer/LD/AttnEncoder.py @@ -1,11 +1,10 @@ #encoding: utf-8 import torch +from math import ceil, sqrt from torch import nn -from modules.LD import ATTNCombiner - -from math import sqrt, ceil +from modules.LD import ATTNCombiner from transformer.LD.Encoder import Encoder as EncoderBase from cnfg.ihyp import * @@ -19,7 +18,7 @@ def __init__(self, isize, nwd, num_layer, *inputs, **kwargs): self.attn_emb = ATTNCombiner(isize, isize, dropout) self.attns = nn.ModuleList([ATTNCombiner(isize, isize, dropout) for i in range(num_layer)]) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): def transform(lin, w, drop): @@ -67,9 +66,8 @@ def build_chunk_mean(atm, rept, bsize, nchk, ntok, npad, mask=None, rmask=None, _rmask = _nmask.ge(_ntok) out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) #if _rmask is not None: #_nele = (_ntok - _nmask).masked_fill(_nmask.eq(_ntok), 1).view(bsize, _nchk, 1).to(out, non_blocking=True) diff --git a/transformer/LD/Decoder.py b/transformer/LD/Decoder.py index a7d8e5a..d9c2a91 100644 --- a/transformer/LD/Decoder.py +++ b/transformer/LD/Decoder.py @@ -1,37 +1,37 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn +from modules.TA import PositionwiseFF from modules.base import CrossAttn, ResidueCombiner +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam, repeat_bsize_for_beam_tensor +from utils.fmt.parser import parse_none from utils.sampler import SampleMax -from modules.TA import PositionwiseFF - -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, repeat_bsize_for_beam_tensor -from math import sqrt - -from cnfg.vocab.base import pad_id - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): - - _ahsize = isize if ahsize is None else ahsize + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + self.layer_normer1 = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) self.cattn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) - self.ff = PositionwiseFF(isize, _fhsize, dropout) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop) self.scff = ResidueCombiner(isize, 2, _fhsize, dropout) - def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -53,21 +53,21 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -89,7 +89,7 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, chk_pad_mask=None): return out - def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -106,7 +106,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma states = {} for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, None, src_pad_mask, chk_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, (None, None,), src_pad_mask, chk_pad_mask, None, out) states[_tmp] = _state out = self.classifier(out) @@ -114,7 +114,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -127,7 +127,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma out = self.out_normer(out) for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, chk_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, chk_pad_mask, None, out) states[_tmp] = _state out = self.classifier(out) @@ -135,13 +135,13 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, ma trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -150,14 +150,13 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: @@ -168,7 +167,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam states = {} for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, None, src_pad_mask, chk_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, (None, None,), src_pad_mask, chk_pad_mask, None, out) states[_tmp] = _state out = self.lsm(self.classifier(out)) @@ -181,7 +180,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) #inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) self.repeat_cross_attn_buffer(beam_size) @@ -202,7 +201,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam out = self.out_normer(out) for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask, _chk_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask, _chk_pad_mask, None, out) states[_tmp] = _state out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) @@ -229,7 +228,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: @@ -255,6 +254,6 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, chk_pad_mask=None, beam return trans.view(bsize, beam_size, -1).select(1, 0) - def decode(self, inpute, inputh, src_pad_mask, chk_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, inputh, src_pad_mask, chk_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, inputh, src_pad_mask, chk_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputh, src_pad_mask, chk_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, inputh, src_pad_mask, chk_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, inputh, src_pad_mask, chk_pad_mask, max_len, fill_pad=fill_pad, **kwargs) diff --git a/transformer/LD/Encoder.py b/transformer/LD/Encoder.py index ac6b5be..23f9ab7 100644 --- a/transformer/LD/Encoder.py +++ b/transformer/LD/Encoder.py @@ -1,28 +1,28 @@ #encoding: utf-8 import torch +from math import ceil, sqrt from torch import nn -from modules.base import CrossAttn, Dropout, ResidueCombiner - -from math import sqrt, ceil -from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.base import CrossAttn, Dropout, ResidueCombiner +from transformer.TA.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.cattn = CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) self.scff = ResidueCombiner(isize, 2, _fhsize, dropout) - def forward(self, inputs, sumr, mask=None, rmask=None): + def forward(self, inputs, sumr, mask=None, rmask=None, **kwargs): #_bsize, _seql, _isize = inputs.size() #_rep1, _rep2 = self.cattn(inputs.repeat(2, 1, 1), sumr, rmask).view(2, _bsize, _seql, _isize).unbind(0) @@ -37,19 +37,19 @@ def forward(self, inputs, sumr, mask=None, rmask=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, max_chunk_tokens=8, min_chunks=4, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, max_chunk_tokens=8, min_chunks=4, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, num_layer_dec=num_layer_dec, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, num_layer_dec=num_layer_dec, **kwargs) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.sc_tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(1.0 / (num_layer + 1)), sqrt(1.0 / (num_layer + 1)))) self.sc_tattn_drop = Dropout(dropout) if dropout > 0.0 else None @@ -57,7 +57,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.mxct = max_chunk_tokens self.mnck = float(min_chunks) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): def transform(lin, w, drop): @@ -95,9 +95,8 @@ def build_chunk_mean(rept, bsize, nchk, ntok, npad, mask=None, rmask=None, nele= _rmask = _nmask.ge(_ntok) out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) #if _rmask is not None: #_nele = (_ntok - _nmask).view(bsize, _nchk, 1).to(out, non_blocking=True) diff --git a/transformer/LD/NMT.py b/transformer/LD/NMT.py index fa44193..80f5fc6 100644 --- a/transformer/LD/NMT.py +++ b/transformer/LD/NMT.py @@ -1,34 +1,34 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.LD.Encoder import Encoder from transformer.LD.Decoder import Decoder +from transformer.LD.Encoder import Encoder from transformer.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) - self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, dec_layer) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, dec_layer) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, act_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask ence, ench, hmask = self.enc(inpute, _mask) return self.dec(ence, ench, inputo, _mask, hmask) @@ -37,9 +37,9 @@ def forward(self, inpute, inputo, mask=None): # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len diff --git a/transformer/MuLang/Eff/Base/Decoder.py b/transformer/MuLang/Eff/Base/Decoder.py index 4c61171..731b79e 100644 --- a/transformer/MuLang/Eff/Base/Decoder.py +++ b/transformer/MuLang/Eff/Base/Decoder.py @@ -1,32 +1,34 @@ #encoding: utf-8 import torch -from torch import nn -from modules.mulang.eff.base import LayerNorm, MBLinear, ResSelfAttn, ResCrossAttn, PositionwiseFF -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt +from torch import nn -from cnfg.vocab.base import pad_id - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.mulang.eff.base import LayerNorm, MBLinear, PositionwiseFF, ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done, torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_decoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_decoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.self_attn.norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, ntask=ntask) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual, ntask=ntask) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=self.ff.norm_residual, ntask=ntask) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=self.ff.norm_residual, ntask=ntask) - def forward(self, inpute, inputo, taskid=None, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, taskid=None, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, taskid=taskid, mask=tgt_pad_mask) @@ -44,12 +46,12 @@ def forward(self, inpute, inputo, taskid=None, src_pad_mask=None, tgt_pad_mask=N class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, ntask=None, task_emb_w=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, ntask=None, task_emb_w=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=None, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=None, share_layer=share_layer, **kwargs) self.task_emb = nn.Embedding(ntask, isize, padding_idx=None) if task_emb_w is not None: @@ -58,17 +60,17 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. if bindemb: self.classifier.weight = self.wemb.weight if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, ntask=ntask) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) self.out_normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None if forbidden_index is not None: self.fbl = [tuple(set(fblu)) for fblu in forbidden_index] - def forward(self, inpute, inputo, taskid=None, src_pad_mask=None): + def forward(self, inpute, inputo, taskid=None, src_pad_mask=None, **kwargs): nquery = inputo.size(-1) @@ -91,18 +93,18 @@ def forward(self, inpute, inputo, taskid=None, src_pad_mask=None): return out - def decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, taskid, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, taskid, src_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, taskid, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, taskid, src_pad_mask, max_len, fill_pad=fill_pad, **kwargs) - def greedy_decode(self, inpute, taskid=None, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, taskid=None, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) out = self.get_sos_emb(inpute) _task_emb = self.task_emb.weight[taskid] - out = sos_emb + _task_emb + out = out + _task_emb if self.pemb is not None: sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) @@ -122,7 +124,7 @@ def greedy_decode(self, inpute, taskid=None, src_pad_mask=None, max_len=512, fil wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -144,13 +146,13 @@ def greedy_decode(self, inpute, taskid=None, src_pad_mask=None, max_len=512, fil trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -159,16 +161,15 @@ def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_l real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) _task_emb = self.task_emb.weight[taskid] if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty - out = sos_emb + _task_emb + out = out + _task_emb if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -192,7 +193,7 @@ def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_l _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -239,7 +240,7 @@ def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_l trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: @@ -268,17 +269,20 @@ def beam_decode(self, inpute, taskid=None, src_pad_mask=None, beam_size=8, max_l def fix_load(self): if self.fbl is not None: - with torch.no_grad(): + with torch_no_grad(): for ind, fblu in enumerate(self.fbl): self.classifier.bias[ind].index_fill_(0, torch.as_tensor(fblu, dtype=torch.long, device=self.classifier.bias.device), -inf_default) - def update_vocab(self, indices): + def update_vocab(self, indices, wemb_weight=None): - _nwd = len(indices) + _nwd = indices.numel() _wemb = nn.Embedding(_nwd, self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) _classifier = MBLinear(self.classifier.weight.size(-1), _nwd, self.classifier.bias.size(0)) - with torch.no_grad(): - _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + with torch_no_grad(): + if wemb_weight is None: + _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) + else: + _wemb.weight = wemb_weight if self.classifier.weight.is_set_to(self.wemb.weight): _classifier.weight = _wemb.weight else: diff --git a/transformer/MuLang/Eff/Base/Encoder.py b/transformer/MuLang/Eff/Base/Encoder.py index 3f1baa1..cf3c1cf 100644 --- a/transformer/MuLang/Eff/Base/Encoder.py +++ b/transformer/MuLang/Eff/Base/Encoder.py @@ -1,26 +1,27 @@ #encoding: utf-8 -from torch import nn -from modules.mulang.eff.base import LayerNorm, MWLinear, ResSelfAttn, PositionwiseFF from math import sqrt +from torch import nn -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.mulang.eff.base import LayerNorm, MWLinear, PositionwiseFF, ResSelfAttn +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_encoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, ntask=None, k_rel_pos=use_k_relative_position_encoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.attn.norm_residual, k_rel_pos=k_rel_pos, ntask=ntask) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=self.ff.norm_residual, ntask=ntask) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=self.ff.norm_residual, ntask=ntask) - def forward(self, inputs, taskid=None, mask=None): + def forward(self, inputs, taskid=None, mask=None, **kwargs): context = self.attn(inputs, taskid=taskid, mask=mask) @@ -30,30 +31,29 @@ def forward(self, inputs, taskid=None, mask=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, ntask=None, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, ntask=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) self.task_emb = nn.Embedding(ntask, isize, padding_idx=None) self.transo = MWLinear(isize, isize, ntask, bias=enable_proj_bias_default) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, ntask=ntask) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize, ntask=ntask) for i in range(num_layer)]) self.out_normer = LayerNorm(isize, ntask=ntask, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None - def forward(self, inputs, taskid=None, mask=None): + def forward(self, inputs, taskid=None, mask=None, **kwargs): - out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) + self.task_emb.weight[taskid] + out = self.wemb(inputs) + self.task_emb.weight[taskid] if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) diff --git a/transformer/MuLang/Eff/Base/NMT.py b/transformer/MuLang/Eff/Base/NMT.py index a26ff3c..710f9a2 100644 --- a/transformer/MuLang/Eff/Base/NMT.py +++ b/transformer/MuLang/Eff/Base/NMT.py @@ -1,24 +1,23 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.NMT import NMT as NMTBase - -from transformer.MuLang.Eff.Base.Encoder import Encoder from transformer.MuLang.Eff.Base.Decoder import Decoder +from transformer.MuLang.Eff.Base.Encoder import Encoder +from transformer.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, ntask=None, **kwargs): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, ntask=None, **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=None, **kwargs) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=None, **kwargs) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, ntask=ntask) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, ntask=ntask) if global_emb: emb_w = self.enc.wemb.weight @@ -26,20 +25,20 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ else: emb_w = task_emb_w = None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, ntask=ntask, task_emb_w=task_emb_w) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, ntask=ntask, task_emb_w=task_emb_w) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, taskid=None, mask=None): + def forward(self, inpute, inputo, taskid=None, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, taskid=taskid, mask=_mask), inputo, taskid=taskid, src_pad_mask=_mask) - def decode(self, inpute, taskid=None, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, taskid=None, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len return self.dec.decode(self.enc(inpute, taskid=taskid, mask=mask), taskid=taskid, src_pad_mask=mask, beam_size=beam_size, max_len=_max_len, length_penalty=length_penalty) diff --git a/transformer/NMT.py b/transformer/NMT.py index f34c811..e67df3c 100644 --- a/transformer/NMT.py +++ b/transformer/NMT.py @@ -3,17 +3,17 @@ import torch from torch import nn -from utils.base import all_done, select_zero_ -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - from transformer.Encoder import Encoder - # switch the comment between the following two lines to choose standard decoder or average decoder. Using transformer.TA.Decoder for Transparent Decoder. from transformer.Decoder import Decoder #from transformer.AvgDecoder import Decoder +from utils.base import select_zero_ +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class NMT(nn.Module): @@ -28,18 +28,18 @@ class NMT(nn.Module): # xseql: maxmimum length of sequence # ahsize: number of hidden units for MultiHeadAttention - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, **kwargs): super(NMT, self).__init__() enc_layer, dec_layer = parse_double_value_tuple(num_layer) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index) - #self.dec = Decoder(isize, tnwd, dec_layer, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index)# for RNMT + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index) + #self.dec = Decoder(isize, tnwd, dec_layer, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index)# for RNMT if rel_pos_enabled: share_rel_pos_cache(self) @@ -47,11 +47,11 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ # inpute: source sentences from encoder (bsize, seql) # inputo: decoded translation (bsize, nquery) # mask: user specified mask, otherwise it will be: - # inpute.eq(0).unsqueeze(1) + # inpute.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, _mask), inputo, _mask) @@ -59,9 +59,9 @@ def forward(self, inpute, inputo, mask=None): # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len @@ -80,18 +80,25 @@ def load_base(self, base_nmt): def update_vocab(self, src_indices=None, tgt_indices=None): - if (src_indices is not None) and hasattr(self.enc, "update_vocab"): - self.enc.update_vocab(src_indices) - if (tgt_indices is not None) and hasattr(self.dec, "update_vocab"): - self.dec.update_vocab(tgt_indices) + _share_emb, _sembw = False, None + _update_src, _update_tgt = src_indices is not None, tgt_indices is not None + if _update_src and _update_tgt and src_indices.equal(tgt_indices) and hasattr(self.enc, "get_embedding_weight") and hasattr(self.dec, "get_embedding_weight"): + _share_emb = self.enc.get_embedding_weight().is_set_to(self.dec.get_embedding_weight()) + if _update_src and hasattr(self.enc, "update_vocab"): + _ = self.enc.update_vocab(src_indices) + if _share_emb: + _sembw = _ + if _update_tgt and hasattr(self.dec, "update_vocab"): + self.dec.update_vocab(tgt_indices, wemb_weight=_sembw) def update_classifier(self, *args, **kwargs): - self.dec.update_classifier(*args, **kwargs) + if hasattr(self.dec, "update_classifier"): + self.dec.update_classifier(*args, **kwargs) def train_decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, mask=None): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len @@ -120,7 +127,7 @@ def train_greedy_decode(self, inpute, mask=None, max_len=512): out = torch.cat((out, wds), -1) # done_trans: (bsize) - done_trans = wds.squeeze(1).eq(2) if done_trans is None else (done_trans | wds.squeeze(1).eq(2)) + done_trans = wds.squeeze(1).eq(eos_id) if done_trans is None else (done_trans | wds.squeeze(1).eq(eos_id)) if all_done(done_trans, bsize): break @@ -190,7 +197,7 @@ def train_beam_decode(self, inpute, mask=None, beam_size=8, max_len=512, length_ out = torch.cat((out.index_select(0, _inds), wds), -1) # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(2)).view(bsize, beam_size) + done_trans = wds.view(bsize, beam_size).eq(eos_id) if done_trans is None else (done_trans.view(real_bsize).index_select(0, _inds) | wds.view(real_bsize).eq(eos_id)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) diff --git a/transformer/PLM/BART/Decoder.py b/transformer/PLM/BART/Decoder.py index 73664ab..083bd78 100644 --- a/transformer/PLM/BART/Decoder.py +++ b/transformer/PLM/BART/Decoder.py @@ -1,105 +1,61 @@ #encoding: utf-8 import torch -from torch import nn -from modules.dropout import Dropout -from modules.TA import ResSelfAttn, ResCrossAttn, PositionwiseFF -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, mask_tensor_type -from utils.plm.base import copy_plm_parameter -from cnfg.vocab.plm.roberta import pad_id, eos_id, pemb_start_ind from math import sqrt +from torch import nn -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.TA import PositionwiseFF, ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.plm.bart import load_plm_decoder_layer +from utils.plm.base import copy_plm_parameter +from utils.sampler import SampleMax +from utils.torch.comp import all_done, torch_no_grad from cnfg.plm.bart.base import remove_classifier_bias from cnfg.plm.bart.ihyp import * +from cnfg.vocab.plm.roberta import eos_id, pad_id, pemb_start_ind class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, model_name="decoder", **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, model_name="decoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) self.model_name = model_name - self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance, xseql=cache_len_default) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, enable_bias=enable_prev_ln_bias_default, enable_proj_bias=enable_proj_bias_default) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) - - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): - - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): - copy_plm_parameter(self.self_attn.net.adaptor.weight, plm_parameters, ["%s.layers.%d.self_attn.q_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - _bias_key = "%s.layers.%d.self_attn.q_proj.bias" % (_model_name, layer_idx,) - if self.self_attn.net.adaptor.bias is None and (_bias_key in plm_parameters): - self.self_attn.net.adaptor.bias = nn.Parameter(torch.zeros(self.self_attn.net.adaptor.weight.size(0))) - if self.self_attn.net.adaptor.bias is not None: - copy_plm_parameter(self.self_attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.self_attn.k_proj.bias" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - copy_plm_parameter(self.self_attn.net.outer.weight, plm_parameters, "%s.layers.%d.self_attn.out_proj.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.self_attn.out_proj.bias" % (_model_name, layer_idx,) - if self.self_attn.net.outer.bias is None and (_bias_key in plm_parameters): - self.self_attn.net.outer.bias = nn.Parameter(torch.zeros(self.self_attn.net.outer.weight.size(0))) - if self.self_attn.net.outer.bias is not None: - copy_plm_parameter(self.self_attn.net.outer.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.self_attn.normer.weight, plm_parameters, "%s.layers.%d.self_attn_layer_norm.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.self_attn.normer.bias, plm_parameters, "%s.layers.%d.self_attn_layer_norm.bias" % (_model_name, layer_idx,)) - copy_plm_parameter(self.cross_attn.net.query_adaptor.weight, plm_parameters, "%s.layers.%d.encoder_attn.q_proj.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.encoder_attn.q_proj.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.query_adaptor.bias is None and (_bias_key in plm_parameters): - self.cross_attn.net.query_adaptor.bias = nn.Parameter(torch.zeros(self.cross_attn.net.query_adaptor.weight.size(0))) - if self.cross_attn.net.query_adaptor.bias is not None: - copy_plm_parameter(self.cross_attn.net.query_adaptor.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.cross_attn.net.kv_adaptor.weight, plm_parameters, ["%s.layers.%d.encoder_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.encoder_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - _bias_key = "%s.layers.%d.encoder_attn.k_proj.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.kv_adaptor.bias is None and (_bias_key in plm_parameters): - self.cross_attn.net.kv_adaptor.bias = nn.Parameter(torch.zeros(self.cross_attn.net.kv_adaptor.weight.size(0))) - if self.cross_attn.net.kv_adaptor.bias is not None: - copy_plm_parameter(self.cross_attn.net.kv_adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.encoder_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - copy_plm_parameter(self.cross_attn.net.outer.weight, plm_parameters, "%s.layers.%d.encoder_attn.out_proj.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.encoder_attn.out_proj.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.outer.bias is None and (_bias_key in plm_parameters): - self.cross_attn.net.outer.bias = nn.Parameter(torch.zeros(self.cross_attn.net.outer.weight.size(0))) - if self.cross_attn.net.outer.bias is not None: - copy_plm_parameter(self.cross_attn.net.outer.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.cross_attn.normer.weight, plm_parameters, "%s.layers.%d.encoder_attn_layer_norm.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.cross_attn.normer.bias, plm_parameters, "%s.layers.%d.encoder_attn_layer_norm.bias" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.net[0].weight, plm_parameters, "%s.layers.%d.fc1.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.net[0].bias, plm_parameters, "%s.layers.%d.fc1.bias" % (_model_name, layer_idx,)) - _l = self.ff.net[-2] if isinstance(self.ff.net[-1], Dropout) else self.ff.net[-1] - copy_plm_parameter(_l.weight, plm_parameters, "%s.layers.%d.fc2.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.fc2.bias" % (_model_name, layer_idx,) - if _l.bias is None and (_bias_key in plm_parameters): - _l.bias = nn.Parameter(torch.zeros(_l.weight.size(0))) - if _l.bias is not None: - copy_plm_parameter(_l.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.ff.normer.weight, plm_parameters, "%s.layers.%d.final_layer_norm.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.normer.bias, plm_parameters, "%s.layers.%d.final_layer_norm.bias" % (_model_name, layer_idx,)) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) + + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + load_plm_decoder_layer(self, plm_parameters, model_name=model_name, layer_idx=layer_idx, **kwargs) class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, model_name="decoder", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, model_name="decoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) self.model_name = model_name self.wemb.padding_idx = pad_id self.pemb = None if disable_pemb else nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) if share_layer: - _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) - def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False): + def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False, **kwargs): nquery = inputo.size(-1) @@ -121,7 +77,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False): return out - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -170,7 +126,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -207,7 +163,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -253,7 +209,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: @@ -289,20 +245,27 @@ def get_sos_emb(self, inpute, bsize=None): def fix_init(self): self.fix_load() - with torch.no_grad(): + with torch_no_grad(): #self.wemb.weight[pad_id].zero_() self.classifier.weight[pad_id].zero_() + if self.pemb is not None: + _ = sqrt(2.0 / sum(self.pemb.size())) + self.pemb.uniform_(- _, _) - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embed_tokens.weight" % _model_name) copy_plm_parameter(self.pemb, plm_parameters, "%s.embed_positions.weight" % _model_name) copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.layernorm_embedding.weight" % _model_name) copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.layernorm_embedding.bias" % _model_name) + if (not remove_classifier_bias) and ("final_logits_bias" in plm_parameters): + if self.classifier.bias is None: + self.classifier.bias = nn.Parameter(torch.zeros(self.classifier.weight.size(0))) + copy_plm_parameter(self.classifier.bias, plm_parameters, "final_logits_bias") for i, net in enumerate(self.nets): - net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i) + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) # BART does NOT have the bias vector in the classifier if remove_classifier_bias: self.classifier.bias = None diff --git a/transformer/PLM/BART/Encoder.py b/transformer/PLM/BART/Encoder.py index ee09636..a2f0607 100644 --- a/transformer/PLM/BART/Encoder.py +++ b/transformer/PLM/BART/Encoder.py @@ -1,92 +1,85 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn -from modules.dropout import Dropout + from modules.TA import PositionwiseFF +from transformer.Encoder import Encoder as EncoderBase +from transformer.PLM.BERT.Encoder import EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none +from utils.plm.bart import load_plm_encoder_layer from utils.plm.base import copy_plm_parameter -from cnfg.vocab.plm.roberta import pad_id, pemb_start_ind -from cnfg.plm.bart.ihyp import * +from utils.torch.comp import torch_no_grad -from transformer.PLM.BERT.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from cnfg.plm.bart.ihyp import * +from cnfg.vocab.plm.roberta import pad_id, pemb_start_ind class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="encoder", **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="encoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, model_name=model_name, **kwargs) - - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) - - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): - - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): - copy_plm_parameter(self.attn.net.adaptor.weight, plm_parameters, ["%s.layers.%d.self_attn.q_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - _bias_key = "%s.layers.%d.self_attn.q_proj.bias" % (_model_name, layer_idx,) - if self.attn.net.adaptor.bias is None and (_bias_key in plm_parameters): - self.attn.net.adaptor.bias = nn.Parameter(torch.zeros(self.attn.net.adaptor.weight.size(0))) - if self.attn.net.adaptor.bias is not None: - copy_plm_parameter(self.attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.self_attn.k_proj.bias" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) - copy_plm_parameter(self.attn.net.outer.weight, plm_parameters, "%s.layers.%d.self_attn.out_proj.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.self_attn.out_proj.bias" % (_model_name, layer_idx,) - if self.attn.net.outer.bias is None and (_bias_key in plm_parameters): - self.attn.net.outer.bias = nn.Parameter(torch.zeros(self.attn.net.outer.weight.size(0))) - if self.attn.net.outer.bias is not None: - copy_plm_parameter(self.attn.net.outer.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.attn.normer.weight, plm_parameters, "%s.layers.%d.self_attn_layer_norm.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.attn.normer.bias, plm_parameters, "%s.layers.%d.self_attn_layer_norm.bias" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.net[0].weight, plm_parameters, "%s.layers.%d.fc1.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.net[0].bias, plm_parameters, "%s.layers.%d.fc1.bias" % (_model_name, layer_idx,)) - _l = self.ff.net[-2] if isinstance(self.ff.net[-1], Dropout) else self.ff.net[-1] - copy_plm_parameter(_l.weight, plm_parameters, "%s.layers.%d.fc2.weight" % (_model_name, layer_idx,)) - _bias_key = "%s.layers.%d.fc2.bias" % (_model_name, layer_idx,) - if _l.bias is None and (_bias_key in plm_parameters): - _l.bias = nn.Parameter(torch.zeros(_l.weight.size(0))) - if _l.bias is not None: - copy_plm_parameter(_l.bias, plm_parameters, _bias_key) - copy_plm_parameter(self.ff.normer.weight, plm_parameters, "%s.layers.%d.final_layer_norm.weight" % (_model_name, layer_idx,)) - copy_plm_parameter(self.ff.normer.bias, plm_parameters, "%s.layers.%d.final_layer_norm.bias" % (_model_name, layer_idx,)) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, model_name=model_name, **kwargs) + + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) + + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + load_plm_encoder_layer(self, plm_parameters, model_name=model_name, layer_idx=layer_idx, **kwargs) class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, share_layer=False, model_name="encoder", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, model_name="encoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, share_layer=share_layer, model_name=model_name, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + self.model_name = model_name + self.pemb = None if disable_pemb else nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) self.wemb.padding_idx = pad_id - self.temb = None if share_layer: - _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) def forward(self, inputs, mask=None, **kwargs): seql = inputs.size(1) - out = self.drop(self.out_normer(self.pemb.narrow(0, pemb_start_ind, seql) + self.wemb(inputs))) + out = self.wemb(inputs) + if self.pemb is not None: + out = out + self.pemb.narrow(0, pemb_start_ind, seql) + if self.out_normer is not None: + out = self.out_normer(out) + if self.drop is not None: + out = self.drop(out) - _mask = inputs.eq(pad_id).unsqueeze(1) if mask is None else mask for net in self.nets: - out = net(out, _mask) + out = net(out, mask) return out - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embed_tokens.weight" % _model_name) copy_plm_parameter(self.pemb, plm_parameters, "%s.embed_positions.weight" % _model_name) copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.layernorm_embedding.weight" % _model_name) copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.layernorm_embedding.bias" % _model_name) for i, net in enumerate(self.nets): - net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i) + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) + + def fix_init(self): + + super(Encoder, self).fix_init() + if self.pemb is not None: + with torch_no_grad(): + _ = sqrt(2.0 / sum(self.pemb.size())) + self.pemb.uniform_(- _, _) diff --git a/transformer/PLM/BART/NMT.py b/transformer/PLM/BART/NMT.py index 65f22c0..85afdf4 100644 --- a/transformer/PLM/BART/NMT.py +++ b/transformer/PLM/BART/NMT.py @@ -1,45 +1,45 @@ #encoding: utf-8 +from transformer.PLM.BART.Decoder import Decoder +from transformer.PLM.BART.Encoder import Encoder +from transformer.PLM.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple, parse_none from utils.plm.base import set_ln_ieps from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple -from cnfg.vocab.plm.roberta import pad_id - -from transformer.PLM.BART.Encoder import Encoder -from transformer.PLM.BART.Decoder import Decoder -from transformer.PLM.BERT.NMT import NMT as NMTBase from cnfg.plm.bart.ihyp import * +from cnfg.vocab.plm.roberta import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name=("encoder", "decoder",)): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name=("encoder", "decoder",), **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name, **kwargs) - enc_model_name, dec_model_name = parse_double_value_tuple(model_name) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) + self.model_name = model_name + enc_model_name, dec_model_name = parse_double_value_tuple(self.model_name) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=dec_model_name) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=dec_model_name) set_ln_ieps(self, ieps_ln_default) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, mask=None, word_prediction=False): + def forward(self, inpute, inputo, mask=None, word_prediction=False, **kwargs): _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, _mask), inputo, _mask, word_prediction=word_prediction) - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len diff --git a/transformer/PLM/BERT/Decoder.py b/transformer/PLM/BERT/Decoder.py index 28e51bf..e8383fc 100644 --- a/transformer/PLM/BERT/Decoder.py +++ b/transformer/PLM/BERT/Decoder.py @@ -2,23 +2,25 @@ import torch from torch import nn -from modules.base import Linear + from modules.act import Custom_Act, GELU +from modules.base import Linear from modules.dropout import Dropout - +from utils.fmt.parser import parse_none from utils.plm.base import copy_plm_parameter -#from cnfg.vocab.plm.bert import pad_id +from utils.torch.comp import torch_no_grad from cnfg.plm.bert.ihyp import * +#from cnfg.vocab.plm.bert import pad_id class Decoder(nn.Module): - def __init__(self, isize, nwd, num_layer=None, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, model_name="bert", **kwargs): + def __init__(self, isize, nwd, num_layer=None, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, model_name="bert", **kwargs): super(Decoder, self).__init__() + self.model_name = model_name self.drop = Dropout(dropout, inplace=True) if dropout > 0.0 else None - self.ff = nn.Sequential(Linear(isize, isize), Custom_Act() if use_adv_act_default else GELU(), nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters)) self.classifier = Linear(isize, nwd) self.lsm = nn.LogSoftmax(-1) @@ -27,21 +29,21 @@ def __init__(self, isize, nwd, num_layer=None, fhsize=None, dropout=0.0, attn_dr self.rel_classifier = Linear(isize, 2) self.pooler = nn.Sequential(Linear(isize, isize), nn.Tanh()) - def forward(self, inpute, *args, word_prediction=False, **kwargs): + def forward(self, inpute, *args, mlm_mask=None, word_prediction=False, **kwargs): - out = self.ff(inpute) + out = self.ff(inpute if mlm_mask is None else inpute[mlm_mask]) if word_prediction: out = self.lsm(self.classifier(out)) return out - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.ff[0].weight, plm_parameters, "cls.predictions.transform.dense.weight") _bias_key = "cls.predictions.transform.dense.bias" - if self.ff[0].bias is None and (_bias_key in plm_parameters): + if (self.ff[0].bias is None) and (_bias_key in plm_parameters): self.ff[0].bias = nn.Parameter(torch.zeros(self.ff[0].weight.size(0))) if self.ff[0].bias is not None: copy_plm_parameter(self.ff[0].bias, plm_parameters, _bias_key) @@ -54,11 +56,28 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): copy_plm_parameter(self.pooler[0].weight, plm_parameters, "%s.pooler.dense.weight" % _model_name) copy_plm_parameter(self.pooler[0].bias, plm_parameters, "%s.pooler.dense.bias" % _model_name) + def get_embedding_weight(self): + + return self.classifier.weight + + def update_vocab(self, indices, wemb_weight=None): + + _nwd = indices.numel() + _classifier = Linear(self.classifier.weight.size(-1), _nwd, bias=self.classifier.bias is not None) + with torch_no_grad(): + if wemb_weight is None: + _classifier.weight.copy_(self.classifier.weight.index_select(0, indices)) + else: + _classifier.weight = wemb_weight + if self.classifier.bias is not None: + _classifier.bias.copy_(self.classifier.bias.index_select(0, indices)) + self.classifier = _classifier + def update_classifier(self, indices): _nwd = indices.numel() _classifier = Linear(self.classifier.weight.size(-1), _nwd, bias=self.classifier.bias is not None) - with torch.no_grad(): + with torch_no_grad(): _classifier.weight.copy_(self.classifier.weight.index_select(0, indices)) if self.classifier.bias is not None: _classifier.bias.copy_(self.classifier.bias.index_select(0, indices)) diff --git a/transformer/PLM/BERT/Encoder.py b/transformer/PLM/BERT/Encoder.py index f8517e6..0f8f6ae 100644 --- a/transformer/PLM/BERT/Encoder.py +++ b/transformer/PLM/BERT/Encoder.py @@ -1,44 +1,41 @@ #encoding: utf-8 import torch -from torch import nn -from modules.dropout import Dropout -from modules.plm.bert import PositionwiseFF from math import sqrt +from torch import nn +from modules.dropout import Dropout +from transformer.Encoder import Encoder as EncoderBase +from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from utils.plm.base import copy_plm_parameter -from cnfg.vocab.plm.bert import pad_id +from utils.torch.comp import torch_no_grad + from cnfg.plm.bert.base import num_type from cnfg.plm.bert.ihyp import * - -from transformer.TA.Encoder import EncoderLayer as EncoderLayerBase -from transformer.Encoder import Encoder as EncoderBase +from cnfg.vocab.plm.bert import pad_id class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="bert", **kwargs): - - _ahsize = isize if ahsize is None else ahsize - _fhsize = _ahsize * 4 if fhsize is None else fhsize + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="bert", **kwargs): - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) self.model_name = model_name - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.attn.net.adaptor.weight, plm_parameters, ["%s.encoder.layer.%d.attention.self.query.weight" % (_model_name, layer_idx,), "%s.encoder.layer.%d.attention.self.key.weight" % (_model_name, layer_idx,), "%s.encoder.layer.%d.attention.self.value.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) _bias_key = "%s.encoder.layer.%d.attention.self.query.bias" % (_model_name, layer_idx,) - if self.attn.net.adaptor.bias is None and (_bias_key in plm_parameters): + if (self.attn.net.adaptor.bias is None) and (_bias_key in plm_parameters): self.attn.net.adaptor.bias = nn.Parameter(torch.zeros(self.attn.net.adaptor.weight.size(0))) if self.attn.net.adaptor.bias is not None: copy_plm_parameter(self.attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.encoder.layer.%d.attention.self.key.bias" % (_model_name, layer_idx,), "%s.encoder.layer.%d.attention.self.value.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) copy_plm_parameter(self.attn.net.outer.weight, plm_parameters, "%s.encoder.layer.%d.attention.output.dense.weight" % (_model_name, layer_idx,)) _bias_key = "%s.encoder.layer.%d.attention.output.dense.bias" % (_model_name, layer_idx,) - if self.attn.net.outer.bias is None and (_bias_key in plm_parameters): + if (self.attn.net.outer.bias is None) and (_bias_key in plm_parameters): self.attn.net.outer.bias = nn.Parameter(torch.zeros(self.attn.net.outer.weight.size(0))) if self.attn.net.outer.bias is not None: copy_plm_parameter(self.attn.net.outer.bias, plm_parameters, _bias_key) @@ -49,7 +46,7 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): _l = self.ff.net[-2] if isinstance(self.ff.net[-1], Dropout) else self.ff.net[-1] copy_plm_parameter(_l.weight, plm_parameters, "%s.encoder.layer.%d.output.dense.weight" % (_model_name, layer_idx,)) _bias_key = "%s.encoder.layer.%d.output.dense.bias" % (_model_name, layer_idx,) - if _l.bias is None and (_bias_key in plm_parameters): + if (_l.bias is None) and (_bias_key in plm_parameters): _l.bias = nn.Parameter(torch.zeros(_l.weight.size(0))) if _l.bias is not None: copy_plm_parameter(_l.bias, plm_parameters, _bias_key) @@ -58,50 +55,59 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, num_type=num_type, share_layer=False, model_name="bert", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, num_type=num_type, share_layer=False, disable_pemb=disable_std_pemb_encoder, model_name="bert", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) self.model_name = model_name - self.pemb = nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) + self.pemb = None if disable_pemb else nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) self.temb = nn.Embedding(num_type, isize) self.wemb.padding_idx = pad_id if share_layer: - _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) - def forward(self, inputs, token_types=None, mask=None): + def forward(self, inputs, token_types=None, mask=None, **kwargs): seql = inputs.size(1) - out = self.drop(self.out_normer(self.pemb.narrow(0, 0, seql) + (self.temb.weight[0] if token_types is None else self.temb(token_types)) + self.wemb(inputs))) + out = None if self.pemb is None else self.pemb.narrow(0, 0, seql) + if self.temb is not None: + _ = self.temb.weight[0] if token_types is None else self.temb(token_types) + out = _ if out is None else (out + _) + _ = self.wemb(inputs) + out = _ if out is None else (out + _) + if self.out_normer is not None: + out = self.out_normer(out) + if self.drop is not None: + out = self.drop(out) - _mask = inputs.eq(pad_id).unsqueeze(1) if mask is None else mask for net in self.nets: - out = net(out, _mask) + out = net(out, mask) return out - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embeddings.word_embeddings.weight" % _model_name) copy_plm_parameter(self.pemb, plm_parameters, "%s.embeddings.position_embeddings.weight" % _model_name) copy_plm_parameter(self.temb.weight, plm_parameters, "%s.embeddings.token_type_embeddings.weight" % _model_name) copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.embeddings.LayerNorm.weight" % _model_name) copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.embeddings.LayerNorm.bias" % _model_name) for i, net in enumerate(self.nets): - net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i) + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) def fix_init(self): super(Encoder, self).fix_init() - with torch.no_grad(): - _ = sqrt(2.0 / sum(self.pemb.size())) - self.pemb.uniform_(- _, _) + if self.pemb is not None: + with torch_no_grad(): + _ = sqrt(2.0 / sum(self.pemb.size())) + self.pemb.uniform_(- _, _) diff --git a/transformer/PLM/BERT/NMT.py b/transformer/PLM/BERT/NMT.py index 450d2be..5f230cb 100644 --- a/transformer/PLM/BERT/NMT.py +++ b/transformer/PLM/BERT/NMT.py @@ -1,52 +1,40 @@ #encoding: utf-8 +from transformer.PLM.BERT.Decoder import Decoder +from transformer.PLM.BERT.Encoder import Encoder +from transformer.PLM.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple, parse_none from utils.plm.base import set_ln_ieps from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple -from cnfg.vocab.plm.bert import pad_id - -from transformer.PLM.BERT.Encoder import Encoder -from transformer.PLM.BERT.Decoder import Decoder -from transformer.NMT import NMT as NMTBase from cnfg.plm.bert.ihyp import * +from cnfg.vocab.plm.bert import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name="bert"): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name="bert", **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name, **kwargs) self.model_name = model_name enc_model_name, dec_model_name = parse_double_value_tuple(self.model_name) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, model_name=dec_model_name)#, num_layer=dec_layer, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index + self.dec = Decoder(isize, tnwd, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, model_name=dec_model_name)#, num_layer=dec_layer, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index set_ln_ieps(self, ieps_ln_default) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, token_types=None, mask=None, word_prediction=False): + def forward(self, inpute, token_types=None, mask=None, word_prediction=False, **kwargs): _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, token_types=token_types, mask=_mask), word_prediction=word_prediction) - - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): - - _model_name = self.model_name if model_name is None else model_name - enc_model_name, dec_model_name = parse_double_value_tuple(_model_name) - self.enc.load_plm(plm_parameters, model_name=enc_model_name, layer_idx=layer_idx) - self.dec.load_plm(plm_parameters, model_name=dec_model_name, layer_idx=layer_idx) - - def update_classifier(self, *args, **kwargs): - - self.dec.update_classifier(*args, **kwargs) diff --git a/transformer/PLM/MBART/Decoder.py b/transformer/PLM/MBART/Decoder.py new file mode 100644 index 0000000..37e9b86 --- /dev/null +++ b/transformer/PLM/MBART/Decoder.py @@ -0,0 +1,288 @@ +#encoding: utf-8 + +import torch +from math import sqrt +from torch import nn + +from modules.plm.mbart import PositionwiseFF, ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none +from utils.plm.bart import load_plm_decoder_layer +from utils.plm.base import copy_plm_parameter +from utils.sampler import SampleMax +from utils.torch.comp import all_done, torch_no_grad + +from cnfg.plm.mbart.base import remove_classifier_bias +from cnfg.plm.mbart.ihyp import * +from cnfg.vocab.plm.mbart import eos_id, pad_id, pemb_start_ind + +class DecoderLayer(DecoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, model_name="model.decoder", **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + + self.model_name = model_name + self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) + + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + load_plm_decoder_layer(self, plm_parameters, model_name=model_name, layer_idx=layer_idx, **kwargs) + +class Decoder(DecoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, model_name="model.decoder", **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + + self.model_name = model_name + self.wemb.padding_idx = pad_id + self.pemb = None if disable_pemb else nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) + self.emb_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + if share_layer: + _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) + + def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False, **kwargs): + + nquery = inputo.size(-1) + + out = self.wemb(inputo) + if self.pemb is not None: + out = self.pemb.narrow(0, pemb_start_ind, nquery).add(out, alpha=sqrt(out.size(-1))) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + _mask = self._get_subsequent_mask(nquery) + + for net in self.nets: + out = net(inpute, out, src_pad_mask, _mask) + + if self.out_normer is not None: + out = self.out_normer(out) + + if word_prediction: + out = self.lsm(self.classifier(out)) + + return out + + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, lang_id=None, **kwargs): + + bsize = inpute.size(0) + + out = self.get_sos_emb(inpute, lang_id=lang_id) + if self.pemb is not None: + sqrt_isize = sqrt(out.size(-1)) + out = self.pemb[pemb_start_ind].add(out, alpha=sqrt_isize) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans = [wds] + done_trans = wds.eq(eos_id) + + for i in range(1, max_len): + + out = self.wemb(wds) + if self.pemb is not None: + out = self.pemb[pemb_start_ind + i].add(out, alpha=sqrt_isize) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, states[_tmp], src_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.classifier(out) + wds = SampleMax(out.softmax(-1), dim=-1, keepdim=False) if sample else out.argmax(dim=-1) + + trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) + + done_trans = done_trans | wds.eq(eos_id) + if all_done(done_trans, bsize): + break + + return torch.cat(trans, 1) + + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, lang_id=None, **kwargs): + + bsize, seql = inpute.size()[:2] + + beam_size2 = beam_size * beam_size + bsizeb2 = bsize * beam_size2 + real_bsize = bsize * beam_size + + out = self.get_sos_emb(inpute, lang_id=lang_id) + + if length_penalty > 0.0: + lpv = out.new_ones(real_bsize, 1) + lpv_base = 6.0 ** length_penalty + + if self.pemb is not None: + sqrt_isize = sqrt(out.size(-1)) + out = self.pemb[pemb_start_ind].add(out, alpha=sqrt_isize) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + states = {} + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, (None, None,), src_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)) + + scores, wds = out.topk(beam_size, dim=-1) + scores = scores.squeeze(1) + sum_scores = scores + wds = wds.view(real_bsize, 1) + trans = wds + _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) + _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) + + done_trans = wds.view(bsize, beam_size).eq(eos_id) + + self.repeat_cross_attn_buffer(beam_size) + + _src_pad_mask = None if src_pad_mask is None else src_pad_mask.repeat(1, beam_size, 1).view(real_bsize, 1, seql) + + states = expand_bsize_for_beam(states, beam_size=beam_size) + + for step in range(1, max_len): + + out = self.wemb(wds) + if self.pemb is not None: + out = self.pemb[pemb_start_ind + step].add(out, alpha=sqrt_isize) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + for _tmp, net in enumerate(self.nets): + out, _state = net(inpute, states[_tmp], _src_pad_mask, None, out) + states[_tmp] = _state + + if self.out_normer is not None: + out = self.out_normer(out) + + out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) + + _scores, _wds = out.topk(beam_size, dim=-1) + _done_trans_unsqueeze = done_trans.unsqueeze(2) + _scores = (_scores.masked_fill(_done_trans_unsqueeze.expand(bsize, beam_size, beam_size), 0.0) + sum_scores.unsqueeze(2).repeat(1, 1, beam_size).masked_fill_(select_zero_(_done_trans_unsqueeze.repeat(1, 1, beam_size), -1, 0), -inf_default)) + + if length_penalty > 0.0: + lpv.masked_fill_(~done_trans.view(real_bsize, 1), ((step + 6.0) ** length_penalty) / lpv_base) + + if clip_beam and (length_penalty > 0.0): + scores, _inds = (_scores.view(real_bsize, beam_size) / lpv.expand(real_bsize, beam_size)).view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + _inds_add_beam2).view(real_bsize) + sum_scores = _scores.view(bsizeb2).index_select(0, _tinds).view(bsize, beam_size) + else: + scores, _inds = _scores.view(bsize, beam_size2).topk(beam_size, dim=-1) + _tinds = (_inds + _inds_add_beam2).view(real_bsize) + sum_scores = scores + + wds = _wds.view(bsizeb2).index_select(0, _tinds).view(real_bsize, 1) + + _inds = (_inds // beam_size + _inds_add_beam).view(real_bsize) + + trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) + + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) + + _done = False + if length_penalty > 0.0: + lpv = lpv.index_select(0, _inds) + elif (not return_all) and all_done(done_trans.select(1, 0), bsize): + _done = True + + if _done or all_done(done_trans, real_bsize): + break + + states = index_tensors(states, indices=_inds, dim=0) + + if (not clip_beam) and (length_penalty > 0.0): + scores = scores / lpv.view(bsize, beam_size) + scores, _inds = scores.topk(beam_size, dim=-1) + _inds = (_inds + _inds_add_beam).view(real_bsize) + trans = trans.view(real_bsize, -1).index_select(0, _inds) + + if return_all: + + return trans.view(bsize, beam_size, -1), scores + else: + + return trans.view(bsize, beam_size, -1).select(1, 0) + + # MBART starts decoding with the language ID + def get_sos_emb(self, inpute, bsize=None, lang_id=None): + + if isinstance(lang_id, torch.Tensor): + return self.wemb(lang_id).unsqueeze(1) + else: + bsize = inpute.size(0) if bsize is None else bsize + return self.wemb.weight[lang_id].view(1, 1, -1).expand(bsize, 1, -1) + + def fix_init(self): + + self.fix_load() + with torch_no_grad(): + #self.wemb.weight[pad_id].zero_() + self.classifier.weight[pad_id].zero_() + if self.pemb is not None: + _ = sqrt(2.0 / sum(self.pemb.size())) + self.pemb.uniform_(- _, _) + + def load_plm(self, plm_parameters, model_name=None, **kwargs): + + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): + copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embed_tokens.weight" % _model_name) + copy_plm_parameter(self.pemb, plm_parameters, "%s.embed_positions.weight" % _model_name) + copy_plm_parameter(self.emb_normer.weight, plm_parameters, "%s.layernorm_embedding.weight" % _model_name) + copy_plm_parameter(self.emb_normer.bias, plm_parameters, "%s.layernorm_embedding.bias" % _model_name) + copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.layer_norm.weight" % _model_name) + copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.layer_norm.bias" % _model_name) + if (not remove_classifier_bias) and ("final_logits_bias" in plm_parameters): + if self.classifier.bias is None: + self.classifier.bias = nn.Parameter(torch.zeros(self.classifier.weight.size(0))) + copy_plm_parameter(self.classifier.bias, plm_parameters, "final_logits_bias") + for i, net in enumerate(self.nets): + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) + # MBART does NOT have the bias vector in the classifier + if remove_classifier_bias: + self.classifier.bias = None diff --git a/transformer/PLM/MBART/Encoder.py b/transformer/PLM/MBART/Encoder.py new file mode 100644 index 0000000..9f9f3ef --- /dev/null +++ b/transformer/PLM/MBART/Encoder.py @@ -0,0 +1,88 @@ +#encoding: utf-8 + +import torch +from math import sqrt +from torch import nn + +from modules.plm.mbart import PositionwiseFF, ResSelfAttn +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none +from utils.plm.bart import load_plm_encoder_layer +from utils.plm.base import copy_plm_parameter +from utils.torch.comp import torch_no_grad + +from cnfg.plm.mbart.ihyp import * +from cnfg.vocab.plm.mbart import pad_id, pemb_start_ind + +class EncoderLayer(EncoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="model.encoder", **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + + self.model_name = model_name + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) + + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + load_plm_encoder_layer(self, plm_parameters, model_name=model_name, layer_idx=layer_idx, **kwargs) + +class Encoder(EncoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, model_name="model.encoder", **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + + self.model_name = model_name + self.pemb = None if disable_pemb else nn.Parameter(torch.Tensor(xseql, isize).uniform_(- sqrt(2.0 / (isize + xseql)), sqrt(2.0 / (isize + xseql)))) + self.wemb.padding_idx = pad_id + self.emb_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) + + if share_layer: + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) for i in range(num_layer)]) + + def forward(self, inputs, mask=None, **kwargs): + + seql = inputs.size(1) + out = self.wemb(inputs) + if self.pemb is not None: + out = self.pemb.narrow(0, pemb_start_ind, seql).add(out, alpha=sqrt(out.size(-1))) + out = self.emb_normer(out) + if self.drop is not None: + out = self.drop(out) + + for net in self.nets: + out = net(out, mask) + + return out if self.out_normer is None else self.out_normer(out) + + def load_plm(self, plm_parameters, model_name=None, **kwargs): + + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): + copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embed_tokens.weight" % _model_name) + copy_plm_parameter(self.pemb, plm_parameters, "%s.embed_positions.weight" % _model_name) + copy_plm_parameter(self.emb_normer.weight, plm_parameters, "%s.layernorm_embedding.weight" % _model_name) + copy_plm_parameter(self.emb_normer.bias, plm_parameters, "%s.layernorm_embedding.bias" % _model_name) + copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.layer_norm.weight" % _model_name) + copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.layer_norm.bias" % _model_name) + for i, net in enumerate(self.nets): + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) + + def fix_init(self): + + super(Encoder, self).fix_init() + if self.pemb is not None: + with torch_no_grad(): + _ = sqrt(2.0 / sum(self.pemb.size())) + self.pemb.uniform_(- _, _) diff --git a/transformer/PLM/MBART/NMT.py b/transformer/PLM/MBART/NMT.py new file mode 100644 index 0000000..bb1a860 --- /dev/null +++ b/transformer/PLM/MBART/NMT.py @@ -0,0 +1,47 @@ +#encoding: utf-8 + +from transformer.PLM.MBART.Decoder import Decoder +from transformer.PLM.MBART.Encoder import Encoder +from transformer.PLM.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple, parse_none +from utils.plm.base import set_ln_ieps +from utils.relpos.base import share_rel_pos_cache + +from cnfg.plm.mbart.ihyp import * +from cnfg.vocab.plm.mbart import pad_id + +class NMT(NMTBase): + + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name=("model.encoder", "model.decoder",), **kwargs): + + enc_layer, dec_layer = parse_double_value_tuple(num_layer) + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name, **kwargs) + + self.model_name = model_name + enc_model_name, dec_model_name = parse_double_value_tuple(self.model_name) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) + + emb_w = self.enc.wemb.weight if global_emb else None + + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=dec_model_name) + + set_ln_ieps(self, ieps_ln_default) + if rel_pos_enabled: + share_rel_pos_cache(self) + + def forward(self, inpute, inputo, mask=None, word_prediction=False, **kwargs): + + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask + + return self.dec(self.enc(inpute, _mask), inputo, _mask, word_prediction=word_prediction) + + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, lang_id=None, **kwargs): + + mask = inpute.eq(pad_id).unsqueeze(1) + _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len + + return self.dec.decode(self.enc(inpute, mask), mask, beam_size, _max_len, length_penalty, lang_id=lang_id) diff --git a/transformer/PLM/MBART/__init__.py b/transformer/PLM/MBART/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/PLM/MBART/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/PLM/NMT.py b/transformer/PLM/NMT.py new file mode 100644 index 0000000..ff28cef --- /dev/null +++ b/transformer/PLM/NMT.py @@ -0,0 +1,15 @@ +#encoding: utf-8 + +from transformer.NMT import NMT as NMTBase +from utils.fmt.parser import parse_double_value_tuple, parse_none + +class NMT(NMTBase): + + def load_plm(self, plm_parameters, model_name=None, **kwargs): + + _model_name = parse_none(model_name, self.model_name) + enc_model_name, dec_model_name = parse_double_value_tuple(_model_name) + if hasattr(self, "enc") and hasattr(self.enc, "load_plm"): + self.enc.load_plm(plm_parameters, model_name=enc_model_name, **kwargs) + if hasattr(self, "dec") and hasattr(self.dec, "load_plm"): + self.dec.load_plm(plm_parameters, model_name=dec_model_name, **kwargs) diff --git a/transformer/PLM/RoBERTa/Decoder.py b/transformer/PLM/RoBERTa/Decoder.py index e9fa0ed..ee992d4 100644 --- a/transformer/PLM/RoBERTa/Decoder.py +++ b/transformer/PLM/RoBERTa/Decoder.py @@ -3,28 +3,29 @@ import torch from torch import nn -from utils.plm.base import copy_plm_parameter -#from cnfg.vocab.plm.roberta import pad_id - from transformer.PLM.BERT.Decoder import Decoder as DecoderBase +from utils.fmt.parser import parse_none +from utils.plm.base import copy_plm_parameter +from utils.torch.comp import torch_no_grad #from cnfg.plm.roberta.ihyp import * +#from cnfg.vocab.plm.roberta import pad_id class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer=None, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, model_name="roberta", **kwargs): + def __init__(self, isize, nwd, num_layer=None, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, model_name="roberta", **kwargs): - super(Decoder, self).__init__(isize, nwd, num_layer=num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, model_name=model_name, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer=num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, model_name=model_name, **kwargs) self.rel_classifier = None - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.ff[0].weight, plm_parameters, "lm_head.dense.weight") _bias_key = "lm_head.dense.bias" - if self.ff[0].bias is None and (_bias_key in plm_parameters): + if (self.ff[0].bias is None) and (_bias_key in plm_parameters): self.ff[0].bias = nn.Parameter(torch.zeros(self.ff[0].weight.size(0))) if self.ff[0].bias is not None: copy_plm_parameter(self.ff[0].bias, plm_parameters, _bias_key) diff --git a/transformer/PLM/RoBERTa/Encoder.py b/transformer/PLM/RoBERTa/Encoder.py index 927e454..8c8d286 100644 --- a/transformer/PLM/RoBERTa/Encoder.py +++ b/transformer/PLM/RoBERTa/Encoder.py @@ -1,26 +1,58 @@ #encoding: utf-8 -from cnfg.vocab.plm.roberta import pad_id, pemb_start_ind -from cnfg.plm.roberta.base import num_type -from cnfg.plm.roberta.ihyp import * - from transformer.PLM.BERT.Encoder import Encoder as EncoderBase +from utils.fmt.parser import parse_none +from utils.plm.base import copy_plm_parameter +from utils.torch.comp import torch_all, torch_no_grad + +from cnfg.plm.roberta.base import eliminate_type_emb, num_type +from cnfg.plm.roberta.ihyp import * +from cnfg.vocab.plm.roberta import pad_id, pemb_start_ind class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, num_type=num_type, share_layer=False, model_name="roberta", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, num_type=num_type, share_layer=False, model_name="roberta", eliminate_type_emb=eliminate_type_emb, **kwargs): - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, num_type=num_type, share_layer=share_layer, model_name=model_name, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, num_type=num_type, share_layer=share_layer, model_name=model_name, **kwargs) self.wemb.padding_idx = pad_id + self.eliminate_type_emb = eliminate_type_emb - def forward(self, inputs, token_types=None, mask=None): + def forward(self, inputs, token_types=None, mask=None, **kwargs): seql = inputs.size(1) - out = self.drop(self.out_normer(self.pemb.narrow(0, pemb_start_ind, seql) + (self.temb.weight[0] if token_types is None else self.temb(token_types)) + self.wemb(inputs))) + out = None if self.pemb is None else self.pemb.narrow(0, pemb_start_ind, seql) + if self.temb is not None: + _ = self.temb.weight[0] if token_types is None else self.temb(token_types) + out = _ if out is None else (out + _) + _ = self.wemb(inputs) + out = _ if out is None else (out + _) + if self.out_normer is not None: + out = self.out_normer(out) + if self.drop is not None: + out = self.drop(out) - _mask = inputs.eq(pad_id).unsqueeze(1) if mask is None else mask for net in self.nets: - out = net(out, _mask) + out = net(out, mask) return out + + def load_plm(self, plm_parameters, model_name=None, **kwargs): + + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): + copy_plm_parameter(self.wemb.weight, plm_parameters, "%s.embeddings.word_embeddings.weight" % _model_name) + copy_plm_parameter(self.pemb, plm_parameters, "%s.embeddings.position_embeddings.weight" % _model_name) + _temb_key = "%s.embeddings.token_type_embeddings.weight" % _model_name + if self.eliminate_type_emb and (self.temb.weight.size(0) == 1): + _temb_w = plm_parameters[_temb_key] + if not torch_all(_temb_w.eq(0.0)).item(): + self.wemb.weight.add_(_temb_w) + self.wemb.weight[pad_id].sub_(_temb_w) + self.temb = None + else: + copy_plm_parameter(self.temb.weight, plm_parameters, _temb_key) + copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.embeddings.LayerNorm.weight" % _model_name) + copy_plm_parameter(self.out_normer.bias, plm_parameters, "%s.embeddings.LayerNorm.bias" % _model_name) + for i, net in enumerate(self.nets): + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) diff --git a/transformer/PLM/RoBERTa/NMT.py b/transformer/PLM/RoBERTa/NMT.py index 3bc18fc..ff1c8f1 100644 --- a/transformer/PLM/RoBERTa/NMT.py +++ b/transformer/PLM/RoBERTa/NMT.py @@ -1,39 +1,39 @@ #encoding: utf-8 +from transformer.PLM.NMT import NMT as NMTBase +from transformer.PLM.RoBERTa.Decoder import Decoder +from transformer.PLM.RoBERTa.Encoder import Encoder +from utils.fmt.parser import parse_double_value_tuple, parse_none from utils.plm.base import set_ln_ieps from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple -from cnfg.vocab.plm.roberta import pad_id - -from transformer.PLM.RoBERTa.Encoder import Encoder -from transformer.PLM.RoBERTa.Decoder import Decoder -from transformer.PLM.BERT.NMT import NMT as NMTBase from cnfg.plm.roberta.ihyp import * +from cnfg.vocab.plm.roberta import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name="roberta"): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name="roberta", **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name, **kwargs) + self.model_name = model_name enc_model_name, dec_model_name = parse_double_value_tuple(self.model_name) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, model_name=dec_model_name)#, num_layer=dec_layer, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index + self.dec = Decoder(isize, tnwd, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, model_name=dec_model_name)#, num_layer=dec_layer, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index set_ln_ieps(self, ieps_ln_default) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, token_types=None, mask=None, word_prediction=False): + def forward(self, inpute, token_types=None, mask=None, word_prediction=False, **kwargs): _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask diff --git a/transformer/PLM/T5/Decoder.py b/transformer/PLM/T5/Decoder.py index 2eb0b4e..f2c8e22 100644 --- a/transformer/PLM/T5/Decoder.py +++ b/transformer/PLM/T5/Decoder.py @@ -1,38 +1,40 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn + from modules.dropout import Dropout from modules.norm import RMSNorm as Norm -from modules.plm.t5 import ResSelfAttn, ResCrossAttn, PositionwiseFF -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, mask_tensor_type +from modules.plm.t5 import PositionwiseFF, ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none from utils.plm.base import copy_plm_parameter from utils.plm.t5 import reorder_pemb -from cnfg.vocab.plm.t5 import pad_id, eos_id, sos_id -from math import sqrt - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from utils.sampler import SampleMax +from utils.torch.comp import all_done, torch_no_grad from cnfg.plm.t5.base import remove_classifier_bias from cnfg.plm.t5.ihyp import * +from cnfg.vocab.plm.t5 import eos_id, pad_id, sos_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, k_rel_pos_cattn=use_k_relative_position_cattn, max_bucket_distance_cattn=relative_position_max_bucket_distance_cattn, model_name="decoder", **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, k_rel_pos_cattn=use_k_relative_position_cattn, max_bucket_distance_cattn=relative_position_max_bucket_distance_cattn, model_name="decoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) self.model_name = model_name - self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos_cattn, max_bucket_distance=max_bucket_distance_cattn) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -48,19 +50,19 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un else: return context, states_return - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.self_attn.net.adaptor.weight, plm_parameters, ["%s.block.%d.layer.0.SelfAttention.q.weight" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.k.weight" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.v.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) _bias_key = "%s.block.%d.layer.0.SelfAttention.q.bias" % (_model_name, layer_idx,) - if self.self_attn.net.adaptor.bias is None and (_bias_key in plm_parameters): + if (self.self_attn.net.adaptor.bias is None) and (_bias_key in plm_parameters): self.self_attn.net.adaptor.bias = nn.Parameter(torch.zeros(self.attn.net.adaptor.weight.size(0))) if self.self_attn.net.adaptor.bias is not None: copy_plm_parameter(self.self_attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.block.%d.layer.0.SelfAttention.k.bias" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.v.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) copy_plm_parameter(self.self_attn.net.outer.weight, plm_parameters, "%s.block.%d.layer.0.SelfAttention.o.weight" % (_model_name, layer_idx,)) _bias_key = "%s.block.%d.layer.0.SelfAttention.o.bias" % (_model_name, layer_idx,) - if self.self_attn.net.outer.bias is None and (_bias_key in plm_parameters): + if (self.self_attn.net.outer.bias is None) and (_bias_key in plm_parameters): self.self_attn.net.outer.bias = nn.Parameter(torch.zeros(self.attn.net.outer.weight.size(0))) if self.self_attn.net.outer.bias is not None: copy_plm_parameter(self.self_attn.net.outer.bias, plm_parameters, _bias_key) @@ -73,19 +75,19 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): copy_plm_parameter(self.self_attn.normer.bias, plm_parameters, _bias_key) copy_plm_parameter(self.cross_attn.net.query_adaptor.weight, plm_parameters, "%s.block.%d.layer.1.EncDecAttention.q.weight" % (_model_name, layer_idx,)) _bias_key = "%s.block.%d.layer.1.EncDecAttention.q.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.query_adaptor.bias is None and (_bias_key in plm_parameters): + if (self.cross_attn.net.query_adaptor.bias is None) and (_bias_key in plm_parameters): self.cross_attn.net.query_adaptor.bias = nn.Parameter(torch.zeros(self.cross_attn.net.query_adaptor.weight.size(0))) if self.cross_attn.net.query_adaptor.bias is not None: copy_plm_parameter(self.cross_attn.net.query_adaptor.bias, plm_parameters, _bias_key) copy_plm_parameter(self.cross_attn.net.kv_adaptor.weight, plm_parameters, ["%s.block.%d.layer.1.EncDecAttention.k.weight" % (_model_name, layer_idx,), "%s.block.%d.layer.1.EncDecAttention.v.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) _bias_key = "%s.block.%d.layer.1.EncDecAttention.k.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.kv_adaptor.bias is None and (_bias_key in plm_parameters): + if (self.cross_attn.net.kv_adaptor.bias is None) and (_bias_key in plm_parameters): self.cross_attn.net.kv_adaptor.bias = nn.Parameter(torch.zeros(self.cross_attn.net.kv_adaptor.weight.size(0))) if self.cross_attn.net.kv_adaptor.bias is not None: copy_plm_parameter(self.cross_attn.net.kv_adaptor.bias, plm_parameters, [_bias_key, "%s.block.%d.layer.1.EncDecAttention.v.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) copy_plm_parameter(self.cross_attn.net.outer.weight, plm_parameters, "%s.block.%d.layer.1.EncDecAttention.o.weight" % (_model_name, layer_idx,)) _bias_key = "%s.block.%d.layer.1.EncDecAttention.o.bias" % (_model_name, layer_idx,) - if self.cross_attn.net.outer.bias is None and (_bias_key in plm_parameters): + if (self.cross_attn.net.outer.bias is None) and (_bias_key in plm_parameters): self.cross_attn.net.outer.bias = nn.Parameter(torch.zeros(self.cross_attn.net.outer.weight.size(0))) if self.cross_attn.net.outer.bias is not None: copy_plm_parameter(self.cross_attn.net.outer.bias, plm_parameters, _bias_key) @@ -120,25 +122,24 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, model_name="decoder", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, disable_pemb=disable_std_pemb_decoder, model_name="decoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) self.model_name = model_name self.wemb.padding_idx = pad_id if share_layer: - _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - # Cross-attention layers have relative position encoding in the original T5 but not in v1.1. - self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=use_k_relative_position_decoder if i == 0 else 0, max_bucket_distance=relative_position_max_bucket_distance_decoder if i==0 else 0, k_rel_pos_cattn=use_k_relative_position_cattn if i == 0 else 0, max_bucket_distance_cattn=relative_position_max_bucket_distance_cattn if i==0 else 0, model_name=model_name) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, k_rel_pos_cattn=use_k_relative_position_cattn, max_bucket_distance_cattn=relative_position_max_bucket_distance_cattn, model_name=model_name) for i in range(num_layer)])# if i == 0 else 0 self.out_normer = Norm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None - def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False): + def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False, **kwargs): nquery = inputo.size(-1) @@ -166,7 +167,7 @@ def forward(self, inpute, inputo, src_pad_mask=None, word_prediction=False): return out - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -222,7 +223,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -231,13 +232,12 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) if self.pemb is not None: out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: @@ -348,14 +348,14 @@ def get_sos_emb(self, inpute, bsize=None): def fix_init(self): self.fix_load() - with torch.no_grad(): + with torch_no_grad(): #self.wemb.weight[pad_id].zero_() self.classifier.weight[pad_id].zero_() - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): if "lm_head.weight" in plm_parameters: copy_plm_parameter(self.classifier.weight, plm_parameters, "lm_head.weight") _ = "%s.embed_tokens.weight" % _model_name @@ -365,7 +365,7 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): if (self.out_normer.bias is not None) and (_bias_key in plm_parameters): copy_plm_parameter(self.out_normer.bias, plm_parameters, _bias_key) for i, net in enumerate(self.nets): - net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i) + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) # T5 does NOT have the bias vector in the classifier if remove_classifier_bias: self.classifier.bias = None diff --git a/transformer/PLM/T5/Encoder.py b/transformer/PLM/T5/Encoder.py index 41e1a7b..a547197 100644 --- a/transformer/PLM/T5/Encoder.py +++ b/transformer/PLM/T5/Encoder.py @@ -2,41 +2,45 @@ import torch from torch import nn + from modules.dropout import Dropout from modules.norm import RMSNorm as Norm -from modules.plm.t5 import ResSelfAttn, PositionwiseFF +from modules.plm.t5 import PositionwiseFF, ResSelfAttn +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from utils.plm.base import copy_plm_parameter from utils.plm.t5 import reorder_pemb -from cnfg.vocab.plm.t5 import pad_id -from cnfg.plm.t5.ihyp import * +from utils.torch.comp import torch_no_grad -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from cnfg.plm.t5.ihyp import * +from cnfg.vocab.plm.t5 import pad_id class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="encoder", **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name="encoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + self.model_name = model_name self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) - self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual, custom_act=use_adv_act_default, enable_bias=enable_prev_ln_bias_default, use_glu=use_glu_ffn) - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, layer_idx=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): copy_plm_parameter(self.attn.net.adaptor.weight, plm_parameters, ["%s.block.%d.layer.0.SelfAttention.q.weight" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.k.weight" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.v.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) _bias_key = "%s.block.%d.layer.0.SelfAttention.q.bias" % (_model_name, layer_idx,) - if self.attn.net.adaptor.bias is None and (_bias_key in plm_parameters): + if (self.attn.net.adaptor.bias is None) and (_bias_key in plm_parameters): self.attn.net.adaptor.bias = nn.Parameter(torch.zeros(self.attn.net.adaptor.weight.size(0))) if self.attn.net.adaptor.bias is not None: copy_plm_parameter(self.attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.block.%d.layer.0.SelfAttention.k.bias" % (_model_name, layer_idx,), "%s.block.%d.layer.0.SelfAttention.v.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) copy_plm_parameter(self.attn.net.outer.weight, plm_parameters, "%s.block.%d.layer.0.SelfAttention.o.weight" % (_model_name, layer_idx,)) _bias_key = "%s.block.%d.layer.0.SelfAttention.o.bias" % (_model_name, layer_idx,) - if self.attn.net.outer.bias is None and (_bias_key in plm_parameters): + if (self.attn.net.outer.bias is None) and (_bias_key in plm_parameters): self.attn.net.outer.bias = nn.Parameter(torch.zeros(self.attn.net.outer.weight.size(0))) if self.attn.net.outer.bias is not None: copy_plm_parameter(self.attn.net.outer.bias, plm_parameters, _bias_key) @@ -71,28 +75,29 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, model_name="encoder", **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, share_layer=False, disable_pemb=disable_std_pemb_encoder, model_name="encoder", **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, share_layer=share_layer, disable_pemb=disable_pemb, **kwargs) + self.model_name = model_name self.wemb.padding_idx = pad_id if share_layer: - _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, model_name=model_name) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=use_k_relative_position_encoder if i == 0 else 0, max_bucket_distance=relative_position_max_bucket_distance_encoder if i == 0 else 0, model_name=model_name) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, model_name=model_name) for i in range(num_layer)])# if i == 0 else 0 self.out_normer = Norm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): out = self.wemb(inputs) if self.pemb is not None: - out = out * sqrt(out.size(-1)) + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) @@ -106,10 +111,10 @@ def forward(self, inputs, mask=None): return out - def load_plm(self, plm_parameters, model_name=None, layer_idx=None): + def load_plm(self, plm_parameters, model_name=None, **kwargs): - _model_name = self.model_name if model_name is None else model_name - with torch.no_grad(): + _model_name = parse_none(model_name, self.model_name) + with torch_no_grad(): _ = "%s.embed_tokens.weight" % _model_name copy_plm_parameter(self.wemb.weight, plm_parameters, _ if _ in plm_parameters else "shared.weight") copy_plm_parameter(self.out_normer.weight, plm_parameters, "%s.final_layer_norm.weight" % _model_name) @@ -117,4 +122,4 @@ def load_plm(self, plm_parameters, model_name=None, layer_idx=None): if (self.out_normer.bias is not None) and (_bias_key in plm_parameters): copy_plm_parameter(self.out_normer.bias, plm_parameters, _bias_key) for i, net in enumerate(self.nets): - net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i) + net.load_plm(plm_parameters, model_name=_model_name, layer_idx=i, **kwargs) diff --git a/transformer/PLM/T5/NMT.py b/transformer/PLM/T5/NMT.py index 0df852f..17ac327 100644 --- a/transformer/PLM/T5/NMT.py +++ b/transformer/PLM/T5/NMT.py @@ -1,47 +1,47 @@ #encoding: utf-8 +from transformer.PLM.NMT import NMT as NMTBase +from transformer.PLM.T5.Decoder import Decoder +from transformer.PLM.T5.Encoder import Encoder +from utils.fmt.parser import parse_double_value_tuple, parse_none from utils.plm.base import set_ln_ieps from utils.plm.t5 import extend_rel_emb from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple -from cnfg.vocab.plm.t5 import pad_id - -from transformer.PLM.T5.Encoder import Encoder -from transformer.PLM.T5.Decoder import Decoder -from transformer.PLM.BERT.NMT import NMT as NMTBase from cnfg.plm.t5.ihyp import * +from cnfg.vocab.plm.t5 import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name=("encoder", "decoder",)): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, model_name=("encoder", "decoder",), **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=model_name, **kwargs) - enc_model_name, dec_model_name = parse_double_value_tuple(model_name) - self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) + self.model_name = model_name + enc_model_name, dec_model_name = parse_double_value_tuple(self.model_name) + self.enc = Encoder(isize, snwd, enc_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, model_name=enc_model_name) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=dec_model_name) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindemb=bindDecoderEmb, forbidden_index=forbidden_index, model_name=dec_model_name) set_ln_ieps(self, ieps_ln_default) if rel_pos_enabled: - share_rel_pos_cache(self) + share_rel_pos_cache(self, share_emb=True) extend_rel_emb(self) - def forward(self, inpute, inputo, mask=None, word_prediction=False): + def forward(self, inpute, inputo, mask=None, word_prediction=False, **kwargs): _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, _mask), inputo, _mask, word_prediction=word_prediction) - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len diff --git a/transformer/Probe/Decoder.py b/transformer/Probe/Decoder.py index dce87f6..a7840b0 100644 --- a/transformer/Probe/Decoder.py +++ b/transformer/Probe/Decoder.py @@ -1,28 +1,27 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn -from modules.base import Linear, Dropout from modules.attn.rap import ResCrossAttn - -from math import sqrt - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.base import Dropout, Linear +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, compute_ffn=True): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, compute_ffn=True, **kwargs): if query_unit is None: context = self.self_attn(inputo, mask=tgt_pad_mask) @@ -55,21 +54,21 @@ def load_base(self, base_decoder_layer): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.tattn_w = nn.Parameter(torch.Tensor(num_layer * num_head).uniform_(- sqrt(1.0 / (num_layer * num_head)), sqrt(1.0 / (num_layer * num_head)))) self.tattn_drop = Dropout(dropout) if dropout > 0.0 else None self.trans = Linear(isize, isize, bias=False) - def forward(self, inpute, inputo, inputea, src_pad_mask=None): + def forward(self, inpute, inputo, inputea, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() diff --git a/transformer/Probe/Encoder.py b/transformer/Probe/Encoder.py index 17e4527..7e9030d 100644 --- a/transformer/Probe/Encoder.py +++ b/transformer/Probe/Encoder.py @@ -1,25 +1,24 @@ #encoding: utf-8 -from transformer.Encoder import Encoder as EncoderBase - from math import sqrt +from transformer.Encoder import Encoder as EncoderBase + from cnfg.ihyp import * class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, num_layer_ana=0, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, num_layer_ana=0, **kwargs): - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, **kwargs) self.num_layer_ana = num_layer_ana - def forward(self, inputs, mask=None, no_std_out=False): + def forward(self, inputs, mask=None, no_std_out=False, **kwargs): out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) diff --git a/transformer/Probe/NMT.py b/transformer/Probe/NMT.py index f98f146..8d6e221 100644 --- a/transformer/Probe/NMT.py +++ b/transformer/Probe/NMT.py @@ -1,35 +1,34 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.Probe.Encoder import Encoder -from transformer.Probe.Decoder import Decoder - from transformer.NMT import NMT as NMTBase +from transformer.Probe.Decoder import Decoder +from transformer.Probe.Encoder import Encoder +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=0): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=0, **kwargs): - super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) enc_layer, dec_layer = parse_double_value_tuple(num_layer) - self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, num_layer_ana) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, num_layer_ana) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, act_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask ence, ence_layer = self.enc(inpute, _mask) diff --git a/transformer/Probe/ReDecoder.py b/transformer/Probe/ReDecoder.py index 5e09348..c12bc4c 100644 --- a/transformer/Probe/ReDecoder.py +++ b/transformer/Probe/ReDecoder.py @@ -1,26 +1,25 @@ #encoding: utf-8 import torch -from torch import nn - from math import sqrt +from torch import nn from modules.base import Linear - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=ahsize, **kwargs) self.perform_self_attn = True self.perform_cross_attn = True - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: if self.perform_self_attn: @@ -56,19 +55,19 @@ def load_base(self, base_decoder_layer): class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, num_layer_ana=None, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, num_layer_ana=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.trans = Linear(isize, isize, bias=False) _num_layer_ana = num_layer if num_layer_ana is None else num_layer_ana self.nets = nn.ModuleList(list(self.nets[:_num_layer_ana])) if _num_layer_ana > 0 else None - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): nquery = inputo.size(-1) diff --git a/transformer/Probe/ReNMT.py b/transformer/Probe/ReNMT.py index 5af6871..cdb9f65 100644 --- a/transformer/Probe/ReNMT.py +++ b/transformer/Probe/ReNMT.py @@ -1,25 +1,24 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.Probe.ReDecoder import Decoder - from transformer.NMT import NMT as NMTBase +from transformer.Probe.ReDecoder import Decoder +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, num_layer_ana=None, **kwargs): - super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, num_layer, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) emb_w = self.enc.wemb.weight if global_emb else None _, dec_layer = parse_double_value_tuple(num_layer) - self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index, num_layer_ana) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, act_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index, num_layer_ana) if num_layer_ana <= 0: self.enc = None @@ -27,12 +26,12 @@ def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_ if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): if self.enc is None: return self.dec(None, inputo, None) else: - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(self.enc(inpute, _mask), inputo, _mask) def load_base(self, base_nmt): diff --git a/transformer/Prompt/RoBERTa/NMT.py b/transformer/Prompt/RoBERTa/NMT.py index 2424f2e..36aa4ab 100644 --- a/transformer/Prompt/RoBERTa/NMT.py +++ b/transformer/Prompt/RoBERTa/NMT.py @@ -1,20 +1,16 @@ #encoding: utf-8 -import torch -from modules.base import Linear -from cnfg.vocab.plm.roberta import pad_id, mask_id - from transformer.PLM.RoBERTa.NMT import NMT as NMTBase from cnfg.plm.roberta.ihyp import * +from cnfg.vocab.plm.roberta import mask_id, pad_id class NMT(NMTBase): - def forward(self, inpute, token_types=None, mask=None, word_prediction=True): + def forward(self, inpute, token_types=None, mask=None, word_prediction=True, **kwargs): _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask out = self.enc(inpute, token_types=token_types, mask=_mask) _bsize, _, _hsize = out.size() - _pm = inpute.eq(mask_id).unsqueeze(-1).expand(-1, -1, _hsize) - return self.dec(out[_pm].view(_bsize, _hsize), word_prediction=word_prediction) + return self.dec(out[inpute.eq(mask_id)].view(_bsize, _hsize), word_prediction=word_prediction) diff --git a/transformer/README.md b/transformer/README.md index 09b0c60..acdc0df 100644 --- a/transformer/README.md +++ b/transformer/README.md @@ -2,7 +2,7 @@ ## `NMT.py` -The transformer model encapsulates encoder and decoder. Set [these lines](NMT.py#L10-L14) to make a choice between the standard encoder / decoder and the others. +The transformer model encapsulates encoder and decoder. Set [these lines](NMT.py#L6-L9) to make a choice between the standard encoder / decoder and the others. ## `Encoder.py` @@ -14,11 +14,11 @@ The standard decoder of transformer. ## `AvgDecoder.py` -The average decoder of transformer proposed by [Accelerating Neural Transformer via an Average Attention Network](https://www.aclweb.org/anthology/P18-1166/). +The average decoder of transformer proposed by [Accelerating Neural Transformer via an Average Attention Network](https://aclanthology.org/P18-1166/). ## `EnsembleNMT.py` -A model encapsulates several NMT models to do ensemble decoding. Configure [these lines](EnsembleNMT.py#L8-L12) to make a choice between the standard decoder and the average decoder. +A model encapsulates several NMT models to do ensemble decoding. Configure [these lines](EnsembleNMT.py#L6-L9) to make a choice between the standard decoder and the average decoder. ## `EnsembleEncoder.py` @@ -30,7 +30,7 @@ A model encapsulates several standard decoders for ensemble decoding. ## `EnsembleAvgDecoder.py` -A model encapsulates several average decoders proposed by [Accelerating Neural Transformer via an Average Attention Network](https://www.aclweb.org/anthology/P18-1166/) for ensemble decoding. +A model encapsulates several average decoders proposed by [Accelerating Neural Transformer via an Average Attention Network](https://aclanthology.org/P18-1166/) for ensemble decoding. ## `AGG/` @@ -38,15 +38,19 @@ Implementation of aggregation models. ### `Hier*.py` -Hierarchical aggregation proposed in [Exploiting Deep Representations for Neural Machine Translation](https://www.aclweb.org/anthology/D18-1457/). +Hierarchical aggregation proposed in [Exploiting Deep Representations for Neural Machine Translation](https://aclanthology.org/D18-1457/). ## `TA/` -Implementation of transparent attention proposed in [Training Deeper Neural Machine Translation Models with Transparent Attention](https://aclweb.org/anthology/D18-1338/). +Implementation of transparent attention proposed in [Training Deeper Neural Machine Translation Models with Transparent Attention](https://aclanthology.org/D18-1338/). ## `SC/` -Implementation of sentential context proposed in [Exploiting Sentential Context for Neural Machine Translation](https://www.aclweb.org/anthology/P19-1624/). +Implementation of sentential context proposed in [Exploiting Sentential Context for Neural Machine Translation](https://aclanthology.org/P19-1624/). + +## `SDU/` + +Implementation of self-dependency units proposed in [Highway Transformer: Self-Gating Enhanced Self-Attentive Networks](https://aclanthology.org/2020.acl-main.616/). ## `RealFormer/` @@ -54,11 +58,11 @@ Implementation of [RealFormer: Transformer Likes Residual Attention](https://arx ## `LD/` -Implementation of NMT with phrase representations proposed in [Learning Source Phrase Representations for Neural Machine Translation](https://www.aclweb.org/anthology/2020.acl-main.37/). +Implementation of NMT with phrase representations proposed in [Learning Source Phrase Representations for Neural Machine Translation](https://aclanthology.org/2020.acl-main.37/). ## `Doc/` -Implementation of context-aware Transformer proposed in [Improving the Transformer Translation Model with Document-Level Context](https://www.aclweb.org/anthology/D18-1049/). +Implementation of context-aware Transformer proposed in [Improving the Transformer Translation Model with Document-Level Context](https://aclanthology.org/D18-1049/). ## `APE/` diff --git a/transformer/RNMTDecoder.py b/transformer/RNMTDecoder.py index 8c7fcfe..c2a054a 100644 --- a/transformer/RNMTDecoder.py +++ b/transformer/RNMTDecoder.py @@ -5,37 +5,36 @@ import torch from torch import nn -from modules.base import * -from utils.sampler import SampleMax -from utils.base import all_done -from modules.rnncells import * - -from cnfg.vocab.base import pad_id - +from modules.base import CrossAttn, Dropout, Linear +from modules.rnncells import LSTMCell4RNMT, prepare_initState from transformer.Decoder import Decoder as DecoderBase +from utils.fmt.parser import parse_none +from utils.sampler import SampleMax +from utils.torch.comp import all_done#, torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class FirstLayer(nn.Module): # isize: input size # osize: output size - def __init__(self, isize, osize=None, dropout=0.0): + def __init__(self, isize, osize=None, dropout=0.0, **kwargs): super(FirstLayer, self).__init__() - osize = isize if osize is None else osize + _osize = parse_none(osize, isize) - self.net = LSTMCell4RNMT(isize, osize) - self.init_hx = nn.Parameter(torch.zeros(1, osize)) - self.init_cx = nn.Parameter(torch.zeros(1, osize)) + self.net = LSTMCell4RNMT(isize, _osize) + self.init_hx = nn.Parameter(torch.zeros(1, _osize)) + self.init_cx = nn.Parameter(torch.zeros(1, _osize)) self.drop = Dropout(dropout, inplace=False) if dropout > 0.0 else None # inputo: embedding of decoded translation (bsize, nquery, isize) # query_unit: single query to decode, used to support decoding for given step - def forward(self, inputo, states=None, first_step=False): + def forward(self, inputo, states=None, first_step=False, **kwargs): if states is None: hx, cx = prepare_initState(self.init_hx, self.init_cx, inputo.size(0)) @@ -62,15 +61,15 @@ class DecoderLayer(nn.Module): # isize: input size # osize: output size - def __init__(self, isize, osize=None, dropout=0.0, residual=True): + def __init__(self, isize, osize=None, dropout=0.0, residual=True, **kwargs): super(DecoderLayer, self).__init__() - osize = isize if osize is None else osize + _osize = parse_none(osize, isize) - self.net = LSTMCell4RNMT(isize + osize, osize) - self.init_hx = nn.Parameter(torch.zeros(1, osize)) - self.init_cx = nn.Parameter(torch.zeros(1, osize)) + self.net = LSTMCell4RNMT(isize + _osize, _osize) + self.init_hx = nn.Parameter(torch.zeros(1, _osize)) + self.init_cx = nn.Parameter(torch.zeros(1, _osize)) self.drop = Dropout(dropout, inplace=False) if dropout > 0.0 else None @@ -79,7 +78,7 @@ def __init__(self, isize, osize=None, dropout=0.0, residual=True): # inputo: embedding of decoded translation (bsize, nquery, isize) # query_unit: single query to decode, used to support decoding for given step - def forward(self, inputo, attn, states=None, first_step=False): + def forward(self, inputo, attn, states=None, first_step=False, **kwargs): if states is None: hx, cx = prepare_initState(self.init_hx, self.init_cx, inputo.size(0)) @@ -117,11 +116,11 @@ class Decoder(DecoderBase): # ahsize: number of hidden units for MultiHeadAttention # bindemb: bind embedding and classifier weight - def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, projector=True, **kwargs): + def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, projector=True, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=isize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=isize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, **kwargs) self.flayer = FirstLayer(isize, osize=isize, dropout=dropout) @@ -141,9 +140,9 @@ def __init__(self, isize, nwd, num_layer, dropout=0.0, attn_drop=0.0, emb_w=None # inpute: encoded representation from encoder (bsize, seql, isize) # inputo: decoded translation (bsize, nquery) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): out = self.wemb(inputo) @@ -159,7 +158,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): # the following line of code is to mask for the decoder, # which I think is useless, since only may pay attention to previous tokens, whos loss will be omitted by the loss function. - #_mask = torch.gt(_mask + inputo.eq(0).unsqueeze(1), 0) + #_mask = torch.gt(_mask + inputo.eq(pad_id).unsqueeze(1), 0) for net in self.nets: out = net(out, attn) @@ -173,10 +172,10 @@ def forward(self, inpute, inputo, src_pad_mask=None): # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # max_len: maximum length to generate - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -212,7 +211,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # done_trans: (bsize) - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -237,7 +236,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break @@ -245,11 +244,11 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, # inpute: encoded representation from encoder (bsize, seql, isize) # src_pad_mask: mask for given encoding source sentence (bsize, 1, seql), see Encoder, generated with: - # src_pad_mask = input.eq(0).unsqueeze(1) + # src_pad_mask = input.eq(pad_id).unsqueeze(1) # beam_size: beam size # max_len: maximum length to generate - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -258,7 +257,6 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: # lpv: length penalty vector for each beam (bsize * beam_size, 1) @@ -302,7 +300,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt # done_trans: (bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) # inpute: (bsize, seql, isize) => (bsize * beam_size, seql, isize) @@ -384,7 +382,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), (wds.masked_fill(done_trans.view(real_bsize), pad_id) if fill_pad else wds).unsqueeze(1)), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) & wds.eq(2)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) & wds.eq(eos_id)).view(bsize, beam_size) # check early stop for beam search # done_trans: (bsize, beam_size) @@ -425,5 +423,5 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt """def fix_load(self): if self.fbl is not None: - with torch.no_grad(): + with torch_no_grad(): list(self.classifier.modules())[-1].bias.index_fill_(0, torch.as_tensor(self.fbl, dtype=torch.long, device=self.classifier.bias.device), -inf_default)""" diff --git a/transformer/RealFormer/Decoder.py b/transformer/RealFormer/Decoder.py index 5808c21..7e2bb82 100644 --- a/transformer/RealFormer/Decoder.py +++ b/transformer/RealFormer/Decoder.py @@ -1,33 +1,33 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn -from modules.attn.res import ResSelfAttn, ResCrossAttn - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase - +from modules.attn.res import ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.fmt.parser import parse_none from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ -from math import sqrt - -from cnfg.vocab.base import pad_id +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) - def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, resin=None): + def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, resin=None, **kwargs): if resin is None: sresin = cresin = None @@ -50,20 +50,20 @@ def forward(self, inpute, inputo, src_pad_mask=None, tgt_pad_mask=None, query_un class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): nquery = inputo.size(-1) @@ -86,7 +86,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -112,7 +112,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -135,13 +135,13 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -150,14 +150,13 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -181,7 +180,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -229,7 +228,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/RealFormer/Encoder.py b/transformer/RealFormer/Encoder.py index a0cd16e..9cac6e4 100644 --- a/transformer/RealFormer/Encoder.py +++ b/transformer/RealFormer/Encoder.py @@ -1,26 +1,26 @@ #encoding: utf-8 -from torch import nn from math import sqrt +from torch import nn from modules.attn.res import ResSelfAttn - -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) - def forward(self, inputs, mask=None, resin=None): + def forward(self, inputs, mask=None, resin=None, **kwargs): context, resout = self.attn(inputs, mask=mask, resin=resin) @@ -30,25 +30,24 @@ def forward(self, inputs, mask=None, resin=None): class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) diff --git a/transformer/RetrAttn/Decoder.py b/transformer/RetrAttn/Decoder.py index 65aaf2d..ce264c0 100644 --- a/transformer/RetrAttn/Decoder.py +++ b/transformer/RetrAttn/Decoder.py @@ -2,34 +2,34 @@ from torch import nn -from modules.attn.retr import ResSelfAttn, ResCrossAttn - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from modules.attn.retr import ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, k_rel_pos=use_k_relative_position_decoder, smoothing=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, k_rel_pos=use_k_relative_position_decoder, smoothing=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) - super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.self_attn.norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, smoothing=smoothing, use_cumsum=True) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual, smoothing=smoothing) class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) diff --git a/transformer/RetrAttn/Encoder.py b/transformer/RetrAttn/Encoder.py index f95df71..6df73a7 100644 --- a/transformer/RetrAttn/Encoder.py +++ b/transformer/RetrAttn/Encoder.py @@ -1,32 +1,33 @@ #encoding: utf-8 from torch import nn -from modules.attn.retr import ResSelfAttn -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.attn.retr import ResSelfAttn +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class EncoderLayer(EncoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, k_rel_pos=use_k_relative_position_encoder, smoothing=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, k_rel_pos=use_k_relative_position_encoder, smoothing=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize - super(EncoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) + _ahsize = parse_none(ahsize, isize) + super(EncoderLayer, self).__init__(isize, fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, k_rel_pos=k_rel_pos, **kwargs) self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.attn.norm_residual, k_rel_pos=k_rel_pos, smoothing=smoothing, use_cumsum=False) class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) diff --git a/transformer/SC/Decoder.py b/transformer/SC/Decoder.py index 2d0b948..e4c69f9 100644 --- a/transformer/SC/Decoder.py +++ b/transformer/SC/Decoder.py @@ -1,37 +1,37 @@ #encoding: utf-8 import torch +from math import sqrt from torch import nn +from modules.TA import PositionwiseFF, ResCrossAttn from modules.base import ResidueCombiner +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam, repeat_bsize_for_beam_tensor +from utils.fmt.parser import parse_none from utils.sampler import SampleMax -from modules.TA import ResCrossAttn, PositionwiseFF - -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_, repeat_bsize_for_beam_tensor -from math import sqrt - -from cnfg.vocab.base import pad_id - -from transformer.Decoder import DecoderLayer as DecoderLayerBase, Decoder as DecoderBase +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class DecoderLayer(DecoderLayerBase): - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=self.cross_attn.norm_residual) - self.ff = PositionwiseFF(isize, _fhsize, dropout) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=self.ff.norm_residual) self.scff = ResidueCombiner(isize, 2, _fhsize) self.drop, self.layer_normer1 = self.self_attn.drop, self.self_attn.drop.normer self.self_attn = self.self_attn.net - def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, query_unit=None, **kwargs): if query_unit is None: @@ -68,21 +68,21 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None, tgt_pad_mask=None, class Decoder(DecoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, share_layer=False, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=False, forbidden_index=None, share_layer=False, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([DecoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) - def forward(self, inpute, inputh, inputo, src_pad_mask=None): + def forward(self, inpute, inputh, inputo, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -104,7 +104,7 @@ def forward(self, inpute, inputh, inputo, src_pad_mask=None): return out - def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -121,7 +121,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad states = {} for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, None, src_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, (None, None,), src_pad_mask, None, out) states[_tmp] = _state out = self.classifier(out) @@ -129,7 +129,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -142,7 +142,7 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad out = self.out_normer(out) for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, states[_tmp], src_pad_mask, None, out) states[_tmp] = _state out = self.classifier(out) @@ -150,13 +150,13 @@ def greedy_decode(self, inpute, inputh, src_pad_mask=None, max_len=512, fill_pad trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) - def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -165,14 +165,13 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -182,7 +181,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 states = {} for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, None, src_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, (None, None,), src_pad_mask, None, out) states[_tmp] = _state out = self.lsm(self.classifier(out)) @@ -195,7 +194,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) inputh = repeat_bsize_for_beam_tensor(inputh, beam_size) self.repeat_cross_attn_buffer(beam_size) @@ -215,7 +214,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 out = self.out_normer(out) for _tmp, (net, inputu, inputhu) in enumerate(zip(self.nets, inpute.unbind(dim=-1), inputh.unbind(dim=-1))): - out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask, None, out, True) + out, _state = net(inputu, inputhu, states[_tmp], _src_pad_mask, None, out) states[_tmp] = _state out = self.lsm(self.classifier(out)).view(bsize, beam_size, -1) @@ -242,7 +241,7 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: @@ -268,6 +267,6 @@ def beam_decode(self, inpute, inputh, src_pad_mask=None, beam_size=8, max_len=51 return trans.view(bsize, beam_size, -1).select(1, 0) - def decode(self, inpute, inputh, src_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False): + def decode(self, inpute, inputh, src_pad_mask, beam_size=1, max_len=512, length_penalty=0.0, fill_pad=False, **kwargs): - return self.beam_decode(inpute, inputh, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad) if beam_size > 1 else self.greedy_decode(inpute, inputh, src_pad_mask, max_len, fill_pad=fill_pad) + return self.beam_decode(inpute, inputh, src_pad_mask, beam_size, max_len, length_penalty, fill_pad=fill_pad, **kwargs) if beam_size > 1 else self.greedy_decode(inpute, inputh, src_pad_mask, max_len, fill_pad=fill_pad, **kwargs) diff --git a/transformer/SC/Encoder.py b/transformer/SC/Encoder.py index 1278d1b..bd68f61 100644 --- a/transformer/SC/Encoder.py +++ b/transformer/SC/Encoder.py @@ -1,22 +1,23 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import * from math import sqrt +from torch import nn +from modules.base import CrossAttn, Dropout from transformer.TA.Encoder import Encoder as EncoderBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * class Encoder(EncoderBase): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, num_layer_dec=num_layer_dec, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, num_layer_dec=num_layer_dec, **kwargs) self.attns = nn.ModuleList([CrossAttn(isize, _ahsize, isize, num_head, dropout=attn_drop) for i in range(num_layer)]) @@ -25,9 +26,9 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): def transform(lin, w, drop): @@ -41,9 +42,8 @@ def transform(lin, w, drop): bsize, seql = inputs.size() out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) diff --git a/transformer/SC/NMT.py b/transformer/SC/NMT.py index 95efab0..4dda243 100644 --- a/transformer/SC/NMT.py +++ b/transformer/SC/NMT.py @@ -1,36 +1,34 @@ #encoding: utf-8 -from utils.relpos.base import share_rel_pos_cache -from utils.fmt.parser import parse_double_value_tuple - -from transformer.SC.Encoder import Encoder -from transformer.SC.Decoder import Decoder from transformer.NMT import NMT as NMTBase - -from math import sqrt +from transformer.SC.Decoder import Decoder +from transformer.SC.Encoder import Encoder +from utils.fmt.parser import parse_double_value_tuple +from utils.relpos.base import share_rel_pos_cache from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class NMT(NMTBase): - def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None): + def __init__(self, isize, snwd, tnwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, global_emb=False, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindDecoderEmb=True, forbidden_index=None, **kwargs): enc_layer, dec_layer = parse_double_value_tuple(num_layer) - super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) + super(NMT, self).__init__(isize, snwd, tnwd, (enc_layer, dec_layer,), fhsize=fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, global_emb=global_emb, num_head=num_head, xseql=xseql, ahsize=ahsize, norm_output=norm_output, bindDecoderEmb=bindDecoderEmb, forbidden_index=forbidden_index) - self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, num_head, xseql, ahsize, norm_output, num_layer) + self.enc = Encoder(isize, snwd, enc_layer, fhsize, dropout, attn_drop, act_drop, num_head, xseql, ahsize, norm_output, num_layer) emb_w = self.enc.wemb.weight if global_emb else None - self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) + self.dec = Decoder(isize, tnwd, dec_layer, fhsize, dropout, attn_drop, act_drop, emb_w, num_head, xseql, ahsize, norm_output, bindDecoderEmb, forbidden_index) if rel_pos_enabled: share_rel_pos_cache(self) - def forward(self, inpute, inputo, mask=None): + def forward(self, inpute, inputo, mask=None, **kwargs): - _mask = inpute.eq(0).unsqueeze(1) if mask is None else mask + _mask = inpute.eq(pad_id).unsqueeze(1) if mask is None else mask return self.dec(*self.enc(inpute, _mask), inputo, _mask) @@ -38,9 +36,9 @@ def forward(self, inpute, inputo, mask=None): # beam_size: the beam size for beam search # max_len: maximum length to generate - def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0): + def decode(self, inpute, beam_size=1, max_len=None, length_penalty=0.0, **kwargs): - mask = inpute.eq(0).unsqueeze(1) + mask = inpute.eq(pad_id).unsqueeze(1) _max_len = (inpute.size(1) + max(64, inpute.size(1) // 4)) if max_len is None else max_len diff --git a/transformer/SDU/Decoder.py b/transformer/SDU/Decoder.py new file mode 100644 index 0000000..8002f97 --- /dev/null +++ b/transformer/SDU/Decoder.py @@ -0,0 +1,37 @@ +#encoding: utf-8 + +from torch import nn + +from modules.sdu import PositionwiseFF, ResCrossAttn, ResSelfAttn +from transformer.Decoder import Decoder as DecoderBase, DecoderLayer as DecoderLayerBase +from utils.fmt.parser import parse_none + +from cnfg.ihyp import * + +class DecoderLayer(DecoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_decoder, max_bucket_distance=relative_position_max_bucket_distance_decoder, **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(DecoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + + self.self_attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, uni_direction_reduction=True, max_bucket_distance=max_bucket_distance) + self.cross_attn = ResCrossAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) + +class Decoder(DecoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, emb_w=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, bindemb=True, forbidden_index=None, share_layer=False, **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Decoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, emb_w=emb_w, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, bindemb=bindemb, forbidden_index=forbidden_index, share_layer=share_layer, **kwargs) + + if share_layer: + _shared_layer = DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([DecoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) diff --git a/transformer/SDU/Encoder.py b/transformer/SDU/Encoder.py new file mode 100644 index 0000000..2232345 --- /dev/null +++ b/transformer/SDU/Encoder.py @@ -0,0 +1,36 @@ +#encoding: utf-8 + +from torch import nn + +from modules.sdu import PositionwiseFF, ResSelfAttn +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none + +from cnfg.ihyp import * + +class EncoderLayer(EncoderLayerBase): + + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, norm_residual=norm_residual_default, k_rel_pos=use_k_relative_position_encoder, max_bucket_distance=relative_position_max_bucket_distance_encoder, **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance, **kwargs) + + self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop, norm_residual=norm_residual, k_rel_pos=k_rel_pos, max_bucket_distance=max_bucket_distance) + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop, norm_residual=norm_residual) + +class Encoder(EncoderBase): + + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, **kwargs): + + _ahsize = parse_none(ahsize, isize) + _fhsize = _ahsize * 4 if fhsize is None else fhsize + + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + + if share_layer: + _shared_layer = EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) + self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) + else: + self.nets = nn.ModuleList([EncoderLayer(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize) for i in range(num_layer)]) diff --git a/transformer/SDU/__init__.py b/transformer/SDU/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/transformer/SDU/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/transformer/TA/Decoder.py b/transformer/TA/Decoder.py index 05b13ea..26a08d4 100644 --- a/transformer/TA/Decoder.py +++ b/transformer/TA/Decoder.py @@ -1,20 +1,21 @@ #encoding: utf-8 import torch -from utils.sampler import SampleMax -from utils.base import all_done, index_tensors, expand_bsize_for_beam, select_zero_ from math import sqrt -from cnfg.vocab.base import pad_id - from transformer.Decoder import Decoder as DecoderBase +from utils.base import index_tensors, select_zero_ +from utils.decode.beam import expand_bsize_for_beam +from utils.sampler import SampleMax +from utils.torch.comp import all_done from cnfg.ihyp import * +from cnfg.vocab.base import eos_id, pad_id class Decoder(DecoderBase): # inpute: encoded representation from encoder (bsize, seql, isize, num_layer) - def forward(self, inpute, inputo, src_pad_mask=None): + def forward(self, inpute, inputo, src_pad_mask=None, **kwargs): bsize, nquery = inputo.size() @@ -38,7 +39,7 @@ def forward(self, inpute, inputo, src_pad_mask=None): return out # inpute: encoded representation from encoder (bsize, seql, isize, num_layer) - def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False): + def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, sample=False, **kwargs): bsize = inpute.size(0) @@ -53,7 +54,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, states = {} for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): - out, _state = net(inputu, None, src_pad_mask, None, out, True) + out, _state = net(inputu, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -64,7 +65,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans = [wds] - done_trans = wds.eq(2) + done_trans = wds.eq(eos_id) for i in range(1, max_len): @@ -75,7 +76,7 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, out = self.drop(out) for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): - out, _state = net(inputu, states[_tmp], src_pad_mask, None, out, True) + out, _state = net(inputu, states[_tmp], src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -86,14 +87,14 @@ def greedy_decode(self, inpute, src_pad_mask=None, max_len=512, fill_pad=False, trans.append(wds.masked_fill(done_trans, pad_id) if fill_pad else wds) - done_trans = done_trans | wds.eq(2) + done_trans = done_trans | wds.eq(eos_id) if all_done(done_trans, bsize): break return torch.cat(trans, 1) # inpute: encoded representation from encoder (bsize, seql, isize, num_layer) - def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False): + def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, length_penalty=0.0, return_all=False, clip_beam=clip_beam_with_lp, fill_pad=False, **kwargs): bsize, seql = inpute.size()[:2] @@ -102,14 +103,13 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt real_bsize = bsize * beam_size out = self.get_sos_emb(inpute) - isize = out.size(-1) if length_penalty > 0.0: lpv = out.new_ones(real_bsize, 1) lpv_base = 6.0 ** length_penalty if self.pemb is not None: - sqrt_isize = sqrt(isize) + sqrt_isize = sqrt(out.size(-1)) out = self.pemb.get_pos(0).add(out, alpha=sqrt_isize) if self.drop is not None: out = self.drop(out) @@ -117,7 +117,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt states = {} for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): - out, _state = net(inputu, None, src_pad_mask, None, out, True) + out, _state = net(inputu, (None, None,), src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -133,7 +133,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt _inds_add_beam2 = torch.arange(0, bsizeb2, beam_size2, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) _inds_add_beam = torch.arange(0, real_bsize, beam_size, dtype=wds.dtype, device=wds.device).unsqueeze(1).expand(bsize, beam_size) - done_trans = wds.view(bsize, beam_size).eq(2) + done_trans = wds.view(bsize, beam_size).eq(eos_id) self.repeat_cross_attn_buffer(beam_size) @@ -150,7 +150,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt out = self.drop(out) for _tmp, (net, inputu) in enumerate(zip(self.nets, inpute.unbind(dim=-1))): - out, _state = net(inputu, states[_tmp], _src_pad_mask, None, out, True) + out, _state = net(inputu, states[_tmp], _src_pad_mask, None, out) states[_tmp] = _state if self.out_normer is not None: @@ -180,7 +180,7 @@ def beam_decode(self, inpute, src_pad_mask=None, beam_size=8, max_len=512, lengt trans = torch.cat((trans.index_select(0, _inds), wds.masked_fill(done_trans.view(real_bsize, 1), pad_id) if fill_pad else wds), 1) - done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(2).squeeze(1)).view(bsize, beam_size) + done_trans = (done_trans.view(real_bsize).index_select(0, _inds) | wds.eq(eos_id).squeeze(1)).view(bsize, beam_size) _done = False if length_penalty > 0.0: diff --git a/transformer/TA/Encoder.py b/transformer/TA/Encoder.py index fab4bb6..d979eaf 100644 --- a/transformer/TA/Encoder.py +++ b/transformer/TA/Encoder.py @@ -1,12 +1,13 @@ #encoding: utf-8 import torch -from torch import nn -from modules.base import Dropout -from modules.TA import ResSelfAttn, PositionwiseFF from math import sqrt +from torch import nn -from transformer.Encoder import EncoderLayer as EncoderLayerBase, Encoder as EncoderBase +from modules.TA import PositionwiseFF, ResSelfAttn +from modules.base import Dropout +from transformer.Encoder import Encoder as EncoderBase, EncoderLayer as EncoderLayerBase +from utils.fmt.parser import parse_none from cnfg.ihyp import * @@ -18,25 +19,15 @@ class EncoderLayer(EncoderLayerBase): # num_head: number of heads in MultiHeadAttention # ahsize: hidden size of MultiHeadAttention - def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, ahsize=None, **kwargs): + def __init__(self, isize, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, ahsize=None, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, ahsize=_ahsize, **kwargs) + super(EncoderLayer, self).__init__(isize, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, ahsize=_ahsize, **kwargs) self.attn = ResSelfAttn(isize, _ahsize, num_head=num_head, dropout=attn_drop) - self.ff = PositionwiseFF(isize, _fhsize, dropout) - - # inputs: input of this layer (bsize, seql, isize) - - def forward(self, inputs, mask=None): - - context = self.attn(inputs, mask=mask) - - context = self.ff(context) - - return context + self.ff = PositionwiseFF(isize, hsize=_fhsize, dropout=dropout, act_drop=act_drop) class Encoder(EncoderBase): @@ -49,34 +40,33 @@ class Encoder(EncoderBase): # xseql: maxmimum length of sequence # ahsize: number of hidden units for MultiHeadAttention - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, **kwargs): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, share_layer=False, num_layer_dec=6, **kwargs): - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize - super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) + super(Encoder, self).__init__(isize, nwd, num_layer, fhsize=_fhsize, dropout=dropout, attn_drop=attn_drop, act_drop=act_drop, num_head=num_head, xseql=xseql, ahsize=_ahsize, norm_output=norm_output, share_layer=share_layer, **kwargs) if share_layer: - _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + _shared_layer = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.nets = nn.ModuleList([_shared_layer for i in range(num_layer)]) else: - self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) for i in range(num_layer)]) + self.nets = nn.ModuleList([EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) for i in range(num_layer)]) self.tattn_w = nn.Parameter(torch.Tensor(num_layer + 1, num_layer_dec).uniform_(- sqrt(1.0 / (num_layer + 1)), sqrt(1.0 / (num_layer + 1)))) self.tattn_drop = Dropout(dropout) if dropout > 0.0 else None # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): bsize, seql = inputs.size() out = self.wemb(inputs) - out = out * sqrt(out.size(-1)) if self.pemb is not None: - out = out + self.pemb(inputs, expand=False) + out = self.pemb(inputs, expand=False).add(out, alpha=sqrt(out.size(-1))) if self.drop is not None: out = self.drop(out) diff --git a/transformer/UniEncoder.py b/transformer/UniEncoder.py index 318c75d..87be2c1 100644 --- a/transformer/UniEncoder.py +++ b/transformer/UniEncoder.py @@ -2,21 +2,22 @@ import torch from torch import nn -from modules.base import * - -from cnfg.vocab.base import pad_id +from modules.base import ACT_Loss, CoordinateEmb, Dropout, Scorer from transformer.Encoder import EncoderLayer +from utils.fmt.parser import parse_none +from utils.torch.comp import torch_no_grad from cnfg.ihyp import * +from cnfg.vocab.base import pad_id class Encoder(nn.Module): - def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True): + def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0.0, act_drop=None, num_head=8, xseql=cache_len_default, ahsize=None, norm_output=True, **kwargs): super(Encoder, self).__init__() - _ahsize = isize if ahsize is None else ahsize + _ahsize = parse_none(ahsize, isize) _fhsize = _ahsize * 4 if fhsize is None else fhsize self.num_layer = num_layer @@ -26,7 +27,7 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. self.wemb = nn.Embedding(nwd, isize, padding_idx=pad_id) self.pemb = CoordinateEmb(isize, xseql, num_layer, 0, 0) - self.net = EncoderLayer(isize, _fhsize, dropout, attn_drop, num_head, _ahsize) + self.net = EncoderLayer(isize, _fhsize, dropout, attn_drop, act_drop, num_head, _ahsize) self.halter = nn.Sequential(Scorer(isize), nn.Sigmoid()) self.out_normer = nn.LayerNorm(isize, eps=ieps_ln_default, elementwise_affine=enable_ln_parameters) if norm_output else None @@ -35,9 +36,9 @@ def __init__(self, isize, nwd, num_layer, fhsize=None, dropout=0.0, attn_drop=0. # inputs: (bsize, seql) # mask: (bsize, 1, seql), generated with: - # mask = inputs.eq(0).unsqueeze(1) + # mask = inputs.eq(pad_id).unsqueeze(1) - def forward(self, inputs, mask=None): + def forward(self, inputs, mask=None, **kwargs): bsize, seql = inputs.size() out = self.wemb(inputs) @@ -89,9 +90,15 @@ def forward(self, inputs, mask=None): else: return out, loss_act + def get_embedding_weight(self): + + return self.wemb.weight + def update_vocab(self, indices): - _wemb = nn.Embedding(len(indices), self.wemb.weight.size(-1), padding_idx=pad_id) - with torch.no_grad(): + _wemb = nn.Embedding(indices.numel(), self.wemb.weight.size(-1), padding_idx=self.wemb.padding_idx) + with torch_no_grad(): _wemb.weight.copy_(self.wemb.weight.index_select(0, indices)) self.wemb = _wemb + + return self.wemb.weight diff --git a/utils/README.md b/utils/README.md new file mode 100644 index 0000000..ca50a7a --- /dev/null +++ b/utils/README.md @@ -0,0 +1,5 @@ +# HP-LSTM + +## `ctorch.py` + +Extension functions for pytorch, depends on C++ implementation under `cpp/`, which can be installed (via `python setup.py install`). diff --git a/utils/aan.py b/utils/aan.py index 391e69c..249834b 100644 --- a/utils/aan.py +++ b/utils/aan.py @@ -1,6 +1,7 @@ #encoding: utf-8 from torch.nn import ModuleList + from modules.aan import AverageAttn def share_aan_cache(netin): @@ -14,6 +15,6 @@ def share_aan_cache(netin): if _cache is None: _cache = layer.w else: - layer.w = _cache + layer.register_buffer("w", _cache, persistent=False) return netin diff --git a/utils/angle.py b/utils/angle.py index 3e0d11b..a909bfe 100644 --- a/utils/angle.py +++ b/utils/angle.py @@ -3,14 +3,18 @@ import torch from math import sqrt +from utils.torch.comp import torch_no_grad + +from cnfg.ihyp import ieps_default + def prep_cos(og, ng): return (og * ng).sum(), og.pow(2).sum(), ng.pow(2).sum() -def cos_acc_pg(old_pg, new_pg): +def cos_acc_pg(old_pg, new_pg, ieps=ieps_default): - with torch.no_grad(): + with torch_no_grad(): on, o, n = zip(*[prep_cos(ou, nu) for ou, nu in zip(old_pg, new_pg)]) - sim = (torch.stack(on, 0).sum() / (torch.stack(o, 0).sum() * torch.stack(n, 0).sum()).sqrt()).item() + sim = (torch.stack(on, 0).sum() / (torch.stack(o, 0).sum() * torch.stack(n, 0).sum()).sqrt().add(ieps)).item() - return sim + return min(max(-1.0, sim), 1.0) diff --git a/utils/base.py b/utils/base.py index 03e26b2..bf610ab 100644 --- a/utils/base.py +++ b/utils/base.py @@ -1,170 +1,16 @@ #encoding: utf-8 +import logging import torch -from torch import Tensor -from torch.nn import ModuleDict -from os import makedirs, remove -from os.path import exists as fs_check -from threading import Thread from functools import wraps -from random import sample, seed as rpyseed from math import ceil +from os import makedirs +from os.path import exists as fs_check +from random import seed as rpyseed +from torch import Tensor +from torch.nn import ModuleDict -import logging - -from utils.h5serial import h5save, h5load - -from cnfg.ihyp import h5modelwargs, optm_step_zero_grad_set_none, n_keep_best, use_deterministic, enable_torch_check - -try: - torch.autograd.set_detect_anomaly(enable_torch_check) -except Exception as e: - print(e) - -secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64} - -def all_done_bool(stat, *inputs, **kwargs): - - return stat.all().item() - -def all_done_byte(stat, bsize=None, **kwargs): - - return stat.int().sum().item() == (stat.numel() if bsize is None else bsize) - -def exist_any_bool(stat): - - return stat.any().item() - -def exist_any_byte(stat): - - return stat.int().sum().item() > 0 - -def torch_all_bool_wodim(x, *inputs, **kwargs): - - return x.all(*inputs, **kwargs) - -def torch_all_byte_wodim(x, *inputs, **kwargs): - - return x.int().sum(*inputs, **kwargs).eq(x.numel()) - -def torch_all_bool_dim(x, dim, *inputs, **kwargs): - - return x.all(dim, *inputs, **kwargs) - -def torch_all_byte_dim(x, dim, *inputs, **kwargs): - - return x.int().sum(*inputs, dim=dim, **kwargs).eq(x.size(dim)) - -def torch_all_bool(x, *inputs, dim=None, **kwargs): - - return x.all(*inputs, **kwargs) if dim is None else x.all(dim, *inputs, **kwargs) - -def torch_all_byte(x, *inputs, dim=None, **kwargs): - - return x.int().sum(*inputs, **kwargs).eq(x.numel()) if dim is None else x.int().sum(*inputs, dim=dim, **kwargs).eq(x.size(dim)) - -def torch_any_bool_wodim(x, *inputs, **kwargs): - - return x.any(*inputs, **kwargs) - -def torch_any_byte_wodim(x, *inputs, **kwargs): - - return x.int().sum(*inputs, **kwargs).gt(0) - -def torch_any_bool_dim(x, dim, *inputs, **kwargs): - - return x.any(dim, *inputs, **kwargs) - -def torch_any_byte_dim(x, dim, *inputs, **kwargs): - - return x.int().sum(*inputs, dim=dim, **kwargs).gt(0) - -def torch_any_bool(x, *inputs, dim=None, **kwargs): - - return x.any(*inputs, **kwargs) if dim is None else x.any(dim, *inputs, **kwargs) - -def torch_any_byte(x, *inputs, dim=None, **kwargs): - - return x.int().sum(*inputs, **kwargs).gt(0) if dim is None else x.int().sum(*inputs, dim=dim, **kwargs).gt(0) - -def flip_mask_bool(mask, dim): - - return mask.to(torch.uint8, non_blocking=True).flip(dim).to(mask.dtype, non_blocking=True) - -def flip_mask_byte(mask, dim): - - return mask.flip(dim) - -class EmptyAutocast: - - def __init__(self, *inputs, **kwargs): - - self.args, self.kwargs = inputs, kwargs - - def __enter__(self): - - return self - - def __exit__(self, *inputs, **kwargs): - - pass - -class EmptyGradScaler: - - def __init__(self, *args, **kwargs): - - self.args, self.kwargs = args, kwargs - - def scale(self, outputs): - - return outputs - - def step(self, optimizer, *args, **kwargs): - - return optimizer.step(*args, **kwargs) - - def update(self, *args, **kwargs): - - pass - -def is_autocast_enabled_empty(*args, **kwargs): - - return False - -# handling torch.bool -try: - mask_tensor_type = torch.bool - secure_type_map[mask_tensor_type] = torch.int64 - nccl_type_map = {torch.bool:torch.uint8} - all_done = all_done_bool - exist_any = exist_any_bool - torch_all = torch_all_bool - torch_all_dim = torch_all_bool_dim - torch_all_wodim = torch_all_bool_wodim - torch_any = torch_any_bool - torch_any_dim = torch_any_bool_dim - torch_any_wodim = torch_any_bool_wodim - flip_mask = flip_mask_bool -except Exception as e: - mask_tensor_type = torch.uint8 - nccl_type_map = None - all_done = all_done_byte - exist_any = exist_any_byte - torch_all = torch_all_byte - torch_all_dim = torch_all_byte_dim - torch_all_wodim = torch_all_byte_wodim - torch_any = torch_any_byte - torch_any_dim = torch_any_byte_dim - torch_any_wodim = torch_any_byte_wodim - flip_mask = flip_mask_byte - -# handling torch.cuda.amp, fp16 will NOT be really enabled if torch.cuda.amp does not exist (for early versions) -try: - from torch.cuda.amp import autocast, GradScaler - is_autocast_enabled = torch.is_autocast_enabled - fp16_supported = True -except Exception as e: - autocast, GradScaler, is_autocast_enabled, fp16_supported = EmptyAutocast, EmptyGradScaler, is_autocast_enabled_empty, False +from cnfg.vocab.base import pad_id def pad_tensors(tensor_list, dim=-1): @@ -183,9 +29,9 @@ def get_pad_size(tsize, stdlen, dim=-1): return [tensor if tensor.size(dim) == maxlen else torch.cat((tensor, tensor.new_zeros(get_pad_size(tensor.size(), maxlen))), dim) for tensor in tensor_list] -def clear_pad(batch_in, mask=None, dim=-1): +def clear_pad(batch_in, mask=None, dim=-1, pad_id=pad_id): - _mask = batch_in.eq(0) if mask is None else mask + _mask = batch_in.eq(pad_id) if mask is None else mask npad = _mask.int().sum(dim).min().item() if npad > 0: return batch_in.narrow(dim, 0, batch_in.size(dim) - npad) @@ -204,24 +50,6 @@ def clear_pad_mask(batch_list, mask, dims, mask_dim=-1, return_contiguous=True): else: return batch_list, mask -def freeze_module(module): - - for p in module.parameters(): - if p.requires_grad: - p.requires_grad_(False) - -def unfreeze_module(module): - - def unfreeze_fixing(mod): - - if hasattr(mod, "fix_unfreeze"): - mod.fix_unfreeze() - - for p in module.parameters(): - p.requires_grad_(True) - - module.apply(unfreeze_fixing) - def eq_indexes(tensor, indexes): rs = None @@ -232,134 +60,6 @@ def eq_indexes(tensor, indexes): rs |= tensor.eq(ind) return rs -def getlr(optm): - - lr = [] - for i, param_group in enumerate(optm.param_groups): - lr.append(float(param_group["lr"])) - - return lr - -def updated_lr(oldlr, newlr): - - rs = False - for olr, nlr in zip(oldlr, newlr): - if olr != nlr: - rs = True - break - - return rs - -def reset_Adam(optm, amsgrad=False): - - for group in optm.param_groups: - for p in group["params"]: - state = optm.state[p] - if len(state) != 0: - state["step"] = 0 - state["exp_avg"].zero_() - state["exp_avg_sq"].zero_() - if amsgrad: - state["max_exp_avg_sq"].zero_() - -def reinit_Adam(optm, amsgrad=False): - - for group in optm.param_groups: - for p in group["params"]: - optm.state[p].clear() - -def dynamic_sample(incd, dss_ws, dss_rm): - - rd = {} - for k, v in incd.items(): - if v in rd: - rd[v].append(k) - else: - rd[v] = [k] - incs = list(rd.keys()) - incs.sort(reverse=True) - _full_rl = [] - for v in incs: - _full_rl.extend(rd[v]) - - return _full_rl[:dss_ws] + sample(_full_rl[dss_ws:], dss_rm) if dss_rm > 0 else _full_rl[:dss_ws] - -def load_model_cpu(modf, base_model): - - mpg = h5load(modf) - - for para, mp in zip(base_model.parameters(), mpg): - para.data = mp.data - - return base_model - -def load_model_cpu_old(modf, base_model): - - base_model.load_state_dict(h5load(modf)) - - return base_model - -class SaveModelCleaner: - - def __init__(self): - - self.holder = {} - - def __call__(self, fname, typename): - - if typename in self.holder: - self.holder[typename].update(fname) - else: - self.holder[typename] = bestfkeeper(fnames=[fname]) - -save_model_cleaner = SaveModelCleaner() - -def save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs): - - _msave = model.module if sub_module else model - try: - h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) - if mtyp is not None: - save_model_cleaner(fname, mtyp) - except Exception as e: - if print_func is not None: - print_func(str(e)) - -def async_save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs, para_lock=None, log_success=None): - - def _worker(model, fname, sub_module=False, print_func=print, mtyp=None, para_lock=None, log_success=None): - - success = True - _msave = model.module if sub_module else model - try: - if para_lock is None: - h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) - if mtyp is not None: - save_model_cleaner(fname, mtyp) - else: - with para_lock: - h5save([t.data for t in _msave.parameters()], fname, h5args=h5args) - if mtyp is not None: - save_model_cleaner(fname, mtyp) - except Exception as e: - if print_func is not None: - print_func(str(e)) - success = False - if success and (print_func is not None) and (log_success is not None): - print_func(str(log_success)) - - Thread(target=_worker, args=(model, fname, sub_module, print_func, mtyp, para_lock, log_success)).start() - -def save_states(state_dict, fname, print_func=print, mtyp=None): - - try: - torch.save(state_dict, fname) - if mtyp is not None: - save_model_cleaner(fname, mtyp) - except Exception as e: - if print_func is not None: - print_func(str(e)) - def get_logger(fname): logger = logging.getLogger(__name__) @@ -384,52 +84,6 @@ def set_random_seed(seed, set_cuda=False): torch.manual_seed(_rseed) if set_cuda: torch.cuda.manual_seed_all(_rseed) - try: - torch.backends.cuda.matmul.allow_tf32 = use_deterministic - torch.backends.cudnn.allow_tf32 = use_deterministic - torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = allow_fp16_reduction - except: - pass - # Make cudnn methods deterministic according to: https://pytorch.org/docs/stable/notes/randomness.html#cudnn - try: - torch.use_deterministic_algorithms(use_deterministic) - except: - torch.backends.cudnn.deterministic = use_deterministic - torch.backends.cudnn.benchmark = False - -def module_train(netin, module, mode=True): - - for net in netin.modules(): - if isinstance(net, module): - net.train(mode=mode) - - return netin - -def repeat_bsize_for_beam_tensor(tin, beam_size): - - _tsize = list(tin.size()) - _rarg = [1 for i in range(len(_tsize))] - _rarg[1] = beam_size - _tsize[0] *= beam_size - - return tin.repeat(*_rarg).view(_tsize) - -def expand_bsize_for_beam(*inputs, beam_size=1): - - outputs = [] - for inputu in inputs: - if isinstance(inputu, Tensor): - outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size)) - elif isinstance(inputu, dict): - outputs.append({k: expand_bsize_for_beam(v, beam_size=beam_size) for k, v in inputu.items()}) - elif isinstance(inputu, tuple): - outputs.append(tuple(expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu)) - elif isinstance(inputu, list): - outputs.append([expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu]) - else: - outputs.append(inputu) - - return outputs[0] if len(inputs) == 1 else tuple(outputs) def index_tensors(*inputs, indices=None, dim=0): @@ -477,20 +131,99 @@ def ModuleList2Dict(modin): return ModuleDict(zip([str(i) for i in range(len(modin))], modin)) -def add_module(m, strin, m_add): +def get_module_nl(m, nl): + + _m, _success = m, True + for _tmp in nl: + # update _modules with pytorch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module + if _tmp in _m._modules: + _m = _m._modules[_tmp] + else: + _success = False + break + + return _m, _success + +def add_module(m, strin, m_add, print_func=print, **kwargs): _name_list = strin.split(".") if len(_name_list) == 1: m.add_module(strin, m_add) else: - _m = m - # update _modules with pytorch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.add_module - for _tmp in _name_list[:-1]: - _m = _m._modules[_tmp] - _m.add_module(_name_list[-1], m_add) + _m, _success = get_module_nl(m, _name_list[:-1]) + if _success: + _m.add_module(_name_list[-1], m_add) + elif print_func is not None: + print_func(strin) + + return m + +def add_parameter(m, strin, p_add, print_func=print, **kwargs): + + _name_list = strin.split(".") + if len(_name_list) == 1: + m.register_parameter(strin, p_add) + else: + _m, _success = get_module_nl(m, _name_list[:-1]) + if _success: + _m.register_parameter(_name_list[-1], p_add) + elif print_func is not None: + print_func(strin) return m +def add_buffer(m, strin, b_add, persistent=True, print_func=print, **kwargs): + + _name_list = strin.split(".") + if len(_name_list) == 1: + m.register_buffer(strin, b_add, persistent=persistent) + else: + _m, _success = get_module_nl(m, _name_list[:-1]) + if _success: + _m.register_buffer(_name_list[-1], b_add, persistent=persistent) + elif print_func is not None: + print_func(strin) + + return m + +def is_buffer_persistent(m, strin, persistent=True, print_func=print, **kwargs): + + _name_list = strin.split(".") + rs = persistent + if len(_name_list) == 1: + # update _non_persistent_buffers_set with pytorch: https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module.register_buffer + if hasattr(m, "_non_persistent_buffers_set"): + rs = strin not in m._non_persistent_buffers_set + else: + _m, _success = get_module_nl(m, _name_list[:-1]) + if _success: + if hasattr(_m, "_non_persistent_buffers_set"): + rs = _name_list[-1] not in _m._non_persistent_buffers_set + elif print_func is not None: + print_func(strin) + + return rs + +def bind_module_parameter(srcm, tgtm, **kwargs): + + _ = tgtm + for _n, _p in srcm.named_parameters(): + _ = add_parameter(_, _n, _p, **kwargs) + + return _ + +def bind_module_buffer(srcm, tgtm, persistent=None, **kwargs): + + _ = tgtm + for _n, _b in srcm.named_buffers(): + _ = add_buffer(_, _n, _b, persistent=is_buffer_persistent(srcm, _n) if persistent is None else persistent, **kwargs) + + return _ + +def bind_module_parabuf(srcm, tgtm, persistent=None, **kwargs): + + return bind_module_buffer(srcm, bind_module_parameter(srcm, tgtm, **kwargs), persistent=persistent, **kwargs) + def reduce_model_core(modin, redm, attr_func=None): if attr_func is None: @@ -572,32 +305,6 @@ def iternext(iterin): return rs -def optm_step_std(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False, zero_grad_none=optm_step_zero_grad_set_none): - - if multi_gpu: - model.collect_gradients() - if scaler is None: - optm.step(closure=closure) - else: - scaler.step(optm, closure=closure) - scaler.update() - if not multi_gpu_optimizer: - optm.zero_grad(set_to_none=zero_grad_none) - if multi_gpu: - model.update_replicas() - -def optm_step_wofp16(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False, zero_grad_none=optm_step_zero_grad_set_none): - - if multi_gpu: - model.collect_gradients() - optm.step(closure=closure) - if not multi_gpu_optimizer: - optm.zero_grad(set_to_none=zero_grad_none) - if multi_gpu: - model.update_replicas() - -optm_step = optm_step_std if fp16_supported else optm_step_wofp16 - def divide_para_ind(para_list, ngroup, return_np=False): elel = [pu.numel() for pu in para_list] @@ -706,28 +413,3 @@ def get_hold(self, k, sv=None): def __exit__(self, *inputs, **kwargs): pass - -class bestfkeeper: - - def __init__(self, fnames=None, k=n_keep_best): - - self.fnames, self.k = [] if fnames is None else fnames, k - self.clean() - - def update(self, fname=None): - - self.fnames.append(fname) - self.clean(last_fname=fname) - - def clean(self, last_fname=None): - - _n_files = len(self.fnames) - _last_fname = (self.fnames[-1] if self.fnames else None) if last_fname is None else last_fname - while _n_files > self.k: - fname = self.fnames.pop(0) - if (fname is not None) and (fname != _last_fname) and fs_check(fname): - try: - remove(fname) - except Exception as e: - print(e) - _n_files -= 1 diff --git a/utils/comm.py b/utils/comm.py index dcdc999..13bc19b 100644 --- a/utils/comm.py +++ b/utils/comm.py @@ -1,7 +1,8 @@ #encoding: utf-8 import torch.cuda.comm as comm -from utils.base import nccl_type_map + +from utils.torch.comp import nccl_type_map def secure_broadcast_coalesced(tensors, devices, buffer_size=10485760): diff --git a/utils/contpara.py b/utils/contpara.py index 2fbff2e..b48d949 100644 --- a/utils/contpara.py +++ b/utils/contpara.py @@ -2,14 +2,14 @@ # WARNING: this file may create _contiguous_parameters to the model -import torch from torch import nn from utils.base import filter_para_grad +from utils.torch.comp import torch_no_grad class ContiguousParams(nn.Module): - def __init__(self, parameters=None, init_tensors=None): + def __init__(self, parameters=None, init_tensors=None, **kwargs): super(ContiguousParams, self).__init__() @@ -56,7 +56,7 @@ def allocate(self, parameters=None, init_tensors=None): def bind(self, update=True): - with torch.no_grad(): + with torch_no_grad(): for pl, weight in zip(self.pll, self.weights): if len(pl) > 1: lind = 0 @@ -73,7 +73,7 @@ def bind(self, update=True): def bind_data(self, update=True): - with torch.no_grad(): + with torch_no_grad(): for pl, weight in zip(self.pll, self.weights): if len(pl) > 1: lind = 0 diff --git a/utils/decode/__init__.py b/utils/decode/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/decode/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/decode/base.py b/utils/decode/base.py new file mode 100644 index 0000000..768db77 --- /dev/null +++ b/utils/decode/base.py @@ -0,0 +1,26 @@ +#encoding: utf-8 + +def set_is_decoding(m, mode): + + for _ in m.modules(): + if hasattr(_, "is_decoding"): + if isinstance(_.is_decoding, bool): + _.is_decoding = mode + else: + _.is_decoding(mode) + + return m + +class model_decoding: + + def __init__(self, net, **kwargs): + + self.net = net + + def __enter__(self): + + return set_is_decoding(self.net, True) + + def __exit__(self, *inputs, **kwargs): + + set_is_decoding(self.net, False) diff --git a/utils/decode/beam.py b/utils/decode/beam.py new file mode 100644 index 0000000..9fb646e --- /dev/null +++ b/utils/decode/beam.py @@ -0,0 +1,29 @@ +#encoding: utf-8 + +from torch import Tensor + +def repeat_bsize_for_beam_tensor(tin, beam_size): + + _tsize = list(tin.size()) + _rarg = [1 for i in range(len(_tsize))] + _rarg[1] = beam_size + _tsize[0] *= beam_size + + return tin.repeat(*_rarg).view(_tsize) + +def expand_bsize_for_beam(*inputs, beam_size=1): + + outputs = [] + for inputu in inputs: + if isinstance(inputu, Tensor): + outputs.append(repeat_bsize_for_beam_tensor(inputu, beam_size)) + elif isinstance(inputu, dict): + outputs.append({k: expand_bsize_for_beam(v, beam_size=beam_size) for k, v in inputu.items()}) + elif isinstance(inputu, tuple): + outputs.append(tuple(expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu)) + elif isinstance(inputu, list): + outputs.append([expand_bsize_for_beam(tmpu, beam_size=beam_size) for tmpu in inputu]) + else: + outputs.append(inputu) + + return outputs[0] if len(inputs) == 1 else tuple(outputs) diff --git a/utils/dynbatch.py b/utils/dynbatch.py index 9a7b646..78bb44b 100644 --- a/utils/dynbatch.py +++ b/utils/dynbatch.py @@ -1,13 +1,13 @@ #encoding: utf-8 -import torch - -from math import log2, exp, pi, acos +from math import acos, exp, log2, pi from random import random -from utils.angle import prep_cos, cos_acc_pg + +from utils.angle import cos_acc_pg from utils.random import multinomial +from utils.torch.comp import torch_no_grad -# comment the following line and uncomment the 4 lines following it to load para_group_select_alpha from cnfg.dynb +# comment the following line and uncomment the 4 lines below it to load para_group_select_alpha from cnfg.dynb para_group_select_alpha = 3.0 """try: from cnfg.dynb import select_alpha as para_group_select_alpha @@ -36,19 +36,21 @@ def pos_norm(lin, alpha=1.0): if alpha != 1.0: tmp = [tmpu ** alpha for tmpu in tmp] _mv = sum(tmp) + if _mv == 0.0: + _mv = 1.0 return [tmpu / _mv for tmpu in tmp] def backup_para_grad(plin): - with torch.no_grad(): + with torch_no_grad(): rs = [pu.grad.clone() for pu in plin] return rs class EffRecorder: - def __init__(self, num_choice, num_his=50, init_value=180.0): + def __init__(self, num_choice, num_his=50, init_value=180.0, **kwargs): self.his = [[init_value] for i in range(num_choice)] self.num_his = num_his @@ -67,7 +69,7 @@ def get_w(self): class MvAvgRecorder: - def __init__(self, num_choice, beta=None, num_his=50, init_value=180.0): + def __init__(self, num_choice, beta=None, num_his=50, init_value=180.0, **kwargs): self.beta = (0.9 if num_his is None else (0.5 ** (1.0 / num_his))) if beta is None else beta self.his = [(init_value * (1.0 - self.beta)) for i in range(num_choice)] @@ -92,7 +94,7 @@ class GradientMonitor: # num_his_gm: cache num_his_gm gradients into a history, and return this number of angle changes. # returns: (update_r, angle_r), update_r: to performing an optimization step, angle_r: the angle change in current step. - def __init__(self, num_group, select_func, module=None, angle_alpha=1.1, num_tol_amin=3, num_his_record=50, num_his_gm=1): + def __init__(self, num_group, select_func, module=None, angle_alpha=1.1, num_tol_amin=3, num_his_record=50, num_his_gm=1, **kwargs): self.scale = 180.0 / pi self.num_group = num_group @@ -160,6 +162,8 @@ def get_delta(lin): def get_delta_norm(lin): _mv = max(*lin) + if _mv == 0.0: + _mv = 1.0 return (_mv - min(*lin)) / _mv diff --git a/utils/fmt/ape/triple.py b/utils/fmt/ape/triple.py index 40fcd78..cd14a6e 100644 --- a/utils/fmt/ape/triple.py +++ b/utils/fmt/ape/triple.py @@ -1,23 +1,28 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, fmt, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, fmt, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rsm = [] rst = [] nd = maxlen = mlen_i = mlen_m = mlen_t = 0 - for i_d, md, td in zip(list_reader(finput, keep_empty_line=True), list_reader(fmt, keep_empty_line=True), list_reader(ftarget, keep_empty_line=True)): + for i_d, md, td in zip(file_reader(finput, keep_empty_line=True), file_reader(fmt, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): lid = len(i_d) lmd = len(md) ltd = len(td) lgth = lid + lmd + ltd if maxlen == 0: - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * 3, maxtoken, bsize) if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): rsi.append(i_d) rsm.append(md) @@ -37,23 +42,22 @@ def batch_loader(finput, fmt, ftarget, bsize, maxpad, maxpart, maxtoken, minbsiz mlen_i = lid mlen_m = lmd mlen_t = ltd - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * 3, maxtoken, bsize) nd = 1 if rsi: yield rsi, rsm, rst, mlen_i, mlen_m, mlen_t -def batch_mapper(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, md, td, mlen_i, mlen_m, mlen_t in _batch_loader(finput, fmt, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, md, td, mlen_i, mlen_m, mlen_t in batch_loader(finput, fmt, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) rsm, extok_m = map_batch(md, vocabt) rst, extok_t = map_batch(td, vocabt) yield rsi, rsm, rst, mlen_i + extok_i, mlen_m + extok_m, mlen_t + extok_t -def batch_padder(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, md, td, mlen_i, mlen_m, mlen_t in _batch_mapper(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), pad_batch(md, mlen_m), pad_batch(td, mlen_t) + for i_d, md, td, mlen_i, mlen_m, mlen_t in batch_mapper(finput, fmt, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(md, mlen_m, pad_id=pad_id), pad_batch(td, mlen_t, pad_id=pad_id) diff --git a/utils/fmt/base.py b/utils/fmt/base.py index 364bacf..decaf46 100644 --- a/utils/fmt/base.py +++ b/utils/fmt/base.py @@ -1,21 +1,38 @@ #encoding: utf-8 import sys - +from bz2 import open as bz_open +from gzip import open as gz_open +from lzma import open as xz_open from random import shuffle -from cnfg.vocab.base import * +from cnfg.hyp import raw_cache_compression_level +from cnfg.vocab.base import pad_id serial_func, deserial_func = repr, eval -tostr = lambda lin: [str(lu) for lu in lin] -toint = lambda lin: [int(lu) for lu in lin] -tofloat = lambda lin: [float(lu) for lu in lin] +iter_to_str = lambda lin: map(str, lin) +iter_to_int = lambda lin: map(int, lin) +iter_to_float = lambda lin: map(float, lin) + +def sys_open(fname, mode="r", compresslevel=raw_cache_compression_level, **kwargs): + + if fname == "-": + return ((sys.stdin.buffer if "r" in mode else sys.stdout.buffer) if "b" in mode else (sys.stdin if "r" in mode else sys.stdout)) + else: + if fname.endswith(".gz"): + return gz_open(fname, mode=mode, compresslevel=compresslevel, **kwargs) + elif fname.endswith(".bz2"): + return bz_open(fname, mode=mode, compresslevel=compresslevel, **kwargs) + elif fname.endswith(".xz"): + return xz_open(fname, mode=mode, **kwargs) + else: + return open(fname, mode=mode, **kwargs) def save_objects(fname, *inputs): ens = "\n".encode("utf-8") - with sys.stdout.buffer if fname == "-" else open(fname, "wb") as f: + with sys_open(fname, "wb") as f: for tmpu in inputs: f.write(serial_func(tmpu).encode("utf-8")) f.write(ens) @@ -23,7 +40,7 @@ def save_objects(fname, *inputs): def load_objects(fname): rs = [] - with sys.stdin.buffer if fname == "-" else open(fname, "rb") as f: + with sys_open(fname, "rb") as f: for line in f: tmp = line.strip() if tmp: @@ -34,7 +51,7 @@ def load_objects(fname): def load_states(fname): rs = [] - with sys.stdin.buffer if fname == "-" else open(fname, "rb") as f: + with sys_open(fname, "rb") as f: for line in f: tmp = line.strip() if tmp: @@ -44,13 +61,13 @@ def load_states(fname): return rs -def list_reader(fname, keep_empty_line=True, print_func=print): +def list_reader(fname, keep_empty_line=True, sep=None, print_func=print): - with sys.stdin.buffer if fname == "-" else open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: - tmp = clean_list(tmp.decode("utf-8").split()) + tmp = clean_list(tmp.decode("utf-8").split(sep=sep)) yield tmp else: if print_func is not None: @@ -60,115 +77,66 @@ def list_reader(fname, keep_empty_line=True, print_func=print): def line_reader(fname, keep_empty_line=True, print_func=print): - with sys.stdin.buffer if fname == "-" else open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: for line in frd: tmp = line.strip() if tmp: yield tmp.decode("utf-8") else: if print_func is not None: - print_func("Reminder: encounter an empty line, which shall not be the case.") + print_func("Reminder: encounter an empty line, which may not be the case.") if keep_empty_line: yield "" -def ldvocab(vfile, minf=False, omit_vsize=False, vanilla=False, init_vocab=init_vocab, init_normal_token_id=init_normal_token_id): +def line_char_reader(fname, keep_empty_line=True, print_func=print): - if vanilla: - rs, cwd = {}, 0 - else: - rs, cwd = init_vocab.copy(), init_normal_token_id - if omit_vsize: - vsize = omit_vsize - else: - vsize = False - for data in list_reader(vfile, keep_empty_line=False): - freq = int(data[0]) - if (not minf) or freq > minf: - if vsize: - ndata = len(data) - 1 - if vsize >= ndata: - for wd in data[1:]: - rs[wd] = cwd - cwd += 1 - else: - for wd in data[1:vsize + 1]: - rs[wd] = cwd - cwd += 1 - ndata = vsize - break - vsize -= ndata - if vsize <= 0: - break + with sys_open(fname, "rb") as frd: + for line in frd: + tmp = line.strip() + if tmp: + yield list(tmp.decode("utf-8")) else: - for wd in data[1:]: - rs[wd] = cwd - cwd += 1 - else: - break - return rs, cwd - -def save_vocab(vcb_dict, fname, omit_vsize=False): - - r_vocab = {} - for k, v in vcb_dict.items(): - if v not in r_vocab: - r_vocab[v]=[str(v), k] - else: - r_vocab[v].append(k) - - freqs = list(r_vocab.keys()) - freqs.sort(reverse=True) - - ens = "\n".encode("utf-8") - remain = omit_vsize - with sys.stdout.buffer if fname == "-" else open(fname, "wb") as f: - for freq in freqs: - cdata = r_vocab[freq] - ndata = len(cdata) - 1 - if remain and (remain < ndata): - cdata = cdata[:remain + 1] - ndata = remain - f.write(" ".join(cdata).encode("utf-8")) - f.write(ens) - if remain: - remain -= ndata - if remain <= 0: - break + if print_func is not None: + print_func("Reminder: encounter an empty line, which may not be the case.") + if keep_empty_line: + yield [] -def reverse_dict(din): +def list_reader_wst(fname, keep_empty_line=True, sep=None, print_func=print): - return {v:k for k, v in din.items()} + with sys_open(fname, "rb") as frd: + for line in frd: + tmp = line.strip(b"\r\n") + if tmp: + tmp = clean_list(tmp.decode("utf-8").split(sep=sep)) + yield tmp + else: + if print_func is not None: + print_func("Reminder: encounter an empty line, which may not be the case.") + if keep_empty_line: + yield [] -def ldvocab_list(vfile, minf=False, omit_vsize=False): +def line_reader_wst(fname, keep_empty_line=True, print_func=print): - rs = [] - if omit_vsize: - vsize = omit_vsize - else: - vsize = False - cwd = 0 - for data in list_reader(vfile, keep_empty_line=False): - freq = int(data[0]) - if (not minf) or freq > minf: - if vsize: - ndata = len(data) - 1 - if vsize >= ndata: - rs.extend(data[1:]) - cwd += ndata - else: - rs.extend(data[1:vsize + 1]) - cwd += vsize - break - vsize -= ndata - if vsize <= 0: - break + with sys_open(fname, "rb") as frd: + for line in frd: + tmp = line.strip(b"\r\n") + if tmp: + yield tmp.decode("utf-8") else: - rs.extend(data[1:]) - cwd += len(data) - 1 - else: - break + if print_func is not None: + print_func("Reminder: encounter an empty line, which may not be the case.") + if keep_empty_line: + yield "" - return rs, cwd +def loop_file_so(fsrc, frs, process_func=None, processor=None): + + ens = "\n".encode("utf-8") + with sys_open(fsrc, "rb") as frd, sys_open(frs, "wb") as fwrt: + for line in frd: + tmp = line.strip() + if tmp: + fwrt.write(process_func(tmp.decode("utf-8"), processor).encode("utf-8")) + fwrt.write(ens) def clean_str(strin): @@ -190,7 +158,60 @@ def clean_liststr_lentok(lin): return " ".join(rs), len(rs) -def maxfreq_filter_core(ls, lt): +def maxfreq_filter_many(inputs): + + tmp = {} + for _ in inputs: + us, ut = tuple(_[:-1]), _[-1] + if us in tmp: + tmp[us][ut] = tmp[us].get(ut, 0) + 1 + else: + tmp[us] = {ut: 1} + + rs = [] + for tus, tlt in tmp.items(): + _rs = [] + _maxf = 0 + for key, value in tlt.items(): + if value > _maxf: + _maxf = value + _rs = [key] + elif value == _maxf: + _rs.append(key) + for tut in _rs: + rs.append((*tus, tut,)) + + return rs + +def maxfreq_filter_bi(inputs): + + tmp = {} + for us, ut in inputs: + if us in tmp: + tmp[us][ut] = tmp[us].get(ut, 0) + 1 + else: + tmp[us] = {ut: 1} + + rs = [] + for tus, tlt in tmp.items(): + _rs = [] + _maxf = 0 + for key, value in tlt.items(): + if value > _maxf: + _maxf = value + _rs = [key] + elif value == _maxf: + _rs.append(key) + for tut in _rs: + rs.append((tus, tut,)) + + return rs + +def maxfreq_filter(inputs): + + return maxfreq_filter_many(inputs) if len(inputs[0]) > 2 else maxfreq_filter_bi(inputs) + +def maxfreq_filter_core_pair(ls, lt): tmp = {} for us, ut in zip(ls, lt): @@ -215,14 +236,14 @@ def maxfreq_filter_core(ls, lt): return rls, rlt -def maxfreq_filter(*inputs): +def maxfreq_filter_pair(*inputs): if len(inputs) > 2: # here we assume that we only have one target and it is at the last position - rsh, rst = maxfreq_filter_core(tuple(zip(*inputs[0:-1])), inputs[-1]) + rsh, rst = maxfreq_filter_core_pair(tuple(zip(*inputs[:-1])), inputs[-1]) return *zip(*rsh), rst else: - return maxfreq_filter_core(*inputs) + return maxfreq_filter_core_pair(*inputs) def shuffle_pair(*inputs): @@ -239,19 +260,6 @@ def get_bsize(maxlen, maxtoken, maxbsize): return min(rs, maxbsize) -def no_unk_mapper(vcb, ltm, print_func=None): - - if print_func is None: - return [vcb[wd] for wd in ltm if wd in vcb] - else: - rs = [] - for wd in ltm: - if wd in vcb: - rs.append(vcb[wd]) - else: - print_func("Error mapping: "+ wd) - return rs - def list2dict(lin, kfunc=None): return {k: lu for k, lu in enumerate(lin)} if kfunc is None else {kfunc(k): lu for k, lu in enumerate(lin)} @@ -315,17 +323,23 @@ def dict_insert_list(dict_in, value, *keys): return dict_in -def legal_vocab(sent, ilgset, ratio): +def seperate_list_iter(lin, k): - total = ilg = 0 - for tmpu in sent.split(): - if tmpu: - if tmpu in ilgset: - ilg += 1 - total += 1 - rt = float(ilg) / float(total) + i = 0 + _ = [] + for lu in lin: + _.append(lu) + i += 1 + if i >= k: + yield _ + _ = [] + i = 0 + if _: + yield _ + +def seperate_list(lin, k): - return False if rt > ratio else True + return list(seperate_list_iter(lin, k)) def all_in(lin, setin): @@ -365,20 +379,6 @@ def get_bi_ratio(ls, lt): else: return float(lt) / float(ls) -def map_batch_core(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs): - - if isinstance(i_d[0], (tuple, list,)): - return [map_batch_core(idu, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs) for idu in i_d] - else: - rsi = [sos_id] - rsi.extend([vocabi.get(wd, unk_id) for wd in i_d] if use_unk else no_unk_mapper(vocabi, i_d))#[vocabi[wd] for wd in i_d if wd in vocabi] - rsi.append(eos_id) - return rsi - -def map_batch(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs): - - return map_batch_core(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs), 2 - def pad_batch(i_d, mlen_i, pad_id=pad_id): if isinstance(i_d[0], (tuple, list,)): @@ -393,7 +393,7 @@ class FileList(list): def __init__(self, files, *inputs, **kwargs): - super(FileList, self).__init__(open(fname, *inputs, **kwargs) for fname in files) + super(FileList, self).__init__(sys_open(fname, *inputs, **kwargs) for fname in files) def __enter__(self): @@ -408,9 +408,8 @@ def multi_line_reader(fname, *inputs, num_line=1, **kwargs): _i = 0 rs = [] - _enc = ("rb" in inputs) or ("rb" in kwargs.values()) - ens = "\n".encode("utf-8") if _enc else "\n" - with (sys.stdin.buffer if _enc else sys.stdin) if fname == "-" else open(fname, *inputs, **kwargs) as frd: + ens = "\n".encode("utf-8") if ("rb" in inputs) or ("rb" in kwargs.values()) else "\n" + with sys_open(fname, *inputs, **kwargs) as frd: for line in frd: tmp = line.rstrip() rs.append(tmp) @@ -421,3 +420,11 @@ def multi_line_reader(fname, *inputs, num_line=1, **kwargs): _i = 0 if rs: yield ens.join(rs) + +def read_lines(fin, num_lines): + + _last_ind = num_lines - 1 + for i, _ in enumerate(fin, 1): + yield _ + if i > _last_ind: + break diff --git a/utils/fmt/base4torch.py b/utils/fmt/base4torch.py index 64943f8..3042734 100644 --- a/utils/fmt/base4torch.py +++ b/utils/fmt/base4torch.py @@ -1,10 +1,11 @@ #encoding: utf-8 import torch - from math import sqrt + from utils.fmt.base import list_reader -from utils.h5serial import h5save, h5load +from utils.h5serial import h5load +from utils.torch.comp import torch_no_grad def parse_cuda(use_cuda_arg, gpuid=None): @@ -24,7 +25,7 @@ def parse_cuda(use_cuda_arg, gpuid=None): multi_gpu = False torch.cuda.set_device(cuda_device.index) else: - use_cuda, cuda_device, cuda_devices, multi_gpu = False, False, None, False + use_cuda, cuda_device, cuda_devices, multi_gpu = False, None, None, False return use_cuda, cuda_device, cuda_devices, multi_gpu @@ -47,7 +48,7 @@ def parse_cuda_decode(use_cuda_arg, gpuid=None, multi_gpu_decoding=False): multi_gpu = False torch.cuda.set_device(cuda_device.index) else: - use_cuda, cuda_device, cuda_devices, multi_gpu = False, False, None, False + use_cuda, cuda_device, cuda_devices, multi_gpu = False, None, None, False return use_cuda, cuda_device, cuda_devices, multi_gpu @@ -68,7 +69,7 @@ def load_emb(embf, embt, nword, scale_down_emb, freeze_emb): _emb = _emb.narrow(0, 0, nword).contiguous() if scale_down_emb: _emb.div_(sqrt(embt.size(-1))) - with torch.no_grad(): + with torch_no_grad(): embt.copy_(_emb) if freeze_emb: embt.requires_grad_(False) diff --git a/utils/fmt/char/__init__.py b/utils/fmt/char/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/char/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/char/dual.py b/utils/fmt/char/dual.py new file mode 100644 index 0000000..2d8e6e4 --- /dev/null +++ b/utils/fmt/char/dual.py @@ -0,0 +1,8 @@ +#encoding: utf-8 + +from utils.fmt.base import line_char_reader as file_reader +from utils.fmt.dual import batch_padder as batch_padder_base + +def batch_padder(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, file_reader=file_reader, **kwargs): + + return batch_padder_base(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, file_reader=file_reader, **kwargs) diff --git a/utils/fmt/char/single.py b/utils/fmt/char/single.py new file mode 100644 index 0000000..ae3cb45 --- /dev/null +++ b/utils/fmt/char/single.py @@ -0,0 +1,8 @@ +#encoding: utf-8 + +from utils.fmt.base import line_char_reader as file_reader +from utils.fmt.single import batch_padder as batch_padder_base + +def batch_padder(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, file_reader=file_reader, **kwargs): + + return batch_padder_base(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, file_reader=file_reader, **kwargs) diff --git a/utils/fmt/diff.py b/utils/fmt/diff.py new file mode 100644 index 0000000..85e0997 --- /dev/null +++ b/utils/fmt/diff.py @@ -0,0 +1,60 @@ +#encoding: utf-8 + +from difflib import SequenceMatcher + +def seq_diff(x, ref): + + for tag, xsi, xei, rsi, rei in SequenceMatcher(None, x, ref, autojunk=False).get_opcodes(): + _tc = tag[0] + if _tc == "d": + for _ in x[xsi:xei]: + yield _tc, _ + elif _tc == "e": + for _ in x[xsi:xei]: + yield _tc, _ + elif _tc == "i": + for _ in ref[rsi:rei]: + yield _tc, _ + else: + for _ in x[xsi:xei]: + yield "d", _ + for _ in ref[rsi:rei]: + yield "i", _ + +def reorder_insert(seqin): + + _d_cache = [] + for _du in seqin: + _op = _du[0] + if _op == "d": + _d_cache.append(_du) + else: + if (_op == "e") and _d_cache: + yield from _d_cache + _d_cache = [] + yield _du + if _d_cache: + yield from _d_cache + +def seq_diff_reorder_insert(x, ref): + + for tag, xsi, xei, rsi, rei in SequenceMatcher(None, x, ref, autojunk=False).get_opcodes(): + _tc = tag[0] + if _tc == "d": + for _ in x[xsi:xei]: + yield _tc, _ + elif _tc == "e": + for _ in x[xsi:xei]: + yield _tc, _ + elif _tc == "i": + for _ in ref[rsi:rei]: + yield _tc, _ + else: + for _ in ref[rsi:rei]: + yield "i", _ + for _ in x[xsi:xei]: + yield "d", _ + +seq_diff_ratio = lambda x, ref: SequenceMatcher(None, x, ref, autojunk=False).ratio() +seq_diff_ratio_ub = lambda x, ref: SequenceMatcher(None, x, ref, autojunk=False).quick_ratio() +seq_diff_ratio_ub_fast = lambda x, ref: SequenceMatcher(None, x, ref, autojunk=False).real_quick_ratio() diff --git a/utils/fmt/doc/base.py b/utils/fmt/doc/base.py index 62e7f1d..c44f3ed 100644 --- a/utils/fmt/doc/base.py +++ b/utils/fmt/doc/base.py @@ -2,17 +2,17 @@ import sys -from utils.fmt.base import clean_list +from utils.fmt.base import clean_list, sys_open -def doc_reader(fname): +def doc_reader(fname, sep=None): - with sys.stdin.buffer if fname == "-" else open(fname, "rb") as frd: + with sys_open(fname, "rb") as frd: cache = [] max_tok = 0 for line in frd: tmp = line.strip() if tmp: - tmp = clean_list(tmp.decode("utf-8").split()) + tmp = clean_list(tmp.decode("utf-8").split(sep=sep)) _ld = len(tmp) if _ld > max_tok: max_tok = _ld diff --git a/utils/fmt/doc/para/dual.py b/utils/fmt/doc/para/dual.py index c400934..63ecb04 100644 --- a/utils/fmt/doc/para/dual.py +++ b/utils/fmt/doc/para/dual.py @@ -1,23 +1,27 @@ #encoding: utf-8 -from utils.fmt.base import get_bsize, map_batch, pad_batch -from utils.fmt.doc.base import doc_reader from math import ceil -def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, pad_batch +from utils.fmt.doc.base import doc_reader as file_reader +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rst = [] nd = maxlen = minlen = mlen_i = mlen_t = nsent = 0 - for (i_d, i_lgth), (td, t_lgth) in zip(doc_reader(finput), doc_reader(ftarget)): + for (i_d, i_lgth), (td, t_lgth) in zip(file_reader(finput), file_reader(ftarget)): cur_nsent = len(i_d) lgth = i_lgth + t_lgth if maxlen == 0: _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) maxlen = lgth + _maxpad minlen = lgth - _maxpad - _bsize = max(1, get_bsize(maxlen, maxtoken, bsize) // cur_nsent) + _bsize = max(1, get_bsize(maxlen + _maxpad, maxtoken, bsize) // cur_nsent) nsent = cur_nsent if (cur_nsent == nsent) and ((nd < minbsize) or (lgth <= maxlen and lgth >= minlen and nd < _bsize)): rsi.append(i_d) @@ -37,21 +41,19 @@ def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) maxlen = lgth + _maxpad minlen = lgth - _maxpad - _bsize = max(1, get_bsize(maxlen, maxtoken, bsize) // cur_nsent) + _bsize = max(1, get_bsize(maxlen + _maxpad, maxtoken, bsize) // cur_nsent) nd = 1 if rsi: yield rsi, rst, mlen_i, mlen_t, nsent -def batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, td, mlen_i, mlen_t, nsent in _batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, td, mlen_i, mlen_t, nsent in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) rst, extok_t = map_batch(td, vocabt) yield rsi, rst, mlen_i + extok_i, mlen_t + extok_t, nsent -def batch_padder(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, td, mlen_i, mlen_t, nsent in _batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), pad_batch(td, mlen_t), nsent + for i_d, td, mlen_i, mlen_t, nsent in batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(td, mlen_t, pad_id=pad_id), nsent diff --git a/utils/fmt/doc/para/many.py b/utils/fmt/doc/para/many.py new file mode 100644 index 0000000..4788ea1 --- /dev/null +++ b/utils/fmt/doc/para/many.py @@ -0,0 +1,72 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, pad_batch +from utils.fmt.doc.base import doc_reader as file_reader +from utils.fmt.doc.para.single import batch_padder as batch_padder_single +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader_many(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): + + _f_maxpart = float(maxpart) + rs = [[] for i in range(len(filelist))] + nd = maxlen = minlen = nsent = 0 + mlen = None + for linelens in zip(*[file_reader(f) for f in filelist]): + lines, lens = zip(*linelens) + cur_nsent = len(lines[0]) + lgth = sum(lens) + if maxlen == 0: + _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = max(1, get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) // cur_nsent) + mlen = lens + if (cur_nsent == nsent) and ((nd < minbsize) or (lgth <= maxlen and lgth >= minlen and nd < _bsize)): + for line, rsu in zip(lines, rs): + rsu.append(line) + for cur_len, (i, mlenu,) in zip(lens, enumerate(mlen)): + if cur_len > mlenu: + mlen[i] = cur_len + nd += 1 + else: + yield rs, mlen, nsent + rs = [[line] for line in lines] + mlen = lens + nsent = cur_nsent + _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = max(1, get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) // cur_nsent) + nd = 1 + if rs: + yield rs, mlen, nsent + +def batch_mapper_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader_many, **kwargs): + + for _rs, _mlen, nsent in batch_loader(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + rs = [] + mlen = [] + for rsu, mlenu, vocab in zip(_rs, _mlen, vocablist): + _rs, extok = map_batch(rsu, vocab) + rs.append(_rs) + mlen.append(mlenu + extok) + yield rs, mlen, nsent + +def batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper_many, pad_id=pad_id, **kwargs): + + for rs, mlen, nsent in batch_mapper(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield *tuple(pad_batch(rsu, mlenu, pad_id=pad_id) for rsu, mlenu in zip(rs, mlen)), nsent + +def batch_padder(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + + if isinstance(filelist, (list, tuple,)): + if len(filelist) > 1: + return batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist[0], vocablist[0], bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) diff --git a/utils/fmt/doc/para/single.py b/utils/fmt/doc/para/single.py index befc626..019a95c 100644 --- a/utils/fmt/doc/para/single.py +++ b/utils/fmt/doc/para/single.py @@ -1,15 +1,8 @@ #encoding: utf-8 -from utils.fmt.base import map_batch -from utils.fmt.doc.mono.single import batch_loader, batch_padder as batch_padder_base +from utils.fmt.doc.mono.single import batch_padder as batch_padder_base +from utils.fmt.vocab.base import map_batch -def batch_mapper(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_padder(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, mlen_i, nsent in _batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): - rsi, extok_i = map_batch(i_d, vocabi) - yield rsi, mlen_i + extok_i, nsent - -def batch_padder(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): - - return batch_padder_base(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, custom_batch_mapper=batch_mapper if custom_batch_mapper is None else custom_batch_mapper) + return batch_padder_base(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, **kwargs) diff --git a/utils/fmt/dual.py b/utils/fmt/dual.py index 42dcc69..3547450 100644 --- a/utils/fmt/dual.py +++ b/utils/fmt/dual.py @@ -1,21 +1,26 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rst = [] nd = maxlen = mlen_i = mlen_t = 0 - for i_d, td in zip(list_reader(finput, keep_empty_line=True), list_reader(ftarget, keep_empty_line=True)): + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): lid = len(i_d) ltd = len(td) lgth = lid + ltd if maxlen == 0: - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): rsi.append(i_d) rst.append(td) @@ -30,22 +35,21 @@ def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): rst = [td] mlen_i = lid mlen_t = ltd - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) nd = 1 if rsi: yield rsi, rst, mlen_i, mlen_t -def batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, td, mlen_i, mlen_t in _batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, td, mlen_i, mlen_t in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) rst, extok_t = map_batch(td, vocabt) yield rsi, rst, mlen_i + extok_i, mlen_t + extok_t -def batch_padder(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, td, mlen_i, mlen_t in _batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), pad_batch(td, mlen_t) + for i_d, td, mlen_i, mlen_t in batch_mapper(finput, ftarget, vocabi, vocabt, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(td, mlen_t, pad_id=pad_id) diff --git a/utils/fmt/json.py b/utils/fmt/json.py index 95b4c9a..9135a30 100644 --- a/utils/fmt/json.py +++ b/utils/fmt/json.py @@ -1,16 +1,18 @@ #encoding: utf-8 -from json import loads, dumps +from json import dumps, loads + +from utils.fmt.base import sys_open def dumpf(obj, fname): _ = dumps(obj) - with open(fname, "wb") as f: + with sys_open(fname, "wb") as f: f.write(_.encode("utf-8")) def loadf(fname, print_func=print): - with open(fname, "rb") as f: + with sys_open(fname, "rb") as f: _ = f.read() try: return loads(_.decode("utf-8")) diff --git a/utils/fmt/lang/__init__.py b/utils/fmt/lang/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/lang/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/lang/zh/__init__.py b/utils/fmt/lang/zh/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/lang/zh/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/lang/zh/deseg.py b/utils/fmt/lang/zh/deseg.py new file mode 100644 index 0000000..f574216 --- /dev/null +++ b/utils/fmt/lang/zh/deseg.py @@ -0,0 +1,8 @@ +#encoding: utf-8 + +from re import compile + +re_space = compile(r"(? mlenu: + mlen[i] = cur_len + nd += 1 + else: + yield rs, mlen + rs = [[line] for line in lines] + mlen = lens + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) + nd = 1 + if rs: + yield rs, mlen + +def batch_mapper_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader_many, **kwargs): + + for _rs, _mlen in batch_loader(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + rs = [] + mlen = [] + for rsu, mlenu, vocab in zip(_rs, _mlen, vocablist): + _rs, extok = map_batch(rsu, vocab) + rs.append(_rs) + mlen.append(mlenu + extok) + yield rs, mlen + +def batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper_many, pad_id=pad_id, **kwargs): + + for rs, mlen in batch_mapper(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield tuple(pad_batch(rsu, mlenu, pad_id=pad_id) for rsu, mlenu in zip(rs, mlen)) + +def batch_padder(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + + if isinstance(filelist, (list, tuple,)): + if len(filelist) > 1: + return batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist[0], vocablist[0], bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) diff --git a/utils/fmt/manyvalue.py b/utils/fmt/manyvalue.py new file mode 100644 index 0000000..db551be --- /dev/null +++ b/utils/fmt/manyvalue.py @@ -0,0 +1,62 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, line_reader, list_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +file_reader = (list_reader, line_reader,) + +def batch_loader_many(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): + + _f_maxpart = float(maxpart) + rs = [[] for i in range(len(filelist))] + nd = maxlen = 0 + mlen = None + _list_reader, _line_reader = file_reader + for lines in zip(*([_list_reader(f, keep_empty_line=True) for f in filelist[:-1]] + [_line_reader(filelist[-1], keep_empty_line=True)])): + lens = [len(line) for line in lines[:-1]] + lgth = sum(lens) + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) + mlen = lens + if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): + for line, rsu in zip(lines[:-1], rs): + rsu.append(line) + rs[-1].append(float(lines[-1])) + for cur_len, (i, mlenu,) in zip(lens, enumerate(mlen)): + if cur_len > mlenu: + mlen[i] = cur_len + nd += 1 + else: + yield rs, mlen + rs = [[line] for line in lines[:-1]] + rs.append([float(lines[-1])]) + mlen = lens + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) + nd = 1 + if rs: + yield rs, mlen + +def batch_mapper_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader_many, **kwargs): + + for _rs, _mlen in batch_loader(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + rs = [] + mlen = [] + for rsu, mlenu, vocab in zip(_rs, _mlen, vocablist): + _rs, extok = map_batch(rsu, vocab) + rs.append(_rs) + mlen.append(mlenu + extok) + rs.append(_rs[-1]) + yield rs, mlen + +def batch_padder(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper_many, pad_id=pad_id, **kwargs): + + for rs, mlen in batch_mapper(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield *tuple(pad_batch(rsu, mlenu, pad_id=pad_id) for rsu, mlenu in zip(rs, mlen)), rs[-1] diff --git a/utils/fmt/mulang/eff/char/__init__.py b/utils/fmt/mulang/eff/char/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/mulang/eff/char/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/mulang/eff/char/dual.py b/utils/fmt/mulang/eff/char/dual.py new file mode 100644 index 0000000..c812a10 --- /dev/null +++ b/utils/fmt/mulang/eff/char/dual.py @@ -0,0 +1,53 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, line_reader as file_reader +from utils.fmt.mulang.eff.dual import batch_padder as batch_padder_base + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): + + _f_maxpart = float(maxpart) + rsi = [] + rst = [] + rstask = None + nd = maxlen = mlen_i = mlen_t = 0 + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): + _ind = i_d.find(" ") + lid = len(i_d) - _ind - 1 + ltd = len(td) + lgth = lid + ltd + _task = i_d[:_ind] + # uncomment the following 2 lines to filter out empty data (e.g. in OPUS-100). + if (lid <= 0) or (ltd <= 0): + continue + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + rstask = _task + if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and nd < _bsize)): + rsi.append(list(i_d[_ind + 1:])) + rst.append(list(td)) + if lid > mlen_i: + mlen_i = lid + if ltd > mlen_t: + mlen_t = ltd + nd += 1 + else: + yield rsi, rst, rstask, mlen_i, mlen_t + rsi = [list(i_d[_ind + 1:])] + rstask = _task + rst = [list(td)] + mlen_i = lid + mlen_t = ltd + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rst, rstask, mlen_i, mlen_t + +def batch_padder(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, **kwargs): + + return batch_padder_base(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, **kwargs) diff --git a/utils/fmt/mulang/eff/char/single.py b/utils/fmt/mulang/eff/char/single.py new file mode 100644 index 0000000..d2606c1 --- /dev/null +++ b/utils/fmt/mulang/eff/char/single.py @@ -0,0 +1,46 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, line_reader as file_reader +from utils.fmt.mulang.eff.single import batch_padder as batch_padder_base + +def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): + + _f_maxpart = float(maxpart) + rsi = [] + rstask = None + nd = maxlen = minlen = mlen_i = 0 + for i_d in file_reader(finput, keep_empty_line=True): + _ind = i_d.find(" ") + lgth = len(i_d) - _ind - 1 + _task = i_d[:_ind] + #if lgth <= 0: + #continue + if maxlen == 0: + _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + rstask = _task + if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and lgth >= minlen and nd < _bsize)): + rsi.append(list(i_d[_ind + 1:])) + if lgth > mlen_i: + mlen_i = lgth + nd += 1 + else: + yield rsi, rstask, mlen_i + rsi = [list(i_d[_ind + 1:])] + rstask = _task + mlen_i = lgth + _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) + maxlen = lgth + _maxpad + minlen = lgth - _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rstask, mlen_i + +def batch_padder(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, **kwargs): + + return batch_padder_base(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, **kwargs) diff --git a/utils/fmt/mulang/eff/dual.py b/utils/fmt/mulang/eff/dual.py index 1d2a761..e678f56 100644 --- a/utils/fmt/mulang/eff/dual.py +++ b/utils/fmt/mulang/eff/dual.py @@ -1,26 +1,31 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rst = [] rstask = None nd = maxlen = mlen_i = mlen_t = 0 - for i_d, td in zip(list_reader(finput, keep_empty_line=True), list_reader(ftarget, keep_empty_line=True)): + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): lid = len(i_d) - 1 ltd = len(td) lgth = lid + ltd _task = i_d[0] # uncomment the following 2 lines to filter out empty data (e.g. in OPUS-100). - #if (lid <= 0) or (ltd <= 0): - #continue + if (lid <= 0) or (ltd <= 0): + continue if maxlen == 0: - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) rstask = _task if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and nd < _bsize)): rsi.append(i_d[1:]) @@ -37,22 +42,21 @@ def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): rst = [td] mlen_i = lid mlen_t = ltd - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) nd = 1 if rsi: yield rsi, rst, rstask, mlen_i, mlen_t -def batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, td, taskd, mlen_i, mlen_t in _batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, td, taskd, mlen_i, mlen_t in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) rst, extok_t = map_batch(td, vocabt) yield rsi, rst, vocabtask[taskd], mlen_i + extok_i, mlen_t + extok_t -def batch_padder(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, td, taskd, mlen_i, mlen_t in _batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), pad_batch(td, mlen_t), taskd + for i_d, td, taskd, mlen_i, mlen_t in batch_mapper(finput, ftarget, vocabi, vocabt, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(td, mlen_t, pad_id=pad_id), taskd diff --git a/utils/fmt/mulang/eff/many.py b/utils/fmt/mulang/eff/many.py new file mode 100644 index 0000000..f48613d --- /dev/null +++ b/utils/fmt/mulang/eff/many.py @@ -0,0 +1,78 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.mulang.eff.single import batch_padder as batch_padder_single +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader_many(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): + + _f_maxpart = float(maxpart) + rs = [[] for i in range(len(filelist))] + nd = maxlen = 0 + mlen = rstask = None + for lines in zip(*[file_reader(f, keep_empty_line=True) for f in filelist]): + lens = [len(line) for line in lines] + lens[0] -= 1 + lgth = sum(lens) + src_line = lines[0] + _task = src_line[0] + # uncomment the following 2 lines to filter out empty data (e.g. in OPUS-100). + if any(_len <= 0 for _len in lens): + continue + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) + mlen = lens + rstask = _task + if (rstask == _task) and ((nd < minbsize) or (lgth <= maxlen and nd < _bsize)): + rs[0].append(src_line[1:]) + for line, rsu in zip(lines[1:], rs[1:]): + rsu.append(line) + for cur_len, (i, mlenu,) in zip(lens, enumerate(mlen)): + if cur_len > mlenu: + mlen[i] = cur_len + nd += 1 + else: + yield rs, rstask, mlen + rs = [[src_line[1:]]] + rs.extend([[line] for line in lines[1:]]) + mlen = lens + rstask = _task + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(lgth + _maxpad * len(lens), maxtoken, bsize) + nd = 1 + if rs: + yield rs, rstask, mlen + +def batch_mapper_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader_many, **kwargs): + + vocabtask = vocablist[-1] + for _rs, taskd, _mlen in batch_loader(filelist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + rs = [] + mlen = [] + for rsu, mlenu, vocab in zip(_rs, _mlen, vocablist): + _rs, extok = map_batch(rsu, vocab) + rs.append(_rs) + mlen.append(mlenu + extok) + yield rs, vocabtask[taskd], mlen + +def batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper_many, pad_id=pad_id, **kwargs): + + for rs, taskd, mlen in batch_mapper(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield *tuple(pad_batch(rsu, mlenu, pad_id=pad_id) for rsu, mlenu in zip(rs, mlen)), taskd + +def batch_padder(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + + if isinstance(filelist, (list, tuple,)): + if len(filelist) > 1: + return batch_padder_many(filelist, vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist[0], *vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) + else: + return batch_padder_single(filelist, *vocablist, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs) diff --git a/utils/fmt/mulang/eff/single.py b/utils/fmt/mulang/eff/single.py index 24d8cb8..ed30f73 100644 --- a/utils/fmt/mulang/eff/single.py +++ b/utils/fmt/mulang/eff/single.py @@ -1,15 +1,19 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rstask = None nd = maxlen = minlen = mlen_i = 0 - for i_d in list_reader(finput, keep_empty_line=True): + for i_d in file_reader(finput, keep_empty_line=True): lgth = len(i_d) - 1 _task = i_d[0] #if lgth <= 0: @@ -38,15 +42,13 @@ def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): if rsi: yield rsi, rstask, mlen_i -def batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, taskd, mlen_i in _batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, taskd, mlen_i in batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) yield rsi, vocabtask[taskd], mlen_i + extok_i -def batch_padder(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, taskd, mlen_i in _batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), taskd + for i_d, taskd, mlen_i in batch_mapper(finput, vocabi, vocabtask, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), taskd diff --git a/utils/fmt/plm/bart/__init__.py b/utils/fmt/plm/bart/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/plm/bart/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/plm/bart/dual.py b/utils/fmt/plm/bart/dual.py new file mode 100644 index 0000000..75d61ec --- /dev/null +++ b/utils/fmt/plm/bart/dual.py @@ -0,0 +1,50 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, iter_to_int, list_reader as file_reader +from utils.fmt.plm.dual import batch_padder as batch_padder_base + +from cnfg.vocab.plm.roberta import eos_id, pad_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, tgt_start_id=eos_id, **kwargs): + + _f_maxpart = float(maxpart) + rsi = [] + rst = [] + nd = maxlen = mlen_i = mlen_t = 0 + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): + i_d, td = list(iter_to_int(i_d)), list(iter_to_int(td)) + if tgt_start_id is not None: + td[0] = tgt_start_id + lid = len(i_d) + ltd = len(td) + lgth = lid + ltd + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): + rsi.append(i_d) + rst.append(td) + if lid > mlen_i: + mlen_i = lid + if ltd > mlen_t: + mlen_t = ltd + nd += 1 + else: + yield rsi, rst, mlen_i, mlen_t + rsi = [i_d] + rst = [td] + mlen_i = lid + mlen_t = ltd + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rst, mlen_i, mlen_t + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, pad_id=pad_id, **kwargs): + + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/bert/base.py b/utils/fmt/plm/bert/base.py index 92d3f4e..2d5b039 100644 --- a/utils/fmt/plm/bert/base.py +++ b/utils/fmt/plm/bert/base.py @@ -2,9 +2,10 @@ import sys -from utils.fmt.base import line_reader, reverse_dict +from utils.fmt.base import line_reader, sys_open +from utils.fmt.vocab.base import reverse_dict -from cnfg.vocab.plm.bert import * +from cnfg.vocab.plm.bert import eos_id, mask_id, pad_id, sos_id, unk_id, vocab_size def ldvocab(vfile, *args, **kwargs): @@ -37,7 +38,7 @@ def save_vocab(vcb_dict, fname): freqs.sort() _ = "\n".join([r_vocab[_key] for _key in freqs]) - with sys.stdout.buffer if fname == "-" else open(fname, "wb") as f: + with sys_open(fname, "wb") as f: f.write(_.encode("utf-8")) f.write("\n".encode("utf-8")) diff --git a/utils/fmt/plm/bert/dual.py b/utils/fmt/plm/bert/dual.py index e9b93d8..93a2ba8 100644 --- a/utils/fmt/plm/bert/dual.py +++ b/utils/fmt/plm/bert/dual.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.bert import pad_id from utils.fmt.plm.dual import batch_padder as batch_padder_base -def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.bert import pad_id + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/bert/single.py b/utils/fmt/plm/bert/single.py index 75d1114..61cd7ea 100644 --- a/utils/fmt/plm/bert/single.py +++ b/utils/fmt/plm/bert/single.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.bert import pad_id from utils.fmt.plm.single import batch_padder as batch_padder_base -def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.bert import pad_id + +def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/dual.py b/utils/fmt/plm/dual.py index 30b3549..d233769 100644 --- a/utils/fmt/plm/dual.py +++ b/utils/fmt/plm/dual.py @@ -1,22 +1,26 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, pad_batch, toint, pad_id from math import ceil -def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, iter_to_int, list_reader as file_reader, pad_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, iter_to_int=iter_to_int, **kwargs): _f_maxpart = float(maxpart) rsi = [] rst = [] nd = maxlen = mlen_i = mlen_t = 0 - for i_d, td in zip(list_reader(finput, keep_empty_line=True), list_reader(ftarget, keep_empty_line=True)): - i_d, td = toint(i_d), toint(td) + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): + i_d, td = list(iter_to_int(i_d)), list(iter_to_int(td)) lid = len(i_d) ltd = len(td) lgth = lid + ltd if maxlen == 0: - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): rsi.append(i_d) rst.append(td) @@ -31,14 +35,14 @@ def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): rst = [td] mlen_i = lid mlen_t = ltd - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) nd = 1 if rsi: yield rsi, rst, mlen_i, mlen_t -def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_loader=batch_loader, pad_id=pad_id, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, td, mlen_i, mlen_t in _batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, td, mlen_i, mlen_t in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(td, mlen_t, pad_id=pad_id) diff --git a/utils/fmt/plm/dual_reg.py b/utils/fmt/plm/dual_reg.py new file mode 100644 index 0000000..e5f23ec --- /dev/null +++ b/utils/fmt/plm/dual_reg.py @@ -0,0 +1,46 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, iter_to_int, line_reader, list_reader, pad_batch + +from cnfg.vocab.base import pad_id + +file_reader = (list_reader, line_reader,) + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, iter_to_int=iter_to_int, **kwargs): + + _f_maxpart = float(maxpart) + rsi = [] + rst = [] + nd = maxlen = mlen_i = 0 + _list_reader, _line_reader = file_reader + for i_d, td in zip(_list_reader(finput, keep_empty_line=True), _line_reader(ftarget, keep_empty_line=True)): + i_d, td = list(iter_to_int(i_d)), float(td) + lgth = len(i_d) + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): + rsi.append(i_d) + rst.append(td) + if lgth > mlen_i: + mlen_i = lgth + nd += 1 + else: + yield rsi, rst, mlen_i + rsi = [i_d] + rst = [td] + mlen_i = lgth + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rst, mlen_i + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_loader=batch_loader, pad_id=pad_id, **kwargs): + + for i_d, td, mlen_i in batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), td diff --git a/utils/fmt/plm/mbart/__init__.py b/utils/fmt/plm/mbart/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/plm/mbart/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/plm/mbart/dual.py b/utils/fmt/plm/mbart/dual.py new file mode 100644 index 0000000..985ea51 --- /dev/null +++ b/utils/fmt/plm/mbart/dual.py @@ -0,0 +1,53 @@ +#encoding: utf-8 + +from math import ceil + +from utils.fmt.base import get_bsize, iter_to_int, list_reader as file_reader +from utils.fmt.plm.dual import batch_padder as batch_padder_base + +from cnfg.vocab.plm.mbart import add_sos_id, pad_id, shift_target_lang_id + +def batch_loader(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, shift_target_lang_id=shift_target_lang_id, add_sos_id=add_sos_id, **kwargs): + + _f_maxpart = float(maxpart) + rsi = [] + rst = [] + nd = maxlen = mlen_i = mlen_t = 0 + _add_sos_id = isinstance(add_sos_id, int) + for i_d, td in zip(file_reader(finput, keep_empty_line=True), file_reader(ftarget, keep_empty_line=True)): + i_d, td = list(iter_to_int(i_d)), list(iter_to_int(td)) + if _add_sos_id: + td.insert(0, add_sos_id) + if shift_target_lang_id: + td.insert(0, td.pop(-1)) + lid = len(i_d) + ltd = len(td) + lgth = lid + ltd + if maxlen == 0: + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): + rsi.append(i_d) + rst.append(td) + if lid > mlen_i: + mlen_i = lid + if ltd > mlen_t: + mlen_t = ltd + nd += 1 + else: + yield rsi, rst, mlen_i, mlen_t + rsi = [i_d] + rst = [td] + mlen_i = lid + mlen_t = ltd + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) + nd = 1 + if rsi: + yield rsi, rst, mlen_i, mlen_t + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, pad_id=pad_id, **kwargs): + + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, batch_loader=batch_loader, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/roberta/base.py b/utils/fmt/plm/roberta/base.py index d95e6ac..1ba745d 100644 --- a/utils/fmt/plm/roberta/base.py +++ b/utils/fmt/plm/roberta/base.py @@ -1,9 +1,9 @@ #encoding: utf-8 -from utils.fmt.base import reverse_dict -from utils.fmt.json import loadf, dumpf +from utils.fmt.json import dumpf, loadf +from utils.fmt.vocab.base import reverse_dict -from cnfg.vocab.plm.roberta import * +from cnfg.vocab.plm.roberta import eos_id, mask_id, pad_id, sos_id, unk_id, vocab_size def ldvocab(vfile, *args, **kwargs): diff --git a/utils/fmt/plm/roberta/dual.py b/utils/fmt/plm/roberta/dual.py index 7afd8ab..9a0f734 100644 --- a/utils/fmt/plm/roberta/dual.py +++ b/utils/fmt/plm/roberta/dual.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.roberta import pad_id from utils.fmt.plm.dual import batch_padder as batch_padder_base -def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.roberta import pad_id + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/roberta/dual_reg.py b/utils/fmt/plm/roberta/dual_reg.py new file mode 100644 index 0000000..82f1cef --- /dev/null +++ b/utils/fmt/plm/roberta/dual_reg.py @@ -0,0 +1,9 @@ +#encoding: utf-8 + +from utils.fmt.plm.dual_reg import batch_padder as batch_padder_base + +from cnfg.vocab.plm.roberta import pad_id + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): + + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/roberta/single.py b/utils/fmt/plm/roberta/single.py index c40d3f9..e1a6541 100644 --- a/utils/fmt/plm/roberta/single.py +++ b/utils/fmt/plm/roberta/single.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.roberta import pad_id from utils.fmt.plm.single import batch_padder as batch_padder_base -def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.roberta import pad_id + +def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/single.py b/utils/fmt/plm/single.py index 329cd77..5f60929 100644 --- a/utils/fmt/plm/single.py +++ b/utils/fmt/plm/single.py @@ -1,15 +1,18 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, pad_batch, toint, pad_id from math import ceil -def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, iter_to_int, list_reader as file_reader, pad_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, iter_to_int=iter_to_int, **kwargs): _f_maxpart = float(maxpart) rsi = [] nd = maxlen = minlen = mlen_i = 0 - for i_d in list_reader(finput, keep_empty_line=True): - i_d = toint(i_d) + for i_d in file_reader(finput, keep_empty_line=True): + i_d = list(iter_to_int(i_d)) lgth = len(i_d) if maxlen == 0: _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) @@ -33,8 +36,7 @@ def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): if rsi: yield rsi, mlen_i -def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_loader=batch_loader, pad_id=pad_id, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, mlen_i in _batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, mlen_i in batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): yield pad_batch(i_d, mlen_i, pad_id=pad_id) diff --git a/utils/fmt/plm/t5/dual.py b/utils/fmt/plm/t5/dual.py index b97a8dc..b8a3e2c 100644 --- a/utils/fmt/plm/t5/dual.py +++ b/utils/fmt/plm/t5/dual.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.t5 import pad_id from utils.fmt.plm.dual import batch_padder as batch_padder_base -def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.t5 import pad_id + +def batch_padder(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/t5/single.py b/utils/fmt/plm/t5/single.py index 04357c2..c1105e6 100644 --- a/utils/fmt/plm/t5/single.py +++ b/utils/fmt/plm/t5/single.py @@ -1,8 +1,9 @@ #encoding: utf-8 -from cnfg.vocab.plm.t5 import pad_id from utils.fmt.plm.single import batch_padder as batch_padder_base -def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, pad_id=pad_id, **kwargs): +from cnfg.vocab.plm.t5 import pad_id + +def batch_padder(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs): - return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader, pad_id=pad_id, **kwargs) + return batch_padder_base(finput, bsize, maxpad, maxpart, maxtoken, minbsize, pad_id=pad_id, **kwargs) diff --git a/utils/fmt/plm/token.py b/utils/fmt/plm/token.py index 97e1eaf..ac476e4 100644 --- a/utils/fmt/plm/token.py +++ b/utils/fmt/plm/token.py @@ -2,50 +2,42 @@ import sys -from utils.fmt.base import tostr +from utils.fmt.base import iter_to_int, iter_to_str, loop_file_so, sys_open -tokenize_line = lambda lin, processor: " ".join(processor.convert_ids_to_tokens(processor(lin, return_token_type_ids=True, return_attention_mask=False, return_offsets_mapping=True).input_ids)) -map_line = lambda lin, processor: " ".join(tostr(processor(*lin.split("\t"), return_token_type_ids=True, return_attention_mask=False, return_offsets_mapping=True).input_ids)) -detokenize_line = lambda lin, processor: processor(lin, skip_special_tokens=False, clean_up_tokenization_spaces=False) +tokenize_line = lambda lin, processor: " ".join(processor.convert_ids_to_tokens(processor(lin, return_token_type_ids=False, return_attention_mask=False, return_offsets_mapping=False).input_ids)) +map_line = lambda lin, processor: " ".join(iter_to_str(processor(*lin.split("\t"), return_token_type_ids=False, return_attention_mask=False, return_offsets_mapping=False).input_ids)) +detokenize_line = lambda lin, processor: processor(list(iter_to_int(lin.split())), skip_special_tokens=False, clean_up_tokenization_spaces=False) def map_line_with_token_type(lin, processor): - _ = processor(*tmp.decode("utf-8").split("\t"), return_token_type_ids=True, return_attention_mask=False, return_offsets_mapping=True) + _ = processor(*lin.split("\t"), return_token_type_ids=True, return_attention_mask=False, return_offsets_mapping=False) - return " ".join(tostr(_.input_ids)), " ".join(tostr(_.token_type_ids)) + return " ".join(iter_to_str(_.input_ids)), " ".join(iter_to_str(_.token_type_ids)) -def loop_file_so(fsrc, vcb, frs, process_func=None, processor=None): +def tokenize_file(fsrc, frs, processor=None, process_func=tokenize_line): - ens = "\n".encode("utf-8") - with sys.stdin.buffer if fsrc == "-" else open(fsrc, "rb") as frd, sys.stdout.buffer if frs == "-" else open(frs, "wb") as fwrt: - for line in frd: - tmp = line.strip() - if tmp: - fwrt.write(process_func(tmp.decode("utf-8"), processor).encode("utf-8")) - fwrt.write(ens) - -def tokenize_file(fsrc, vcb, frs, Tokenizer=None): - - return loop_file_so(fsrc, vcb, frs, process_func=tokenize_line, processor=Tokenizer(tokenizer_file=vcb)) + return loop_file_so(fsrc, frs, process_func=process_func, processor=processor) -def map_file(fsrc, vcb, frs, Tokenizer=None): +def map_file(fsrc, frs, processor=None, process_func=map_line): - return loop_file_so(fsrc, vcb, frs, process_func=map_line, processor=Tokenizer(tokenizer_file=vcb)) + return loop_file_so(fsrc, frs, process_func=process_func, processor=processor) -def map_file_with_token_type(fsrc, vcb, frsi, frst, Tokenizer=None): +def map_file_with_token_type(fsrc, frsi, frst, processor=None, process_func=map_line_with_token_type): - tokenizer = Tokenizer(tokenizer_file=vcb) ens = "\n".encode("utf-8") - with sys.stdin.buffer if fsrc == "-" else open(fsrc, "rb") as frd, sys.stdout.buffer if frsi == "-" else open(frsi, "wb") as fwrti, sys.stdout.buffer if frst == "-" else open(frst, "wb") as fwrtt: + with sys_open(fsrc, "rb") as frd, sys_open(frsi, "wb") as fwrti, sys_open(frst, "wb") as fwrtt: for line in frd: tmp = line.strip() if tmp: - _input_ids, _token_type_ids = map_line_with_token_type(tmp.decode("utf-8"), tokenizer) + _input_ids, _token_type_ids = process_func(tmp.decode("utf-8"), processor) fwrti.write(_input_ids.encode("utf-8")) + fwrti.write(ens) fwrtt.write(_token_type_ids.encode("utf-8")) - fwrti.write(ens) - fwrtt.write(ens) + fwrtt.write(ens) + else: + fwrti.write(ens) + fwrtt.write(ens) -def map_back_file(fsrc, vcb, frs, Tokenizer=None): +def map_back_file(fsrc, frs, processor=None, process_func=detokenize_line): - return loop_file_so(fsrc, vcb, frs, process_func=detokenize_line, processor=Tokenizer(tokenizer_file=vcb).decode) + return loop_file_so(fsrc, frs, process_func=process_func, processor=processor) diff --git a/utils/fmt/raw/__init__.py b/utils/fmt/raw/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/raw/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/raw/cachepath.py b/utils/fmt/raw/cachepath.py new file mode 100644 index 0000000..fc8f515 --- /dev/null +++ b/utils/fmt/raw/cachepath.py @@ -0,0 +1,28 @@ +#encoding: utf-8 + +from uuid import uuid4 as uuid_func + +from utils.base import mkdir + +cache_file_prefix = "train" + +def get_cache_path(*fnames): + + _cache_path = None + for _t in fnames: + _ = _t.rfind("/") + 1 + if _ > 0: + _cache_path = _t[:_] + break + _uuid = uuid_func().hex + if _cache_path is None: + _cache_path = "cache/floader/%s/" % _uuid + else: + _cache_path = "%sfloader/%s/" % (_cache_path, _uuid,) + mkdir(_cache_path) + + return _cache_path + +def get_cache_fname(fpath, i=0, fprefix=cache_file_prefix): + + return "%s%s.%d.h5" % (fpath, fprefix, i,) diff --git a/utils/fmt/raw/reader/__init__.py b/utils/fmt/raw/reader/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/raw/reader/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/raw/reader/sort/__init__.py b/utils/fmt/raw/reader/sort/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/raw/reader/sort/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/raw/reader/sort/many.py b/utils/fmt/raw/reader/sort/many.py new file mode 100644 index 0000000..78622f0 --- /dev/null +++ b/utils/fmt/raw/reader/sort/many.py @@ -0,0 +1,40 @@ +#encoding: utf-8 + +from random import shuffle + +from utils.fmt.base import dict_insert_set, iter_dict_sort, read_lines +from utils.fmt.parser import parse_none + +def sort_list_reader(x, *args, clear_input=True, **kwargs): + + _d = {} + for mi in x: + lens = [len(_) for _ in mi] + lgth = sum(lens) + _d = dict_insert_set(_d, mi, lgth, *reversed(lens[1:])) + if clear_input and hasattr(x, "clear"): + x.clear() + for tmp in iter_dict_sort(_d, free=True): + _v = list(tmp) + shuffle(_v) + yield from _v + +class sort_lines_reader: + + def __init__(self, line_read=None): + + self.line_read = line_read + + def __call__(self, x, *args, line_read=None, **kwargs): + + _line_read = parse_none(line_read, self.line_read) + _data_iter = x if _line_read is None else read_lines(x, _line_read) + _d = {} + for mi in _data_iter: + lens = [len(_) for _ in mi] + lgth = sum(lens) + _d = dict_insert_set(_d, mi, lgth, *reversed(lens[1:])) + for tmp in iter_dict_sort(_d, free=True): + _v = list(tmp) + shuffle(_v) + yield from _v diff --git a/utils/fmt/raw/reader/sort/single.py b/utils/fmt/raw/reader/sort/single.py new file mode 100644 index 0000000..8d55979 --- /dev/null +++ b/utils/fmt/raw/reader/sort/single.py @@ -0,0 +1,46 @@ +#encoding: utf-8 + +from random import shuffle + +from utils.fmt.base import read_lines +from utils.fmt.parser import parse_none + +def sort_list_reader(x, *args, clear_input=True, **kwargs): + + _d = {} + for _ in x: + _k = len(_) + if _k in _d: + if _ not in _d[_k]: + _d[_k].add(_) + else: + _d[_k] = set([_]) + if clear_input and hasattr(x, "clear"): + x.clear() + for _k in sorted(_d.keys()): + _v = list(_d.pop(_k)) + shuffle(_v) + yield from _v + +class sort_lines_reader: + + def __init__(self, line_read=None): + + self.line_read = line_read + + def __call__(self, x, *args, line_read=None, **kwargs): + + _line_read = parse_none(line_read, self.line_read) + _data_iter = x if _line_read is None else read_lines(x, _line_read) + _d = {} + for _ in _data_iter: + _k = len(_) + if _k in _d: + if _ not in _d[_k]: + _d[_k].add(_) + else: + _d[_k] = set([_]) + for _k in sorted(_d.keys()): + _v = list(_d.pop(_k)) + shuffle(_v) + yield from _v diff --git a/utils/fmt/raw/reader/sort/tag.py b/utils/fmt/raw/reader/sort/tag.py new file mode 100644 index 0000000..1c33889 --- /dev/null +++ b/utils/fmt/raw/reader/sort/tag.py @@ -0,0 +1,46 @@ +#encoding: utf-8 + +from random import shuffle + +from utils.fmt.base import read_lines +from utils.fmt.parser import parse_none + +def sort_list_reader(x, *args, clear_input=True, **kwargs): + + _d = {} + for _ in x: + _k = len(_[0]) + if _k in _d: + if _ not in _d[_k]: + _d[_k].add(_) + else: + _d[_k] = set([_]) + if clear_input and hasattr(x, "clear"): + x.clear() + for _k in sorted(_d.keys()): + _v = list(_d.pop(_k)) + shuffle(_v) + yield from _v + +class sort_lines_reader: + + def __init__(self, line_read=None): + + self.line_read = line_read + + def __call__(self, x, *args, line_read=None, **kwargs): + + _line_read = parse_none(line_read, self.line_read) + _data_iter = x if _line_read is None else read_lines(x, _line_read) + _d = {} + for _ in _data_iter: + _k = len(_[0]) + if _k in _d: + if _ not in _d[_k]: + _d[_k].add(_) + else: + _d[_k] = set([_]) + for _k in sorted(_d.keys()): + _v = list(_d.pop(_k)) + shuffle(_v) + yield from _v diff --git a/utils/fmt/single.py b/utils/fmt/single.py index 157afc7..8dd5128 100644 --- a/utils/fmt/single.py +++ b/utils/fmt/single.py @@ -1,14 +1,18 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, list_reader as file_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] nd = maxlen = minlen = mlen_i = 0 - for i_d in list_reader(finput, keep_empty_line=True): + for i_d in file_reader(finput, keep_empty_line=True): lgth = len(i_d) if maxlen == 0: _maxpad = max(1, min(maxpad, ceil(lgth / _f_maxpart)) // 2) @@ -32,15 +36,13 @@ def batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): if rsi: yield rsi, mlen_i -def batch_mapper(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, mlen_i in _batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, mlen_i in batch_loader(finput, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) yield rsi, mlen_i + extok_i -def batch_padder(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, mlen_i in _batch_mapper(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i) + for i_d, mlen_i in batch_mapper(finput, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id) diff --git a/utils/fmt/triple.py b/utils/fmt/triple.py index 89f4559..0506b01 100644 --- a/utils/fmt/triple.py +++ b/utils/fmt/triple.py @@ -1,22 +1,30 @@ #encoding: utf-8 -from utils.fmt.base import list_reader, line_reader, get_bsize, map_batch, pad_batch from math import ceil -def batch_loader(finput, fref, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): +from utils.fmt.base import get_bsize, line_reader, list_reader, pad_batch +from utils.fmt.vocab.base import map_batch + +from cnfg.vocab.base import pad_id + +file_reader = (list_reader, line_reader,) + +def batch_loader(finput, fref, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, get_bsize=get_bsize, file_reader=file_reader, **kwargs): _f_maxpart = float(maxpart) rsi = [] rsr = [] rst = [] nd = maxlen = mlen_i = mlen_r = 0 - for i_d, rd, td in zip(list_reader(finput, keep_empty_line=True), list_reader(fref, keep_empty_line=True), line_reader(ftarget, keep_empty_line=True)): + _list_reader, _line_reader = file_reader + for i_d, rd, td in zip(_list_reader(finput, keep_empty_line=True), _list_reader(fref, keep_empty_line=True), _line_reader(ftarget, keep_empty_line=True)): lid = len(i_d) lrd = len(rd) lgth = lid + lrd if maxlen == 0: - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) if (nd < minbsize) or (lgth <= maxlen and nd < _bsize): rsi.append(i_d) rsr.append(rd) @@ -33,22 +41,21 @@ def batch_loader(finput, fref, ftarget, bsize, maxpad, maxpart, maxtoken, minbsi rst = [float(td)] mlen_i = lid mlen_r = lrd - maxlen = lgth + min(maxpad, ceil(lgth / _f_maxpart)) - _bsize = get_bsize(maxlen, maxtoken, bsize) + _maxpad = min(maxpad, ceil(lgth / _f_maxpart)) + maxlen = lgth + _maxpad + _bsize = get_bsize(maxlen + _maxpad, maxtoken, bsize) nd = 1 if rsi: yield rsi, rsr, rst, mlen_i, mlen_r -def batch_mapper(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None): +def batch_mapper(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, map_batch=map_batch, batch_loader=batch_loader, **kwargs): - _batch_loader = batch_loader if custom_batch_loader is None else custom_batch_loader - for i_d, rd, td, mlen_i, mlen_t in _batch_loader(finput, fref, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize): + for i_d, rd, td, mlen_i, mlen_t in batch_loader(finput, fref, ftarget, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): rsi, extok_i = map_batch(i_d, vocabi) rsr, extok_r = map_batch(rd, vocabi) yield rsi, rsr, td, mlen_i + extok_i, mlen_t + extok_r -def batch_padder(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=None, custom_batch_mapper=None): +def batch_padder(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, pad_batch=pad_batch, batch_mapper=batch_mapper, pad_id=pad_id, **kwargs): - _batch_mapper = batch_mapper if custom_batch_mapper is None else custom_batch_mapper - for i_d, rd, td, mlen_i, mlen_t in _batch_mapper(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, custom_batch_loader=custom_batch_loader): - yield pad_batch(i_d, mlen_i), pad_batch(rd, mlen_t), td + for i_d, rd, td, mlen_i, mlen_t in batch_mapper(finput, fref, ftarget, vocabi, bsize, maxpad, maxpart, maxtoken, minbsize, **kwargs): + yield pad_batch(i_d, mlen_i, pad_id=pad_id), pad_batch(rd, mlen_t, pad_id=pad_id), td diff --git a/utils/fmt/u8.py b/utils/fmt/u8.py new file mode 100644 index 0000000..a4f9c73 --- /dev/null +++ b/utils/fmt/u8.py @@ -0,0 +1,33 @@ +#encoding: utf-8 + +from html import unescape +from unicodedata import normalize as uni_norm_func + +#"NFC", "NFD", "NFKD" +uni_normer = "NFKC" + +def clean_sp_char(istr): + + rs = [] + for c in istr: + num = ord(c) + if num == 12288: + rs.append(" ") + elif (num > 65280) and (num < 65375): + rs.append(chr(num - 65248)) + elif not ((num < 32 and num != 9) or (num > 126 and num < 161) or (num > 8202 and num < 8206) or (num > 57343 and num < 63744) or (num > 64975 and num < 65008) or (num > 65519)): + rs.append(c) + + return "".join(rs) + +def norm_u8_str(x, uni_normer=uni_normer): + + return unescape(clean_sp_char(uni_norm_func(uni_normer, x))) + +def norm_u8_byte(x, uni_normer=uni_normer): + + return unescape(clean_sp_char(uni_norm_func(uni_normer, x.decode("utf-8")))).encode("utf-8") + +def norm_u8(x, uni_normer=uni_normer): + + return norm_u8_byte(x, uni_normer=uni_normer) if isinstance(x, bytes) else norm_u8_str(x, uni_normer=uni_normer) diff --git a/utils/fmt/vocab/__init__.py b/utils/fmt/vocab/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/fmt/vocab/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/fmt/vocab/base.py b/utils/fmt/vocab/base.py new file mode 100644 index 0000000..cdea678 --- /dev/null +++ b/utils/fmt/vocab/base.py @@ -0,0 +1,60 @@ +#encoding: utf-8 + +from cnfg.vocab.base import eos_id, sos_id, unk_id, use_unk + +def reverse_dict(din): + + return {v:k for k, v in din.items()} + +def merge_vocab(*vcbin): + + rs = {} + for _ in vcbin: + for k, v in _.items(): + rs[k] = rs.get(k, 0) + v + + return rs + +def legal_vocab(sent, ilgset, ratio): + + total = ilg = 0 + for tmpu in sent.split(): + if tmpu: + if tmpu in ilgset: + ilg += 1 + total += 1 + rt = float(ilg) / float(total) + + return rt < ratio + +def no_unk_mapper(vcb, ltm, print_func=None): + + if print_func is None: + return [vcb[wd] for wd in ltm if wd in vcb] + else: + rs = [] + for wd in ltm: + if wd in vcb: + rs.append(vcb[wd]) + else: + print_func("Error mapping: "+ wd) + return rs + +def map_instance(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs): + + rsi = [sos_id] + rsi.extend([vocabi.get(wd, unk_id) for wd in i_d] if use_unk else no_unk_mapper(vocabi, i_d))#[vocabi[wd] for wd in i_d if wd in vocabi] + rsi.append(eos_id) + + return rsi + +def map_batch_core(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs): + + if isinstance(i_d[0], (tuple, list,)): + return [map_batch_core(idu, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs) for idu in i_d] + else: + return map_instance(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs) + +def map_batch(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs): + + return map_batch_core(i_d, vocabi, use_unk=use_unk, sos_id=sos_id, eos_id=eos_id, unk_id=unk_id, **kwargs), 2 diff --git a/utils/fmt/vocab/char.py b/utils/fmt/vocab/char.py new file mode 100644 index 0000000..f1f5cf3 --- /dev/null +++ b/utils/fmt/vocab/char.py @@ -0,0 +1,23 @@ +#encoding: utf-8 + +from utils.fmt.base import list_reader_wst as file_reader +from utils.fmt.vocab.token import ldvocab as ldvocab_base, ldvocab_freq as ldvocab_freq_base, ldvocab_list as ldvocab_list_base, save_vocab as save_vocab_base + +# use tab as seperator to keep the space token in the vocab +sep_load = sep_save = "\t" + +def ldvocab(*args, sep=sep_load, file_reader=file_reader, **kwargs): + + return ldvocab_base(*args, sep=sep, file_reader=file_reader, **kwargs) + +def save_vocab(*args, sep=sep_save, **kwargs): + + return save_vocab_base(*args, sep=sep, **kwargs) + +def ldvocab_list(*args, sep=sep_load, file_reader=file_reader, **kwargs): + + return ldvocab_list_base(*args, sep=sep, file_reader=file_reader, **kwargs) + +def ldvocab_freq(*args, sep=sep_load, file_reader=file_reader, **kwargs): + + return ldvocab_freq_base(*args, sep=sep, file_reader=file_reader, **kwargs) diff --git a/utils/fmt/vocab/token.py b/utils/fmt/vocab/token.py new file mode 100644 index 0000000..722c9f3 --- /dev/null +++ b/utils/fmt/vocab/token.py @@ -0,0 +1,203 @@ +#encoding: utf-8 + +import sys + +from utils.fmt.base import list_reader as file_reader, sys_open + +from cnfg.vocab.base import init_normal_token_id, init_vocab + +sep_load, sep_save = None, " " + +def ldvocab(vfile, minf=False, omit_vsize=False, vanilla=False, init_vocab=init_vocab, init_normal_token_id=init_normal_token_id, sep=sep_load, file_reader=file_reader, print_func=print): + + if vanilla: + rs, cwd = {}, 0 + else: + rs, cwd = init_vocab.copy(), len(init_vocab) if init_normal_token_id is None else init_normal_token_id + if omit_vsize: + vsize = max(0, omit_vsize - cwd) + else: + vsize = False + dkeys = set() + for data in file_reader(vfile, keep_empty_line=False, sep=sep): + freq = int(data[0]) + if (not minf) or (freq > minf): + if vsize: + ndata = len(data) - 1 + if vsize > ndata: + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + ndata -= 1 + else: + rs[wd] = cwd + cwd += 1 + vsize -= ndata + else: + _break = False + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + else: + rs[wd] = cwd + cwd += 1 + vsize -= 1 + _break = (vsize <= 0) + if _break: + break + if _break: + break + else: + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + else: + rs[wd] = cwd + cwd += 1 + else: + break + if (print_func is not None) and dkeys: + print_func("duplicated vocab keys: %s" % str(dkeys)) + + return rs, cwd + +def save_vocab(vcb_dict, fname, omit_vsize=False, sep=sep_save): + + r_vocab = {} + for k, v in vcb_dict.items(): + if v not in r_vocab: + r_vocab[v]=[str(v), k] + else: + r_vocab[v].append(k) + + freqs = list(r_vocab.keys()) + freqs.sort(reverse=True) + + ens = "\n".encode("utf-8") + remain = omit_vsize + with sys_open(fname, "wb") as f: + for freq in freqs: + cdata = r_vocab[freq] + ndata = len(cdata) - 1 + if remain and (remain < ndata): + cdata = cdata[:remain + 1] + ndata = remain + f.write(sep.join(cdata).encode("utf-8")) + f.write(ens) + if remain: + remain -= ndata + if remain <= 0: + break + +def ldvocab_list(vfile, minf=False, omit_vsize=False, sep=sep_load, file_reader=file_reader, print_func=print): + + rs = [] + if omit_vsize: + vsize = omit_vsize + else: + vsize = False + cwd, lwd, dkeys = 0, set(), set() + for data in file_reader(vfile, keep_empty_line=False, sep=sep): + freq = int(data[0]) + if (not minf) or (freq > minf): + if vsize: + ndata = len(data) - 1 + if vsize > ndata: + for wd in data[1:]: + if wd in lwd: + if wd not in dkeys: + dkeys.add(wd) + ndata -= 1 + else: + rs.append(wd) + lwd.add(wd) + cwd += 1 + vsize -= ndata + else: + _break = False + for wd in data[1:]: + if wd in lwd: + if wd not in dkeys: + dkeys.add(wd) + else: + rs.append(wd) + lwd.add(wd) + cwd += 1 + vsize -= 1 + _break = (vsize <= 0) + if _break: + break + if _break: + break + else: + for wd in data[1:]: + if wd in lwd: + if wd not in dkeys: + dkeys.add(wd) + else: + rs.append(wd) + lwd.add(wd) + cwd += 1 + else: + break + if (print_func is not None) and dkeys: + print_func("duplicated vocab keys: %s" % str(dkeys)) + + return rs, cwd + +def ldvocab_freq(vfile, minf=False, omit_vsize=False, sep=sep_load, file_reader=file_reader, print_func=print): + + rs = {} + if omit_vsize: + vsize = omit_vsize + else: + vsize = False + cwd = 0 + dkeys = set() + for data in file_reader(vfile, keep_empty_line=False, sep=sep): + freq = int(data[0]) + if (not minf) or (freq > minf): + if vsize: + ndata = len(data) - 1 + if vsize > ndata: + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + ndata -= 1 + else: + rs[wd] = freq + cwd += ndata + vsize -= ndata + else: + _break = False + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + else: + rs[wd] = freq + cwd += 1 + vsize -= 1 + _break = (vsize <= 0) + if _break: + break + if _break: + break + else: + for wd in data[1:]: + if wd in rs: + if wd not in dkeys: + dkeys.add(wd) + else: + rs[wd] = freq + cwd += 1 + else: + break + if (print_func is not None) and dkeys: + print_func("duplicated vocab keys: %s" % str(dkeys)) + + return rs, cwd diff --git a/utils/func.py b/utils/func.py new file mode 100644 index 0000000..acaaddd --- /dev/null +++ b/utils/func.py @@ -0,0 +1,5 @@ +#encoding: utf-8 + +def identity_func(x, *args, **kwargs): + + return x diff --git a/utils/h5serial.py b/utils/h5serial.py index 8d8a7fc..ec3b6a2 100644 --- a/utils/h5serial.py +++ b/utils/h5serial.py @@ -1,30 +1,30 @@ #encoding: utf-8 -import torch import h5py -from h5py import Dataset +import torch from collections.abc import Iterator -from utils.fmt.base import list2dict, dict_is_list +from utils.fmt.base import dict_is_list, list2dict -from cnfg.ihyp import * +from cnfg.ihyp import h5_libver, h5modelwargs, hdf5_track_order, list_key_func try: h5py.get_config().track_order = hdf5_track_order except Exception as e: pass -class h5File_ctx(h5py.File): - - def __enter__(self): +if hasattr(h5py.File, "__enter__") and hasattr(h5py.File, "__exit__"): + h5File = h5py.File +else: + class h5File(h5py.File): - return self + def __enter__(self): - def __exit__(self, *inputs, **kwargs): + return self - self.close() + def __exit__(self, *inputs, **kwargs): -h5File = h5py.File if hasattr(h5py.File, "__enter__") and hasattr(h5py.File, "__exit__") else h5File_ctx + self.close() def h5write_dict(gwrt, dtw, h5args=h5modelwargs): @@ -37,10 +37,8 @@ def h5write_dict(gwrt, dtw, h5args=h5modelwargs): gwrt.create_group(k) h5write_list(gwrt[k], _v, h5args=h5args) else: - if _v.device.type == "cpu": - gwrt.create_dataset(k, data=_v.numpy(), **h5args) - else: - gwrt.create_dataset(k, data=_v.cpu().numpy(), **h5args) + _ = _v if _v.device.type == "cpu" else _v.cpu() + gwrt.create_dataset(k, data=_.numpy(), **h5args) def h5write_list(gwrt, ltw, h5args=h5modelwargs): @@ -72,7 +70,7 @@ def h5load_group(grd): rsd = {} for k, v in grd.items(): - if isinstance(v, Dataset): + if isinstance(v, h5py.Dataset): rsd[k] = torch.from_numpy(v[()]) else: rsd[k] = h5load_group(v) diff --git a/utils/init/base.py b/utils/init/base.py index cfc49fc..903d8a8 100644 --- a/utils/init/base.py +++ b/utils/init/base.py @@ -1,12 +1,12 @@ #encoding: utf-8 -import torch -from torch.nn import Embedding, Linear, LayerNorm +from math import sqrt +from torch.nn import Embedding, LayerNorm, Linear from torch.nn.init import _calculate_fan_in_and_fan_out -from math import sqrt +from utils.torch.comp import torch_no_grad -from cnfg.hyp import lipschitz_initialization +from cnfg.hyp import lipschitz_initialization, lipschitz_scale def xavier_uniform_(tensor, gain=1.0): @@ -15,7 +15,7 @@ def xavier_uniform_(tensor, gain=1.0): _scale *= gain if tensor.requires_grad and (tensor.dim() > 1): - with torch.no_grad(): + with torch_no_grad(): _fin, _fo = _calculate_fan_in_and_fan_out(tensor) _bound = _scale / sqrt(float(_fin + _fo)) tensor.uniform_(-_bound, _bound) @@ -29,7 +29,7 @@ def kaiming_uniform_(tensor, gain=1.0): _scale *= gain if tensor.requires_grad and (tensor.dim() > 1): - with torch.no_grad(): + with torch_no_grad(): _fin, _ = _calculate_fan_in_and_fan_out(tensor) _bound = _scale / sqrt(float(_fin)) tensor.uniform_(-_bound, _bound) @@ -41,7 +41,7 @@ def init_model_params_glorot(modin, gain=1.0): _scale = sqrt(6.0) if gain is not None and gain > 0.0 and gain != 1.0: _scale *= gain - with torch.no_grad(): + with torch_no_grad(): for p in modin.parameters(): if p.requires_grad and (p.dim() > 1): _fin, _fo = _calculate_fan_in_and_fan_out(p) @@ -55,7 +55,7 @@ def init_model_params_kaiming(modin, gain=1.0): _scale = sqrt(3.0) if gain is not None and gain > 0.0 and gain != 1.0: _scale *= gain - with torch.no_grad(): + with torch_no_grad(): for p in modin.parameters(): if p.requires_grad and (p.dim() > 1): _fin, _ = _calculate_fan_in_and_fan_out(p) @@ -64,11 +64,11 @@ def init_model_params_kaiming(modin, gain=1.0): return modin -def init_model_params_lipschitz(modin, gain_glorot=sqrt(1.0/3.0), gain_kaiming=sqrt(1.0/3.0)): +def init_model_params_lipschitz(modin, gain_glorot=sqrt(1.0/3.0) * lipschitz_scale, gain_kaiming=sqrt(1.0/3.0) * lipschitz_scale): _tmpm = init_model_params_kaiming(modin, gain=gain_kaiming) - with torch.no_grad(): + with torch_no_grad(): for _m in _tmpm.modules(): if isinstance(_m, Embedding): init_model_params_glorot(_m, gain=gain_glorot) diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000..8cc859d --- /dev/null +++ b/utils/io.py @@ -0,0 +1,128 @@ +#encoding: utf-8 + +import torch +from os import remove +from os.path import exists as fs_check +from threading import Thread + +from utils.h5serial import h5load, h5save +from utils.torch.comp import torch_no_grad + +from cnfg.ihyp import h5modelwargs, hdf5_save_parameter_name, n_keep_best#, hdf5_load_parameter_name + +def load_model_cpu_p(modf, base_model, mp=None, **kwargs): + + with torch_no_grad(): + for para, mp in zip(base_model.parameters(), h5load(modf, restore_list=True) if mp is None else mp): + para.copy_(mp) + + return base_model + +def load_model_cpu_np(modf, base_model, mp=None, strict=False, print_func=print, **kwargs): + + _ = base_model.load_state_dict(h5load(modf, restore_list=False) if mp is None else mp, strict=strict, **kwargs) + if (print_func is not None) and (_ is not None): + for _msg in _: + if _msg: + print_func(_msg) + + return base_model + +def load_model_cpu_auto(modf, base_model, mp=None, **kwargs): + + _mp = h5load(modf, restore_list=True) if mp is None else mp + _load_model_func = load_model_cpu_p if isinstance(_mp, list) else load_model_cpu_np + + return _load_model_func(modf, base_model, mp=_mp, **kwargs) + +mp_func_p = lambda m: [_t.data for _t in m.parameters()] +mp_func_np = lambda m: {_k: _t.data for _k, _t in m.named_parameters()} + +load_model_cpu = load_model_cpu_auto#load_model_cpu_np if hdf5_load_parameter_name else load_model_cpu_p +mp_func = mp_func_np if hdf5_save_parameter_name else mp_func_p + +class bestfkeeper: + + def __init__(self, fnames=None, k=n_keep_best, **kwargs): + + self.fnames, self.k = [] if fnames is None else fnames, k + self.clean() + + def update(self, fname=None): + + self.fnames.append(fname) + self.clean(last_fname=fname) + + def clean(self, last_fname=None): + + _n_files = len(self.fnames) + _last_fname = (self.fnames[-1] if self.fnames else None) if last_fname is None else last_fname + while _n_files > self.k: + fname = self.fnames.pop(0) + if (fname is not None) and (fname != _last_fname) and fs_check(fname): + try: + remove(fname) + except Exception as e: + print(e) + _n_files -= 1 + +class SaveModelCleaner: + + def __init__(self): + + self.holder = {} + + def __call__(self, fname, typename, **kwargs): + + if typename in self.holder: + self.holder[typename].update(fname) + else: + self.holder[typename] = bestfkeeper(fnames=[fname]) + +save_model_cleaner = SaveModelCleaner() + +def save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs): + + _msave = model.module if sub_module else model + try: + h5save(mp_func(_msave), fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) + except Exception as e: + if print_func is not None: + print_func(str(e)) + +def async_save_model(model, fname, sub_module=False, print_func=print, mtyp=None, h5args=h5modelwargs, para_lock=None, log_success=None): + + def _worker(model, fname, sub_module=False, print_func=print, mtyp=None, para_lock=None, log_success=None): + + success = True + _msave = model.module if sub_module else model + try: + if para_lock is None: + h5save(mp_func(_msave), fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) + else: + with para_lock: + h5save(mp_func(_msave), fname, h5args=h5args) + if mtyp is not None: + save_model_cleaner(fname, mtyp) + except Exception as e: + if print_func is not None: + print_func(str(e)) + success = False + if success and (print_func is not None) and (log_success is not None): + print_func(str(log_success)) + + Thread(target=_worker, args=(model, fname, sub_module, print_func, mtyp, para_lock, log_success)).start() + +def save_states(state_dict, fname, print_func=print, mtyp=None): + + try: + torch.save(state_dict, fname) + if mtyp is not None: + save_model_cleaner(fname, mtyp) + except Exception as e: + if print_func is not None: + print_func(str(e)) diff --git a/utils/math.py b/utils/math.py index 9a1bebb..cb8fe53 100644 --- a/utils/math.py +++ b/utils/math.py @@ -1,7 +1,20 @@ #encoding: utf-8 +from itertools import accumulate from math import log +def pos_norm(x): + + _s = sum(x) + if _s == 0.0: + _s = 1.0 + + return [_ / _s for _ in x] + +def cumsum(*args, **kwargs): + + return list(accumulate(*args, **kwargs)) + def arcsigmoid(x): return -log((1.0 / x) - 1.0) diff --git a/utils/mulang.py b/utils/mulang.py index f09074d..b3369e8 100644 --- a/utils/mulang.py +++ b/utils/mulang.py @@ -1,7 +1,7 @@ #encoding: utf-8 -from utils.random import multinomial from utils.data import inf_data_generator +from utils.random import multinomial def T_normalize(wl, T): @@ -23,7 +23,7 @@ def sample_iter(wl, T, ntrain, taskl): class data_sampler: - def __init__(self, task_weight, task_weight_T, ntrain, train_taskl, nsample=None): + def __init__(self, task_weight, task_weight_T, ntrain, train_taskl, nsample=None, **kwargs): self.generator = sample_iter(task_weight, task_weight_T, ntrain, train_taskl) self.nsample = nsample @@ -31,3 +31,34 @@ def __init__(self, task_weight, task_weight_T, ntrain, train_taskl, nsample=None def generate(self, nsample=None): return [next(self.generator) for i in range(self.nsample if nsample is None else nsample)] + +class balance_loader: + + def __init__(self, tls, sfunc=min): + + self.tls = tls + self.imax = len(tls) + self.imin = - (self.imax + 1) + self.ndata = self.imax * sfunc(len(_) for _ in self.tls) + self.dg = [inf_data_generator(_) for _ in self.tls] + self.c = [0 for _ in range(self.imax)] + + def get_one(self): + + _im, _vm = 0, self.c[0] + for _i, _v in enumerate(self.c): + if _v < _vm: + _im, _vm = _i, _v + + return _im, next(self.dg[_im]) + + def __call__(self, ndata=None): + + for _ in range(self.ndata if ndata is None else ndata): + yield self.get_one() + + def update(self, i, v=0): + + if (i < self.imax) and (i > self.imin) and (v > 0): + _ = self.c[i] + v + self.c = [0 if _i == i else (_v - _) for _i, _v in enumerate(self.c)] diff --git a/utils/plm/bart.py b/utils/plm/bart.py new file mode 100644 index 0000000..ada1472 --- /dev/null +++ b/utils/plm/bart.py @@ -0,0 +1,89 @@ +#encoding: utf-8 + +import torch +from torch import nn + +from modules.dropout import Dropout +from utils.fmt.parser import parse_none +from utils.plm.base import copy_plm_parameter +from utils.torch.comp import torch_no_grad + +def load_plm_encoder_layer(layer, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + _model_name = parse_none(model_name, layer.model_name) + with torch_no_grad(): + copy_plm_parameter(layer.attn.net.adaptor.weight, plm_parameters, ["%s.layers.%d.self_attn.q_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + _bias_key = "%s.layers.%d.self_attn.q_proj.bias" % (_model_name, layer_idx,) + if (layer.attn.net.adaptor.bias is None) and (_bias_key in plm_parameters): + layer.attn.net.adaptor.bias = nn.Parameter(torch.zeros(layer.attn.net.adaptor.weight.size(0))) + if layer.attn.net.adaptor.bias is not None: + copy_plm_parameter(layer.attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.self_attn.k_proj.bias" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + copy_plm_parameter(layer.attn.net.outer.weight, plm_parameters, "%s.layers.%d.self_attn.out_proj.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.self_attn.out_proj.bias" % (_model_name, layer_idx,) + if (layer.attn.net.outer.bias is None) and (_bias_key in plm_parameters): + layer.attn.net.outer.bias = nn.Parameter(torch.zeros(layer.attn.net.outer.weight.size(0))) + if layer.attn.net.outer.bias is not None: + copy_plm_parameter(layer.attn.net.outer.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.attn.normer.weight, plm_parameters, "%s.layers.%d.self_attn_layer_norm.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.attn.normer.bias, plm_parameters, "%s.layers.%d.self_attn_layer_norm.bias" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.net[0].weight, plm_parameters, "%s.layers.%d.fc1.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.net[0].bias, plm_parameters, "%s.layers.%d.fc1.bias" % (_model_name, layer_idx,)) + _l = layer.ff.net[-2] if isinstance(layer.ff.net[-1], Dropout) else layer.ff.net[-1] + copy_plm_parameter(_l.weight, plm_parameters, "%s.layers.%d.fc2.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.fc2.bias" % (_model_name, layer_idx,) + if (_l.bias is None) and (_bias_key in plm_parameters): + _l.bias = nn.Parameter(torch.zeros(_l.weight.size(0))) + if _l.bias is not None: + copy_plm_parameter(_l.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.ff.normer.weight, plm_parameters, "%s.layers.%d.final_layer_norm.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.normer.bias, plm_parameters, "%s.layers.%d.final_layer_norm.bias" % (_model_name, layer_idx,)) + +def load_plm_decoder_layer(layer, plm_parameters, model_name=None, layer_idx=None, **kwargs): + + _model_name = parse_none(model_name, layer.model_name) + with torch_no_grad(): + copy_plm_parameter(layer.self_attn.net.adaptor.weight, plm_parameters, ["%s.layers.%d.self_attn.q_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + _bias_key = "%s.layers.%d.self_attn.q_proj.bias" % (_model_name, layer_idx,) + if (layer.self_attn.net.adaptor.bias is None) and (_bias_key in plm_parameters): + layer.self_attn.net.adaptor.bias = nn.Parameter(torch.zeros(layer.self_attn.net.adaptor.weight.size(0))) + if layer.self_attn.net.adaptor.bias is not None: + copy_plm_parameter(layer.self_attn.net.adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.self_attn.k_proj.bias" % (_model_name, layer_idx,), "%s.layers.%d.self_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + copy_plm_parameter(layer.self_attn.net.outer.weight, plm_parameters, "%s.layers.%d.self_attn.out_proj.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.self_attn.out_proj.bias" % (_model_name, layer_idx,) + if (layer.self_attn.net.outer.bias is None) and (_bias_key in plm_parameters): + layer.self_attn.net.outer.bias = nn.Parameter(torch.zeros(layer.self_attn.net.outer.weight.size(0))) + if layer.self_attn.net.outer.bias is not None: + copy_plm_parameter(layer.self_attn.net.outer.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.self_attn.normer.weight, plm_parameters, "%s.layers.%d.self_attn_layer_norm.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.self_attn.normer.bias, plm_parameters, "%s.layers.%d.self_attn_layer_norm.bias" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.cross_attn.net.query_adaptor.weight, plm_parameters, "%s.layers.%d.encoder_attn.q_proj.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.encoder_attn.q_proj.bias" % (_model_name, layer_idx,) + if (layer.cross_attn.net.query_adaptor.bias is None) and (_bias_key in plm_parameters): + layer.cross_attn.net.query_adaptor.bias = nn.Parameter(torch.zeros(layer.cross_attn.net.query_adaptor.weight.size(0))) + if layer.cross_attn.net.query_adaptor.bias is not None: + copy_plm_parameter(layer.cross_attn.net.query_adaptor.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.cross_attn.net.kv_adaptor.weight, plm_parameters, ["%s.layers.%d.encoder_attn.k_proj.weight" % (_model_name, layer_idx,), "%s.layers.%d.encoder_attn.v_proj.weight" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + _bias_key = "%s.layers.%d.encoder_attn.k_proj.bias" % (_model_name, layer_idx,) + if (layer.cross_attn.net.kv_adaptor.bias is None) and (_bias_key in plm_parameters): + layer.cross_attn.net.kv_adaptor.bias = nn.Parameter(torch.zeros(layer.cross_attn.net.kv_adaptor.weight.size(0))) + if layer.cross_attn.net.kv_adaptor.bias is not None: + copy_plm_parameter(layer.cross_attn.net.kv_adaptor.bias, plm_parameters, [_bias_key, "%s.layers.%d.encoder_attn.v_proj.bias" % (_model_name, layer_idx,)], func=torch.cat, func_kwargs={"dim": 0}) + copy_plm_parameter(layer.cross_attn.net.outer.weight, plm_parameters, "%s.layers.%d.encoder_attn.out_proj.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.encoder_attn.out_proj.bias" % (_model_name, layer_idx,) + if (layer.cross_attn.net.outer.bias is None) and (_bias_key in plm_parameters): + layer.cross_attn.net.outer.bias = nn.Parameter(torch.zeros(layer.cross_attn.net.outer.weight.size(0))) + if layer.cross_attn.net.outer.bias is not None: + copy_plm_parameter(layer.cross_attn.net.outer.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.cross_attn.normer.weight, plm_parameters, "%s.layers.%d.encoder_attn_layer_norm.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.cross_attn.normer.bias, plm_parameters, "%s.layers.%d.encoder_attn_layer_norm.bias" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.net[0].weight, plm_parameters, "%s.layers.%d.fc1.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.net[0].bias, plm_parameters, "%s.layers.%d.fc1.bias" % (_model_name, layer_idx,)) + _l = layer.ff.net[-2] if isinstance(layer.ff.net[-1], Dropout) else layer.ff.net[-1] + copy_plm_parameter(_l.weight, plm_parameters, "%s.layers.%d.fc2.weight" % (_model_name, layer_idx,)) + _bias_key = "%s.layers.%d.fc2.bias" % (_model_name, layer_idx,) + if (_l.bias is None) and (_bias_key in plm_parameters): + _l.bias = nn.Parameter(torch.zeros(_l.weight.size(0))) + if _l.bias is not None: + copy_plm_parameter(_l.bias, plm_parameters, _bias_key) + copy_plm_parameter(layer.ff.normer.weight, plm_parameters, "%s.layers.%d.final_layer_norm.weight" % (_model_name, layer_idx,)) + copy_plm_parameter(layer.ff.normer.bias, plm_parameters, "%s.layers.%d.final_layer_norm.bias" % (_model_name, layer_idx,)) diff --git a/utils/plm/base.py b/utils/plm/base.py index 41f5986..91ea0d8 100644 --- a/utils/plm/base.py +++ b/utils/plm/base.py @@ -29,10 +29,11 @@ def copy_plm_parameter(src, plm_parameters, keys, func=None, func_args=None, fun _mdl.append(_i) _src.copy_(_tgt) if _mdl and (print_func is not None): - print("size mismatch for %s at dimension(s) %s" % (_p_k, ",".join([str(_) for _ in _mdl]),)) - print(_s_size, _t_size) + print_func("size mismatch for %s at dimension(s) %s" % (_p_k, ",".join([str(_) for _ in _mdl]),)) + print_func(_s_size, _t_size) elif print_func is not None: print_func("dimension mismatch for %s" % _p_k) + print_func(_s_size, _t_size) elif print_func is not None: print_func("%s does not exist" % _p_k) diff --git a/utils/plm/t5.py b/utils/plm/t5.py index f9dfe75..3e0b539 100644 --- a/utils/plm/t5.py +++ b/utils/plm/t5.py @@ -2,7 +2,8 @@ import torch from torch.nn import ModuleList -from modules.plm.t5 import SelfAttn, CrossAttn + +from modules.plm.t5 import CrossAttn, SelfAttn def reorder_pemb(w): diff --git a/utils/process.py b/utils/process.py new file mode 100644 index 0000000..4061d9f --- /dev/null +++ b/utils/process.py @@ -0,0 +1,28 @@ +#encoding: utf-8 + +from multiprocessing import Process +from time import sleep + +def start_process(*args, **kwargs): + + _ = Process(*args, **kwargs) + _.start() + + return _ + +def process_keeper_core(t, sleep_secs, *args, **kwargs): + + if t.is_alive(): + sleep(sleep_secs) + else: + t.join() + t.close() + t = start_process(*args, **kwargs) + + return t + +def process_keeper(condition, sleep_secs, *args, **kwargs): + + _t = start_process(*args, **kwargs) + while condition.value: + _t = process_keeper_core(_t, sleep_secs, *args, **kwargs) diff --git a/utils/relpos/base.py b/utils/relpos/base.py index 385c2dd..cf01e3f 100644 --- a/utils/relpos/base.py +++ b/utils/relpos/base.py @@ -1,9 +1,10 @@ #encoding: utf-8 from torch.nn import ModuleList -from modules.base import MultiHeadAttn, SelfAttn -def share_rel_pos_cache(netin): +from modules.base import CrossAttn, MultiHeadAttn, SelfAttn + +def share_rel_pos_cache(netin, share_emb=False): rel_cache_d = {} rel_map_cache_d = {} @@ -11,22 +12,24 @@ def share_rel_pos_cache(netin): if isinstance(net, ModuleList): base_nets = {} for layer in net.modules(): - if isinstance(layer, (SelfAttn, MultiHeadAttn,)): - if layer.rel_pemb is not None: + if isinstance(layer, (SelfAttn, MultiHeadAttn, CrossAttn,)): + if hasattr(layer, "rel_pemb") and (layer.rel_pemb is not None): _key_rel_pos_map = None if layer.rel_pos_map is None else layer.rel_pos_map.size() _key = (layer.clamp_min, layer.clamp_max, layer.rel_shift, _key_rel_pos_map,) if _key in base_nets: layer.ref_rel_posm = base_nets[_key] + if share_emb: + layer.rel_pemb = base_nets[_key].rel_pemb else: base_nets[_key] = layer if _key_rel_pos_map is not None: if _key in rel_map_cache_d: - layer.rel_pos_map = rel_map_cache_d[_key] + layer.register_buffer("rel_pos_map", rel_map_cache_d[_key], persistent=False) else: rel_map_cache_d[_key] = layer.rel_pos_map _key = (layer.clamp_min, layer.clamp_max, layer.rel_shift, _key_rel_pos_map, layer.rel_pos.size(),) if _key in rel_cache_d: - layer.rel_pos = rel_cache_d[_key] + layer.register_buffer("rel_pos", rel_cache_d[_key], persistent=False) else: rel_cache_d[_key] = layer.rel_pos diff --git a/utils/relpos/bucket.py b/utils/relpos/bucket.py index c27b7f2..ed02101 100644 --- a/utils/relpos/bucket.py +++ b/utils/relpos/bucket.py @@ -3,7 +3,7 @@ import torch from math import log -''' +""" relative postional encoding of T5, implementation of the transformers library for reference: https://github.com/huggingface/transformers/blob/v4.21.2/src/transformers/models/t5/modeling_t5.py#L374-L419 def _relative_position_bucket(length, num_buckets=32, max_distance=128, bidirectional=True): @@ -25,7 +25,7 @@ def _relative_position_bucket(length, num_buckets=32, max_distance=128, bidirect relative_buckets += torch.where(is_small, relative_position, relative_position_if_large) return relative_buckets -''' +""" def build_rel_pos_bucket_map(k_rel_pos=32, max_len=128, uni_direction=False, device=None): diff --git a/utils/retrattn.py b/utils/retrattn.py index bcf0844..0b2aec6 100644 --- a/utils/retrattn.py +++ b/utils/retrattn.py @@ -1,6 +1,7 @@ #encoding: utf-8 from torch.nn import ModuleList + from modules.attn.retr import SelfAttn def share_retrattn_cache(netin): @@ -15,6 +16,6 @@ def share_retrattn_cache(netin): if _cache is None: _cache = layer.csum else: - layer.csum = _cache + layer.register_buffer("csum", _cache, persistent=False) return netin diff --git a/utils/sampler.py b/utils/sampler.py index be7077c..a49c6e3 100644 --- a/utils/sampler.py +++ b/utils/sampler.py @@ -1,6 +1,6 @@ #encoding: utf-8 -from utils.torch import multinomial +from utils.torch.ext import multinomial def SampleMax(x, dim=-1, keepdim=False): diff --git a/utils/server/__init__.py b/utils/server/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/server/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/server/batcher.py b/utils/server/batcher.py new file mode 100755 index 0000000..8637803 --- /dev/null +++ b/utils/server/batcher.py @@ -0,0 +1,132 @@ +#encoding: utf-8 + +from asyncio import sleep as asleep +from threading import Lock +from time import sleep + +from utils.thread import start_thread_with_keeper + +from cnfg.server import batcher_maintain_interval as maintain_interval, batcher_wait_interval as wait_interval, batcher_watcher_interval as watcher_interval, thread_keeper_interval + +class BatchWrapper: + + def __init__(self, handler, wait_interval=wait_interval, maintain_interval=maintain_interval, watcher_interval=watcher_interval): + + self.handler, self.wait_interval, self.maintain_interval, self.watcher_interval = handler, wait_interval, maintain_interval, watcher_interval + self.ipool, self.opool, self.rids, self.idpool, self.cid, self.mids, self.ipool_lck, self.opool_lck, self.rids_lck, self.idpool_lck, self.cid_lck, self.mids_lck = {}, {}, set(), set(), 1, set(), Lock(), Lock(), Lock(), Lock(), Lock(), Lock() + self.opids = tuple() + self.running = True + self.t_process, self.t_mnt = start_thread_with_keeper([self.is_running], None, thread_keeper_interval, target=self.processor), start_thread_with_keeper([self.is_running], None, thread_keeper_interval, target=self.maintainer) + + async def __call__(self, x): + + _watcher_interval = self.watcher_interval + _id = None + _rs = x + try: + _id = self.get_id() + with self.ipool_lck: + self.ipool[_id] = x + while _id in self.ipool: + await asleep(_watcher_interval) + while _id in self.rids: + await asleep(_watcher_interval) + if _id in self.opool: + with self.opool_lck: + _rs = self.opool.pop(_id) + finally: + if _id is not None: + with self.ipool_lck: + if _id in self.ipool: + del self.ipool[_id] + with self.rids_lck: + if _id in self.rids: + self.rids.remove(_id) + with self.opool_lck: + if _id in self.opool: + del self.opool[_id] + with self.idpool_lck: + self.idpool.add(_id) + if self.mids and (_id in self.mids): + with self.mids_lck: + if _id in self.mids: + self.mids.remove(_id) + + return _rs + + def processor(self): + + while self.running: + if self.ipool: + with self.ipool_lck: + _hd = self.ipool + with self.rids_lck: + self.rids |= set(_hd.keys()) + self.ipool = {} + _i = list(set(_ for _v in _hd.values() for _ in _v)) + _map = {_k: _v for _k, _v in zip(_i, self.handler(_i))} + _rs = {_k: [_map.get(_iu, _iu) for _iu in _i] for _k, _i in _hd.items()} + with self.opool_lck: + self.opool |= _rs + with self.rids_lck: + self.rids -= set(_hd.keys()) + else: + sleep(self.wait_interval) + + def maintainer(self): + + while self.running: + if self.opool: + if self.opids: + _gc = set() + with self.opool_lck: + for _ in self.opids: + if _ in self.opool: + del self.opool[_] + _gc.add(_) + self.opids = tuple(self.opool.keys()) + if _gc: + with self.idpool_lck: + self.idpool |= _gc + else: + if (not self.rids) and (not self.ipool): + if self.mids: + with self.mids_lck, self.idpool_lck: + self.idpool |= self.mids + self.mids.clear() + else: + with self.idpool_lck, self.cid_lck, self.ipool_lck, self.rids_lck, self.opool_lck: + _mids = set(range(1, self.cid)) - self.idpool - set(self.ipool.keys()) - self.rids - set(self.opool.keys()) + if _mids: + with self.mids_lck: + self.mids |= _mids + sleep(self.maintain_interval) + + def get_id_core(self): + + rs = None + with self.idpool_lck: + if self.idpool: + rs = self.idpool.pop() + if rs is None: + with self.cid_lck: + rs = self.cid + self.cid += 1 + + return rs + + def get_id(self): + + rs = self.get_id_core() + while (rs in self.ipool) or (rs in self.rids) or (rs in self.opool): + rs = self.get_id_core() + + return rs + + def status(self, mode): + + self.running = mode + + def is_running(self): + + return self.running diff --git a/utils/state/thrand.py b/utils/state/thrand.py index 9be294f..d34f0da 100644 --- a/utils/state/thrand.py +++ b/utils/state/thrand.py @@ -4,7 +4,7 @@ class THRandomState: - def __init__(self, use_cuda=True): + def __init__(self, use_cuda=True, **kwargs): self.use_cuda = torch.cuda.is_available() and use_cuda diff --git a/utils/thread.py b/utils/thread.py new file mode 100644 index 0000000..a74829d --- /dev/null +++ b/utils/thread.py @@ -0,0 +1,53 @@ +#encoding: utf-8 + +from threading import Lock, Thread +from time import sleep + +class LockHolder: + + def __init__(self, value=None): + + self.value = value + self.lck = Lock() + + def __call__(self, *args): + + if args: + with self.lck: + self.value = args[0] + else: + with self.lck: + return self.value + +def start_thread(*args, **kwargs): + + _ = Thread(*args, **kwargs) + _.start() + + return _ + +def thread_keeper_core(t, sleep_secs, *args, **kwargs): + + if t.is_alive(): + sleep(sleep_secs) + else: + t.join() + t = start_thread(*args, **kwargs) + + return t + +def thread_keeper(conditions, func, sleep_secs, *args, **kwargs): + + _conditions = tuple(conditions) + _t = start_thread(*args, **kwargs) + if len(_conditions) > 1: + while func(_() for _ in _conditions): + _t = thread_keeper_core(_t, sleep_secs, *args, **kwargs) + else: + _condition = _conditions[0] + while _condition(): + _t = thread_keeper_core(_t, sleep_secs, *args, **kwargs) + +def start_thread_with_keeper(conditions, func, sleep_secs, *args, **kwargs): + + return start_thread(target=thread_keeper, args=[conditions, func, sleep_secs, *args], kwargs=kwargs) diff --git a/utils/torch/__init__.py b/utils/torch/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/torch/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/torch/c.py b/utils/torch/c.py new file mode 100644 index 0000000..87c33e7 --- /dev/null +++ b/utils/torch/c.py @@ -0,0 +1,26 @@ +#encoding: utf-8 + +from torch.autograd import Function +from torch.utils.cpp_extension import load + +try: + import movavg_cpp +except Exception as e: + movavg_cpp = load(name="movavg_cpp", sources=["utils/cpp/movavg.cpp"]) + +class MovAvgFunction(Function): + + @staticmethod + def forward(ctx, x, dim=None, beta=0.9, inplace=False): + + out = movavg_cpp.forward(x, dim, beta, inplace) + ctx.dim, ctx.beta = dim, beta + + return out + + @staticmethod + def backward(ctx, grad_out): + + return movavg_cpp.backward(grad_out, ctx.dim, ctx.beta), None, None, None + +MovAvgFunc = MovAvgFunction.apply diff --git a/utils/torch/comp.py b/utils/torch/comp.py new file mode 100644 index 0000000..223dbe4 --- /dev/null +++ b/utils/torch/comp.py @@ -0,0 +1,208 @@ +#encoding: utf-8 + +import torch + +from utils.func import identity_func + +from cnfg.ihyp import allow_fp16_reduction, allow_tf32, enable_torch_check, use_deterministic, use_inference_mode, use_torch_compile + +secure_type_map = {torch.float16: torch.float64, torch.float32: torch.float64, torch.uint8: torch.int64, torch.int8: torch.int64, torch.int16: torch.int64, torch.int32: torch.int64} + +try: + if hasattr(torch, "set_float32_matmul_precision"): + torch.set_float32_matmul_precision("medium" if allow_fp16_reduction else ("high" if allow_tf32 else "highest")) + torch.backends.cuda.matmul.allow_tf32 = allow_tf32 + torch.backends.cudnn.allow_tf32 = allow_tf32 + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = allow_fp16_reduction + torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = allow_fp16_reduction +except Exception as e: + print(e) + +if hasattr(torch.autograd, "set_multithreading_enabled"): + try: + torch.autograd.set_multithreading_enabled(True) + except Exception as e: + print(e) + +# Make cudnn methods deterministic according to: https://pytorch.org/docs/stable/notes/randomness.html#cudnn +_config_cudnn_deterministic_variable = True +if hasattr(torch, "use_deterministic_algorithms"): + try: + torch.use_deterministic_algorithms(use_deterministic, warn_only=True) + _config_cudnn_deterministic_variable = False + except Exception as e: + print(e) +if _config_cudnn_deterministic_variable: + torch.backends.cudnn.deterministic = use_deterministic + +torch.backends.cudnn.benchmark = False + +if hasattr(torch, "autograd") and hasattr(torch.autograd, "set_detect_anomaly"): + try: + torch.autograd.set_detect_anomaly(enable_torch_check) + except Exception as e: + print(e) + +def all_done_bool(stat, *inputs, **kwargs): + + return stat.all().item() + +def all_done_byte(stat, bsize=None, **kwargs): + + return stat.int().sum().item() == (stat.numel() if bsize is None else bsize) + +def exist_any_bool(stat): + + return stat.any().item() + +def exist_any_byte(stat): + + return stat.int().sum().item() > 0 + +def torch_all_bool_wodim(x, *inputs, **kwargs): + + return x.all(*inputs, **kwargs) + +def torch_all_byte_wodim(x, *inputs, **kwargs): + + return x.int().sum(*inputs, **kwargs).eq(x.numel()) + +def torch_all_bool_dim(x, dim, *inputs, **kwargs): + + return x.all(dim, *inputs, **kwargs) + +def torch_all_byte_dim(x, dim, *inputs, **kwargs): + + return x.int().sum(*inputs, dim=dim, **kwargs).eq(x.size(dim)) + +def torch_all_bool(x, *inputs, dim=None, **kwargs): + + return x.all(*inputs, **kwargs) if dim is None else x.all(dim, *inputs, **kwargs) + +def torch_all_byte(x, *inputs, dim=None, **kwargs): + + return x.int().sum(*inputs, **kwargs).eq(x.numel()) if dim is None else x.int().sum(*inputs, dim=dim, **kwargs).eq(x.size(dim)) + +def torch_any_bool_wodim(x, *inputs, **kwargs): + + return x.any(*inputs, **kwargs) + +def torch_any_byte_wodim(x, *inputs, **kwargs): + + return x.int().sum(*inputs, **kwargs).gt(0) + +def torch_any_bool_dim(x, dim, *inputs, **kwargs): + + return x.any(dim, *inputs, **kwargs) + +def torch_any_byte_dim(x, dim, *inputs, **kwargs): + + return x.int().sum(*inputs, dim=dim, **kwargs).gt(0) + +def torch_any_bool(x, *inputs, dim=None, **kwargs): + + return x.any(*inputs, **kwargs) if dim is None else x.any(dim, *inputs, **kwargs) + +def torch_any_byte(x, *inputs, dim=None, **kwargs): + + return x.int().sum(*inputs, **kwargs).gt(0) if dim is None else x.int().sum(*inputs, dim=dim, **kwargs).gt(0) + +def flip_mask_bool(mask, dim): + + return mask.to(torch.uint8, non_blocking=True).flip(dim).to(mask.dtype, non_blocking=True) + +def flip_mask_byte(mask, dim): + + return mask.flip(dim) + +class EmptyAutocast: + + def __init__(self, *inputs, **kwargs): + + self.args, self.kwargs = inputs, kwargs + + def __enter__(self): + + return self + + def __exit__(self, *inputs, **kwargs): + + pass + +class EmptyGradScaler: + + def __init__(self, *args, **kwargs): + + self.args, self.kwargs = args, kwargs + + def scale(self, outputs): + + return outputs + + def step(self, optimizer, *args, **kwargs): + + return optimizer.step(*args, **kwargs) + + def update(self, *args, **kwargs): + + pass + +def torch_is_autocast_enabled_empty(*args, **kwargs): + + return False + +# handling torch.bool +if hasattr(torch, "bool"): + mask_tensor_type = torch.bool + secure_type_map[mask_tensor_type] = torch.int64 + nccl_type_map = {torch.bool:torch.uint8} + all_done = all_done_bool + exist_any = exist_any_bool + torch_all = torch_all_bool + torch_all_dim = torch_all_bool_dim + torch_all_wodim = torch_all_bool_wodim + torch_any = torch_any_bool + torch_any_dim = torch_any_bool_dim + torch_any_wodim = torch_any_bool_wodim + flip_mask = flip_mask_bool +else: + mask_tensor_type = torch.uint8 + nccl_type_map = None + all_done = all_done_byte + exist_any = exist_any_byte + torch_all = torch_all_byte + torch_all_dim = torch_all_byte_dim + torch_all_wodim = torch_all_byte_wodim + torch_any = torch_any_byte + torch_any_dim = torch_any_byte_dim + torch_any_wodim = torch_any_byte_wodim + flip_mask = flip_mask_byte + +# handling torch.cuda.amp, fp16 will NOT be really enabled if torch.cuda.amp does not exist (for early versions) +_config_torch_cuda_amp = True +if hasattr(torch, "cuda") and hasattr(torch.cuda, "amp"): + try: + from torch.cuda.amp import GradScaler, autocast as torch_autocast + torch_is_autocast_enabled = torch.is_autocast_enabled + is_fp16_supported = True + _config_torch_cuda_amp = False + except Exception as e: + print(e) +if _config_torch_cuda_amp: + torch_autocast, GradScaler, torch_is_autocast_enabled, is_fp16_supported = EmptyAutocast, EmptyGradScaler, torch_is_autocast_enabled_empty, False + +# inference mode for torch >= 1.9.0 +using_inference_mode = use_inference_mode and hasattr(torch, "inference_mode") and hasattr(torch, "is_inference_mode_enabled") +if using_inference_mode: + torch_is_inference_mode_enabled, torch_inference_mode = torch.is_inference_mode_enabled, torch.inference_mode +else: + def torch_is_inference_mode_enabled(): + + return not torch.is_grad_enabled() + + def torch_inference_mode(mode=True): + + return torch.set_grad_enabled(not mode) +torch_is_grad_enabled, torch_set_grad_enabled, torch_no_grad = torch.is_grad_enabled, torch.set_grad_enabled, torch.no_grad + +torch_compile = torch.compile if hasattr(torch, "compile") and use_torch_compile else identity_func diff --git a/utils/torch.py b/utils/torch/ext.py similarity index 87% rename from utils/torch.py rename to utils/torch/ext.py index a7e2b85..713367c 100644 --- a/utils/torch.py +++ b/utils/torch/ext.py @@ -3,7 +3,7 @@ import torch from numbers import Number -from cnfg.ihyp import ieps_upper_bound_default +from cnfg.ihyp import ieps_ln_default, ieps_upper_bound_default upper_one = 1.0 - ieps_upper_bound_default @@ -61,6 +61,10 @@ def comb_grow(start, end, k, alpha=0.5): return exp_grow(start, end, k).mul_(alpha).add_(linear_grow(start, end, k).mul_(1.0 - alpha)) +def cosim(a, b, dim=-1, keepdim=False, eps=ieps_ln_default): + + return a.mul(b).sum(dim=dim, keepdim=keepdim).div_(a.norm(p=2, dim=dim, keepdim=keepdim).mul(b.norm(p=2, dim=dim, keepdim=keepdim)).add_(eps)) + def arcsigmoid(x): return ((1.0 / x) - 1.0).log().neg() @@ -87,7 +91,7 @@ def ensure_num_interop_threads(n): class num_threads: - def __init__(self, n): + def __init__(self, n, **kwargs): self.num_threads_exe = n @@ -102,7 +106,7 @@ def __exit__(self, *inputs, **kwargs): class num_interop_threads: - def __init__(self, n): + def __init__(self, n, **kwargs): self.num_threads_exe = n diff --git a/utils/pyctorch.py b/utils/torch/pyc.py similarity index 68% rename from utils/pyctorch.py rename to utils/torch/pyc.py index e114f49..48313fd 100644 --- a/utils/pyctorch.py +++ b/utils/torch/pyc.py @@ -2,15 +2,17 @@ import torch +from utils.fmt.parser import parse_none + non_tensor = torch.Tensor() def transfer_CNone_tuple(lin): - return tuple(non_tensor if lu is None else lu for lu in lin) + return tuple(parse_none(_, non_tensor) for _ in lin) def transfer_CNone_list(lin): - return [non_tensor if lu is None else lu for lu in lin] + return [parse_none(_, non_tensor) for _ in lin] def transfer_CNone(din): @@ -21,4 +23,4 @@ def transfer_CNone(din): elif isinstance(din, dict): return {k: transfer_CNone(du) for k, du in din.items()} else: - return non_tensor if din is None else din + return parse_none(din, non_tensor) diff --git a/utils/train/__init__.py b/utils/train/__init__.py new file mode 100644 index 0000000..8fb0d7c --- /dev/null +++ b/utils/train/__init__.py @@ -0,0 +1 @@ +#encoding: utf-8 diff --git a/utils/train/base.py b/utils/train/base.py new file mode 100644 index 0000000..1fa4096 --- /dev/null +++ b/utils/train/base.py @@ -0,0 +1,93 @@ +#encoding: utf-8 + +from utils.torch.comp import is_fp16_supported + +from cnfg.ihyp import optm_step_zero_grad_set_none + +def freeze_module(module): + + for p in module.parameters(): + if p.requires_grad: + p.requires_grad_(False) + +def unfreeze_module(module): + + def unfreeze_fixing(mod): + + if hasattr(mod, "fix_unfreeze"): + mod.fix_unfreeze() + + for p in module.parameters(): + p.requires_grad_(True) + + module.apply(unfreeze_fixing) + +def getlr(optm): + + lr = [] + for i, param_group in enumerate(optm.param_groups): + lr.append(float(param_group["lr"])) + + return lr + +def updated_lr(oldlr, newlr): + + rs = False + for olr, nlr in zip(oldlr, newlr): + if olr != nlr: + rs = True + break + + return rs + +def reset_Adam(optm, amsgrad=False): + + for group in optm.param_groups: + for p in group["params"]: + state = optm.state[p] + if len(state) != 0: + state["step"] = 0 + state["exp_avg"].zero_() + state["exp_avg_sq"].zero_() + if amsgrad: + state["max_exp_avg_sq"].zero_() + +def reinit_Adam(optm, amsgrad=False): + + for group in optm.param_groups: + for p in group["params"]: + optm.state[p].clear() + +def module_train(netin, module, mode=True): + + for net in netin.modules(): + if isinstance(net, module): + net.train(mode=mode) + + return netin + +def optm_step_std(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False, zero_grad_none=optm_step_zero_grad_set_none): + + if multi_gpu: + model.collect_gradients() + if scaler is None: + optm.step(closure=closure) + else: + scaler.step(optm, closure=closure) + scaler.update() + if not multi_gpu_optimizer: + optm.zero_grad(set_to_none=zero_grad_none) + if multi_gpu: + model.update_replicas() + +def optm_step_wofp16(optm, model=None, scaler=None, closure=None, multi_gpu=False, multi_gpu_optimizer=False, zero_grad_none=optm_step_zero_grad_set_none): + + if multi_gpu: + model.collect_gradients() + optm.step(closure=closure) + if not multi_gpu_optimizer: + optm.zero_grad(set_to_none=zero_grad_none) + if multi_gpu: + model.update_replicas() + +optm_step = optm_step_std if is_fp16_supported else optm_step_wofp16 diff --git a/utils/train/dss.py b/utils/train/dss.py new file mode 100644 index 0000000..bcf535b --- /dev/null +++ b/utils/train/dss.py @@ -0,0 +1,19 @@ +#encoding: utf-8 + +from random import sample + +def dynamic_sample(incd, dss_ws, dss_rm): + + rd = {} + for k, v in incd.items(): + if v in rd: + rd[v].append(k) + else: + rd[v] = [k] + incs = list(rd.keys()) + incs.sort(reverse=True) + _full_rl = [] + for v in incs: + _full_rl.extend(rd[v]) + + return _full_rl[:dss_ws] + sample(_full_rl[dss_ws:], dss_rm) if dss_rm > 0 else _full_rl[:dss_ws]