diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index 72d047efa01..496957b2b5d 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -1842,6 +1842,25 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"} +class VitsOnnxConfig(TextEncoderOnnxConfig): + NORMALIZED_CONFIG_CLASS = NormalizedTextConfig + ATOL_FOR_VALIDATION = 1e-4 + + @property + def inputs(self) -> Dict[str, Dict[int, str]]: + return { + "input_ids": {0: "text_batch_size", 1: "sequence_length"}, + "attention_mask": {0: "text_batch_size", 1: "sequence_length"}, + } + + @property + def outputs(self) -> Dict[str, Dict[int, str]]: + return { + "waveform": {0: "text_batch_size", 1: "n_samples"}, + "spectrogram": {0: "text_batch_size", 2: "num_bins"}, + } + + class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator): def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"): shape = [self.batch_size, self.sequence_length, self.normalized_config.input_features_per_channel] diff --git a/optimum/exporters/tasks.py b/optimum/exporters/tasks.py index 4ec641d6c14..efa782353b4 100644 --- a/optimum/exporters/tasks.py +++ b/optimum/exporters/tasks.py @@ -1044,6 +1044,10 @@ class TasksManager: "vit": supported_tasks_mapping( "feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig" ), + "vits": supported_tasks_mapping( + "text-to-audio", + onnx="VitsOnnxConfig", + ), "wavlm": supported_tasks_mapping( "feature-extraction", "automatic-speech-recognition", diff --git a/tests/exporters/exporters_utils.py b/tests/exporters/exporters_utils.py index bc1d8a4a289..ab0b8488fb8 100644 --- a/tests/exporters/exporters_utils.py +++ b/tests/exporters/exporters_utils.py @@ -149,6 +149,7 @@ "t5": "hf-internal-testing/tiny-random-t5", "table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel", "vit": "hf-internal-testing/tiny-random-vit", + "vits": "echarlaix/tiny-random-vits", "yolos": "hf-internal-testing/tiny-random-YolosModel", "whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken "hubert": "hf-internal-testing/tiny-random-HubertModel",