From fb37f10c1fd4ff9e10a98ad3656bca9bce1b16a5 Mon Sep 17 00:00:00 2001 From: Shengqiang Li <49022799+Shengqiang-Li@users.noreply.github.com> Date: Thu, 8 Feb 2024 11:07:18 +0800 Subject: [PATCH] [examples] set num workers via argparse (#205) Co-authored-by: ShengqiangLi --- examples/aishell-3/run.sh | 3 ++- examples/baker/run.sh | 5 +++-- examples/ljspeech/run.sh | 3 ++- examples/multilingual/run.sh | 3 ++- wetts/vits/train.py | 4 ++-- wetts/vits/utils/task.py | 5 +++++ 6 files changed, 16 insertions(+), 7 deletions(-) diff --git a/examples/aishell-3/run.sh b/examples/aishell-3/run.sh index e7e025a..c884bce 100755 --- a/examples/aishell-3/run.sh +++ b/examples/aishell-3/run.sh @@ -63,7 +63,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --train_data $data/train.txt \ --val_data $data/val.txt \ --speaker_table $data/speaker.txt \ - --phone_table $data/phones.txt + --phone_table $data/phones.txt \ + --num_workers 8 fi diff --git a/examples/baker/run.sh b/examples/baker/run.sh index 1b78a72..cc98d88 100755 --- a/examples/baker/run.sh +++ b/examples/baker/run.sh @@ -14,7 +14,7 @@ config=configs/v3.json # Please download data from https://www.data-baker.com/data/index/TNtts, and # set `raw_data_dir` to your data. -raw_data_dir=/mnt/mnt-data-1/binbin.zhang/data/BZNSYP +raw_data_dir=. # path to dataset directory data=data test_audio=test_audio ckpt_step=200000 @@ -56,7 +56,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --train_data $data/train.txt \ --val_data $data/val.txt \ --speaker_table $data/speaker.txt \ - --phone_table $data/phones.txt + --phone_table $data/phones.txt \ + --num_workers 8 fi diff --git a/examples/ljspeech/run.sh b/examples/ljspeech/run.sh index 0f319be..0db62d7 100644 --- a/examples/ljspeech/run.sh +++ b/examples/ljspeech/run.sh @@ -57,7 +57,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --train_data $data/train.txt \ --val_data $data/val.txt \ --speaker_table $data/speaker.txt \ - --phone_table $data/phones.txt + --phone_table $data/phones.txt \ + --num_workers 8 fi diff --git a/examples/multilingual/run.sh b/examples/multilingual/run.sh index 3a8908b..e2b8f0f 100644 --- a/examples/multilingual/run.sh +++ b/examples/multilingual/run.sh @@ -46,7 +46,8 @@ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then --train_data $data/train.txt \ --val_data $data/val.txt \ --speaker_table $data/speaker.txt \ - --phone_table $data/phones.txt + --phone_table $data/phones.txt \ + --num_workers 8 fi if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then diff --git a/wetts/vits/train.py b/wetts/vits/train.py index ec9c553..7a1f75f 100644 --- a/wetts/vits/train.py +++ b/wetts/vits/train.py @@ -66,7 +66,7 @@ def main(): collate_fn = TextAudioSpeakerCollate() train_loader = DataLoader( train_dataset, - num_workers=8, + num_workers=hps.num_workers, shuffle=False, pin_memory=True, collate_fn=collate_fn, @@ -77,7 +77,7 @@ def main(): hps.data) eval_loader = DataLoader( eval_dataset, - num_workers=8, + num_workers=hps.num_workers, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, diff --git a/wetts/vits/utils/task.py b/wetts/vits/utils/task.py index 2459d1f..ece5423 100644 --- a/wetts/vits/utils/task.py +++ b/wetts/vits/utils/task.py @@ -181,6 +181,10 @@ def get_hparams(init=True): type=str, required=True, help="phone table") + parser.add_argument('--num_workers', + default=8, + type=int, + help='num of subprocess workers for reading') parser.add_argument( "--speaker_table", type=str, @@ -218,6 +222,7 @@ def get_hparams(init=True): hparams = HParams(**config) hparams.model_dir = model_dir + hparams.num_workers = args.num_workers return hparams