-
Notifications
You must be signed in to change notification settings - Fork 23
/
Copy pathmodel.py
389 lines (340 loc) · 15.5 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
import math
from collections import OrderedDict
import numpy as np
import scipy.signal
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parameter import Parameter
class stft(nn.Module):
def __init__(self, nfft=1024, hop_length=512, window="hanning"):
super(stft, self).__init__()
assert nfft % 2 == 0
self.hop_length = hop_length
self.n_freq = n_freq = nfft//2 + 1
self.real_kernels, self.imag_kernels = _get_stft_kernels(nfft, window)
self.real_kernels_size = self.real_kernels.size()
self.conv = nn.Sequential(
nn.Conv2d(1, self.real_kernels_size[0], kernel_size=(self.real_kernels_size[2], self.real_kernels_size[3]), stride=(self.hop_length)),
nn.BatchNorm2d(self.real_kernels_size[0]),
nn.Hardtanh(0, 20, inplace=True)
)
def forward(self, sample):
sample = sample.unsqueeze(1)
magn = self.conv(sample)
magn = magn.permute(0, 2, 1, 3)
return magn
def _get_stft_kernels(nfft, window):
nfft = int(nfft)
assert nfft % 2 == 0
def kernel_fn(freq, time):
return np.exp(-1j * (2 * np.pi * time * freq) / float(nfft))
kernels = np.fromfunction(kernel_fn, (nfft//2+1, nfft), dtype=np.float64)
if window == "hanning":
win_cof = scipy.signal.get_window("hanning", nfft)[np.newaxis, :]
else:
win_cof = np.ones((1, nfft), dtype=np.float64)
kernels = kernels[:, np.newaxis, np.newaxis, :] * win_cof
real_kernels = nn.Parameter(torch.from_numpy(np.real(kernels)).float())
imag_kernels = nn.Parameter(torch.from_numpy(np.imag(kernels)).float())
return real_kernels, imag_kernels
class PCEN(nn.Module):
def __init__(self):
super(PCEN,self).__init__()
'''
initialising the layer param with the best parametrised values i searched on web (scipy using theese values)
alpha = 0.98
delta=2
r=0.5
'''
self.log_alpha = Parameter(torch.FloatTensor([0.98]))
self.log_delta = Parameter(torch.FloatTensor([2]))
self.log_r = Parameter(torch.FloatTensor([0.5]))
self.eps = 0.000001
def forward(self,x,smoother):
# t = x.size(0)
# t = x.size(1)
# t = x.size(2)
# t = x.size(3)
# alpha = self.log_alpha.exp().expand_as(x)
# delta = self.log_delta.exp().expand_as(x)
# r = self.log_r.exp().expand_as(x)
# print 'updated values are alpha={} , delta={} , r={}'.format(self.log_alpha,self.log_delta,self.log_r)
smooth = (self.eps + smoother) ** (-(self.log_alpha))
# pcen = (x/(self.eps + smoother)**alpha + delta)**r - delta**r
pcen = (x * smooth + self.log_delta)**self.log_r - self.log_delta**self.log_r
return pcen
class InferenceBatchSoftmax(nn.Module):
def forward(self, input_):
if not self.training:
return F.softmax(input_, dim=-1)
else:
return input_
class Cov1dBlock(nn.Module):
def __init__(self, input_size, output_size, kernal_size, stride, drop_out_prob=-1.0, dilation=1, padding='same',bn=True,activationUse=True):
super(Cov1dBlock, self).__init__()
self.input_size = input_size
self.output_size = output_size
self.kernal_size = kernal_size
self.stride = stride
self.dilation = dilation
self.drop_out_prob = drop_out_prob
self.activationUse = activationUse
self.padding = kernal_size[0] #(kernal_size[0]-stride)//2 if kernal_size[0]!=1 else
'''using the below code for the padding calculation'''
input_rows = input_size
filter_rows = kernal_size[0]
effective_filter_size_rows = (filter_rows - 1) * dilation + 1
out_rows = (input_rows + stride - 1) // stride
self.rows_odd = False
if padding=='same':
self.padding_needed =max(0, (out_rows - 1) * stride + effective_filter_size_rows -
input_rows)
self.padding_rows = max(0, (out_rows - 1) * stride +
(filter_rows - 1) * dilation + 1 - input_rows)
self.rows_odd = (self.padding_rows % 2 != 0)
self.addPaddings = self.padding_rows
elif padding=='half':
self.addPaddings = kernal_size[0]
elif padding == 'invalid':
self.addPaddings = 0
self.paddingAdded = nn.ReflectionPad1d(self.addPaddings//2) if self.addPaddings >0 else None
self.conv1 = nn.Sequential(
nn.Conv1d(in_channels=input_size, out_channels=output_size, kernel_size=kernal_size,
stride=stride, padding=(0), dilation=dilation),
# nn.ReLU6()
)
self.batchNorm = nn.BatchNorm1d(num_features=output_size,momentum=0.90,eps=0.001) if bn else None
# self.activation = nn.Hardtanh(min_val=0,max_val=20) if activationUse else None
# self.activation = nn.ReLU6() if activationUse else None
# self.activation = True if activationUse else False
self.drop_out_layer = nn.Dropout(drop_out_prob) if self.drop_out_prob != -1 else None
# self.activation = nn.ReLU6() if activationUse else None
# torch.nn.init.xavier_normal(self.conv1._modules['0'].weight)
def forward(self, xs, hid=None):
if self.paddingAdded is not None:
xs = self.paddingAdded(xs)
fusedLayer = getattr(self,'fusedLayer',None)
if fusedLayer is not None:
output = fusedLayer(xs)
else:
output = self.conv1(xs)
if self.batchNorm is not None:
output = self.batchNorm(output)
if self.activationUse:
output = torch.clamp(input=output,min=0,max=20)
# output = self.activation(output)
if self.training:
if self.drop_out_layer is not None:
output = self.drop_out_layer(output)
return output
class WaveToLetter(nn.Module):
def __init__(self,sample_rate,window_size, labels="abc",audio_conf=None,mixed_precision=False):
super(WaveToLetter, self).__init__()
# model metadata needed for serialization/deserialization
if audio_conf is None:
audio_conf = {}
self._version = '0.0.1'
self._audio_conf = audio_conf or {}
self._labels = labels
self._sample_rate=sample_rate
self._window_size=window_size
self.mixed_precision=mixed_precision
nfft = (self._sample_rate * self._window_size)
input_size = 1+int((nfft/2))
hop_length = sample_rate * self._audio_conf.get("window_stride", 0.01)
# self.pcen = PCEN()
self.frontEnd = stft(hop_length=int(hop_length), nfft=int(nfft))
conv1 = Cov1dBlock(input_size=input_size,output_size=256,kernal_size=(11,),stride=2,dilation=1,drop_out_prob=0.2,padding='same')
conv2s = []
conv2s.append(('conv1d_{}'.format(0),conv1))
inputSize = 256
for idx in range(15):
layergroup = idx//3
if (layergroup) == 0:
convTemp = Cov1dBlock(input_size=inputSize,output_size=256,kernal_size=(11,),stride=1,dilation=1,drop_out_prob=0.2,padding='same')
conv2s.append(('conv1d_{}'.format(idx+1),convTemp))
inputSize = 256
elif (layergroup) == 1:
convTemp = Cov1dBlock(input_size=inputSize, output_size=384, kernal_size=(13,), stride=1, dilation=1,
drop_out_prob=0.2)
conv2s.append(('conv1d_{}'.format(idx + 1), convTemp))
inputSize=384
elif (layergroup) ==2:
convTemp = Cov1dBlock(input_size=inputSize, output_size=512, kernal_size=(17,), stride=1, dilation=1,
drop_out_prob=0.2)
conv2s.append(('conv1d_{}'.format(idx + 1), convTemp))
inputSize = 512
elif (layergroup) ==3:
convTemp = Cov1dBlock(input_size=inputSize, output_size=640, kernal_size=(21,), stride=1, dilation=1,
drop_out_prob=0.3)
conv2s.append(('conv1d_{}'.format(idx + 1), convTemp))
inputSize = 640
elif (layergroup) ==4:
convTemp = Cov1dBlock(input_size=inputSize, output_size=768, kernal_size=(25,), stride=1, dilation=1,
drop_out_prob=0.3)
conv2s.append(('conv1d_{}'.format(idx + 1), convTemp))
inputSize = 768
conv1 = Cov1dBlock(input_size=inputSize, output_size=896, kernal_size=(29,), stride=1, dilation=2, drop_out_prob=0.4)
conv2s.append(('conv1d_{}'.format(16), conv1))
conv1 = Cov1dBlock(input_size=896, output_size=1024, kernal_size=(1,), stride=1, dilation=1, drop_out_prob=0.4)
conv2s.append(('conv1d_{}'.format(17), conv1))
conv1 = Cov1dBlock(input_size=1024, output_size=len(self._labels), kernal_size=(1,),stride=1,bn=False,activationUse=False)
conv2s.append(('conv1d_{}'.format(18), conv1))
self.conv1ds = nn.Sequential(OrderedDict(conv2s))
self.inference_softmax = InferenceBatchSoftmax()
def forward(self, x):
x = self.frontEnd(x)
x = x.squeeze(1)
x = self.conv1ds(x)
x = x.transpose(1,2)
x = self.inference_softmax(x)
return x
@classmethod
def load_model(cls, path, cuda=False):
package = torch.load(path, map_location=lambda storage, loc: storage)
model = cls(labels=package['labels'], audio_conf=package['audio_conf'],sample_rate=package["sample_rate"]
,window_size=package["window_size"],mixed_precision=package.get('mixed_precision',False))
# the blacklist parameters are params that were previous erroneously saved by the model
# care should be taken in future versions that if batch_norm on the first rnn is required
# that it be named something else
blacklist = ['rnns.0.batch_norm.module.weight', 'rnns.0.batch_norm.module.bias',
'rnns.0.batch_norm.module.running_mean', 'rnns.0.batch_norm.module.running_var']
for x in blacklist:
if x in package['state_dict']:
del package['state_dict'][x]
# keyNames = package['state_dict'].keys()
#
# for keyname in keyNames:
# if "num_batches_tracked" in keyname:
# del package['state_dict'][keyname]
model.load_state_dict(package['state_dict'])
# for x in model.rnns:
# x.flatten_parameters()
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model
@classmethod
def load_model_package(cls, package, cuda=False):
model = cls(labels=package['labels'], audio_conf=package['audio_conf'],sample_rate=package.get("sample_rate",16000)
,window_size=package.get("window_size",.02),mixed_precision=package.get('mixed_precision',False))
model.load_state_dict(package['state_dict'])
if cuda:
model = torch.nn.DataParallel(model).cuda()
return model
@staticmethod
def serialize(model, optimizer=None, epoch=None, iteration=None, loss_results=None,
cer_results=None, wer_results=None, avg_loss=None, meta=None):
model_is_cuda = next(model.parameters()).is_cuda
# model = model.module if model_is_cuda else model
package = {
'version': model._version,
'audio_conf': model._audio_conf,
'labels': model._labels,
'state_dict': model.state_dict(),
'mixed_precision': model.mixed_precision,
'sample_rate': model._sample_rate,
'window_size': model._window_size
}
if optimizer is not None:
package['optim_dict'] = optimizer.state_dict()
if avg_loss is not None:
package['avg_loss'] = avg_loss
if epoch is not None:
package['epoch'] = epoch + 1 # increment for readability
if iteration is not None:
package['iteration'] = iteration
if loss_results is not None:
package['loss_results'] = loss_results
package['cer_results'] = cer_results
package['wer_results'] = wer_results
if meta is not None:
package['meta'] = meta
return package
@staticmethod
def get_labels(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._labels if model_is_cuda else model._labels
@staticmethod
def get_sample_rate(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._sample_rate if model_is_cuda else model._sample_rate
@staticmethod
def get_window_size(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._window_size if model_is_cuda else model._window_size
@staticmethod
def setAudioConfKey(model,key,value):
model._audio_conf[key] = value
return model
@staticmethod
def get_param_size(model):
params = 0
for p in model.parameters():
tmp = 1
for x in p.size():
tmp *= x
params += tmp
return params
@staticmethod
def get_audio_conf(model):
model_is_cuda = next(model.parameters()).is_cuda
return model.module._audio_conf if model_is_cuda else model._audio_conf
@staticmethod
def get_meta(model):
model_is_cuda = next(model.parameters()).is_cuda
m = model.module if model_is_cuda else model
meta = {
"version": m._version
}
return meta
def fuse_model(self):
for m in self.modules():
if type(m) == Cov1dBlock:
torch.quantization.fuse_modules(m, [ 'conv1','batchNorm'], inplace=True)
# if type(m) == InvertedResidual:
# for idx in range(len(m.conv)):
# if type(m.conv[idx]) == nn.Conv2d:
# torch.quantization.fuse_modules(m.conv, [str(idx), str(idx + 1)], inplace=True)
def convertTensorType(self,dtypeToUse=torch.float16):
module = self._modules
layermodules = module['conv1ds']
# for layermodule in layermodules:
convLayerPrefix = 'conv1d_{}'
for i in range(0,19):
convLayer = getattr(layermodules,convLayerPrefix.format(i))
modulesconv = convLayer._modules
if 'batchNorm' in modulesconv:
fusedLayer = self.fuse_conv_and_bn(modulesconv['conv1']._modules['0'],modulesconv['batchNorm'])
setattr(convLayer,'fusedLayer',fusedLayer)
return
def fuse_conv_and_bn(self,conv, bn):
#
# init
fusedconv = torch.nn.Conv1d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=0,
dilation=conv.dilation,
bias=True
)
#
# prepare filters
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_bn = torch.diag(bn.weight.div(torch.sqrt(bn.eps + bn.running_var)))
fusedconv.weight.copy_(torch.mm(w_bn, w_conv).view(fusedconv.weight.size()))
#
# prepare spatial bias
if conv.bias is not None:
b_conv = conv.bias
else:
b_conv = torch.zeros(conv.weight.size(0))
b_bn = bn.bias - bn.weight.mul(bn.running_mean).div(torch.sqrt(bn.running_var + bn.eps))
fusedconv.bias.copy_(b_conv + b_bn)
#
# we're done
return fusedconv
if __name__ == '__main__':
pass