forked from coqui-ai/TTS
-
Notifications
You must be signed in to change notification settings - Fork 1
/
train-glow.py
147 lines (135 loc) · 5.4 KB
/
train-glow.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import os
# Trainer: Where the ✨️ happens.
# TrainingArgs: Defines the set of arguments of the Trainer.
from trainer import Trainer, TrainerArgs
# GlowTTSConfig: all model related values for training, validating and testing.
from TTS.tts.configs.glow_tts_config import GlowTTSConfig
# BaseDatasetConfig: defines name, formatter and path of the dataset.
from TTS.tts.configs.shared_configs import BaseDatasetConfig
from TTS.tts.datasets import load_tts_samples
from TTS.tts.models.glow_tts import GlowTTS
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.utils.audio import AudioProcessor
from TTS.tts.configs.shared_configs import CharactersConfig
# we use the same path as this script as our training folder.
output_path = '/content/drive/MyDrive/output'
def formatter(root_path, meta_file, **kwargs): # pylint: disable=unused-argument
txt_file = os.path.join(root_path, meta_file)
items = []
speaker_name = "maledataset1"
with open(txt_file, "r", encoding="utf-8") as ttf:
for line in ttf:
cols = line.split("|")
wav_file = os.path.join(root_path, "wavs", cols[0])
text = cols[1]
items.append({"text": text, "audio_file": wav_file, "speaker_name":speaker_name})
return items
# DEFINE DATASET CONFIG
# Set LJSpeech as our target dataset and define its path.
# You can also use a simple Dict to define the dataset and pass it to your custom formatter.
dataset_config = BaseDatasetConfig(
name="commonvoice-be", meta_file_train="metadata.csv", path='/content/sample_data/be'
)
characters=CharactersConfig(
characters_class="TTS.tts.utils.text.characters.Graphemes",
pad="_",
eos="~",
bos="^",
blank="@",
characters="\u0430\u0431\u0432\u0433\u0434\u0435\u0451\u0436\u0437\u0456\u0439\u043a\u043b\u043c\u043d\u043e\u043f\u0440\u0441\u0442\u0443\u045e\u0444\u0445\u0446\u0447\u0448\u044b\u044c\u044d\u044e\u044f'",
punctuations="!,.? —",
)
# INITIALIZE THE TRAINING CONFIGURATION
# Configure the model. Every config class inherits the BaseTTSConfig.
config = GlowTTSConfig(
batch_size=16,
eval_batch_size=8,
mixed_precision=True,
use_grad_scaler=False,
num_loader_workers=2,
num_eval_loader_workers=2,
run_eval=True,
test_delay_epochs=-1,
epochs=1000,
text_cleaner="belarusian_cleaners",
use_phonemes=False,
print_step=15,
print_eval=True,
use_noise_augment=True,
output_path=output_path,
datasets=[dataset_config],
characters=characters,
add_blank=False,
enable_eos_bos_chars=True,
save_step=5000,
save_checkpoints=True,
save_all_best=False,
save_best_after=5000,
test_sentences=[
"Тэставы сказ.",
"У рудога вераб’я ў сховішчы пад фатэлем ляжаць нейкія гаючыя зёлкі",
"Я жорстка заб’ю проста ў сэрца гэты расквечаны профіль, што ходзіць ля маёй хаты"
],
audio={
"fft_size": 1024,
"win_length": 1024,
"hop_length": 256,
"frame_shift_ms": None,
"frame_length_ms": None,
"stft_pad_mode": "reflect",
"sample_rate": 16000,
"resample": False,
"preemphasis": 0.0,
"ref_level_db": 20,
"do_sound_norm": True,
"log_func": "np.log10",
"do_trim_silence": True,
"trim_db": 45,
"do_rms_norm": False,
"db_level": None,
"power": 1.5,
"griffin_lim_iters": 60,
"num_mels": 80,
"mel_fmin": 50,
"mel_fmax": 8000,
"spec_gain": 20,
"do_amp_to_db_linear": True,
"do_amp_to_db_mel": True,
"pitch_fmax": 640.0,
"pitch_fmin": 0.0,
"signal_norm": True,
"min_level_db": -100,
"symmetric_norm": True,
"max_norm": 4.0,
"clip_norm": True,
"stats_path": None
}
)
# config_characters = BaseCharacters(**config.characters)
# INITIALIZE THE AUDIO PROCESSOR
# Audio processor is used for feature extraction and audio I/O.
# It mainly serves to the dataloader and the training loggers.
ap = AudioProcessor.init_from_config(config)
# INITIALIZE THE TOKENIZER
# Tokenizer is used to convert text to sequences of token IDs.
# If characters are not defined in the config, default characters are passed to the config
tokenizer, config = TTSTokenizer.init_from_config(config)
# LOAD DATA SAMPLES
# Each sample is a list of ```[text, audio_file_path, speaker_name]```
# You can define your custom sample loader returning the list of samples.
# Or define your custom formatter and pass it to the `load_tts_samples`.
# Check `TTS.tts.datasets.load_tts_samples` for more details.
train_samples, eval_samples = load_tts_samples(dataset_config, eval_split=True, formatter=formatter)
# INITIALIZE THE MODEL
# Models take a config object and a speaker manager as input
# Config defines the details of the model like the number of layers, the size of the embedding, etc.
# Speaker manager is used by multi-speaker models.
model = GlowTTS(config, ap, tokenizer, speaker_manager=None)
# INITIALIZE THE TRAINER
# Trainer provides a generic API to train all the 🐸TTS models with all its perks like mixed-precision training,
# distributed training, etc.
trainer = Trainer(
TrainerArgs(), config, output_path, model=model, train_samples=train_samples, eval_samples=eval_samples
)
# AND... 3,2,1... 🚀
trainer.fit()