-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Jagged batches support (NJT) for Transformer / SDPA inference #22764
Comments
Please take a look at https://github.com/microsoft/onnxruntime/blob/e7987a6b0ba429c0bec248c4a471e1782da4be6c/onnxruntime/python/tools/transformers/notebooks/PyTorch_Bert-Squad_OnnxRuntime_GPU.ipynb There is a section of "Packing Mode (Effective Transformer)". |
Is there anywhere more documentation of this packing mode and the transformation into the packing mode? Is there any mapping onto NJT from PyTorch? (or examples of export to ONNX of NJT-supporting ops) E.g. what is the input format to such converted BERT? still padded batch with BOS/EOS tokens and a bunch of PAD tokens at the end? Is transformation only done for attention node? Any intermediate MLPs still applying to the useless padding tokens? Does it ensure that positional embeddings restart from zero for every batch element? There is "ragged batching" support in Triton Inference Server https://github.com/triton-inference-server/server/blob/main/docs/user_guide/ragged_batching.md which lets the server itself concat inputs and add input of offsets/lengths. In what format does this BERT accept these extra inputs? Does it rely on BOS/EOS to separate batch elements? In other words, how would one connect the "ragged batching" of Triton Inference Server and "Packing Mode" of ORT / Triton's ORT backend? Thanks! |
There is no detail document about that, and you can find the conversion script here: https://github.com/microsoft/onnxruntime/blob/4d614e15bd9e6949bc3066754791da403e00d66c/onnxruntime/python/tools/transformers/convert_to_packing_mode.py. The definition of PackedMultiHeadAttention and MultiHeadAttention operators: onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc Lines 582 to 709 in 4d614e1
onnxruntime/onnxruntime/core/graph/contrib_ops/bert_defs.cc Lines 957 to 1041 in 4d614e1
The packing mode follows the idea of effective transformer. It assumes that your bert model has inputs with padding and use attention mask to indicates the paddings. Then the conversion script will add a node to remove paddings before the first Attention or MultiHeadAttention node, and insert another node to restore padding after the last LayerNormalization node. It is slightly difference from Triton ragged batching since Triton requires the inputs with padding removed. ORT is capable of support ragged batching since the Attention or MultiHeadAttention operator supports passing sequence lengths as mask index. I could add a ragged batch example later. |
I think it would be very valuable! Especially if it could also be a complete example with ONNX export of a recent BERT impl from HF (e.g. using MHA module and SDPA - making sure that these get pattern-matched to the ONNX Attention and PackedAttention) from PyTorch and post-processing the resulting ONNX file, and then serving from Triton |
Describe the feature request
PyTorch / HF (previously branded as BetterTransformer) now have some support for NJT representation:
This allows to have efficient inference in context of continuous batching.
Does ORT have such NJT-enabled SDPA kernels built-in? (i.e. FlashAttention kernels efficiently supporting block-diagonal masks)
Thanks!
Describe scenario use case
N/A
The text was updated successfully, but these errors were encountered: