v3.2.0 - ONNX and OpenVINO backends offering 2-3x speedup; Static Embeddings offering 50x-500x speedups at ~10-20% performance cost
This release introduces 2 new efficient computing backends for SentenceTransformer models: ONNX and OpenVINO + optimization & quantization, allowing for speedups up to 2x-3x; static embeddings via Model2Vec allowing for lightning-fast models (i.e., 50x-500x speedups) at a ~10%-20% performance cost; and various small improvements and fixes.
Install this version with
# Training + Inference
pip install sentence-transformers[train]==3.2.0
# Inference only, use one of:
pip install sentence-transformers==3.2.0
pip install sentence-transformers[onnx-gpu]==3.2.0
pip install sentence-transformers[onnx]==3.2.0
pip install sentence-transformers[openvino]==3.2.0
Faster ONNX and OpenVINO Backends for SentenceTransformer (#2712)
Introducing a new backend
keyword argument to the SentenceTransformer
initialization, allowing values of "torch"
(default), "onnx"
, and "openvino"
.
These come with new installations:
pip install sentence-transformers[onnx-gpu]
# or ONNX for CPU only:
pip install sentence-transformers[onnx]
# or
pip install sentence-transformers[openvino]
It's as simple as:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer("all-MiniLM-L6-v2", backend="onnx")
sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
If you specify a backend
and your model repository or directory contains an ONNX/OpenVINO model file, it will automatically be used! And if your model repository or directory doesn't have one already, an ONNX/OpenVINO model will be automatically exported. Just remember to model.push_to_hub
or model.save_pretrained
into the same model repository or directory to avoid having to re-export the model every time.
All keyword arguments passed via model_kwargs
will be passed on to ORTModel.from_pretrained
or OVBaseModel.from_pretrained
. The most useful arguments are:
provider
: (Only ifbackend="onnx"
) ONNX Runtime provider to use for loading the model, e.g."CPUExecutionProvider"
. See https://onnxruntime.ai/docs/execution-providers/ for possible providers. If not specified, the strongest provider (E.g."CUDAExecutionProvider"
) will be used.file_name
: The name of the ONNX file to load. If not specified, will default to "model.onnx" or otherwise "onnx/model.onnx" for ONNX, and "openvino_model.xml" and "openvino/openvino_model.xml" for OpenVINO. This argument is useful for specifying optimized or quantized models.export
: A boolean flag specifying whether the model will be exported. If not provided, export will be set to True if the model repository or directory does not already contain an ONNX or OpenVINO model.
For example:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"all-MiniLM-L6-v2",
backend="onnx",
model_kwargs={
"file_name": "model_O3.onnx",
"provider": "CPUExecutionProvider",
}
)
sentences = ["This is an example sentence", "Each sentence is converted"]
embeddings = model.encode(sentences)
Benchmarks
We ran benchmarks for CPU and GPU, averaging findings across 4 models of various sizes, 3 datasets, and numerous batch sizes. Here are the findings:
These findings resulted in these recommendations:
For GPU, you can expect 2x speedup with fp16 at no cost, and for CPU you can expect ~2.5x speedup at a cost of 0.4% accuracy.
ONNX Optimization and Quantization
In addition to exporting default ONNX and OpenVINO models, we also introduce 2 helper methods for optimizing and quantizing ONNX models:
Optimization
export_optimized_onnx_model
: This function uses Optimum to implement several optimizations in the ONNX model, ranging from basic optimizations to approximations and mixed precision. Read about the 4 default options here. This function accepts:
model
A SentenceTransformer model loaded withbackend="onnx"
.optimization_config
: "O1", "O2", "O3", or "O4" from 🤗 Optimum or a customOptimizationConfig
instance.model_name_or_path
: The directory or model repository where the optimized model will be saved.push_to_hub
: Whether the push the exported model to the hub withmodel_name_or_path
as the repository name. If False, the model will be saved in the directory specified withmodel_name_or_path
.create_pr
: Ifpush_to_hub
, then this denotes whether a pull request is created rather than pushing the model directly to the repository. Very useful for optimizing models of repositories that you don't have write access to.file_suffix
: The suffix to add to the optimized model file name. Will use theoptimization_config
string or"optimized"
if not set.
The usage is like this:
from sentence_transformers import SentenceTransformer, export_optimized_onnx_model
onnx_model = SentenceTransformer("BAAI/bge-large-en-v1.5", backend="onnx")
export_optimized_onnx_model(
model=onnx_model,
optimization_config="O4",
model_name_or_path="BAAI/bge-large-en-v1.5",
push_to_hub=True,
create_pr=True,
)
After which you can load the model with:
from sentence_transformers import SentenceTransformer
pull_request_nr = 2 # TODO: Update this to the number of your pull request
model = SentenceTransformer(
"BAAI/bge-large-en-v1.5",
backend="onnx",
model_kwargs={"file_name": "onnx/model_O4.onnx"},
revision=f"refs/pr/{pull_request_nr}"
)
or when it gets merged:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"BAAI/bge-large-en-v1.5",
backend="onnx",
model_kwargs={"file_name": "onnx/model_O4.onnx"},
)
Quantization
export_dynamic_quantized_onnx_model
: This function uses Optimum to quantize the ONNX model to int8, also allowing for hardware-specific optimizations. This results in impressive speedups for CPUs. In my findings, each of the default quantization configuration options gave approximately the same performance improvements. This function accepts
model
A SentenceTransformer model loaded withbackend="onnx"
.quantization_config
: "arm64", "avx2", "avx512", or "avx512_vnni" representing quantization configurations from AutoQuantizationConfig, or an QuantizationConfig instance.model_name_or_path
: The directory or model repository where the optimized model will be saved.push_to_hub
: Whether the push the exported model to the hub withmodel_name_or_path
as the repository name. If False, the model will be saved in the directory specified withmodel_name_or_path
.create_pr
: Ifpush_to_hub
, then this denotes whether a pull request is created rather than pushing the model directly to the repository. Very useful for quantizing models of repositories that you don't have write access to.file_suffix
: The suffix to add to the optimized model file name. Will use thequantization_config
string or e.g."int8_quantized"
if not set.
The usage is like this:
from sentence_transformers import SentenceTransformer, export_quantized_onnx_model
onnx_model = SentenceTransformer("BAAI/bge-large-en-v1.5", backend="onnx")
export_quantized_onnx_model(
model=onnx_model,
quantization_config="avx512",
model_name_or_path="BAAI/bge-large-en-v1.5",
push_to_hub=True,
create_pr=True,
)
After which you can load the model with:
from sentence_transformers import SentenceTransformer
pull_request_nr = 2 # TODO: Update this to the number of your pull request
model = SentenceTransformer(
"BAAI/bge-large-en-v1.5",
backend="onnx",
model_kwargs={"file_name": "onnx/model_qint8_avx512.onnx"},
revision=f"refs/pr/{pull_request_nr}"
)
or when it gets merged:
from sentence_transformers import SentenceTransformer
model = SentenceTransformer(
"BAAI/bge-large-en-v1.5",
backend="onnx",
model_kwargs={"file_name": "onnx/model_qint8_avx512.onnx"},
)
Lightning-Fast Static Embeddings via Model2Vec (#2961)
If ONNX or OpenVINO isn't fast enough for you yet, then perhaps you'll enjoy Static Embeddings. These embeddings are a bit akin to GLoVe or Word2vec, i.e. they're bags of token embeddings that are summed together to create text embeddings, allowing for lightning-fast embeddings that don't require any neural networks.
However, these Static Embeddings are created in different ways. For example:
-
Distillation via the Model2Vec technique. This projects allows you to distill any Sentence Transformer model into Static Embeddings. For example, distilling BAAI/bge-base-en-v1.5 resulted in a Static Embeddings Sentence Transformer model that reaches 87.5% of the performance of all-MiniLM-L6-v2 on MTEB (+ PEARL & WordSim) and 97.4% of the performance of all-MiniLM-L6-v2 on various classification benchmarks.
You can initialize Static Embeddings via Model2Vec in two ways:from_model2vec
: You can load one of the pretrained Model2Vec models:
# note: `pip install model2vec` is needed, but not for inference from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding # Initialize a Sentence Transformer model with a static embedding from a pretrained model2vec model static_embedding = StaticEmbedding.from_model2vec("minishlab/M2V_multilingual_output") model = SentenceTransformer(modules=[static_embedding]) # Encode some texts queries = ["What is the capital of France?", "How many people live in the Netherlands?"] documents = ["Paris is the capital of France", "The Netherlands has 17 million inhabitants"] query_embeddings = model.encode(queries) document_embeddings = model.encode(documents) # Compute similarities scores = model.similarity(query_embeddings, document_embeddings) print(scores) """ tensor([[0.8170, 0.3843], [0.3929, 0.5818]]) """
from_distillation
: You can use the name of any Sentence Transformer model alongside some parameters (See this docs for more information) to perform the distillation yourself, without needing any dataset. On my device, this takes ~4s on a GPU and ~2 minutes on a CPU:
# note: `pip install model2vec` is needed, but not for inference from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding # Initialize a Sentence Transformer model with a static embedding by distilling via model2vec static_embedding = StaticEmbedding.from_distillation( "mixedbread-ai/mxbai-embed-large-v1", device="cuda", pca_dims=256, apply_zipf=True, ) model = SentenceTransformer(modules=[static_embedding]) # Encode some texts queries = ["What is the capital of France?", "How many people live in the Netherlands?"] documents = ["Paris is the capital of France", "The Netherlands has 17 million inhabitants"] query_embeddings = model.encode(queries) document_embeddings = model.encode(documents) # Compute similarities scores = model.similarity(query_embeddings, document_embeddings) print(scores) """ tensor([[0.8430, 0.3271], [0.3213, 0.5861]]) """
-
Random initialization: Although this initialization needs finetuning, finetuning a Sentence Transformers model backed by StaticEmbedding is extremely fast. For example, I was able to finetune tomaarsen/static-bert-uncased-gooaq with MatryoshkaLoss & MultipleNegativesRankingLoss on the entire (3 million pairs) gooaq dataset in just 7 minutes. This model reaches a NDCG@10 of 79.33 on a hold-out set of 10k samples from gooaq, whereas e.g. BAAI/bge-base-en-v1.5 reaches 85.01 NDCG@10. In short, only 6.6% less performance for a model that's about 500x faster.
That's not a typo: I can compute embeddings for about 14000 stsb sentences from per second on CPU, compared to about ~24 with BAAI/bge-base-en-v1.5, a.k.a. 625x faster.
Note
You can save_pretrained
and load these models like any other Sentence Transformer models, the StaticEmbedding
initialization is only necessary when you're creating a new model.
- Creation:
from sentence_transformers import SentenceTransformer from sentence_transformers.models import StaticEmbedding # Initialize a Sentence Transformer model with a static embedding from a pretrained model2vec model static_embedding = StaticEmbedding.from_distillation( "mixedbread-ai/mxbai-embed-large-v1", device="cuda", pca_dims=256, apply_zipf=True, ) model = SentenceTransformer(modules=[static_embedding]) model.save_pretrained("static-mxbai-embed-large-v1") # or # model.push_to_hub("tomaarsen/static-mxbai-embed-large-v1")
- Inference:
from sentence_transformers import SentenceTransformer # Initialize a Sentence Transformer model with a static embedding model = SentenceTransformer("static-mxbai-embed-large-v1") model.encode([...])
Small changes
- The
InformationRetrievalEvaluator
now acceptsquery_prompt
,query_prompt_name
,corpus_prompt
, andcorpus_prompt_name
arguments, useful if your model requires specific prompts for queries and/or documents for the best performance. (#2951) - The
mine_hard_negatives
function now acceptsanchor_column_name
andpositive_column_name
for specifying which dataset columns will be used. If not specified, the first two columns are used, respectively. Additionally, themin_score
parameter is added, ensuring that all mined negatives have a similarity score of at leastmin_score
according to the chosenSentenceTransformer
orCrossEncoder
model. (#2977) - If you're using multiple evaluators during training via SequentialEvaluator, e.g. multiple evaluators for different Matryoshka dimensions, then the order is now preserved in the training logs in the model card. Previously, they were sorted by name, resulting in weird orderings (e.g. "gooaq-1024", "gooaq-128", "gooaq-256", "gooaq-32", "gooaq-512", "gooaq-64") (#2963)
CachedGISTEmbedLoss
has been improved to support multiple negatives per sample, i.e. the loss now accepts data in the(anchor, positive, negative_1, …, negative_n)
format. It is the third loss to support this format (see docs):
All changes
- [
fix
] Only save first module in root if "save_in_root" is specified. by @tomaarsen in #2957 - [
feat
] Add query prompts to Information Retrieval Evaluator by @ArthurCamara in #2951 - [
model cards
] Keep evaluation order in training logs if there's multiple evaluators by @tomaarsen in #2963 - Add negatives in CachedGISTEmbedLoss by @daegonYu in #2946
- [ENH] --
CrossEncoder.rank
by @it176131 in #2947 - [
feat
] Add lightning-fast StaticEmbedding module based on model2vec by @tomaarsen in #2961 - [
feat
] Add ONNX and OpenVINO backends by @helena-intel and @tomaarsen in #2712 - Refine mine_hard_negatives arguments by @bakrianoo in #2977
New Contributors
- @daegonYu made their first contribution in #2946
- @it176131 made their first contribution in #2947
- @helena-intel made their first contribution in #2712
- @bakrianoo made their first contribution in #2977
Special thanks to @echarlaix for making the new backends possible due to some last-minute changes in optimum
and optimum-intel
.
Full Changelog: v3.1.1...v3.2.0