Skip to content

Commit

Permalink
fix ort trocr with kv cache
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Oct 17, 2023
1 parent 262b8e8 commit 36cc7b3
Show file tree
Hide file tree
Showing 8 changed files with 100 additions and 75 deletions.
31 changes: 20 additions & 11 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
Expand All @@ -41,7 +42,6 @@
NormalizedTextAndVisionConfig,
NormalizedTextConfig,
NormalizedVisionConfig,
TROCRDummyPastKeyValuseGenerator,
logging,
)
from ...utils.normalized_config import NormalizedConfigManager
Expand Down Expand Up @@ -624,6 +624,13 @@ class ViTOnnxConfig(VisionOnnxConfig):
def inputs(self) -> Dict[str, Dict[int, str]]:
return {"pixel_values": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"}}

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
common_outputs = super().outputs
if self.task == "feature-extraction":
common_outputs["last_hidden_state"] = {0: "batch_size"}
return common_outputs


class CvTOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 13
Expand Down Expand Up @@ -1295,11 +1302,7 @@ class VisionEncoderDecoderOnnxConfig(EncoderDecoderBaseOnnxConfig):
NORMALIZED_CONFIG_CLASS = NormalizedEncoderDecoderConfig
ATOL_FOR_VALIDATION = 1e-3

DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
TROCRDummyPastKeyValuseGenerator,
DummySeq2SeqPastKeyValuesGenerator,
)
DUMMY_INPUT_GENERATOR_CLASSES = (DummyVisionInputGenerator, DummyVisionEncoderDecoderPastKeyValuesGenerator)

