-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathwavenet_utils.py
42 lines (35 loc) · 2.12 KB
/
wavenet_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# -*- coding: utf-8 -*-
from __future__ import print_function
import keras.backend as K
from keras.layers import Conv1D
from keras.utils.conv_utils import conv_output_length
def categorical_mean_squared_error(y_true, y_pred):
"""MSE for categorical variables."""
return K.mean(K.square(K.argmax(y_true, axis=-1) - K.argmax(y_pred, axis=-1)))
class CausalDilatedConv1D(Conv1D):
def __init__(self, nb_filter, filter_length, init='glorot_uniform', activation=None, weights=None,
border_mode='valid', subsample_length=1, atrous_rate=1, W_regularizer=None, b_regularizer=None,
activity_regularizer=None, W_constraint=None, b_constraint=None, bias=True, causal=False, **kwargs):
super(CausalDilatedConv1D, self).__init__(nb_filter, filter_length, weights=weights, activation=activation,
padding=border_mode, strides=subsample_length, dilation_rate=atrous_rate, kernel_regularizer=W_regularizer,
bias_regularizer=b_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=W_constraint,
bias_constraint=b_constraint, use_bias=bias, **kwargs)
self.causal = causal
self.nb_filter = nb_filter
self.atrous_rate = atrous_rate
self.filter_length = filter_length
self.subsample_length = subsample_length
self.border_mode = border_mode
if self.causal and border_mode != 'valid':
raise ValueError("Causal mode dictates border_mode=valid.")
def compute_output_shape(self, input_shape):
input_length = input_shape[1]
if self.causal:
input_length += self.atrous_rate * (self.filter_length - 1)
length = conv_output_length(input_length, self.filter_length, self.border_mode, self.strides[0], dilation=self.atrous_rate)
return (input_shape[0], length, self.nb_filter)
def call(self, x, mask=None):
if self.causal:
x = K.temporal_padding(x, padding=(self.atrous_rate * (self.filter_length - 1), 0))
# return super(CausalAtrousConvolution1D, self).call(x, mask)
return super(CausalDilatedConv1D, self).call(x)