-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain_trivial.py
43 lines (33 loc) · 1.46 KB
/
main_trivial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import torch
from torch import nn
from torch.utils.data.dataset import random_split
from torchaudio import datasets
from trivial.train import fit, get_model
from utils.data import get_data, WrappedDataLoader, preprocess
dev = torch.device(
"cuda") if torch.cuda.is_available() else torch.device("cpu")
dataset = datasets.VCTK_092(root="data", download=False)
ds_size = len(dataset)
train_i = int(0.001 * ds_size)
val_i = int(0.002 * ds_size) - train_i
test_i = int(0.003 * ds_size) - train_i - val_i
other = ds_size - train_i - val_i - test_i
train_ds, val_ds, test_ds, other = random_split(dataset, lengths=[train_i, val_i, test_i, other],
generator=torch.Generator().manual_seed(42))
SAMPLE_RATE = 48000
N_FFT = SAMPLE_RATE * 64 // 1000 + 4
HOP_LENGTH = SAMPLE_RATE * 16 // 1000 + 4
zero_q = 0.9
zero_f = 0.2
one_q = 0.9
one_f = 0.5
bs = 8
train_dl, val_dl, test_dl = get_data(train_ds, val_ds, test_ds, bs)
train_dl = WrappedDataLoader(train_dl, preprocess, HOP_LENGTH, dev, zero_q=zero_q, zero_f=zero_f, one_q=one_q,
one_f=one_f)
val_dl = WrappedDataLoader(val_dl, preprocess, HOP_LENGTH, dev, zero_q=zero_q, zero_f=zero_f, one_q=one_q, one_f=one_f)
test_dl = WrappedDataLoader(test_dl, preprocess, HOP_LENGTH, dev, zero_q=zero_q, zero_f=zero_f, one_q=one_q,
one_f=one_f)
loss_func = nn.MSELoss()
model, opt = get_model(dev)
fit(10, model, loss_func, opt, train_dl, val_dl)