diff --git a/README.md b/README.md index e9137ed3..b5ce7259 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,86 @@ +# **Tacotron-2-Chinese 中文语音合成** + +## **预训练模型下载** + +    [标贝数据集100K步模型(把解压出的 logs-Tacotron-2 文件夹放到 Tacotron-2-Chinese 文件夹中)](https://github.com/JasonWei512/Tacotron-2-Chinese/releases/download/Biaobei_Tacotron-100K/logs-Tacotron-2.zip) + +    仅 Tacotron 频谱预测部分,不含 WaveNet 模型。可用 Griffin-Lim 合成语音(见下)。或用生成的 Mel 频谱通过 [r9y9的WaveNet](https://github.com/JasonWei512/wavenet_vocoder/) 生成高音质语音。 + +    [生成的语音样本](https://github.com/JasonWei512/Tacotron-2-Chinese/issues/7) + +    使用标贝数据集训练,为避免爆显存用了 ffmpeg 把语料的采样率从 48KHz 降到了 36KHz,听感基本无区别。 + +## **安装依赖** + +1. 安装 Python 3 和 Tensorflow 1.10(在 Tensorflow 1.14 上用 WaveNet 会有Bug,在 1.10 上正常)。 + +2. 安装依赖: + + ```Shell + apt-get install -y libasound-dev portaudio19-dev libportaudio2 libportaudiocpp0 ffmpeg libav-tools + ``` + + 若 libav-tools 安装失败则手动安装: + + ```Shell + wget http://launchpadlibrarian.net/339874908/libav-tools_3.3.4-2_all.deb + dpkg -i libav-tools_3.3.4-2_all.deb + ``` + +3. 安装 requirements: + + ```Shell + pip install -r requirements.txt + ``` + +## **训练模型** + +1. 下载 [标贝数据集](https://weixinxcxdb.oss-cn-beijing.aliyuncs.com/gwYinPinKu/BZNSYP.rar),解压至 `Tacotron-2-Chinese` 文件夹根目录。目录结构如下: + + ``` + Tacotron-2-Chinese + |- BZNSYP + |- PhoneLabeling + |- ProsodyLabeling + |- Wave + ``` + +2. 用 ffmpeg 把 `/BZNSYP/Wave/` 中的 wav 的采样率降到36KHz: + + ```Shell + ffmpeg.exe -i 输入.wav -ar 36000 输出.wav + ``` + +3. 预处理数据: + + ```Shell + python preprocess.py --dataset='Biaobei' + ``` + +4. 训练模型(自动从最新 Checkpoint 继续): + + ```Shell + python train.py --model='Tacotron-2' + ``` + +## **合成语音** + +* 用根目录的 `sentences.txt` 中的文本合成语音。 + + ```Shell + python synthesize.py --model='Tacotron-2' --text_list='sentences.txt' + ``` + + 若无 WaveNet 模型,仅有频谱预测模型,则仅由 Griffin-Lim 生成语音,输出至 `/tacotron_output/logs-eval/wavs/` 文件夹中。 + + 若有 WaveNet 模型,则 WaveNet 生成的语音位于 `/wavenet_output/wavs/` 中。 + + 输出的 Mel 频谱位于 `/tacotron_output/eval/` 中。可用 [r9y9的WaveNet](https://github.com/JasonWei512/wavenet_vocoder/) 合成语音。 + +  + +  + # Tacotron-2: Tensorflow implementation of DeepMind's Tacotron-2. A deep neural network architecture described in this paper: [Natural TTS synthesis by conditioning Wavenet on MEL spectogram predictions](https://arxiv.org/pdf/1712.05884.pdf) diff --git a/datasets/preprocessor.py b/datasets/preprocessor.py index 2d200280..c839117d 100644 --- a/datasets/preprocessor.py +++ b/datasets/preprocessor.py @@ -29,16 +29,23 @@ def build_from_path(hparams, input_dirs, mel_dir, linear_dir, wav_dir, n_jobs=12 executor = ProcessPoolExecutor(max_workers=n_jobs) futures = [] index = 1 - for input_dir in input_dirs: - with open(os.path.join(input_dir, 'metadata.csv'), encoding='utf-8') as f: - for line in f: - parts = line.strip().split('|') - basename = parts[0] - wav_path = os.path.join(input_dir, 'wavs', '{}.wav'.format(basename)) - text = parts[2] - futures.append(executor.submit(partial(_process_utterance, mel_dir, linear_dir, wav_dir, basename, wav_path, text, hparams))) - index += 1 + for input_dir in input_dirs: + with open(os.path.join(input_dir, 'ProsodyLabeling', '000001-010000.txt'), encoding='utf-8') as f: + lines = f.readlines() + index = 1 + + sentence_index = '' + sentence_pinyin = '' + + for line in lines: + if line[0].isdigit(): + sentence_index = line[:6] + else: + sentence_pinyin = line.strip() + wav_path = os.path.join(input_dir, 'Wave', '%s.wav' % sentence_index) + futures.append(executor.submit(partial(_process_utterance, mel_dir, linear_dir, wav_dir, sentence_index, wav_path, sentence_pinyin, hparams))) + index = index + 1 return [future.result() for future in tqdm(futures) if future.result() is not None] diff --git a/hparams.py b/hparams.py index b1708425..bb187781 100644 --- a/hparams.py +++ b/hparams.py @@ -5,7 +5,7 @@ hparams = tf.contrib.training.HParams( # Comma-separated list of cleaners to run on text prior to training and eval. For non-English # text, you may want to use "basic_cleaners" or "transliteration_cleaners". - cleaners='english_cleaners', + cleaners='basic_cleaners', #If you only have 1 GPU or want to use only one GPU, please set num_gpus=0 and specify the GPU idx on run. example: @@ -62,7 +62,7 @@ # 6- If audio quality is too metallic or fragmented (or if linear spectrogram plots are showing black silent regions on top), then restart from step 2. num_mels = 80, #Number of mel-spectrogram channels and local conditioning dimensionality num_freq = 1025, # (= n_fft / 2 + 1) only used when adding linear spectrograms post processing network - rescale = True, #Whether to rescale audio prior to preprocessing + rescale = False, #Whether to rescale audio prior to preprocessing rescaling_max = 0.999, #Rescaling value #train samples of lengths between 3sec and 14sec are more than enough to make a model capable of generating consistent speech. @@ -77,14 +77,14 @@ #Mel spectrogram n_fft = 2048, #Extra window size is filled with 0 paddings to match this parameter - hop_size = 275, #For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) - win_size = 1100, #For 22050Hz, 1100 ~= 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) - sample_rate = 22050, #22050 Hz (corresponding to ljspeech dataset) (sox --i ) + hop_size = 450, #For 22050Hz, 275 ~= 12.5 ms (0.0125 * sample_rate) + win_size = 1800, #For 22050Hz, 1100 ~= 50 ms (If None, win_size = n_fft) (0.05 * sample_rate) + sample_rate = 36000, #22050 Hz (corresponding to ljspeech dataset) (sox --i ) frame_shift_ms = None, #Can replace hop_size parameter. (Recommended: 12.5) magnitude_power = 2., #The power of the spectrogram magnitude (1. for energy, 2. for power) #M-AILABS (and other datasets) trim params (there parameters are usually correct for any data, but definitely must be tuned for specific speakers) - trim_silence = True, #Whether to clip silence in Audio (at beginning and end of audio only, not the middle) + trim_silence = False, #Whether to clip silence in Audio (at beginning and end of audio only, not the middle) trim_fft_size = 2048, #Trimming window size trim_hop_size = 512, #Trimmin hop length trim_top_db = 40, #Trimming db difference from reference db (smaller==harder trim.) @@ -105,13 +105,13 @@ preemphasis = 0.97, #filter coefficient. #Limits - min_level_db = -100, + min_level_db = -120, ref_level_db = 20, - fmin = 55, #Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) + fmin = 125, #Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To test depending on dataset. Pitch info: male~[65, 260], female~[100, 525]) fmax = 7600, #To be increased/reduced depending on data. #Griffin Lim - power = 1.5, #Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice. + power = 1.3, #Only used in G&L inversion, usually values between 1.2 and 1.5 are a good choice. griffin_lim_iters = 60, #Number of G&L iterations, typically 30 is enough but we use 60 to ensure convergence. GL_on_GPU = True, #Whether to use G&L GPU version as part of tensorflow graph. (Usually much faster than CPU but slightly worse quality too). ########################################################################################################################################### @@ -199,12 +199,12 @@ #model parameters #To use Gaussian distribution as output distribution instead of mixture of logistics, set "out_channels = 2" instead of "out_channels = 10 * 3". (UNDER TEST) - out_channels = 2, #This should be equal to quantize channels when input type is 'mulaw-quantize' else: num_distributions * 3 (prob, mean, log_scale). - layers = 20, #Number of dilated convolutions (Default: Simplified Wavenet of Tacotron-2 paper) - stacks = 2, #Number of dilated convolution stacks (Default: Simplified Wavenet of Tacotron-2 paper) - residual_channels = 128, #Number of residual block input/output channels. - gate_channels = 256, #split in 2 in gated convolutions - skip_out_channels = 128, #Number of residual block skip convolution channels. + out_channels = 30, #This should be equal to quantize channels when input type is 'mulaw-quantize' else: num_distributions * 3 (prob, mean, log_scale). + layers = 24, #Number of dilated convolutions (Default: Simplified Wavenet of Tacotron-2 paper) + stacks = 4, #Number of dilated convolution stacks (Default: Simplified Wavenet of Tacotron-2 paper) + residual_channels = 256, #Number of residual block input/output channels. + gate_channels = 512, #split in 2 in gated convolutions + skip_out_channels = 256, #Number of residual block skip convolution channels. kernel_size = 3, #The number of inputs to consider in dilated convolutions. #Upsampling parameters (local conditioning) @@ -218,7 +218,7 @@ #Finally, NearestNeighbor is a non-trainable upsampling layer that just expands each frame (or "pixel") to the equivalent hop size. Ignores all upsampling parameters. upsample_type = 'SubPixel', #Type of the upsampling deconvolution. Can be ('1D' or '2D', 'Resize', 'SubPixel' or simple 'NearestNeighbor'). upsample_activation = 'Relu', #Activation function used during upsampling. Can be ('LeakyRelu', 'Relu' or None) - upsample_scales = [11, 25], #prod(upsample_scales) should be equal to hop_size + upsample_scales = [5, 9, 10], #prod(upsample_scales) should be equal to hop_size freq_axis_kernel_size = 3, #Only used for 2D upsampling types. This is the number of requency bands that are spanned at a time for each frame. leaky_alpha = 0.4, #slope of the negative portion of LeakyRelu (LeakyRelu: y=x if x>0 else y=alpha * x) NN_init = True, #Determines whether we want to initialize upsampling kernels/biases in a way to ensure upsample is initialize to Nearest neighbor upsampling. (Mostly for debug) @@ -251,11 +251,11 @@ #Learning rate schedule tacotron_decay_learning_rate = True, #boolean, determines if the learning rate will follow an exponential decay - tacotron_start_decay = 40000, #Step at which learning decay starts - tacotron_decay_steps = 18000, #Determines the learning rate decay slope (UNDER TEST) + tacotron_start_decay = 30000, #Step at which learning decay starts + tacotron_decay_steps = 10000, #Determines the learning rate decay slope (UNDER TEST) tacotron_decay_rate = 0.5, #learning rate decay rate (UNDER TEST) tacotron_initial_learning_rate = 1e-3, #starting learning rate - tacotron_final_learning_rate = 1e-4, #minimal learning rate + tacotron_final_learning_rate = 1e-5, #minimal learning rate #Optimization parameters tacotron_adam_beta1 = 0.9, #AdamOptimizer beta1 parameter @@ -310,7 +310,7 @@ wavenet_learning_rate = 1e-3, #wavenet initial learning rate wavenet_warmup = float(4000), #Only used with 'noam' scheme. Defines the number of ascending learning rate steps. wavenet_decay_rate = 0.5, #Only used with 'exponential' scheme. Defines the decay rate. - wavenet_decay_steps = 200000, #Only used with 'exponential' scheme. Defines the decay steps. + wavenet_decay_steps = 150000, #Only used with 'exponential' scheme. Defines the decay steps. #Optimization parameters wavenet_adam_beta1 = 0.9, #Adam beta1 @@ -340,30 +340,34 @@ #Eval/Debug parameters #Eval sentences (if no eval text file was specified during synthesis, these sentences are used for eval) sentences = [ - # From July 8, 2017 New York Times: - 'Scientists at the CERN laboratory say they have discovered a new particle.', - 'There\'s a way to measure the acute emotional intelligence that has never gone out of style.', - 'President Trump met with other leaders at the Group of 20 conference.', - 'The Senate\'s bill to repeal and replace the Affordable Care Act is now imperiled.', - # From Google's Tacotron example page: - 'Generative adversarial network or variational auto-encoder.', - 'Basilar membrane and otolaryngology are not auto-correlations.', - 'He has read the whole thing.', - 'He reads books.', - 'He thought it was time to present the present.', - 'Thisss isrealy awhsome.', - 'The big brown fox jumps over the lazy dog.', - 'Did the big brown fox jump over the lazy dog?', - "Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick?", - "She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure.", - "Tajima Airport serves Toyooka.", - #From The web (random long utterance) - # 'On offering to help the blind man, the man who then stole his car, had not, at that precise moment, had any evil intention, quite the contrary, \ - # what he did was nothing more than obey those feelings of generosity and altruism which, as everyone knows, \ - # are the two best traits of human nature and to be found in much more hardened criminals than this one, a simple car-thief without any hope of advancing in his profession, \ - # exploited by the real owners of this enterprise, for it is they who take advantage of the needs of the poor.', - # A final Thank you note! - 'Thank you so much for your support!', + "bai2 jia1 xuan1 hou4 lai2 yin2 yi3 hao2 zhuang4 de shi4 yi4 sheng1 li3 qu3 guo4 qi1 fang2 nv3 ren2 .", + "qu3 tou2 fang2 xi2 fu4 shi2 ta1 gang1 gang1 guo4 shi2 liu4 sui4 sheng1 ri4 .", + "na4 shi4 xi1 yuan2 shang4 gong3 jia1 cun1 da4 hu4 gong3 zeng1 rong2 de tou2 sheng1 nv3 ,", + "bi3 ta1 da4 liang3 sui4 .", + "ta1 zai4 wan2 quan2 wu2 zhi1 huang1 luan4 zhong1 , du4 guo4 le xin1 hun1 zhi1 ye4 ,", + "liu2 xia4 le yong2 yuan3 xiu1 yu2 xiang4 ren2 dao4 ji2 de ke3 xiao4 de sha3 yang4 ,", + "er2 zi4 ji3 que4 yong3 sheng1 nan2 yi3 wang4 ji4 .", + "yi4 nian2 hou4 , zhe4 ge4 nv3 ren2 si3 yu2 nan2 chan3 .", + "di4 er4 fang2 qu3 de shi4 nan2 yuan2 pang2 jia1 cun1 yin1 shi2 ren2 jia1 , pang2 xiu1 rui4 de nai3 gan1 nv3 er2 .", + "zhe4 nv3 zi3 you4 zheng4 hao3 bi3 ta1 xiao2 liang3 sui4 ,", + "mu2 yang4 jun4 xiu4 yan3 jing1 hu1 ling2 er .", + "ta1 wan2 quan2 bu4 zhi1 dao4 jia4 ren2 shi4 zen3 me hui2 shi4 ,", + "er2 ta1 ci3 shi2 yi3 an1 shu2 nan2 nv3 zhi1 jian1 suo2 you3 de yin3 mi4 .", + "ta1 kan4 zhe ta1 de xiu1 qie4 huang1 luan4 er2 xiang3 dao4 zi4 ji3 di4 yi1 ci4 de sha3 yang4 fan3 dao4 jue2 de geng4 fu4 ci4 ji .", + "dang1 ta1 hong1 suo1 zhe ba3 duo2 duo3 shan2 shan3 er2 you4 bu4 gan3 wei2 ao4 ta1 de xiao3 xi2 fu4 guo3 ru4 shen1 xia4 de shi2 hou4 ,", + "ta1 ting1 dao4 le ta1 de bu2 shi4 huan1 le4 er2 shi4 tong4 ku3 de yi4 sheng1 ku1 jiao4 .", + "dang1 ta1 pi2 bei4 de xie1 xi1 xia4 lai2 ,", + "cai2 fa1 jue2 jian1 bang3 nei4 ce4 teng2 tong4 zuan1 xin1 ,", + "ta1 ba3 ta1 yao3 lan4 le .", + "ta1 fu3 shang1 xi1 tong4 de shi2 hou4 ,", + "xin1 li3 jiu4 chao2 qi3 le dui4 zhe4 ge4 jiao1 guan4 de you2 dian3 ren4 xing4 de nai3 gan1 nv3 er de nao2 huo3 .", + "zheng4 yu4 fa1 zuo4 ,", + "ta1 que4 ban1 guo4 ta1 de jian1 bang3 an4 shi4 ta1 zai4 lai2 yi1 ci4 .", + "yi4 dang1 jing1 guo4 nan2 nv3 jian1 de di4 yi1 ci4 jiao1 huan1 ,", + "ta1 jiu4 bian4 de2 mei2 you3 jie2 zhi4 de ren4 xing4 .", + "zhe4 ge4 nv3 ren2 cong2 xia4 jiao4 ding3 zhe hong2 chou2 gai4 jin1 , jin4 ru4 bai2 jia1 men2 lou2 ,", + "dao4 tang3 jin4 yi2 ju4 bao2 ban3 guan1 cai tai2 chu1 zhe4 ge4 men2 lou2 ,", + "shi2 jian1 shang4 bu4 zu2 yi1 nian2 , shi4 hai4 lao2 bing4 si3 de .", ], #Wavenet Debug diff --git a/preprocess.py b/preprocess.py index c3a17a86..df25d56e 100644 --- a/preprocess.py +++ b/preprocess.py @@ -36,7 +36,7 @@ def norm_data(args): merge_books = (args.merge_books=='True') print('Selecting data folders..') - supported_datasets = ['LJSpeech-1.0', 'LJSpeech-1.1', 'M-AILABS'] + supported_datasets = ['LJSpeech-1.0', 'LJSpeech-1.1', 'M-AILABS', 'Biaobei'] if args.dataset not in supported_datasets: raise ValueError('dataset value entered {} does not belong to supported datasets: {}'.format( args.dataset, supported_datasets)) @@ -44,6 +44,8 @@ def norm_data(args): if args.dataset.startswith('LJSpeech'): return [os.path.join(args.base_dir, args.dataset)] + if args.dataset.startswith('Biaobei'): + return [os.path.join(args.base_dir, 'BZNSYP')] if args.dataset == 'M-AILABS': supported_languages = ['en_US', 'en_UK', 'fr_FR', 'it_IT', 'de_DE', 'es_ES', 'ru_RU', @@ -89,7 +91,7 @@ def main(): parser.add_argument('--base_dir', default='') parser.add_argument('--hparams', default='', help='Hyperparameter overrides as a comma-separated list of name=value pairs') - parser.add_argument('--dataset', default='LJSpeech-1.1') + parser.add_argument('--dataset', default='Biaobei') parser.add_argument('--language', default='en_US') parser.add_argument('--voice', default='female') parser.add_argument('--reader', default='mary_ann') diff --git a/requirements.txt b/requirements.txt index 7bc12670..72300c9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,5 +9,5 @@ tqdm==4.11.2 Unidecode==0.4.20 pyaudio==0.2.11 sounddevice==0.3.10 -lws +lws==1.2.0 keras \ No newline at end of file diff --git a/sentences.txt b/sentences.txt index b842e7c9..87d26f77 100644 --- a/sentences.txt +++ b/sentences.txt @@ -1,17 +1,7 @@ -Scientists at the CERN laboratory say they have discovered a new particle. -There's a way to measure the acute emotional intelligence that has never gone out of style. -President Trump met with other leaders at the Group of 20 conference. -The Senate's bill to repeal and replace the Affordable Care Act is now imperiled. -Generative adversarial network or variational auto-encoder. -Basilar membrane and otolaryngology are not auto-correlations. -He has read the whole thing. -He reads books. -He thought it was time to present the present. -Thisss isrealy awhsome. -Punctuation sensitivity, is working. -Punctuation sensitivity is working. -Peter Piper picked a peck of pickled peppers. How many pickled peppers did Peter Piper pick? -She sells sea-shells on the sea-shore. The shells she sells are sea-shells I'm sure. -Tajima Airport serves Toyooka. -On offering to help the blind man, the man who then stole his car, had not, at that precise moment, had any evil intention, quite the contrary, what he did was nothing more than obey those feelings of generosity and altruism which, as everyone knows, are the two best traits of human nature and to be found in much more hardened criminals than this one, a simple car-thief without any hope of advancing in his profession, exploited by the real owners this enterprise, for it is they who take advantage of the needs of the poor. -Thank you so much for your support! \ No newline at end of file +ni3 men5 you3 yi1 ge4 hao3 +quan2 shi4 jie4 pao3 dao4 shen2 me5 di4 fang1 +ni3 men5 bi3 qi2 ta1 de5 xi1 fang1 ji4 zhe3 pao3 de5 hai2 kuai4 +dan4 shi4 wen4 lai2 wen4 qu4 de5 wen4 ti2 dou1 tu1 sen1 po4 +sang1 tai4 na2 yi4 fu5 +gou3 li4 guo2 jia1 sheng1 si3 yi3 +qi3 yin1 huo4 fu2 bi4 qu1 zhi1 \ No newline at end of file diff --git a/tacotron/train.py b/tacotron/train.py index 2d0369d9..54902081 100644 --- a/tacotron/train.py +++ b/tacotron/train.py @@ -180,7 +180,7 @@ def train(log_dir, args, hparams): step = 0 time_window = ValueWindow(100) loss_window = ValueWindow(100) - saver = tf.train.Saver(max_to_keep=20) + saver = tf.train.Saver(max_to_keep=5) log('Tacotron training set to a maximum of {} steps'.format(args.tacotron_train_steps)) diff --git a/tacotron/utils/symbols.py b/tacotron/utils/symbols.py index c5c8f37e..7b7fae07 100644 --- a/tacotron/utils/symbols.py +++ b/tacotron/utils/symbols.py @@ -4,14 +4,10 @@ The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' -from . import cmudict _pad = '_' _eos = '~' -_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!\'\"(),-.:;? ' - -# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -#_arpabet = ['@' + s for s in cmudict.valid_symbols] +_characters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz1234567890@!\'(),-.:;? ' # Export all symbols: -symbols = [_pad, _eos] + list(_characters) #+ _arpabet +symbols = [_pad, _eos] + list(_characters) diff --git a/wavenet_vocoder/models/modules.py b/wavenet_vocoder/models/modules.py index 7dcc4ef2..73dea354 100644 --- a/wavenet_vocoder/models/modules.py +++ b/wavenet_vocoder/models/modules.py @@ -89,7 +89,7 @@ def __init__(self, layer, init=False, init_scale=1., name=None, **kwargs): self.use_bias = layer.use_bias super(WeightNorm, self).__init__(layer, name=name, **kwargs) - self._track_checkpointable(layer, name='layer') + #self._track_checkpointable(layer, name='layer') def set_mode(self, is_training): self.layer.set_mode(is_training) @@ -227,7 +227,7 @@ def __init__(self, filters, layer = WeightNorm(layer, weight_normalization_init, weight_normalization_init_scale) super(CausalConv1D, self).__init__(layer, name=name, **kwargs) - self._track_checkpointable(layer, name='layer') + #self._track_checkpointable(layer, name='layer') self.kw = kernel_size self.dilation_rate = self.layer.dilation_rate self.scope = 'CausalConv1D' if name is None else name diff --git a/wavenet_vocoder/train.py b/wavenet_vocoder/train.py index 66aa4b5b..9bd22bf6 100644 --- a/wavenet_vocoder/train.py +++ b/wavenet_vocoder/train.py @@ -80,7 +80,7 @@ def create_shadow_saver(model, global_step=None): variables += [global_step] shadow_dict = dict(zip(shadow_variables, variables)) #dict(zip(keys, values)) -> {key1: value1, key2: value2, ...} - return tf.train.Saver(shadow_dict, max_to_keep=20) + return tf.train.Saver(shadow_dict, max_to_keep=5) def load_averaged_model(sess, sh_saver, checkpoint_path): sh_saver.restore(sess, checkpoint_path)