This repository has been archived by the owner on Sep 11, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 83
/
Copy pathtrain.py
262 lines (227 loc) · 9.55 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import time
from pathlib import Path
from collections import defaultdict
import numpy as np
from matplotlib import pyplot as plt
import paddle
from paddle import distributed as dist
from paddle.io import DataLoader, DistributedBatchSampler
from parakeet.data import dataset
from parakeet.training.cli import default_argument_parser
from parakeet.training.experiment import ExperimentBase
from parakeet.utils import display, mp_tools
from parakeet.models.tacotron2 import Tacotron2, Tacotron2Loss
from config import get_cfg_defaults
from aishell3 import AiShell3, collate_aishell3_examples
class Experiment(ExperimentBase):
def compute_losses(self, inputs, outputs):
texts, tones, mel_targets, utterance_embeds, text_lens, output_lens, stop_tokens = inputs
mel_outputs = outputs["mel_output"]
mel_outputs_postnet = outputs["mel_outputs_postnet"]
alignments = outputs["alignments"]
losses = self.criterion(mel_outputs, mel_outputs_postnet, mel_targets,
alignments, output_lens, text_lens)
return losses
def train_batch(self):
start = time.time()
batch = self.read_batch()
data_loader_time = time.time() - start
self.optimizer.clear_grad()
self.model.train()
texts, tones, mels, utterance_embeds, text_lens, output_lens, stop_tokens = batch
outputs = self.model(
texts,
text_lens,
mels,
output_lens,
tones=tones,
global_condition=utterance_embeds)
losses = self.compute_losses(batch, outputs)
loss = losses["loss"]
loss.backward()
self.optimizer.step()
iteration_time = time.time() - start
losses_np = {k: float(v) for k, v in losses.items()}
# logging
msg = "Rank: {}, ".format(dist.get_rank())
msg += "step: {}, ".format(self.iteration)
msg += "time: {:>.3f}s/{:>.3f}s, ".format(data_loader_time,
iteration_time)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in losses_np.items())
self.logger.info(msg)
if dist.get_rank() == 0:
for key, value in losses_np.items():
self.visualizer.add_scalar(f"train_loss/{key}", value,
self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def valid(self):
valid_losses = defaultdict(list)
for i, batch in enumerate(self.valid_loader):
texts, tones, mels, utterance_embeds, text_lens, output_lens, stop_tokens = batch
outputs = self.model(
texts,
text_lens,
mels,
output_lens,
tones=tones,
global_condition=utterance_embeds)
losses = self.compute_losses(batch, outputs)
for key, value in losses.items():
valid_losses[key].append(float(value))
attention_weights = outputs["alignments"]
self.visualizer.add_figure(
f"valid_sentence_{i}_alignments",
display.plot_alignment(attention_weights[0].numpy().T),
self.iteration)
self.visualizer.add_figure(
f"valid_sentence_{i}_target_spectrogram",
display.plot_spectrogram(mels[0].numpy().T), self.iteration)
mel_pred = outputs['mel_outputs_postnet']
self.visualizer.add_figure(
f"valid_sentence_{i}_predicted_spectrogram",
display.plot_spectrogram(mel_pred[0].numpy().T), self.iteration)
# write visual log
valid_losses = {k: np.mean(v) for k, v in valid_losses.items()}
# logging
msg = "Valid: "
msg += "step: {}, ".format(self.iteration)
msg += ', '.join('{}: {:>.6f}'.format(k, v)
for k, v in valid_losses.items())
self.logger.info(msg)
for key, value in valid_losses.items():
self.visualizer.add_scalar(f"valid/{key}", value, self.iteration)
@mp_tools.rank_zero_only
@paddle.no_grad()
def eval(self):
"""Evaluation of Tacotron2 in autoregressive manner."""
self.model.eval()
mel_dir = Path(self.output_dir / ("eval_{}".format(self.iteration)))
mel_dir.mkdir(parents=True, exist_ok=True)
for i, batch in enumerate(self.test_loader):
texts, tones, mels, utterance_embeds, *_ = batch
outputs = self.model.infer(
texts, tones=tones, global_condition=utterance_embeds)
display.plot_alignment(outputs["alignments"][0].numpy().T)
plt.savefig(mel_dir / f"sentence_{i}.png")
plt.close()
np.save(mel_dir / f"sentence_{i}",
outputs["mel_outputs_postnet"][0].numpy().T)
print(f"sentence_{i}")
def setup_model(self):
config = self.config
model = Tacotron2(
vocab_size=config.model.vocab_size,
n_tones=config.model.n_tones,
d_mels=config.data.d_mels,
d_encoder=config.model.d_encoder,
encoder_conv_layers=config.model.encoder_conv_layers,
encoder_kernel_size=config.model.encoder_kernel_size,
d_prenet=config.model.d_prenet,
d_attention_rnn=config.model.d_attention_rnn,
d_decoder_rnn=config.model.d_decoder_rnn,
attention_filters=config.model.attention_filters,
attention_kernel_size=config.model.attention_kernel_size,
d_attention=config.model.d_attention,
d_postnet=config.model.d_postnet,
postnet_kernel_size=config.model.postnet_kernel_size,
postnet_conv_layers=config.model.postnet_conv_layers,
reduction_factor=config.model.reduction_factor,
p_encoder_dropout=config.model.p_encoder_dropout,
p_prenet_dropout=config.model.p_prenet_dropout,
p_attention_dropout=config.model.p_attention_dropout,
p_decoder_dropout=config.model.p_decoder_dropout,
p_postnet_dropout=config.model.p_postnet_dropout,
d_global_condition=config.model.d_global_condition,
use_stop_token=config.model.use_stop_token, )
if self.parallel:
model = paddle.DataParallel(model)
grad_clip = paddle.nn.ClipGradByGlobalNorm(
config.training.grad_clip_thresh)
optimizer = paddle.optimizer.Adam(
learning_rate=config.training.lr,
parameters=model.parameters(),
weight_decay=paddle.regularizer.L2Decay(
config.training.weight_decay),
grad_clip=grad_clip)
criterion = Tacotron2Loss(
use_stop_token_loss=config.model.use_stop_token,
use_guided_attention_loss=config.model.use_guided_attention_loss,
sigma=config.model.guided_attention_loss_sigma)
self.model = model
self.optimizer = optimizer
self.criterion = criterion
def setup_dataloader(self):
args = self.args
config = self.config
ljspeech_dataset = AiShell3(args.data)
valid_set, train_set = dataset.split(ljspeech_dataset,
config.data.valid_size)
batch_fn = collate_aishell3_examples
if not self.parallel:
self.train_loader = DataLoader(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True,
collate_fn=batch_fn)
else:
sampler = DistributedBatchSampler(
train_set,
batch_size=config.data.batch_size,
shuffle=True,
drop_last=True)
self.train_loader = DataLoader(
train_set, batch_sampler=sampler, collate_fn=batch_fn)
self.valid_loader = DataLoader(
valid_set,
batch_size=config.data.batch_size,
shuffle=False,
drop_last=False,
collate_fn=batch_fn)
self.test_loader = DataLoader(
valid_set,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=batch_fn)
def main_sp(config, args):
exp = Experiment(config, args)
exp.setup()
exp.resume_or_load()
if not args.test:
exp.run()
else:
exp.eval()
def main(config, args):
if args.nprocs > 1 and args.device == "gpu":
dist.spawn(main_sp, args=(config, args), nprocs=args.nprocs)
else:
main_sp(config, args)
if __name__ == "__main__":
config = get_cfg_defaults()
parser = default_argument_parser()
parser.add_argument("--test", action="store_true")
args = parser.parse_args()
if args.config:
config.merge_from_file(args.config)
if args.opts:
config.merge_from_list(args.opts)
config.freeze()
print(config)
print(args)
main(config, args)