-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathinfer_file.py
178 lines (143 loc) · 8.45 KB
/
infer_file.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import os, sys
sys.path.insert(0, 'lib')
from dev.infer import infer
from dev.sample_stats import get_stats
from dev.train import train
import dev.deepxi_net as deepxi_net
import numpy as np
import tensorflow as tf
import dev.utils as utils
import argparse
from dev.utils import read_wav
from tqdm import tqdm
import dev.gain as gain
import dev.utils as utils
import dev.xi as xi
import numpy as np
import os
import scipy.io as spio
import librosa
import pickle
from scipy.io.wavfile import write as wav_write
np.set_printoptions(threshold=1e6)
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
def str2bool(s): return s.lower() in ("yes", "true", "t", "1")
def create_args():
with open('data/stats.p', 'rb') as f:
stats = pickle.load(f)
parser = argparse.ArgumentParser()
## OPTIONS (GENERAL)
parser.add_argument('--gpu', default='0', type=str, help='GPU selection')
parser.add_argument('--ver', type=str, help='Model version')
parser.add_argument('--epoch', type=int, help='Epoch to use/retrain from')
parser.add_argument('--train', default=False, type=str2bool, help='Training flag')
parser.add_argument('--infer', default=True, type=str2bool, help='Inference flag')
parser.add_argument('--verbose', default=False, type=str2bool, help='Verbose')
parser.add_argument('--model', default='ResNet', type=str, help='Model type')
## OPTIONS (TRAIN)
parser.add_argument('--cont', default=False, type=str2bool, help='Continue testing from last epoch')
parser.add_argument('--mbatch_size', default=10, type=int, help='Mini-batch size')
parser.add_argument('--sample_size', default=1000, type=int, help='Sample size')
parser.add_argument('--max_epochs', default=250, type=int, help='Maximum number of epochs')
parser.add_argument('--grad_clip', default=True, type=str2bool, help='Gradient clipping')
parser.add_argument('--out_type', default='y', type=str, help='Output type for testing')
## GAIN FUNCTION
parser.add_argument('--gain', default='mmse-lsa', type=str, help='Gain function for testing')
## PATHS
parser.add_argument('--model_path', default='model/3f/epoch-175', type=str, help='Model save path')
parser.add_argument('--set_path', default='set', type=str, help='Path to datasets')
parser.add_argument('--data_path', default='data', type=str, help='Save data path')
parser.add_argument('--test_x_path', default='set/test_noisy_speech', type=str, help='Path to the noisy speech test set')
parser.add_argument('--in_filepath', default='test.wav', type=str, help='Output path')
parser.add_argument('--out_filepath', default='out.wav', type=str, help='Output path')
## FEATURES
parser.add_argument('--min_snr', default=-10, type=int, help='Minimum trained SNR level')
parser.add_argument('--max_snr', default=20, type=int, help='Maximum trained SNR level')
parser.add_argument('--f_s', default=16000, type=int, help='Sampling frequency (Hz)')
parser.add_argument('--T_w', default=32, type=int, help='Window length (ms)')
parser.add_argument('--T_s', default=16, type=int, help='Window shift (ms)')
parser.add_argument('--nconst', default=32768.0, type=float, help='Normalisation constant (see feat.addnoisepad())')
parser.add_argument('--N_w', default=int(16000*32*0.001), type=int, help='window length (samples)')
parser.add_argument('--N_s', default=int(16000*16*0.001), type=int, help='window shift (samples)')
parser.add_argument('--NFFT', default=int(pow(2, np.ceil(np.log2(int(16000*32*0.001))))), type=float, help='number of DFT components')
parser.add_argument('--stats', default=stats)
## NETWORK PARAMETERS
parser.add_argument('--d_in', default=257, type=int, help='Input dimensionality')
parser.add_argument('--d_out', default=257, type=int, help='Ouput dimensionality')
parser.add_argument('--d_model', default=256, type=int, help='Model dimensions')
parser.add_argument('--n_blocks', default=40, type=int, help='Number of blocks')
parser.add_argument('--d_f', default=64, type=int, help='Number of filters')
parser.add_argument('--k_size', default=3, type=int, help='Kernel size')
parser.add_argument('--max_d_rate', default=16, type=int, help='Maximum dilation rate')
parser.add_argument('--norm_type', default='FrameLayerNorm', type=str, help='Normalisation type')
parser.add_argument('--net_height', default=[4], type=list, help='RDL block height')
args = parser.parse_args()
return args
def build_restore_model(model_path, args, config):
## MAKE DEEP XI NNET
print('Start: Build and Restore model!')
sess = tf.Session(config=config)
net = deepxi_net.deepxi_net(args)
net.saver.restore(sess, args.model_path)
print('Done: Build and Restore model!')
return net, sess
# def infer(filename ,net, sess):
# print('Start infer file: {}'.format(filename))
# #(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
# (wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
# wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
# print(max(wav), min(wav), np.mean(wav))
# print(wav.shape)
# input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
# xi_bar_hat = sess.run(
# net.infer_output, feed_dict={net.input_ph: input_feat[0],
# net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
# xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
# #file_name = filename.split('/')[-1].split('.')
# y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
# y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
# net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
# if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
# if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
# y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
# out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
# wav_write(out_filepath, args.f_s, y)
# print('Infer out file: {} done'.format(out_filepath))
# return out_filepath
def get_model():
args = create_args()
## GPU CONFIGURATION
config = tf.ConfigProto()
config.allow_soft_placement=True
config.gpu_options.allow_growth=True
config.log_device_placement=False
net, sess = build_restore_model(args.model_path, args, config)
def infer(filename):
print('Start infer file: {}'.format(filename))
#(wav, _) = read_wav(args.in_filepath) # read wav from given file path.
(wav, _) = librosa.load(filename, 16000, mono=True) # read wav from given file path.
wav = np.asarray(np.multiply(wav, 32768.0), dtype=np.int16)
print(max(wav), min(wav), np.mean(wav))
print(wav.shape)
input_feat = sess.run(net.infer_feat, feed_dict={net.s_ph: [wav], net.s_len_ph: [len(wav)]}) # sample of training set.
xi_bar_hat = sess.run(
net.infer_output, feed_dict={net.input_ph: input_feat[0],
net.nframes_ph: input_feat[1], net.training_ph: False}) # output of network.
xi_hat = xi.xi_hat(xi_bar_hat, args.stats['mu_hat'], args.stats['sigma_hat'])
#file_name = filename.split('/')[-1].split('.')
y_MAG = np.multiply(input_feat[0], gain.gfunc(xi_hat, xi_hat+1, gtype=args.gain))
y = np.squeeze(sess.run(net.y, feed_dict={net.y_MAG_ph: y_MAG,
net.x_PHA_ph: input_feat[2], net.nframes_ph: input_feat[1], net.training_ph: False})) # output of network.
if np.isnan(y).any(): ValueError('NaN values found in enhanced speech.')
if np.isinf(y).any(): ValueError('Inf values found in enhanced speech.')
y = np.asarray(np.multiply(y, 32768.0), dtype=np.int16)
out_filepath = filename.replace('.'+filename.split('.')[-1], '_pred.wav')
wav_write(out_filepath, args.f_s, y)
print('Infer out file: {} done'.format(out_filepath))
return out_filepath
return infer
if __name__ == '__main__':
infer = get_model()
infer('Toàn cảnh phòng chống dịch COVID-19 ngày 18-4-2020 - VTV24.mp3')
#infer('set/test_noisy_speech/198853.wav')