-
Notifications
You must be signed in to change notification settings - Fork 85
training_sambert
[TOC]
we recommend using Anaconda to set up your own python virtual environment.
# in case of pip install error, change the pip source may help
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
# build virtual environment
conda env create -f environment.yaml
# activate virtual environment
conda activate maas
We support mit-style data
and general data
two data structures. Make sure your data be organized by these structures below.
.
├── interval
│ ├── 500001.interval
│ ├── 500002.interval
│ ├── 500003.interval
│ ├── ...
│ └── 600010.interval
├── prosody
│ └── prosody.txt
└── wav
├── 500001.wav
├── 500002.wav
├── ...
└── 600010.wav
.
├── txt
│ └── prosody.txt
└── wav
├── 1.wav
├── 2.wav
├── ...
└── 9000.wav
For quick start: A demo dataset is available on DAMO.NLS.KAN-TTS.OpenDataset.
Modify audio config to fit your data, the demo config file: kantts/configs/audio_config_24k.yaml
Run data_process
, you need to pass a speaker name through --speaker
, since our demo dataset is unnamed, you can name it as you like (Paige, Lily... whatever).
python kantts/preprocess/data_process.py --voice_input_dir YOUR_DATA_PATH --voice_output_dir OUTPUT_DATA_FEATURE_PATH --audio_config AUDIO_CONFIG_PATH --speaker YOUR_SPEKER_NAME
Then you get the features for training.
# mit-style data, training stage
.
├── Script.xml
├── am_train.lst
├── am_valid.lst
├── badlist.txt
├── duration/
├── energy/
├── f0/
├── mel/
├── raw_duration/
├── raw_metafile.txt
├── trim_mel/
├── trim_wav/
└── wav/
# general data, training stage
.
├── Script.xml
├── am_train.lst
├── am_valid.lst
├── badlist.txt
├── energy/
├── f0/
├── mel/
├── raw_metafile.txt
├── trim_mel/
├── trim_wav/
└── wav/
Our training recipe is config driven, a default Sambert model config can be found kantts/configs/sambert_24k.yaml
, you can do some modifications on that config and create your own Sambert model :)
For general data, you can use kantts/configs/sambert_16k_MAS.yaml
as a reference. The difference between the two configs is that in kantts/configs/sambert_16k_MAS.yaml
the MAS
parameter is True.
...
postnet_ffn_inner_dim: 512
postnet_dropout: 0.1
postnet_shift: 17
postnet_lstm_units: 128
MAS: True <------ Here is the difference
...
One more thing you must notice is that in kantts/configs/sambert_24k.yaml
or kantts/configs/sambert_16k_MAS.yaml
,speaker_list
field should be modified to the speaker name of your dataset, as below
linguistic_unit:
cleaners: english_cleaners
lfeat_type_list: sy,tone,syllable_flag,word_segment,emo_category,speaker_category
speaker_list: YOUR_SPEAKER_NAME
Now you have got the sword and shield(data and model :-|), go have a try.
# The --root_dir can be multiple args for universal vococer training
CUDA_VISIBLE_DEVICES=0 python kantts/bin/train_sambert.py --model_config YOUR_MODEL_CONFIG --root_dir OUTPUT_DATA_FEATURE_PATH --stage_dir TRAINING_STAGE_PATH
If your GPU devices are enough, you can use distributed training, which is a lot faster than single GPU training. For example, assign GPU device indexes with CUDA_VISIBLE_DEVICES
system variable, --nproc_per_node
denotes the count of GPU devices.
CUDA_VISIBLE_DEVICES=0,1,2,4 python -m pytorch.distributed.launch --nproc_per_node=4 kantts/bin/train_sambert.py --model_config YOUR_MODEL_CONFIG --root_dir OUTPUT_DATA_FEATURE_PATH --stage_dir TRAINING_STAGE_PATH
--resume_path
can be used to resume training with a pre-trained model, or continue training from a previous checkpoint.
--resume_bert_path
can be used to resume training with a pre-trained bert model. (For bert models' training, see Bert turorial)
# The --root_dir can be multiple args for universal vococer training
CUDA_VISIBLE_DEVICES=0 python kantts/bin/train_sambert.py --model_config YOUR_MODEL_CONFIG --root_dir OUTPUT_DATA_FEATURE_PATH --stage_dir TRAINING_STAGE_PATH --resume_path CHECKPOINT_PATH --resume_bert_path BERT_CHECKPOINT_PATH
After training is done, your TRAIING_STAGE_PATH
looks like below
.
├── ckpt/
├── config.yaml
├── log/
└── stdout.log
Model checkpoints are stored in ckpt
directory,
./ckpt
├── checkpoint_10000.pth
├── checkpoint_12000.pth
├── checkpoint_14000.pth
├── checkpoint_16000.pth
├── checkpoint_18000.pth
├── checkpoint_2000.pth
├── checkpoint_4000.pth
├── checkpoint_6000.pth
└── checkpoint_8000.pth
Time to test your powerful model, prepare a validation sequence file.
Here we randomly pick 10 sequences from the validation dataset and use them to generate melspec.
cat OUTPUT_DATA_FEATURE_PATH/am_valid.lst | shuf -n 10 > OUTPUT_DATA_FEATURE_PATH/test_sequence.lst
Then run the command below to infer sambert.
CUDA_VISIBLE_DEVICES=0 python kantts/bin/infer_sambert.py --sentence OUTPUT_DATA_FEATURE_PATH/test_sequence.lst --ckpt YOUR_MODEL_CKPT --output_dir OUTPUT_PATH_TO_STORE_MEL
Then you'll get the predicted melspec files under OUTPUT_PATH_TO_STORE_MEL
.
XXXXXX XXXXXX
XXXXXX XXXXXX
Our implementation refers to the following repositories and papers.