forked from ibab/tensorflow-wavenet
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathgenerate.py
executable file
·397 lines (320 loc) · 12.8 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
from __future__ import division
from __future__ import print_function
import argparse
from datetime import datetime
import json
import os
import librosa
import numpy as np
import tensorflow as tf
import midi
from wavenet import WaveNetModel, MidiMapper, mu_law_decode, mu_law_encode, audio_reader
TEMPERATURE = 1.0
LOGDIR = './logdir'
WAVENET_PARAMS = './wavenet_params.json'
SAVE_EVERY = None
SILENCE_THRESHOLD = 0.1
def get_args():
def _ensure_positive_float(f):
"""Ensure argument is a positive float."""
if float(f) < 0:
raise argparse.ArgumentTypeError(
'Argument must be greater than zero')
return float(f)
parser = argparse.ArgumentParser(description = 'WaveNet generation script')
parser.add_argument('--checkpoint',
type = str,
help = 'Which model checkpoint to generate from')
parser.add_argument('--samples',
type = int,
default = None,
help = 'How many waveform samples to generate')
parser.add_argument('--temperature',
type = _ensure_positive_float,
default = TEMPERATURE,
help = 'Sampling temperature')
parser.add_argument('--logdir',
type = str,
default = LOGDIR,
help = 'Directory in which to store the logging information for TensorBoard.')
parser.add_argument('--wavenet-params',
type = str,
default = WAVENET_PARAMS,
help = 'JSON file with the network parameters')
parser.add_argument('--wav-out-path',
type = str,
default = None,
help = 'Path to output wav file')
parser.add_argument('--save-every',
type = int,
default = SAVE_EVERY,
help = 'How many samples before saving in-progress wav')
parser.add_argument('--fast-generation',
type = bool,
default = True,
help = 'Use fast generation')
parser.add_argument('--wav-seed',
type = str,
default = None,
help = 'The wav file to start generation from')
# GC params
parser.add_argument('--gc-channels',
type = int,
default = None,
help = 'Number of global condition channels. Default: None. Expecting: int')
parser.add_argument('--gc-cardinality',
type = int,
default = None,
help = 'Number of categories upon which we globally condition.')
parser.add_argument('--gc-id',
type = int,
default = None,
help = 'ID of category to generate, if globally conditioned.')
# LC params
parser.add_argument('--initial-lc-channels',
type = int,
default = None,
help = "Number of inital local conditioning channels output by the upsampler. Default: None. Expecting: int")
parser.add_argument('--lc-channels',
type = int,
default = None,
help = "Number of local conditioning channels used by the network. Default: None. Expecting: int")
parser.add_argument('--lc-fileformat',
type = str,
default = None,
help = "Extension of files being used for local conditioning. Default: None. Expecting: string")
parser.add_argument('--lc-filepath',
type = str,
default = None,
help = "Path to the file to be used for local condition based generation. Default: None. Expecting: string.")
parser.add_argument('--sample-rate',
type = int,
default = 16000,
help = "Default sample rate of the wav file. Used for properly upsampling the LC file. Default: 16000. Expecting: int.")
args = parser.parse_args()
if args.gc_channels is not None:
if args.gc_cardinality is None:
raise ValueError("Globally conditioning but gc-cardinality not specified.")
if args.gc_id is None:
raise ValueError("Globally conditioning enalbed but not GC ID specified.")
if args.lc_channels is not None:
if args.lc_fileformat is None:
raise ValueError("Local conditioning enabled with channels but no file format specified.")
if args.lc_filepath is None:
raise ValueError("No local conditioning file provided in the LC filepath")
if args.lc_channels and args.samples:
print("WARNING: LC enabled and number of samples to be generated also given.\n",
"In this case, the sample number will be ignored. Total number of samples",
"generated will depend on the LC file.")
return args
def write_wav(waveform, sample_rate, filename):
y = np.array(waveform)
librosa.output.write_wav(filename, y, sample_rate)
print('Updated wav file at {}'.format(filename))
def create_seed(filename,
sample_rate,
quantization_channels,
window_size,
silence_threshold = SILENCE_THRESHOLD):
audio, _ = librosa.load(filename, sr = sample_rate, mono = True)
audio = audio_reader.trim_silence(audio, silence_threshold)
quantized = mu_law_encode(audio, quantization_channels)
cut_index = tf.cond(tf.size(quantized) < tf.constant(window_size),
lambda: tf.size(quantized),
lambda: tf.constant(window_size))
return quantized[:cut_index]
def get_generation_length_from_midi(sample_rate, midi_filepath):
'''Takes in a sample rate and a path to a MIDI file and
returns the lenght of generation WAV file in samples and microseconds'''
pattern = midi.read_midifile(midi_filepath)
resolution = pattern.resolution
track = pattern[0]
total_microseconds = 0
curr_tempo = 500000
for i in range(0, len(track) - 1):
curr_event = track[i]
if curr_event.name == midi.SetTempoEvent.name:
tempo_binary = (format(curr_event.data[0], '08b')+
format(curr_event.data[1], '08b')+
format(curr_event.data[2], '08b'))
curr_tempo = int(tempo_binary, 2)
elif curr_event.name == midi.EndOfTrackEvent.name:
break
else:
total_microseconds += ((curr_tempo * curr_event.tick) / resolution)
samples_to_generate = total_microseconds * sample_rate / int(10e6)
return samples_to_generate, total_microseconds
def main():
args = get_args()
started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
logdir = os.path.join(args.logdir, 'generate', started_datestring)
with open(args.wavenet_params, 'r') as config_file:
wavenet_params = json.load(config_file)
quantization_channels = wavenet_params['quantization_channels']
sess = tf.Session()
net = WaveNetModel(
batch_size = 1,
dilations = wavenet_params['dilations'],
filter_width = wavenet_params['filter_width'],
residual_channels = wavenet_params['residual_channels'],
dilation_channels = wavenet_params['dilation_channels'],
quantization_channels = wavenet_params['quantization_channels'],
skip_channels = wavenet_params['skip_channels'],
use_biases = wavenet_params['use_biases'],
scalar_input = wavenet_params['scalar_input'],
initial_filter_width = wavenet_params['initial_filter_width'],
gc_channels = args.gc_channels,
gc_cardinality = args.gc_cardinality,
initial_lc_channels = args.initial_lc_channels,
lc_channels = args.lc_channels)
# first set bool flags for conditioned generation
gc_enabled = args.gc_channels is not None
lc_enabled = args.lc_channels is not None
# this is a placeholder for the final output
samples = tf.placeholder(tf.int32)
# this placeholder is for
if lc_enabled:
lc_batch = tf.placeholder(tf.float32)
# if LC is enabled, set up for LC conditioned generation
# TODO: figure out if this will work without starting the queue runners :/
if lc_enabled:
midifile = midi.read_midifile(args.lc_filepath)
mapper = MidiMapper(sample_rate = args.sample_rate, lc_channels = args.initial_lc_channels)
mapper.set_midi(midifile)
lc_embeddings = mapper.upsample()
print(lc_embeddings[8000])
print("Shape of embeddings is {}".format(np.shape(lc_embeddings)))
# determine number of samples to be generated
# if LC enabled, then it depends on the temporal length of the LC file
# sample_count, microsec = get_generation_length_from_midi(args.sample_rate, args.lc_filepath) \
# if lc_enabled else args.samples
sample_count = args.samples
# TODO: figure out how to give this function each LC embedding incrementally for each sample generation
if args.fast_generation and lc_enabled:
# for now (and the foreseable future) LC only works with fast generation
next_sample = net.predict_proba_incremental(samples, args.gc_id, lc_batch)
else:
# DEPRECATED
next_sample = net.predict_proba(samples, args.gc_id)
if args.fast_generation:
sess.run(tf.global_variables_initializer())
# init_ops is inside model.create_generator
# init_ops = init = q.enqueue_many(tf.zeros((1, self.batch_size, self.quantization_channels)))
sess.run(net.init_ops)
# gather all vars to restore
variables_to_restore = {
var.name[:-2]: var for var in tf.global_variables()
if not ('state_buffer' in var.name or 'pointer' in var.name)}
saver = tf.train.Saver(variables_to_restore)
# restore all vars
print('Restoring model from {}'.format(args.checkpoint))
saver.restore(sess, args.checkpoint)
decode = mu_law_decode(samples, wavenet_params['quantization_channels'])
# if we are local conditioning then we should not need a seed at the beginning
if args.wav_seed:
# should not need this for LC
seed = create_seed(args.wav_seed,
wavenet_params['sample_rate'],
quantization_channels,
net.receptive_field)
waveform = sess.run(seed).tolist()
else:
# Silence with a single random sample at the end.
waveform = [quantization_channels / 2] * (net.receptive_field - 1)
waveform.append(np.random.randint(quantization_channels))
if args.fast_generation and args.wav_seed:
# When using the incremental generation, we need to
# feed in all priming samples one by one before starting the
# actual generation.
# This could be done much more efficiently by passing the waveform
# to the incremental generator as an optional argument, which would be
# used to fill the queues initially.
# this is where the new LC samples should be passed in as this is what is called
# for every iteration of the loop
outputs = [next_sample]
outputs.extend(net.push_ops)
print('Priming generation...')
for i, x in enumerate(waveform[-net.receptive_field: -1]):
if i % 100 == 0:
print('Priming sample {}'.format(i))
sess.run(outputs, feed_dict={samples: x})
print('Done.')
last_sample_timestamp = datetime.now()
# for each sample to be generated do the ops in the loop
print(sample_count)
for step in range(int(sample_count)):
# this is where it should be changed to account for LC?
if args.fast_generation:
outputs = [next_sample]
# push_ops is inside model's create_generator
# where push_ops.append(push)
# where push = q.enqueue([current_layer])
# where current_layer = input_batch of the input to the create_generator function
outputs.extend(net.push_ops)
window = waveform[-1]
else:
if len(waveform) > net.receptive_field:
window = waveform[-net.receptive_field:]
else:
window = waveform
outputs = [next_sample]
# Run the WaveNet to predict the next sample.
if lc_enabled:
prediction = sess.run(
outputs,
feed_dict = {
samples : window,
lc_batch : np.reshape(lc_embeddings[step], (1, args.initial_lc_channels))
})[0]
else:
prediction = sess.run(
outputs,
feed_dict = {
samples : window,
})[0]
# this should not need to be changed for LC
# Scale prediction distribution using temperature.
np.seterr(divide = 'ignore')
scaled_prediction = np.log(prediction) / args.temperature
scaled_prediction = (scaled_prediction - np.logaddexp.reduce(scaled_prediction))
scaled_prediction = np.exp(scaled_prediction)
np.seterr(divide = 'warn')
# Prediction distribution at temperature=1.0 should be unchanged after
# scaling.
if args.temperature == 1.0:
np.testing.assert_allclose(
prediction, scaled_prediction, atol = 1e-5,
err_msg = 'Prediction scaling at temperature=1.0 '
'is not working as intended.')
sample = np.random.choice(
np.arange(quantization_channels), p = scaled_prediction)
waveform.append(sample)
# Show progress only once per second.
current_sample_timestamp = datetime.now()
time_since_print = current_sample_timestamp - last_sample_timestamp
if time_since_print.total_seconds() > 1.:
print('Sample {:3<f}/{:3<f}'.format(step + 1, sample_count),
end = '\r')
last_sample_timestamp = current_sample_timestamp
# If we have partial writing, save the result so far.
if (args.wav_out_path and args.save_every and
(step + 1) % args.save_every == 0):
out = sess.run(decode, feed_dict = {samples: waveform})
write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
# Introduce a newline to clear the carriage return from the progress.
print()
# Save the result as an audio summary.
datestring = str(datetime.now()).replace(' ', 'T')
writer = tf.summary.FileWriter(logdir)
tf.summary.audio('generated', decode, wavenet_params['sample_rate'])
summaries = tf.summary.merge_all()
summary_out = sess.run(summaries,
feed_dict={samples: np.reshape(waveform, [-1, 1])})
writer.add_summary(summary_out)
# Save the result as a wav file.
if args.wav_out_path:
out = sess.run(decode, feed_dict={samples: waveform})
write_wav(out, wavenet_params['sample_rate'], args.wav_out_path)
print('Finished generating. The result can be viewed in TensorBoard.')
if __name__ == '__main__':
main()