-
Notifications
You must be signed in to change notification settings - Fork 52
/
bdlstm_train_sample.py
109 lines (93 loc) · 5.08 KB
/
bdlstm_train_sample.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
'''
Example of a single-layer bidirectional long short-term memory network trained with
connectionist temporal classification to predict character sequences from nFeatures x nFrames
arrays of Mel-Frequency Cepstral Coefficients. This is test code to run on the
8-item data set in the "sample_data" directory, for those without access to TIMIT.
Author: Jon Rein
'''
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.python.ops import ctc_ops as ctc
import numpy as np
from utils import load_batched_data
INPUT_PATH = './sample_data/mfcc' #directory of MFCC nFeatures x nFrames 2-D array .npy files
TARGET_PATH = './sample_data/char_y/' #directory of nCharacters 1-D array .npy files
####Learning Parameters
learningRate = 0.001
momentum = 0.9
nEpochs = 300
batchSize = 4
####Network Parameters
nFeatures = 26 #12 MFCC coefficients + energy, and derivatives
nHidden = 128
nClasses = 28#27 characters, plus the "blank" for CTC
####Load data
print('Loading data')
batchedData, maxTimeSteps, totalN = load_batched_data(INPUT_PATH, TARGET_PATH, batchSize)
####Define graph
print('Defining graph')
graph = tf.Graph()
with graph.as_default():
####NOTE: try variable-steps inputs and dynamic bidirectional rnn, when it's implemented in tensorflow
####Graph input
inputX = tf.placeholder(tf.float32, shape=(maxTimeSteps, batchSize, nFeatures))
#Prep input data to fit requirements of rnn.bidirectional_rnn
# Reshape to 2-D tensor (nTimeSteps*batchSize, nfeatures)
inputXrs = tf.reshape(inputX, [-1, nFeatures])
# Split to get a list of 'n_steps' tensors of shape (batch_size, n_hidden)
inputList = tf.split(inputXrs, maxTimeSteps, 0)
targetIxs = tf.placeholder(tf.int64)
targetVals = tf.placeholder(tf.int32)
targetShape = tf.placeholder(tf.int64)
targetY = tf.SparseTensor(targetIxs, targetVals, targetShape)
seqLengths = tf.placeholder(tf.int32, shape=(batchSize))
####Weights & biases
weightsOutH1 = tf.Variable(tf.truncated_normal([2, nHidden],
stddev=np.sqrt(2.0 / (2*nHidden))))
biasesOutH1 = tf.Variable(tf.zeros([nHidden]))
weightsOutH2 = tf.Variable(tf.truncated_normal([2, nHidden],
stddev=np.sqrt(2.0 / (2*nHidden))))
biasesOutH2 = tf.Variable(tf.zeros([nHidden]))
weightsClasses = tf.Variable(tf.truncated_normal([nHidden, nClasses],
stddev=np.sqrt(2.0 / nHidden)))
biasesClasses = tf.Variable(tf.zeros([nClasses]))
####Network
forwardH1 = tf.contrib.rnn.LSTMCell(nHidden, use_peepholes=True, state_is_tuple=True)
backwardH1 = tf.contrib.rnn.LSTMCell(nHidden, use_peepholes=True, state_is_tuple=True)
fbH1, _, _ = tf.contrib.rnn.static_bidirectional_rnn(forwardH1, backwardH1, inputList, dtype=tf.float32,
scope='BDLSTM_H1')
fbH1rs = [tf.reshape(t, [batchSize, 2, nHidden]) for t in fbH1]
outH1 = [tf.reduce_sum(tf.multiply(t, weightsOutH1), reduction_indices=1) + biasesOutH1 for t in fbH1rs]
logits = [tf.matmul(t, weightsClasses) + biasesClasses for t in outH1]
####Optimizing
logits3d = tf.stack(logits)
loss = tf.reduce_mean(ctc.ctc_loss(targetY, logits3d, seqLengths))
optimizer = tf.train.MomentumOptimizer(learningRate, momentum).minimize(loss)
####Evaluating
logitsMaxTest = tf.slice(tf.argmax(logits3d, 2), [0, 0], [seqLengths[0], 1])
predictions = tf.to_int32(ctc.ctc_beam_search_decoder(logits3d, seqLengths)[0][0])
errorRate = tf.reduce_sum(tf.edit_distance(predictions, targetY, normalize=False)) / \
tf.to_float(tf.size(targetY.values))
####Run session
with tf.Session(graph=graph) as session:
print('Initializing')
tf.global_variables_initializer().run()
for epoch in range(nEpochs):
print('Epoch', epoch+1, '...')
batchErrors = np.zeros(len(batchedData))
batchRandIxs = np.random.permutation(len(batchedData)) #randomize batch order
for batch, batchOrigI in enumerate(batchRandIxs):
batchInputs, batchTargetSparse, batchSeqLengths = batchedData[batchOrigI]
batchTargetIxs, batchTargetVals, batchTargetShape = batchTargetSparse
feedDict = {inputX: batchInputs, targetIxs: batchTargetIxs, targetVals: batchTargetVals,
targetShape: batchTargetShape, seqLengths: batchSeqLengths}
_, l, er, lmt = session.run([optimizer, loss, errorRate, logitsMaxTest], feed_dict=feedDict)
print(np.unique(lmt)) #print unique argmax values of first sample in batch; should be blank for a while, then spit out target values
if (batch % 1) == 0:
print('Minibatch', batch, '/', batchOrigI, 'loss:', l)
print('Minibatch', batch, '/', batchOrigI, 'error rate:', er)
batchErrors[batch] = er*len(batchSeqLengths)
epochErrorRate = batchErrors.sum() / totalN
print('Epoch', epoch+1, 'error rate:', epochErrorRate)