Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Jagged batches support (NJT) for Transformer / SDPA inference #22764

Open
vadimkantorov opened this issue Nov 7, 2024 · 4 comments
Labels
feature request request for unsupported feature or enhancement

Comments

@vadimkantorov
Copy link

vadimkantorov commented Nov 7, 2024

Describe the feature request

PyTorch / HF (previously branded as BetterTransformer) now have some support for NJT representation:

This allows to have efficient inference in context of continuous batching.

Does ORT have such NJT-enabled SDPA kernels built-in? (i.e. FlashAttention kernels efficiently supporting block-diagonal masks)

Thanks!

Describe scenario use case

N/A

@vadimkantorov vadimkantorov added the feature request request for unsupported feature or enhancement label Nov 7, 2024
@tianleiwu
Copy link
Contributor

@vadimkantorov
Copy link
Author

vadimkantorov commented Nov 8, 2024

Is there anywhere more documentation of this packing mode and the transformation into the packing mode? Is there any mapping onto NJT from PyTorch? (or examples of export to ONNX of NJT-supporting ops) E.g. what is the input format to such converted BERT? still padded batch with BOS/EOS tokens and a bunch of PAD tokens at the end?

Is transformation only done for attention node? Any intermediate MLPs still applying to the useless padding tokens? Does it ensure that positional embeddings restart from zero for every batch element?

There is "ragged batching" support in Triton Inference Server https://github.com/triton-inference-server/server/blob/main/docs/user_guide/ragged_batching.md which lets the server itself concat inputs and add input of offsets/lengths. In what format does this BERT accept these extra inputs? Does it rely on BOS/EOS to separate batch elements? In other words, how would one connect the "ragged batching" of Triton Inference Server and "Packing Mode" of ORT / Triton's ORT backend?

Thanks!

@tianleiwu
Copy link
Contributor

tianleiwu commented Nov 8, 2024

There is no detail document about that, and you can find the conversion script here: https://github.com/microsoft/onnxruntime/blob/4d614e15bd9e6949bc3066754791da403e00d66c/onnxruntime/python/tools/transformers/convert_to_packing_mode.py.

The definition of PackedMultiHeadAttention and MultiHeadAttention operators:

constexpr const char* PackedMultiHeadAttention_ver1_doc = R"DOC(
This is the packed version of MultiHeadAttention.
Sequences in one batch usually don't have same length and they are padded to have same length,
e.g., below is a batch with 3 sequences and * is padding token.
Sequence_0: 0, 1*, 2*, 3*
Sequence_1: 4, 5, 6*, 7*
Sequence_2: 8, 9, 10, 11
PackedMultiHeadAttention is designed to takes in packed input, i.e., only the real tokens without padding.
An input as above will be packed into 3 tensors like below:
- query ([q0, q4, q5, q8, q9, q10, q11])
- key ([k0, k4, k5, k8, k9, k10, k11])
- value ([v0, v4, v5, v8, v9, v10, v11])
- token_offset: 0, 4, 5, 8, 9, 10, 11, 1*, 2*, 3*, 6*, 7*
- cumulative_sequence_length: 0, 1, 1+2, 1+2+4
The query, key and value tensors contain result of hidden embedding of real tokens after input projections.
Token_offset records the offset of token in the unpacked input.
cumulative_sequence_length records cumulated length of each sequence length.
The operator only supports BERT like model with padding on right now.
)DOC";
// Shape inference for PackedMultiHeadAttention. Here are the shapes of inputs and output:
// When Q, K and V are not packed:
// Input 'query': (token_count, hidden_size)
// Input 'key': (token_count, hidden_size)
// Input 'value': (token_count, v_hidden_size)
// When Q, K and V are packed:
// Input 'query': (token_count, num_heads, 3, head_size)
// Input 'key': None
// Input 'value': None
// Input 'bias': (hidden_size + hidden_size + v_hidden_size)
// Input 'token_offset': (batch_size, sequence_length)
// Input 'cumulative_sequence_length': (batch_size + 1)
// Input 'attention_bias': (batch_size or 1, num_heads or 1, sequence_length, sequence_length) or None
// Output 'output': (token_count, v_hidden_size)
void PackedMultiHeadAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
// Shape inference
if (hasInputShape(ctx, 0)) {
auto& query_shape = getInputShape(ctx, 0);
auto& query_dims = query_shape.dim();
if (query_dims.size() != 2 && query_dims.size() != 4) {
fail_shape_inference("Inputs 0 (query) shall be 2 or 4 dimensions");
}
if (query_dims.size() == 4) { // packed QKV
ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = query_dims[0];
*output_shape.add_dim() = query_dims[1] * query_dims[3];
updateOutputShape(ctx, 0, output_shape);
return;
}
if (hasInputShape(ctx, 2)) {
auto& value_shape = getInputShape(ctx, 2);
auto& value_dims = value_shape.dim();
if (value_dims.size() != 2) {
fail_shape_inference("Inputs 2 (value) shall be 2 dimensions");
}
ONNX_NAMESPACE::TensorShapeProto output_shape;
*output_shape.add_dim() = query_dims[0];
*output_shape.add_dim() = value_dims[1];
updateOutputShape(ctx, 0, output_shape);
return;
}
}
}
ONNX_MS_OPERATOR_SET_SCHEMA(
PackedMultiHeadAttention, 1,
OpSchema()
.SetDoc(PackedMultiHeadAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Attr("scale",
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Input(0,
"query",
"Query with shape (token_count, hidden_size) or packed qkv with shape (token_count, num_heads, 3, head_size)",
"T")
.Input(1,
"key",
"Key with shape (token_count, hidden_size)",
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (token_count, v_hidden_size)",
"T",
OpSchema::Optional)
.Input(3,
"bias",
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
"T",
OpSchema::Optional)
.Input(4,
"token_offset",
"Offset of each token before packing, with shape (batch_size, sequence_length).",
"M")
.Input(5,
"cumulative_sequence_length",
"A tensor with shape (batch_size + 1). It specifies the cumulative sequence length.",
"M")
.Input(6,
"attention_bias",
"It specifies the additional bias to QxK'. "
"The shape is (batch_size or 1, num_heads or 1, sequence_length, sequence_length)",
"T",
OpSchema::Optional)
.Output(0,
"output",
"output tensor with shape (token_count, v_hidden_size)",
"T")
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask, offset and sequence length to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
PackedMultiHeadAttentionTypeAndShapeInference(ctx);
}));

