-
Notifications
You must be signed in to change notification settings - Fork 486
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
SpeechT5 ONNX support #1404
SpeechT5 ONNX support #1404
Conversation
Hi @xenova, a long awaited one =) This PR is still missing tests, documentation and KV cache support but it is in a good state already. I'll finish it next week. For now I only implemented the text-to-speech task following transformers Working version: Also left to do align the |
Wow this is amazing - thanks so much @fxmarty! I've uploaded my model files here. I'll test it in transformers.js, and I'll update those files when the other options are available. I don't suppose you have any python code which I can use for testing, something similar to this?from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForSeq2SeqLM
session = ORTModelForSeq2SeqLM.from_pretrained('Xenova/ipt-350m', subfolder='onnx')
tokenizer = AutoTokenizer.from_pretrained('Xenova/ipt-350m')
generator_ort = pipeline(
task="text-generation",
model=session,
tokenizer=tokenizer,
)
generator_ort('La nostra azienda')
# [{'generated_text': "La nostra azienda è specializzata nella vendita di prodotti per l'igiene orale e per la salute."}] Or will the The speecht5 docs have a nice example here too. Also, I had to downgrade to onnxruntime==1.15.1. 1.16.0 gives this error:
I assume this is because I had onnx<1.14 installed, but just posting here in case. |
I'll wrap up this PR and add a python example :) |
@xenova something like this (not optimized at all). Does that work for you? import onnxruntime as ort
import numpy as np
import soundfile as sf
from transformers import SpeechT5Processor
encoder_path = "/path/to/encoder_model.onnx"
decoder_path = "/path/to/decoder_model_merged.onnx"
postnet_and_vocoder_path = "/path/to/decoder_postnet_and_vocoder.onnx"
encoder = ort.InferenceSession(encoder_path, providers=["CPUExecutionProvider"])
decoder = ort.InferenceSession(decoder_path, providers=["CPUExecutionProvider"])
postnet_and_vocoder = ort.InferenceSession(postnet_and_vocoder_path, providers=["CPUExecutionProvider"])
def add_fake_pkv(inputs):
shape = (1, 12, 0, 64)
for i in range(6):
inputs[f"past_key_values.{i}.encoder.key"] = np.zeros(shape).astype(np.float32)
inputs[f"past_key_values.{i}.encoder.value"] = np.zeros(shape).astype(np.float32)
inputs[f"past_key_values.{i}.decoder.key"] = np.zeros(shape).astype(np.float32)
inputs[f"past_key_values.{i}.decoder.value"] = np.zeros(shape).astype(np.float32)
return inputs
def add_real_pkv(inputs, previous_outputs, cross_attention_pkv):
for i in range(6):
inputs[f"past_key_values.{i}.encoder.key"] = cross_attention_pkv[f"present.{i}.encoder.key"]
inputs[f"past_key_values.{i}.encoder.value"] = cross_attention_pkv[f"present.{i}.encoder.value"]
inputs[f"past_key_values.{i}.decoder.key"] = previous_outputs[f"present.{i}.decoder.key"]
inputs[f"past_key_values.{i}.decoder.value"] = previous_outputs[f"present.{i}.decoder.value"]
return inputs
processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
inputs = processor(text="Hello, my dog is cute", return_tensors="np")
inp = {
"input_ids": inputs["input_ids"]
}
outputs = encoder.run(None, inp)
outputs = {output_key.name: outputs[idx] for idx, output_key in enumerate(encoder.get_outputs())}
encoder_last_hidden_state = outputs["encoder_outputs"]
encoder_attention_mask = outputs["encoder_attention_mask"]
minlenratio = 0.0
maxlenratio = 20.0
reduction_factor = 2
threshold = 0.5
num_mel_bins = 80
maxlen = int(encoder_last_hidden_state.shape[1] * maxlenratio / reduction_factor)
minlen = int(encoder_last_hidden_state.shape[1] * minlenratio / reduction_factor)
spectrogram = []
cross_attentions = []
past_key_values = None
idx = 0
cross_attention_pkv = None
use_cache_branch = False
speaker_embeddings = speaker_embeddings = np.zeros((1, 512)).astype(np.float32)
while True:
idx += 1
decoder_inputs = {}
decoder_inputs["use_cache_branch"] = np.array([use_cache_branch])
decoder_inputs["encoder_attention_mask"] = encoder_attention_mask
decoder_inputs["speaker_embeddings"] = speaker_embeddings
if not use_cache_branch:
decoder_inputs = add_fake_pkv(decoder_inputs)
decoder_inputs["output_sequence"] = np.zeros((1, 1, num_mel_bins)).astype(np.float32)
use_cache_branch = True
decoder_inputs["encoder_hidden_states"] = encoder_last_hidden_state
else:
decoder_inputs = add_real_pkv(decoder_inputs, decoder_outputs, cross_attention_pkv)
decoder_inputs["output_sequence"] = decoder_outputs["output_sequence_out"]
decoder_inputs["encoder_hidden_states"] = np.zeros((1, 0, 768)).astype(np.float32) # useless when cross-attention KV has already been computed
decoder_outputs = decoder.run(None, decoder_inputs)
decoder_outputs = {output_key.name: decoder_outputs[idx] for idx, output_key in enumerate(decoder.get_outputs())}
if idx == 1: # i.e. use_cache_branch = False
cross_attention_pkv = {key: val for key, val in decoder_outputs.items() if ("encoder" in key and "present" in key)}
prob = decoder_outputs["prob"]
spectrum = decoder_outputs["spectrum"]
spectrogram.append(spectrum)
print("prob", prob)
# Finished when stop token or maximum length is reached.
if idx >= minlen and (int(sum(prob >= threshold)) > 0 or idx >= maxlen):
print("len spectrogram", len(spectrogram))
spectrogram = np.concatenate(spectrogram)
vocoder_output = postnet_and_vocoder.run(None, {"spectrogram": spectrogram})
break
sf.write("speech.wav", vocoder_output[0], samplerate=16000) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Works with transformers.js! 🚀 huggingface/transformers.js#345
@echarlaix probably you would prefer to merge first the PR for the decoders? I expect some conflicts between those two. |
@echarlaix WDYT? |
Yes that would be great, thanks for letting me know. To me we can merge the decoder PR cc @michaelbenayoun |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super cool thanks @fxmarty
optimum/exporters/onnx/base.py
Outdated
# Attempt to merge only if the decoder was exported without/with past | ||
if self.use_past is True and len(models_and_onnx_configs) == 3: | ||
# Attempt to merge only if the decoder was exported without/with past, and ignore seq2seq models exported with text-generation task | ||
if len(onnx_files_subpaths) >= 3 and self.use_past is True or self.variant == "with-past": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to check self.variant
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure. I'll need to double check.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accidentally clicked review 😝 meant to just submit a comment
) | ||
model_type = config.model_type.replace("_", "-") | ||
if model_type not in TasksManager._SUPPORTED_MODEL_TYPE: | ||
custom_architecture = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fxmarty this line currently does nothing since it is set to False again in line 381. Do you want to have a look?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch! I'll fix
This PR adds the support of SpeechT5 ONNX export.