def __init__(
self,
Expand All @@ -1323,11 +1326,6 @@ def __init__(
preprocessors=preprocessors,
)

if config.decoder.model_type == "trocr":
self.DUMMY_PKV_GENERATOR_CLASS = TROCRDummyPastKeyValuseGenerator
else:
self.DUMMY_PKV_GENERATOR_CLASS = DummySeq2SeqPastKeyValuesGenerator

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = {}
Expand All @@ -1348,6 +1346,17 @@ def inputs(self) -> Dict[str, Dict[int, str]]:

return common_inputs

@property
def outputs(self) -> Dict[str, Dict[int, str]]:
if self._behavior == ConfigBehavior.ENCODER:
# Some encoders have static sequence length so it is useful to rely on the encoder ONNX config to grab this information.
return self._encoder_onnx_config.outputs
else:
# Ideally, we would want here to have self._decoder_onnx_config.outputs, which is currently not possible
# as we hard-code the task to feature-extraction, that has the wrong output names (e.g. mbart does not support document-question-answering
# so we can not initializer MBartONNXConfig with document-question-answering).
return super().outputs

def patch_model_for_export(
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
) -> "ModelPatcher":
Expand Down
1 change: 1 addition & 0 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,7 @@ def compute_past_key_values_output_shapes(
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
) -> Dict[str, int]:
batch_size = input_ids.size(0)

num_attention_heads = self.normalized_config.num_attention_heads
embed_size_per_head = self.normalized_config.hidden_size // num_attention_heads

Expand Down
26 changes: 15 additions & 11 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,30 +446,27 @@ def forward(
return BaseModelOutput(last_hidden_state=last_hidden_state)

def compute_encoder_known_output_shapes(self, pixel_values: torch.FloatTensor) -> Dict[str, List[int]]:
if self.normalized_config.model_type == "vit":
# for vit models
encoder_sequence_length = (
self.normalized_config.image_size // self.normalized_config.config.patch_size
) ** 2 + 1 # plus cls token
elif self.normalized_config.config.model_type == "donut-swin":
# for donut-swin models
if self.normalized_config.config.model_type == "donut-swin":
# TODO: kind of weird to export to ONNX with dynamic output shape if it is in fact static...
encoder_sequence_length = (
self.normalized_config.config.image_size[0]
* self.normalized_config.config.image_size[1]
// self.normalized_config.config.hidden_size
)
elif self.normalized_config.config.model_type in ["vit", "deit"]:
return None
else:
raise ValueError(
f"Unsupported encoder model type {self.normalized_config.config.model_type} for ORTForVisionSeq2Seq with IOBinding."
"Currently supported models are vit and donut-swin."
"Currently supported models are vit, donut-swin and deit."
"Please submit a PR to add support for this model type."
)

return {
"last_hidden_state": [
pixel_values.shape[0], # batch_size
encoder_sequence_length, # encoder_sequence_length
self.normalized_config.config.hidden_size, # hidden_size
pixel_values.shape[0], # batch size
encoder_sequence_length,
self.normalized_config.config.hidden_size,
]
}

Expand Down Expand Up @@ -1155,6 +1152,7 @@ def __init__(
**kwargs,
)

# The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized.
if config.model_type == "encoder-decoder":
self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class(
config.encoder.model_type
Expand Down Expand Up @@ -1489,6 +1487,7 @@ def __init__(
**kwargs,
)

# The normalized_config initialization in ORTModelPart is unfortunately wrong as the top level config is initialized.
self.encoder.normalized_config = NormalizedConfigManager.get_normalized_config_class(
config.encoder.model_type
)(config.encoder)
Expand All @@ -1497,6 +1496,11 @@ def __init__(
config.decoder.model_type
)(config.decoder)

if self.decoder_with_past is not None:
self.decoder_with_past.normalized_config = NormalizedConfigManager.get_normalized_config_class(
config.decoder.model_type
)(config.decoder)

def _initialize_encoder(self, session: ort.InferenceSession) -> ORTEncoder:
return ORTEncoderForVisionEncoderDecoder(session, self)

Expand Down
2 changes: 1 addition & 1 deletion optimum/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,11 @@
DummyTextInputGenerator,
DummyTimestepInputGenerator,
DummyVisionEmbeddingsGenerator,
DummyVisionEncoderDecoderPastKeyValuesGenerator,
DummyVisionInputGenerator,
FalconDummyPastKeyValuesGenerator,
GPTBigCodeDummyPastKeyValuesGenerator,
MistralDummyPastKeyValuesGenerator,
TROCRDummyPastKeyValuseGenerator,
)
from .modeling_utils import recurse_getattr, recurse_setattr
from .normalized_config import (
Expand Down
73 changes: 54 additions & 19 deletions optimum/utils/input_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,28 +516,17 @@ def __init__(
)

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
if isinstance(self.normalized_config, NormalizedEncoderDecoderConfig):
decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size
encoder_hidden_size = decoder_hidden_size
decoder_num_attention_heads = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_attention_heads
encoder_num_attention_heads = decoder_num_attention_heads # This is used for cross-attention KV cache.
else:
encoder_hidden_size = self.normalized_config.hidden_size
decoder_hidden_size = self.normalized_config.hidden_size
encoder_num_attention_heads = self.normalized_config.encoder_num_attention_heads
decoder_num_attention_heads = self.normalized_config.decoder_num_attention_heads

encoder_shape = (
self.batch_size,
encoder_num_attention_heads,
self.normalized_config.encoder_num_attention_heads,
self.encoder_sequence_length,
encoder_hidden_size // encoder_num_attention_heads,
self.normalized_config.hidden_size // self.normalized_config.encoder_num_attention_heads,
)
decoder_shape = (
self.batch_size,
decoder_num_attention_heads,
self.normalized_config.decoder_num_attention_heads,
self.sequence_length,
decoder_hidden_size // decoder_num_attention_heads,
self.normalized_config.hidden_size // self.normalized_config.decoder_num_attention_heads,
)
return [
(
Expand Down Expand Up @@ -967,7 +956,7 @@ def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int
]


class TROCRDummyPastKeyValuseGenerator(DummySeq2SeqPastKeyValuesGenerator):
class DummyVisionEncoderDecoderPastKeyValuesGenerator(DummySeq2SeqPastKeyValuesGenerator):
def __init__(
self,
task: str,
Expand All @@ -989,7 +978,53 @@ def __init__(
random_sequence_length_range=random_sequence_length_range,
**kwargs,
)
if normalized_config.model_type == "trocr":
image_size = normalized_config.encoder.image_size
patch_size = normalized_config.encoder.patch_size
self.encoder_sequence_length = (image_size // patch_size) ** 2 + 1

if isinstance(normalized_config.DECODER_NORMALIZED_CONFIG_CLASS, NormalizedSeq2SeqConfig):
# Here, the decoder used in the vision-encoder-decoder comes from a seq2seq model.
self.num_layers = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.decoder_num_layers
self.use_cross_attention = True
else:
self.num_layers = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_layers
self.use_cross_attention = False

def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
decoder_hidden_size = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.hidden_size
decoder_num_attention_heads = self.normalized_config.DECODER_NORMALIZED_CONFIG_CLASS.num_attention_heads
decoder_shape = (
self.batch_size,
decoder_num_attention_heads,
self.sequence_length,
decoder_hidden_size // decoder_num_attention_heads,
)

image_size = normalized_config.encoder.image_size
patch_size = normalized_config.encoder.patch_size
self.encoder_sequence_length = (image_size // patch_size) ** 2 + 1
if not self.use_cross_attention:
return [
(
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
else:
encoder_hidden_size = decoder_hidden_size
encoder_num_attention_heads = decoder_num_attention_heads

encoder_shape = (
self.batch_size,
encoder_num_attention_heads,
self.encoder_sequence_length,
encoder_hidden_size // encoder_num_attention_heads,
)
return [
(
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(decoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
self.random_float_tensor(encoder_shape, framework=framework, dtype=float_dtype),
)
for _ in range(self.num_layers)
]
7 changes: 3 additions & 4 deletions optimum/utils/normalized_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,10 @@ def __getattr__(self, attr_name):
hidden_size="d_model",
)

TrOCRLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args(
decoder_num_layers="decoder_layers",
TrOCRLikeNormalizedTextConfig = NormalizedTextConfig.with_args(
num_layers="decoder_layers",
decoder_num_attention_heads="decoder_attention_heads",
hidden_size="cross_attention_hidden_size",
num_attention_heads="decoder_attention_heads",
hidden_size="hidden_size",
)

SpeechToTextLikeNormalizedTextConfig = NormalizedSeq2SeqConfig.with_args(
Expand Down
2 changes: 1 addition & 1 deletion tests/exporters/exporters_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@
"image-to-text",
"image-to-text-with-past",
],
"microsoft/trocr-small-handwritten": ["image-to-text"],
"microsoft/trocr-small-handwritten": ["image-to-text", "image-to-text-with-past"],
"fxmarty/tiny-doc-qa-vision-encoder-decoder": [
"document-question-answering",
"document-question-answering-with-past",
Expand Down
33 changes: 5 additions & 28 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4022,16 +4022,6 @@ class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin):
GENERATION_LENGTH = 100
SPEEDUP_CACHE = 1.1

def exclude_trocr_with_cache(params):
if params[0] == "trocr" and params[1] is True:
return None
return params

def update_trocr_with_cache(params):
if params[0] == "trocr" and params[1] is True:
params[1] = False
return params

def _get_sample_image(self):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Expand All @@ -4049,11 +4039,7 @@ def test_load_vanilla_transformers_which_is_not_supported(self):

self.assertIn("Unrecognized configuration class", str(context.exception))

@parameterized.expand(
grid_parameters(
{"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}, filter_params_func=update_trocr_with_cache
)
)
@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):
model_args = {"test_name": test_name, "model_arch": model_arch, "use_cache": use_cache}
self._setup(model_args)
Expand All @@ -4071,7 +4057,7 @@ def test_generate_utils(self, test_name: str, model_arch: str, use_cache: str):

gc.collect()

@parameterized.expand(grid_parameters(FULL_GRID, filter_params_func=exclude_trocr_with_cache))
@parameterized.expand(grid_parameters(FULL_GRID))
def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand Down Expand Up @@ -4116,17 +4102,12 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach

extra_inputs = [{}, {}]

if use_cache and False:
# TODO: the dims will fail with other models
fake_pkv = tuple((torch.rand(1, 4, 1, 8), torch.rand(1, 4, 1, 8)) for _ in range(5))
extra_inputs[1]["past_key_values"] = fake_pkv

for extra_inps in extra_inputs:
features = feature_extractor(data, return_tensors="pt")
decoder_inputs = {"decoder_input_ids": torch.ones((1, 1), dtype=torch.long) * decoder_start_token_id}

with torch.no_grad():
transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps)
transformers_outputs = transformers_model(**features, **decoder_inputs, **extra_inps, use_cache=True)
for input_type in ["pt", "np"]:
features = feature_extractor(data, return_tensors=input_type)

Expand Down Expand Up @@ -4163,7 +4144,7 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach

gc.collect()

@parameterized.expand(grid_parameters(FULL_GRID, filter_params_func=exclude_trocr_with_cache))
@parameterized.expand(grid_parameters(FULL_GRID))
def test_pipeline_image_to_text(self, test_name: str, model_arch: str, use_cache: bool, use_merged: bool):
if use_cache is False and use_merged is True:
self.skipTest("use_cache=False, use_merged=True are uncompatible")
Expand Down Expand Up @@ -4194,11 +4175,7 @@ def test_pipeline_image_to_text(self, test_name: str, model_arch: str, use_cache

gc.collect()

@parameterized.expand(
grid_parameters(
{"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}, filter_params_func=update_trocr_with_cache
)
)
@parameterized.expand(grid_parameters({"model_arch": SUPPORTED_ARCHITECTURES, "use_cache": [True]}))
@require_torch_gpu
@pytest.mark.gpu_test
def test_pipeline_on_gpu(self, test_name: str, model_arch: str, use_cache: bool):
Expand Down

0 comments on commit 36cc7b3

Please sign in to comment.