Fine tuning embeddings with onnxruntime training: error creating artifacts #22427
Unanswered
riccardopinosio
asked this question in
Training Q&A
Replies: 1 comment
-
Consider not using the given onnx loss function, but instead use some loss function implemented in pytorch or the default one from the transformers library. I have had problems in the past when passing a |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hello,
I would like to fine tune an embedding model like all-MiniLM-L6-v2 or distilbert for semantic search using onnxruntime on device training (not the large model training as there will be no python in the runtime) and, say, a cosine similarity metric. In other words, I want to implement this using onnxruntime on device training (see picture at the bottom). The examples will be tuples like (sentenceA, sentenceB, similarity_label).
I came up with the following simple pytorch model:
The idea here is that the _1 and _2 input names correspond to the tokenizations of the A and B sentences respectively. My idea would be to fine tune this model with the on device training api, and then the learned parameters can be added to the original embedding model for inference, which should be possible since the parameters of the above model ought to be the same as those of the original all-MiniLM model.
I export the above model using torchscript:
The generated onnx model seems happy, although the setting of
torch.onnx.TrainingMode.TRAINING
seems to return odd results e.g. in the following chunck the similarity is not 1.0:Perhaps someone knows why the result above is not 1.0 when in training mode.
Anyway, when I try to export the model for training, it complains:
It seems to be unhappy about the output node but it's unclear to me what the issue is. I checked the graph in netron and the output node has the dimensions I would expect:
With some more experimentation it seems the unhappiness derives from relying on the pretrained transformer model, which works for inference but raises this error at artifact generation time for training, as it seems to have troubles constructing the gradient graph.
Swapping the miniLM with distilbert (https://huggingface.co/docs/transformers/en/model_doc/distilbert) gives the same error.
Beta Was this translation helpful? Give feedback.
All reactions