-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestDataset.py
40 lines (32 loc) · 1.14 KB
/
testDataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
import glob
import numpy as np
import re
import scipy.io.wavfile
import torch
import torch.utils.data as data
import torchvision.transforms as transforms
from sklearn.preprocessing import LabelEncoder
class TestDataset(data.Dataset):
"""
Pytorch dataset for instruments
args:
root: root dir containing an audio directory with wav files.
transform (callable, optional): A function/transform that takes in
a sample and returns a transformed version.
"""
def __init__(self, root, transform=None):
assert(isinstance(root, str))
self.root = root
self.filenames = glob.glob(os.path.join(root, "audio/*.wav"))
self.transform = transform
def __len__(self):
return len(self.filenames)
def __getitem__(self, index):
name = self.filenames[index]
_, sample = scipy.io.wavfile.read(name) # load audio
no_folders = re.compile('\/').split(name)[-1]
index = re.compile('\.').split(no_folders)[0]
if self.transform is not None:
sample = self.transform(sample)
return sample, index