constexpr const char* MultiHeadAttention_ver1_doc = R"DOC(
Multi-Head Self/Cross Attention. Bias from input projection is included.
The key padding mask is optional. When its shape is (batch_size, kv_sequence_length), value 0
means padding or 1 otherwise. When key has right-side padding, its shape could be (batch_size): it is actual length of
each key sequence excluding paddings.
)DOC";
ONNX_MS_OPERATOR_SET_SCHEMA(
MultiHeadAttention, 1,
OpSchema()
.SetDoc(MultiHeadAttention_ver1_doc)
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Attr("mask_filter_value", "The value to be filled in the attention mask. Default value is -10000.0f",
AttributeProto::FLOAT, OPTIONAL_VALUE)
.Attr("scale",
"Custom scale will be used if specified. Default value is 1/sqrt(head_size)",
AttributeProto::FLOAT,
OPTIONAL_VALUE)
.Attr("unidirectional",
"Whether every token can only attend to previous tokens. Default value is 0.",
AttributeProto::INT,
static_cast<int64_t>(0))
.Input(0,
"query",
"Query with shape (batch_size, sequence_length, hidden_size), or packed QKV with shape (batch_size, kv_sequence_length, num_heads, 3, head_size)",
"T")
.Input(1,
"key",
"Key with shape (batch_size, kv_sequence_length, hidden_size), or packed KV with shape (batch_size, kv_sequence_length, num_heads, 2, head_size), "
"or past_key with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(2,
"value",
"Value with shape (batch_size, kv_sequence_length, v_hidden_size), or past_value with shape (batch_size, num_heads, kv_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(3,
"bias",
"Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection",
"T",
OpSchema::Optional)
.Input(4,
"key_padding_mask",
"Key padding mask with shape (batch_size), (3 * batch_size + 2), (batch_size, kv_sequence_length), (batch_size, total_sequence_length), "
"or (batch_size, sequence_length, total_sequence_length)",
"M",
OpSchema::Optional)
.Input(5,
"attention_bias",
"bias added to QxK' with shape (batch_size or 1, num_heads or 1, sequence_length, total_sequence_length)",
"T",
OpSchema::Optional)
.Input(6,
"past_key",
"past state for self attention key with shape (batch_size, num_heads, past_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Input(7,
"past_value",
"past state for self attention value with shape (batch_size, num_heads, past_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(0,
"output",
"3D output tensor with shape (batch_size, sequence_length, v_hidden_size)",
"T")
.Output(1,
"present_key",
"present state for cross attention key with shape (batch_size, num_heads, kv_sequence_length, head_size)"
"or present state for self attention key with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.Output(2,
"present_value",
"present state for cross attention value with shape (batch_size, num_heads, kv_sequence_length, head_size)"
"or present state for self attention value with shape (batch_size, num_heads, total_sequence_length, head_size)",
"T",
OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output to float tensors.")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
MultiHeadAttentionTypeAndShapeInference(ctx, 6);
}));

The packing mode follows the idea of effective transformer. It assumes that your bert model has inputs with padding and use attention mask to indicates the paddings. Then the conversion script will add a node to remove paddings before the first Attention or MultiHeadAttention node, and insert another node to restore padding after the last LayerNormalization node.

It is slightly difference from Triton ragged batching since Triton requires the inputs with padding removed.

ORT is capable of support ragged batching since the Attention or MultiHeadAttention operator supports passing sequence lengths as mask index. I could add a ragged batch example later.

@vadimkantorov
Copy link
Author

vadimkantorov commented Nov 8, 2024

I could add a ragged batch example later.

I think it would be very valuable!

Especially if it could also be a complete example with ONNX export of a recent BERT impl from HF (e.g. using MHA module and SDPA - making sure that these get pattern-matched to the ONNX Attention and PackedAttention) from PyTorch and post-processing the resulting ONNX file, and then serving from Triton

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature request request for unsupported feature or enhancement
Projects
None yet
Development

No branches or pull requests

2 participants