diff --git a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx index 84c670579c0..e089c6cf17d 100644 --- a/docs/source/exporters/onnx/usage_guides/export_a_model.mdx +++ b/docs/source/exporters/onnx/usage_guides/export_a_model.mdx @@ -388,7 +388,7 @@ class CustomMPTOnnxConfig(TextDecoderOnnxConfig): decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: - decoder_sequence_name = "past_sequence_length + 1" + decoder_sequence_name = "past_sequence_length + sequence_length" name = "present" for i in range(self._normalized_config.num_layers): diff --git a/optimum/exporters/onnx/base.py b/optimum/exporters/onnx/base.py index 8cd94194ffe..8e40f290efc 100644 --- a/optimum/exporters/onnx/base.py +++ b/optimum/exporters/onnx/base.py @@ -717,7 +717,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: - decoder_sequence_name = "past_sequence_length + 1" + decoder_sequence_name = "past_sequence_length + sequence_length" name = "present" for i in range(self._normalized_config.num_layers): diff --git a/optimum/exporters/onnx/config.py b/optimum/exporters/onnx/config.py index 9e808e392b9..ba3ed58984d 100644 --- a/optimum/exporters/onnx/config.py +++ b/optimum/exporters/onnx/config.py @@ -96,7 +96,7 @@ def inputs(self) -> Dict[str, Dict[int, str]]: if self.use_past_in_inputs: common_inputs = {"input_ids": {0: "batch_size", 1: "sequence_length"}} self.add_past_key_values(common_inputs, direction="inputs") - common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} + common_inputs["attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"} else: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, diff --git a/optimum/exporters/onnx/model_configs.py b/optimum/exporters/onnx/model_configs.py index cc752779d30..4276aca3654 100644 --- a/optimum/exporters/onnx/model_configs.py +++ b/optimum/exporters/onnx/model_configs.py @@ -373,7 +373,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: - decoder_sequence_name = "past_sequence_length + 1" + decoder_sequence_name = "past_sequence_length + sequence_length" name = "present" for i in range(self._normalized_config.num_layers): @@ -403,7 +403,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: - decoder_sequence_name = "past_sequence_length + 1" + decoder_sequence_name = "past_sequence_length + sequence_length" name = "present" for i in range(self._normalized_config.num_layers): @@ -638,7 +638,7 @@ def inputs_for_causal_lm(self): if self.use_past_in_inputs: common_inputs = { "input_ids": {0: "batch_size", 1: "sequence_length"}, - "attention_mask": {0: "batch_size", 1: "past_sequence_length + 1"}, + "attention_mask": {0: "batch_size", 1: "past_sequence_length + sequence_length"}, } for i in range(self._normalized_config.decoder_num_layers): common_inputs[f"past_key_values.{i}.key"] = { @@ -2216,7 +2216,7 @@ def inputs(self): common_inputs["encoder_outputs"] = {0: "batch_size"} # Contrary to other seq2seq archs as t5 and bart, Pix2Struct DO make use of the decoder_attention_mask input. - common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + 1"} + common_inputs["decoder_attention_mask"] = {0: "batch_size", 1: "past_sequence_length + sequence_length"} return common_inputs diff --git a/tests/exporters/onnx/test_onnx_export.py b/tests/exporters/onnx/test_onnx_export.py index 7671d6cd2e6..e47c99fe009 100644 --- a/tests/exporters/onnx/test_onnx_export.py +++ b/tests/exporters/onnx/test_onnx_export.py @@ -487,7 +487,7 @@ def add_past_key_values(self, inputs_or_outputs: Dict[str, Dict[int, str]], dire decoder_sequence_name = "past_sequence_length" name = "past_key_values" else: - decoder_sequence_name = "past_sequence_length + 1" + decoder_sequence_name = "past_sequence_length + sequence_length" name = "present" for i in range(self._normalized_config.num_layers):