Skip to content

Commit

Permalink
[TTS] Restore_buffer bug fix and update NeMo checkpoint URL (#4041)
Browse files Browse the repository at this point in the history
* restore_buffer bug fix and update NeMo checkpoint URL

Signed-off-by: Subhankar Ghosh <[email protected]>

* skip test

Signed-off-by: ericharper <[email protected]>

* use tacotron2 until new fastpitch model is on ngc

Signed-off-by: ericharper <[email protected]>

* use tacotron2 until new fastpitch model is on ngc

Signed-off-by: ericharper <[email protected]>

Co-authored-by: ericharper <[email protected]>
  • Loading branch information
subhankar-ghosh and ericharper authored Apr 22, 2022
1 parent fe3439f commit 2ef2892
Show file tree
Hide file tree
Showing 5 changed files with 197 additions and 195 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/tts/models/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,7 +479,7 @@ def list_available_models(cls) -> 'List[PretrainedModelInfo]':
list_of_models = []
model = PretrainedModelInfo(
pretrained_model_name="tts_en_fastpitch",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.4.0/files/tts_en_fastpitch_align.nemo",
location="https://api.ngc.nvidia.com/v2/models/nvidia/nemo/tts_en_fastpitch/versions/1.8.1/files/tts_en_fastpitch_align.nemo",
description="This model is trained on LJSpeech sampled at 22050Hz with and can be used to generate female English voices with an American accent.",
class_=cls,
)
Expand Down
6 changes: 3 additions & 3 deletions nemo/collections/tts/modules/fastpitch.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def __init__(
else:
self.speaker_emb = None

self.register_buffer('max_token_duration', torch.tensor(max_token_duration))
self.register_buffer('min_token_duration', torch.tensor(0.0))
self.register_buffer('max_token_duration', torch.tensor(max_token_duration), persistent=False)
self.register_buffer('min_token_duration', torch.tensor(0.0), persistent=False)

self.pitch_emb = torch.nn.Conv1d(
1,
Expand All @@ -171,7 +171,7 @@ def __init__(
# Store values precomputed from training data for convenience
self.register_buffer('pitch_mean', torch.zeros(1))
self.register_buffer('pitch_std', torch.zeros(1))
self.register_buffer('zero_emb', torch.zeros(1))
self.register_buffer('zero_emb', torch.zeros(1), persistent=False)

self.proj = torch.nn.Linear(self.decoder.d_model, n_mel_channels, bias=True)

Expand Down
1 change: 1 addition & 0 deletions tests/collections/tts/test_tts_exportables.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def hifigan_model():


class TestExportable:
@pytest.mark.pleasefixme
@pytest.mark.run_only_on('GPU')
@pytest.mark.unit
def test_FastPitchModel_export_to_onnx(self, fastpitch_model):
Expand Down
169 changes: 86 additions & 83 deletions tutorials/AudioTranslationSample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "RYGnI-EZp_nK"
},
"source": [
"# Getting Started: Sample Conversational AI application\n",
"This notebook shows how to use NVIDIA NeMo (https://github.com/NVIDIA/NeMo) to construct a toy demo which translate Mandarin audio file into English one.\n",
Expand All @@ -12,48 +15,49 @@
"* Transcribe audio with (Mandarin) speech recognition model.\n",
"* Translate text with machine translation model.\n",
"* Generate audio with text-to-speech models."
],
"metadata": {
"id": "RYGnI-EZp_nK"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "V72HXYuQ_p9a"
},
"source": [
"## Installation\n",
"NeMo can be installed via simple pip command.\n",
"This will take about 4 minutes.\n",
"\n",
"(The installation method below should work inside your new Conda environment or in an NVIDIA docker container.)"
],
"metadata": {
"id": "V72HXYuQ_p9a"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "efDmTWf1_iYK"
},
"outputs": [],
"source": [
"BRANCH = 'r1.8.1'\n",
"!python -m pip install git+https://github.com/NVIDIA/NeMo.git@$BRANCH#egg=nemo_toolkit[all]"
],
"outputs": [],
"metadata": {
"id": "efDmTWf1_iYK"
}
]
},
{
"cell_type": "markdown",
"source": [
"## Import all necessary packages"
],
"metadata": {
"id": "EyJ5HiiPrPKA"
}
},
"source": [
"## Import all necessary packages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "tdUqxeUEA8nw"
},
"outputs": [],
"source": [
"# Import NeMo and it's ASR, NLP and TTS collections\n",
"import nemo\n",
Expand All @@ -65,14 +69,13 @@
"import nemo.collections.tts as nemo_tts\n",
"# We'll use this to listen to audio\n",
"import IPython"
],
"outputs": [],
"metadata": {
"id": "tdUqxeUEA8nw"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bt2EZyU3A1aq"
},
"source": [
"## Instantiate pre-trained NeMo models\n",
"\n",
Expand All @@ -81,56 +84,60 @@
"* ``list_available_models()`` - it will list all models currently available on NGC and their names.\n",
"\n",
"* ``from_pretrained(...)`` API downloads and initialized model directly from the NGC using model name.\n"
],
"metadata": {
"id": "bt2EZyU3A1aq"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "YNNHs5Xjr8ox",
"scrolled": true
},
"outputs": [],
"source": [
"# Here is an example of all CTC-based models:\n",
"nemo_asr.models.EncDecCTCModel.list_available_models()\n",
"# More ASR Models are available - see: nemo_asr.models.ASRModel.list_available_models()"
],
"outputs": [],
"metadata": {
"id": "YNNHs5Xjr8ox",
"scrolled": true
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "1h9nhICjA5Dk",
"scrolled": true
},
"outputs": [],
"source": [
"# Speech Recognition model - Citrinet initially trained on Multilingual LibriSpeech English corpus, and fine-tuned on the open source Aishell-2\n",
"asr_model = nemo_asr.models.EncDecCTCModel.from_pretrained(model_name=\"stt_zh_citrinet_1024_gamma_0_25\").cuda()\n",
"\n",
"# Neural Machine Translation model\n",
"nmt_model = nemo_nlp.models.MTEncDecModel.from_pretrained(model_name='nmt_zh_en_transformer6x6').cuda()\n",
"\n",
"# Spectrogram generator which takes text as an input and produces spectrogram\n",
"spectrogram_generator = nemo_tts.models.FastPitchModel.from_pretrained(model_name=\"tts_en_fastpitch\").cuda()\n",
"spectrogram_generator = nemo_tts.models.Tacotron2Model.from_pretrained(model_name=\"tts_en_tacotron2\").cuda()\n",
"\n",
"# Vocoder model which takes spectrogram and produces actual audio\n",
"vocoder = nemo_tts.models.HifiGanModel.from_pretrained(model_name=\"tts_hifigan\").cuda()"
],
"outputs": [],
"metadata": {
"id": "1h9nhICjA5Dk",
"scrolled": true
}
]
},
{
"cell_type": "markdown",
"source": [
"## Get an audio sample in Mandarin"
],
"metadata": {
"id": "KPota-JtsqSY"
}
},
"source": [
"## Get an audio sample in Mandarin"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7cGCEKkcLr52"
},
"outputs": [],
"source": [
"# Download audio sample which we'll try\n",
"# This is a sample from MCV 6.1 Dev dataset - the model hasn't seen it before\n",
Expand All @@ -139,71 +146,71 @@
"!wget 'https://nemo-public.s3.us-east-2.amazonaws.com/zh-samples/common_voice_zh-CN_21347786.mp3'\n",
"# To listen it, click on the play button below\n",
"IPython.display.Audio(audio_sample)"
],
"outputs": [],
"metadata": {
"id": "7cGCEKkcLr52"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BaCdNJhhtBfM"
},
"source": [
"## Transcribe audio file\n",
"We will use speech recognition model to convert audio into text.\n"
],
"metadata": {
"id": "BaCdNJhhtBfM"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "KTA7jM6sL6yC"
},
"outputs": [],
"source": [
"transcribed_text = asr_model.transcribe([audio_sample])\n",
"print(transcribed_text)"
],
"outputs": [],
"metadata": {
"id": "KTA7jM6sL6yC"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BjYb2TMtttCc"
},
"source": [
"## Translate Chinese text into English\n",
"NeMo's NMT models have a handy ``.translate()`` method."
],
"metadata": {
"id": "BjYb2TMtttCc"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "kQTdE4b9Nm9O"
},
"outputs": [],
"source": [
"english_text = nmt_model.translate(transcribed_text)\n",
"print(english_text)"
],
"outputs": [],
"metadata": {
"id": "kQTdE4b9Nm9O"
}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "9Rppc59Ut7uy"
},
"source": [
"## Generate English audio from text\n",
"Speech generation from text typically has two steps:\n",
"* Generate spectrogram from the text. In this example we will use FastPitch model for this.\n",
"* Generate actual audio from the spectrogram. In this example we will use HifiGan model for this.\n"
],
"metadata": {
"id": "9Rppc59Ut7uy"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "wpMYfufgNt15"
},
"outputs": [],
"source": [
"# A helper function which combines FastPitch and HifiGan to go directly from \n",
"# text to audio\n",
Expand All @@ -212,24 +219,23 @@
" spectrogram = spectrogram_generator.generate_spectrogram(tokens=parsed)\n",
" audio = vocoder.convert_spectrogram_to_audio(spec=spectrogram)\n",
" return audio.to('cpu').detach().numpy()"
],
"outputs": [],
"metadata": {
"id": "wpMYfufgNt15"
}
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Listen to generated audio in English\n",
"IPython.display.Audio(text_to_audio(english_text[0]), rate=22050)"
],
"outputs": [],
"metadata": {}
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "LiQ_GQpcBYUs"
},
"source": [
"## Next steps\n",
"A demo like this is great for prototyping and experimentation. However, for real production deployment, you would want to use a service like [NVIDIA Riva](https://developer.nvidia.com/riva).\n",
Expand All @@ -244,10 +250,7 @@
"\n",
"\n",
"You can find scripts for training and fine-tuning ASR, NLP and TTS models [here](https://github.com/NVIDIA/NeMo/tree/main/examples). "
],
"metadata": {
"id": "LiQ_GQpcBYUs"
}
]
}
],
"metadata": {
Expand Down Expand Up @@ -277,4 +280,4 @@
},
"nbformat": 4,
"nbformat_minor": 1
}
}
Loading

0 comments on commit 2ef2892

Please sign in to comment.