-
Notifications
You must be signed in to change notification settings - Fork 46
/
Copy pathwavTranscriber.py
90 lines (73 loc) · 2.71 KB
/
wavTranscriber.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import glob
import webrtcvad
import logging
import wavSplit
from stt import Model
from timeit import default_timer as timer
'''
Load the pre-trained model into the memory
@param models: Output Grapgh Protocol Buffer file
@param scorer: Scorer file
@Retval
Returns a list [STT Object, Model Load Time, Scorer Load Time]
'''
def load_model(models, scorer):
model_load_start = timer()
ds = Model(models)
model_load_end = timer() - model_load_start
logging.debug("Loaded model in %0.3fs." % (model_load_end))
scorer_load_start = timer()
ds.enableExternalScorer(scorer)
scorer_load_end = timer() - scorer_load_start
logging.debug('Loaded external scorer in %0.3fs.' % (scorer_load_end))
return [ds, model_load_end, scorer_load_end]
'''
Run Inference on input audio file
@param ds: Deepspeech object
@param audio: Input audio for running inference on
@param fs: Sample rate of the input audio file
@Retval:
Returns a list [Inference, Inference Time, Audio Length]
'''
def stt(ds, audio, fs):
inference_time = 0.0
audio_length = len(audio) * (1 / fs)
# Run Deepspeech
logging.debug('Running inference...')
inference_start = timer()
output = ds.stt(audio)
inference_end = timer() - inference_start
inference_time += inference_end
logging.debug('Inference took %0.3fs for %0.3fs audio file.' % (inference_end, audio_length))
return [output, inference_time]
'''
Resolve directory path for the models and fetch each of them.
@param dirName: Path to the directory containing pre-trained models
@Retval:
Retunns a tuple containing each of the model files (pb, scorer)
'''
def resolve_models(dirName):
pb = glob.glob(dirName + "/*.tflite")[0]
logging.debug("Found Model: %s" % pb)
scorer = glob.glob(dirName + "/*.scorer")[0]
logging.debug("Found scorer: %s" % scorer)
return pb, scorer
'''
Generate VAD segments. Filters out non-voiced audio frames.
@param waveFile: Input wav file to run VAD on.0
@Retval:
Returns tuple of
segments: a bytearray of multiple smaller audio frames
(The longer audio split into mutiple smaller one's)
sample_rate: Sample rate of the input audio file
audio_length: Duraton of the input audio file
'''
def vad_segment_generator(wavFile, aggressiveness):
logging.debug("Caught the wav file @: %s" % (wavFile))
audio, sample_rate, audio_length = wavSplit.read_wave(wavFile)
assert sample_rate == 16000, "Only 16000Hz input WAV files are supported for now!"
vad = webrtcvad.Vad(int(aggressiveness))
frames = wavSplit.frame_generator(30, audio, sample_rate)
frames = list(frames)
segments = wavSplit.vad_collector(sample_rate, 30, 300, vad, frames)
return segments, sample_rate, audio_length