forked from hccho2/Tacotron2-Wavenet-Korean-TTS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
generate.py
281 lines (205 loc) · 14.5 KB
/
generate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# coding: utf-8
"""
sample_rate = 16000이므로, samples 48000이면 3초 길이가 된다.
> python generate.py --gc_cardinality 2 --gc_id 1 ./logdir-wavenet/train/2018-12-21T22-58-10
> python generate.py --wav_seed ./logdir-wavenet/seed.wav --mel ./logdir-wavenet/mel-son.npy --gc_cardinality 2 --gc_id 1 ./logdir-wavenet/train/2018-12-21T22-58-10 <----scalar_input = True
> python generate.py --wav_seed ./logdir-wavenet/seed.wav --mel ./logdir-wavenet/mel-moon.npy --gc_cardinality 2 --gc_id 0 ./logdir-wavenet/train/2018-12-21T22-58-10
python generate.py --wav_seed ./logdir-wavenet/seed.wav --mel ./logdir-tacotron/generate/mel-2018-12-25_22-27-50-0.npy --gc_cardinality 2 --gc_id 0 ./logdir-wavenet/train/2018-12-21T22-58-10
gc_id = 0(moon), 1(son)
python generate.py --mel ./logdir-wavenet/mel-moon.npy --gc_cardinality 2 --gc_id 0 ./logdir-wavenet/train/2019-03-22T23-08-16
python generate.py --mel ./logdir-wavenet/mel-son.npy --gc_cardinality 2 --gc_id 1 ./logdir-wavenet/train/2019-03-22T23-08-16
"""
import argparse
from datetime import datetime
import json
import os,time
import librosa
import numpy as np
import tensorflow as tf
from wavenet import WaveNetModel, mu_law_decode, mu_law_encode
from hparams import hparams
from utils import load_hparams,load
from utils import audio
from utils import plot
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
def _interp(feats, in_range):
#rescales from [-max, max] (or [0, max]) to [0, 1]
return (feats - in_range[0]) / (in_range[1] - in_range[0])
def get_arguments():
def _str_to_bool(s):
"""Convert string to bool (in argparse context)."""
if s.lower() not in ['true', 'false']:
raise ValueError('Argument needs to be a boolean, got {}'.format(s))
return {'true': True, 'false': False}[s.lower()]
def _ensure_positive_float(f):
"""Ensure argument is a positive float."""
if float(f) < 0:
raise argparse.ArgumentTypeError('Argument must be greater than zero')
return float(f)
parser = argparse.ArgumentParser(description='WaveNet generation script')
parser.add_argument('checkpoint_dir', type=str, help='Which model checkpoint to generate from')
TEMPERATURE = 1.0
parser.add_argument('--temperature', type=_ensure_positive_float, default=TEMPERATURE,help='Sampling temperature')
LOGDIR = './logdir-wavenet'
parser.add_argument('--logdir',type=str,default=LOGDIR,help='Directory in which to store the logging information for TensorBoard.')
parser.add_argument('--wav_out_path',type=str,default=None,help='Path to output wav file')
BATCH_SIZE = 1
parser.add_argument('--batch_size', type=int, default=BATCH_SIZE,help='batch size')
parser.add_argument('--wav_seed',type=str,default=None,help='The wav file to start generation from')
parser.add_argument('--mel',type=str,default=None,help='mel input')
parser.add_argument('--gc_cardinality',type=int,default=None,help='Number of categories upon which we globally condition.')
parser.add_argument('--gc_id',type=int,default=None,help='ID of category to generate, if globally conditioned.')
arguments = parser.parse_args()
if hparams.gc_channels is not None:
if arguments.gc_cardinality is None:
raise ValueError("Globally conditioning but gc_cardinality not specified. Use --gc_cardinality=377 for full VCTK corpus.")
if arguments.gc_id is None:
raise ValueError("Globally conditioning, but global condition was not specified. Use --gc_id to specify global condition.")
return arguments
# def write_wav(waveform, sample_rate, filename):
# y = np.array(waveform)
# librosa.output.write_wav(filename, y, sample_rate)
# print('Updated wav file at {}'.format(filename))
def create_seed(filename,sample_rate,quantization_channels,window_size,scalar_input):
# seed의 앞부분만 사용한다.
seed_audio, _ = librosa.load(filename, sr=sample_rate, mono=True)
seed_audio = audio.trim_silence(seed_audio, hparams)
if scalar_input:
if len(seed_audio) < window_size:
return seed_audio
else: return seed_audio[:window_size]
else:
quantized = mu_law_encode(seed_audio, quantization_channels)
# 짧으면 짧은 대로 return하는데, padding이라도 해야되지 않나???
cut_index = tf.cond(tf.size(quantized) < tf.constant(window_size), lambda: tf.size(quantized), lambda: tf.constant(window_size))
return quantized[:cut_index]
def main():
config = get_arguments()
started_datestring = "{0:%Y-%m-%dT%H-%M-%S}".format(datetime.now())
logdir = os.path.join(config.logdir, 'generate', started_datestring)
if not os.path.exists(logdir):
os.makedirs(logdir)
load_hparams(hparams, config.checkpoint_dir)
with tf.device('/cpu:0'): # cpu가 더 빠르다. gpu로 설정하면 Error. tf.device 없이 하면 더 느려진다.
sess = tf.Session()
scalar_input = hparams.scalar_input
net = WaveNetModel(
batch_size=config.batch_size,
dilations=hparams.dilations,
filter_width=hparams.filter_width,
residual_channels=hparams.residual_channels,
dilation_channels=hparams.dilation_channels,
quantization_channels=hparams.quantization_channels,
out_channels =hparams.out_channels,
skip_channels=hparams.skip_channels,
use_biases=hparams.use_biases,
scalar_input=hparams.scalar_input,
global_condition_channels=hparams.gc_channels,
global_condition_cardinality=config.gc_cardinality,
local_condition_channels=hparams.num_mels,
upsample_factor=hparams.upsample_factor,
legacy = hparams.legacy,
residual_legacy = hparams.residual_legacy,
train_mode=False) # train 단계에서는 global_condition_cardinality를 AudioReader에서 파악했지만, 여기서는 넣어주어야 함
if scalar_input:
samples = tf.placeholder(tf.float32,shape=[net.batch_size,None])
else:
samples = tf.placeholder(tf.int32,shape=[net.batch_size,None]) # samples: mu_law_encode로 변환된 것. one-hot으로 변환되기 전. (batch_size, 길이)
# local condition이 (N,T,num_mels) 여야 하지만, 길이 1까지로 들어가야하기 때무넹, (N,1,num_mels) --> squeeze하면 (N,num_mels)
upsampled_local_condition = tf.placeholder(tf.float32,shape=[net.batch_size,hparams.num_mels])
next_sample = net.predict_proba_incremental(samples,upsampled_local_condition, [config.gc_id]*net.batch_size) # Fast Wavenet Generation Algorithm-1611.09482 algorithm 적용
# making local condition data. placeholder - upsampled_local_condition 넣어줄 upsampled local condition data를 만들어 보자.
mel_input = np.load(config.mel)
sample_size = mel_input.shape[0] * hparams.hop_size
mel_input = np.tile(mel_input,(config.batch_size,1,1))
with tf.variable_scope('wavenet',reuse=tf.AUTO_REUSE):
upsampled_local_condition_data = net.create_upsample(mel_input,upsample_type=hparams.upsample_type)
var_list = [var for var in tf.global_variables() if 'queue' not in var.name ]
saver = tf.train.Saver(var_list)
print('Restoring model from {}'.format(config.checkpoint_dir))
load(saver, sess, config.checkpoint_dir)
sess.run(net.queue_initializer) # 이 부분이 없으면, checkpoint에서 복원된 값들이 들어 있다.
quantization_channels = hparams.quantization_channels
if config.wav_seed:
# wav_seed의 길이가 receptive_field보다 작으면, padding이라도 해야 되는 거 아닌가? 그냥 짧으면 짧은 대로 return함 --> 그래서 너무 짧으면 error
seed = create_seed(config.wav_seed,hparams.sample_rate,quantization_channels,net.receptive_field,scalar_input) # --> mu_law encode 된 것.
if scalar_input:
waveform = seed.tolist()
else:
waveform = sess.run(seed).tolist() # [116, 114, 120, 121, 127, ...]
print('Priming generation...')
for i, x in enumerate(waveform[-net.receptive_field: -1]): # 제일 마지막 1개는 아래의 for loop의 첫 loop에서 넣어준다.
if i % 100 == 0:
print('Priming sample {}/{}'.format(i,net.receptive_field), end='\r')
sess.run(next_sample, feed_dict={samples: np.array([x]*net.batch_size).reshape(net.batch_size,1), upsampled_local_condition: np.zeros([net.batch_size,hparams.num_mels])})
print('Done.')
waveform = np.array([waveform[-net.receptive_field:]]*net.batch_size)
else:
# Silence with a single random sample at the end.
if scalar_input:
waveform = [0.0] * (net.receptive_field - 1)
waveform = np.array(waveform*net.batch_size).reshape(net.batch_size,-1)
waveform = np.concatenate([waveform,2*np.random.rand(net.batch_size).reshape(net.batch_size,-1)-1],axis=-1) # -1~1사이의 random number를 만들어 끝에 붙힌다.
# wavefor: shape(batch_size,net.receptive_field )
else:
waveform = [quantization_channels / 2] * (net.receptive_field - 1) # 필요한 receptive_field 크기보다 1개 작게 만든 후, 아래에서 random하게 1개를 덧붙힌다.
waveform = np.array(waveform*net.batch_size).reshape(net.batch_size,-1)
waveform = np.concatenate([waveform,np.random.randint(quantization_channels,size=net.batch_size).reshape(net.batch_size,-1)],axis=-1) # one hot 변환 전. (batch_size, 5117)
start_time = time.time()
upsampled_local_condition_data = sess.run(upsampled_local_condition_data)
last_sample_timestamp = datetime.now()
for step in range(sample_size): # 원하는 길이를 구하기 위해 loop sample_size
window = waveform[:,-1:] # 제일 끝에 있는 1개만 samples에 넣어 준다. window: shape(N,1)
# Run the WaveNet to predict the next sample.
# fast가 아닌경우. window: [128.0, 128.0, ..., 128.0, 178, 185]
# fast인 경우, window는 숫자 1개.
prediction = sess.run(next_sample, feed_dict={samples: window,upsampled_local_condition: upsampled_local_condition_data[:,step,:]}) # samples는 mu law encoding된 것. 계산 과정에서 one hot으로 변환된다. --> (batch_size,256)
if scalar_input:
sample = prediction # logistic distribution으로부터 sampling 되었기 때문에, randomness가 있다.
else:
# Scale prediction distribution using temperature.
# 다음 과정은 config.temperature==1이면 각 원소를 합으로 나누어주는 것에 불과. 이미 softmax를 적용한 겂이므로, 합이 1이된다. 그래서 값의 변화가 없다.
# config.temperature가 1이 아니며, 각 원소의 log취한 값을 나눈 후, 합이 1이 되도록 rescaling하는 것이 된다.
np.seterr(divide='ignore')
scaled_prediction = np.log(prediction) / config.temperature # config.temperature인 경우는 값의 변화가 없다.
scaled_prediction = (scaled_prediction - np.logaddexp.reduce(scaled_prediction,axis=-1,keepdims=True)) # np.log(np.sum(np.exp(scaled_prediction)))
scaled_prediction = np.exp(scaled_prediction)
np.seterr(divide='warn')
# Prediction distribution at temperature=1.0 should be unchanged after
# scaling.
if config.temperature == 1.0:
np.testing.assert_allclose( prediction, scaled_prediction, atol=1e-5, err_msg='Prediction scaling at temperature=1.0 is not working as intended.')
# argmax로 선택하지 않기 때문에, 같은 입력이 들어가도 달라질 수 있다.
sample = [[np.random.choice(np.arange(quantization_channels), p=p)] for p in scaled_prediction] # choose one sample per batch
waveform = np.concatenate([waveform,sample],axis=-1) #window.shape: (N,1)
# Show progress only once per second.
current_sample_timestamp = datetime.now()
time_since_print = current_sample_timestamp - last_sample_timestamp
if time_since_print.total_seconds() > 1.:
duration = time.time() - start_time
print('Sample {:3<d}/{:3<d}, ({:.3f} sec/step)'.format(step + 1, sample_size, duration), end='\r')
last_sample_timestamp = current_sample_timestamp
# Introduce a newline to clear the carriage return from the progress.
print()
# Save the result as a wav file.
if hparams.input_type == 'raw':
out = waveform[:,net.receptive_field:]
elif hparams.input_type == 'mulaw':
decode = mu_law_decode(samples, quantization_channels,quantization=False)
out = sess.run(decode, feed_dict={samples: waveform[:,net.receptive_field:]})
else: # 'mulaw-quantize'
decode = mu_law_decode(samples, quantization_channels,quantization=True)
out = sess.run(decode, feed_dict={samples: waveform[:,net.receptive_field:]})
# save wav
for i in range(net.batch_size):
config.wav_out_path= logdir + '/test-{}.wav'.format(i)
mel_path = config.wav_out_path.replace(".wav", ".png")
gen_mel_spectrogram = audio.melspectrogram(out[i], hparams).astype(np.float32).T
audio.save_wav(out[i], config.wav_out_path, hparams.sample_rate) # save_wav 내에서 out[i]의 값이 바뀐다.
plot.plot_spectrogram(gen_mel_spectrogram, mel_path, title='generated mel spectrogram',target_spectrogram=mel_input[i])
print('Finished generating.')
if __name__ == '__main__':
s = time.time()
main()
print(time.time()-s,'sec')