Skip to content

Commit

Permalink
Add onnx export for VITS architecture (#1607)
Browse files Browse the repository at this point in the history
* add onnx export for VITS architecture

* fix style

* set task
  • Loading branch information
echarlaix authored May 2, 2024
1 parent e3fd277 commit 189dd25
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 0 deletions.
19 changes: 19 additions & 0 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1842,6 +1842,25 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire
inputs_or_outputs[f"{name}.{i}.encoder.value"] = {2: "encoder_sequence_length_out"}


class VitsOnnxConfig(TextEncoderOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedTextConfig
ATOL_FOR_VALIDATION = 1e-4

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
return {
"input_ids": {0: "text_batch_size", 1: "sequence_length"},
"attention_mask": {0: "text_batch_size", 1: "sequence_length"},
}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"waveform": {0: "text_batch_size", 1: "n_samples"},
"spectrogram": {0: "text_batch_size", 2: "num_bins"},
}


class Speech2TextDummyAudioInputGenerator(DummyAudioInputGenerator):
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size, self.sequence_length, self.normalized_config.input_features_per_channel]
Expand Down
4 changes: 4 additions & 0 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,10 @@ class TasksManager:
"vit": supported_tasks_mapping(
"feature-extraction", "image-classification", "masked-im", onnx="ViTOnnxConfig"
),
"vits": supported_tasks_mapping(
"text-to-audio",
onnx="VitsOnnxConfig",
),
"wavlm": supported_tasks_mapping(
"feature-extraction",
"automatic-speech-recognition",
Expand Down
1 change: 1 addition & 0 deletions tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
"vit": "hf-internal-testing/tiny-random-vit",
"vits": "echarlaix/tiny-random-vits",
"yolos": "hf-internal-testing/tiny-random-YolosModel",
"whisper": "openai/whisper-tiny.en", # hf-internal-testing ones are broken
"hubert": "hf-internal-testing/tiny-random-HubertModel",
Expand Down

0 comments on commit 189dd25

Please sign in to comment.